# Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization

The next algorithm we are going to analyze is Grad-CAM. The first four cells are already known from the previous notebook.

In [None]:
import torch
from torchvision import models, transforms
import numpy as np
import medmnist
import matplotlib.cm as cm
from torch.utils.data import DataLoader

import torch.nn.functional as F
import matplotlib.pyplot as plt

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = models.resnet18()
model.fc = torch.nn.Linear(model.fc.in_features, 7)
model.load_state_dict(torch.load("../Exercise/best_dermamnist_resnet_model.pth", map_location=device))

In [None]:
data_flag = 'dermamnist'
info = medmnist.INFO[data_flag]
DataClass = getattr(medmnist, info['python_class'])

mean = 0.5
std = 0.5
batch_size = 128

# preprocessing
data_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[mean], std=[std])
])

test_dataset = DataClass(split='test', transform=data_transform, size=64, download=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size)

In [None]:
images, _ = next(iter(test_loader))
images = images.to(device)

## **Applying Grad-CAM**

The function get a gradcam heatmap is already provided below:

In [None]:
def make_gradcam_heatmap(model, input_tensor, target_layer, pred_index=None):
    """
    Grad-CAM heatmap generation, for details check:
    https://arxiv.org/pdf/1610.02391
    """
    features = []
    gradients = []

    def forward_hook(module, input, output):
        features.append(output)

    def backward_hook(module, grad_in, grad_out):
        gradients.append(grad_out[0])

    # register hooks to store feature maps and gradients
    fwd_handle = target_layer.register_forward_hook(forward_hook)
    bwd_handle = target_layer.register_full_backward_hook(backward_hook)

    # forward pass
    model.eval()
    output = model(input_tensor)
    if pred_index is None:
        pred_index = torch.argmax(output)
        
    # forward pass
    model.zero_grad()
    class_score = output[:, pred_index]
    class_score.backward()
    
    # get feature maps and gradients from hooks
    feature_map = features[0].detach()
    grads = gradients[0].detach()

    # remove hooks
    fwd_handle.remove()
    bwd_handle.remove()

    # calculate channel importance, avg. gradients along batch, height and width dimension. 
    pooled_grads = torch.mean(grads, dim=[0, 2, 3])

    # weight the feature maps accordingly
    for i in range(feature_map.shape[1]):
        feature_map[:, i, :, :] *= pooled_grads[i]

    # sum the weighted feature map. 
    heatmap = torch.sum(feature_map, dim=1).squeeze()
    
    # apply ReLU to get just positive contributing features.
    heatmap = F.relu(heatmap)
    
    # normalize heatmap
    if torch.max(heatmap) > 0:
        heatmap /= torch.max(heatmap)
        
    return heatmap

We want to overlay the heatmap on the image that we are explaining.

In [None]:
target_layer = model.layer3[-1]

# To Do

# 1. Select one image from images and add batch dimension
# 2. Run the function make_gradcam_heatmap
# 3. Use torch.nn.functionl.interpolate to bring heatmap to image dimensions (64x64)

Lastly, the plot:

In [None]:
def rgb2gray(rgb):
    r, g, b = rgb[:, :, 0], rgb[:, :, 1], rgb[:, :, 2]
    gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
    return gray


plt.imshow(rgb2gray(np.transpose(image, (2, 1, 0))), cmap="gray")
heatmap_plot = plt.imshow(heatmap, alpha=0.4, cmap='jet')
plt.colorbar(heatmap_plot, label='Relative Importance')

plt.title('Grad-CAM Heatmap Overlay')
plt.axis('off')
plt.show()