## Model Interpretability with Grad-CAM

Grad-CAM (Gradient-weighted Class Activation Mapping) is a visualization technique
that highlights regions of an image that most influence a modelâ€™s prediction.
This is especially important in domains such as medical imaging and safety-critical
applications where model decisions must be interpretable.

In [13]:
import torch
import torch.nn as nn
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
import cv2

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

ModuleNotFoundError: No module named 'torchvision'

In [None]:
val_transforms = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

val_dataset = datasets.CIFAR10(
    root="./data",
    train=False,
    transform=val_transforms,
    download=True
)

val_loader = DataLoader(val_dataset, batch_size=1, shuffle=True)
class_names = val_dataset.classes

In [None]:
num_classes = 10

model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
model.fc = nn.Linear(model.fc.in_features, num_classes)
model.load_state_dict(
    torch.load("resnet18_finetuned_cifar10.pth", map_location=device)
)
model = model.to(device)
model.eval()


In [None]:
class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None

        target_layer.register_forward_hook(self.save_activation)
        target_layer.register_backward_hook(self.save_gradient)

    def save_activation(self, module, input, output):
        self.activations = output

    def save_gradient(self, module, grad_input, grad_output):
        self.gradients = grad_output[0]

    def generate(self, input_tensor, class_idx=None):
        output = self.model(input_tensor)

        if class_idx is None:
            class_idx = output.argmax(dim=1).item()

        self.model.zero_grad()
        output[:, class_idx].backward()

        weights = self.gradients.mean(dim=(2, 3), keepdim=True)
        cam = (weights * self.activations).sum(dim=1)
        cam = torch.relu(cam)

        cam -= cam.min()
        cam /= cam.max()

        return cam.squeeze().detach().cpu().numpy()


In [None]:
images, labels = next(iter(val_loader))
images = images.to(device)

gradcam = GradCAM(model, model.layer4)
heatmap = gradcam.generate(images)

img = images[0].permute(1, 2, 0).cpu().numpy()
img = (img - img.min()) / (img.max() - img.min())

plt.figure(figsize=(8, 4))

plt.subplot(1, 2, 1)
plt.imshow(img)
plt.title(f"Original: {class_names[labels.item()]}")
plt.axis("off")

plt.subplot(1, 2, 2)
plt.imshow(img)
plt.imshow(heatmap, cmap="jet", alpha=0.5)
plt.title("Grad-CAM")
plt.axis("off")

plt.show()
