## Requirements

### Imports

In [1]:
import os
import torch
from pathlib import Path

In [2]:
import training

In [3]:
torch.manual_seed(42)

<torch._C.Generator at 0x13a4128a110>

### Paths

In [4]:
SRC_DIR = Path.cwd()
ROOT_DIR = SRC_DIR.parent

DATA_DIR = os.path.join(ROOT_DIR, 'dataset')
PREPROCESSED_DIR = os.path.join(DATA_DIR, 'preprocessed')
CSV_PATH = os.path.join(DATA_DIR, 'csv_mappings', 'train.csv')

MODEL_DIR = os.path.join(ROOT_DIR, 'models')
BASELINE_DIR = os.path.join(MODEL_DIR, 'baselines_finetuned')
RESULT_DIR = os.path.join(BASELINE_DIR, 'results')

### Load Data

In [5]:
BATCH_SIZE = 32

In [6]:
CLASS_NAMES = {
    0: "amanita",
    1: "boletus",
    2: "chantelle",
    3: "deterrimus",
    4: "rufus",
    5: "torminosus",
    6: "aurantiacum",
    7: "procera",
    8: "involutus",
    9: "russula"
}

In [7]:
train_loader, val_loader, test_loader = training.get_data_loaders(PREPROCESSED_DIR, CSV_PATH, BATCH_SIZE)

### Training

##### Training config

In [8]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

NUM_CLASSES = 10 
EPOCHS = 20
PATIENCE = 3
LEARNING_RATE = 0.0001

SCHEDULER = 'StepLR' # StepLR # OneCycleLR # None

##### Load model to explain

In [9]:
model_type = 'alexnet'  

# alexnet # resnet # vgg16 # densenet # efficientnet
# custom_alexnet custom_resnet 

In [10]:
model = training.load_model_for_explaining(model_type, NUM_CLASSES, DEVICE, finetuned=True)

  checkpoint = torch.load(model_path, map_location=device)


Model 'alexnet' loaded successfully from c:\Users\ilian\Documents\Projects\git_projects\university\mushroom_classification\models\baselines_finetuned\alexnet\results\alexnet.pth


##### Explain model

In [13]:
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms
from torchvision import models
import cv2

def get_last_conv_layer(model):
    for name, module in model.named_children():
        if isinstance(module, torch.nn.Sequential):
            for sub_name, sub_module in module.named_children():
                if isinstance(sub_module, torch.nn.Conv2d):
                    return sub_module
    return None

def register_hooks(model):
    def hook_fn(module, input, output):
        global feature_maps
        feature_maps = output.detach()

    last_conv_layer = get_last_conv_layer(model)
    hook = last_conv_layer.register_forward_hook(hook_fn)
    return hook

def grad_cam(model, input_tensor, class_idx=None):
    model.eval()
    hook = register_hooks(model)
    input_tensor.requires_grad_()
    output = model(input_tensor)
    
    if class_idx is None:
        class_idx = torch.argmax(output)
        
    model.zero_grad()
    class_score = output[0, class_idx]
    class_score.backward()

    gradients = input_tensor.grad[0].cpu().numpy()
    pooled_gradients = np.mean(gradients, axis=(1, 2))

    activation_map = feature_maps[0].cpu().numpy()
    for i in range(activation_map.shape[0]):
        activation_map[i, :, :] *= pooled_gradients[i]

    heatmap = np.mean(activation_map, axis=0)
    heatmap = np.maximum(heatmap, 0)
    heatmap = cv2.resize(heatmap, (input_tensor.shape[2], input_tensor.shape[3]))
    heatmap = heatmap / np.max(heatmap)

    hook.remove()
    return heatmap

def overlay_heatmap_on_image(image, heatmap, alpha=0.6):
    heatmap = np.uint8(255 * heatmap)
    heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
    image = image.cpu().numpy().transpose(1, 2, 0)
    image = np.uint8(255 * image)
    overlay = cv2.addWeighted(image, alpha, heatmap, 1 - alpha, 0)
    return overlay

def prepare_input_image(image, transform=None):
    if transform is None:
        transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
    return transform(image).unsqueeze(0)

def visualize_gradcam(image, model, class_idx=None):
    input_tensor = prepare_input_image(image)
    heatmap = grad_cam(model, input_tensor, class_idx)
    overlay = overlay_heatmap_on_image(image, heatmap)
    plt.imshow(overlay)
    plt.show()


In [None]:
visualize_gradcam(image, model)
visualize_gradcam(image, model, class_idx=0)