# Saliency Maps Demo (Grad-CAM)

This notebook generates a **Grad-CAM** saliency map for an image classification model (**ResNet-18** pretrained on ImageNet).

## What you’ll learn
- How to compute a Grad-CAM heatmap for a CNN prediction
- How to visualize when a model may be focusing on the **background** instead of the object

---

## Setup
You need: `torch`, `torchvision`, `Pillow`, `matplotlib`

If you run into installation issues, use a standard PyTorch + torchvision environment.


In [1]:
# Imports
import torch
import torch.nn.functional as F
from torchvision import models, transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

print("Torch:", torch.__version__)


ModuleNotFoundError: No module named 'torchvision'

## 1) Get the data



In [None]:
!wget https://zenodo.org/records/15376499/files/demo_data_clf.zip?download=1

Get an image

Place an image file named **`animal.jpg`** in the same folder as this notebook.
- Good choices: wildlife photo, camera-trap image, landscape with an animal, etc.
- The demo is more interesting when the **background is dominant**, because the model may latch onto it.

If you prefer a different filename, just edit the cell below.

In [None]:
# Load image (edit filename as needed)
image_path = "animal.jpg"

img = Image.open(image_path).convert("RGB")
img


## 2) Load a pretrained model + define preprocessing

We use **ResNet-18** pretrained on ImageNet.  
Preprocessing must match the training normalization used for ImageNet models.


In [None]:
# Load pretrained model
model = models.resnet18(pretrained=True)
model.eval()

# ImageNet preprocessing
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

x = transform(img).unsqueeze(0)  # shape: (1, 3, 224, 224)
x.shape


## 3) Grad-CAM implementation

**Idea:** take gradients of the predicted class score w.r.t. a convolutional feature map, then produce a weighted sum of channels to obtain a class-discriminative heatmap.

We’ll hook into the **last convolutional block** (`layer4`) for a useful spatial map.


In [None]:
class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.activations = None
        self.gradients = None

        # Forward hook: store activations
        self.fwd_handle = self.target_layer.register_forward_hook(self._forward_hook)

        # Backward hook: store gradients (use full backward hook when available)
        if hasattr(self.target_layer, "register_full_backward_hook"):
            self.bwd_handle = self.target_layer.register_full_backward_hook(self._backward_hook)
        else:
            # Fallback for older PyTorch
            self.bwd_handle = self.target_layer.register_backward_hook(self._backward_hook)

    def _forward_hook(self, module, inputs, output):
        self.activations = output  # shape: (N, C, H, W)

    def _backward_hook(self, module, grad_input, grad_output):
        # grad_output[0] corresponds to gradient w.r.t. the layer output
        self.gradients = grad_output[0]  # shape: (N, C, H, W)

    def __call__(self, x, class_idx=None):
        self.model.zero_grad()

        # Forward
        logits = self.model(x)
        if class_idx is None:
            class_idx = int(logits.argmax(dim=1).item())

        # Backward: gradient of the selected logit
        score = logits[:, class_idx].sum()
        score.backward(retain_graph=True)

        # Compute weights: global-average-pool gradients over spatial dims
        weights = self.gradients.mean(dim=(2, 3), keepdim=True)  # (N, C, 1, 1)

        # Weighted sum of activations
        cam = (weights * self.activations).sum(dim=1)  # (N, H, W)
        cam = F.relu(cam)  # keep only positive influence

        # Normalize to [0, 1]
        cam = cam - cam.min()
        cam = cam / (cam.max() + 1e-8)

        return cam.detach(), logits.detach(), class_idx

    def close(self):
        self.fwd_handle.remove()
        self.bwd_handle.remove()


In [None]:
# Choose a target layer: last conv block in ResNet-18 is model.layer4
target_layer = model.layer4
gradcam = GradCAM(model, target_layer)

cam, logits, pred_class = gradcam(x)
pred_class, logits.shape


## 4) Visualize the Grad-CAM heatmap overlay

This overlay helps you see **where the model is focusing** for its prediction.

> Teaching angle: if the heatmap mostly lights up the **background** rather than the animal/object, the model may be “right for the wrong reasons.”


In [None]:
# Utility: overlay heatmap on original image
def overlay_cam_on_image(pil_img, cam_2d, alpha=0.5):
    # cam_2d: torch tensor (H, W) in [0, 1] at 224x224 resolution
    cam_np = cam_2d.cpu().numpy()
    cam_np = np.uint8(255 * cam_np)

    # Resize cam to original image size
    cam_img = Image.fromarray(cam_np).resize(pil_img.size, resample=Image.BILINEAR)

    plt.figure(figsize=(12, 5))

    plt.subplot(1, 2, 1)
    plt.imshow(pil_img)
    plt.title("Original")
    plt.axis("off")

    plt.subplot(1, 2, 2)
    plt.imshow(pil_img)
    plt.imshow(cam_img, cmap="jet", alpha=alpha)
    plt.title("Grad-CAM Overlay")
    plt.axis("off")

    plt.tight_layout()
    plt.show()

overlay_cam_on_image(img, cam[0], alpha=0.5)


## 5) Show “incorrect attention” with a quick occlusion test (optional but powerful)

A simple sanity check:

- Mask out the **high-saliency** region → prediction *should drop* if that region truly matters.
- Mask out the **low-saliency** region → prediction *should change less*.

This often reveals when the model depends on background cues.


In [None]:
# Optional: occlusion sanity check
def occlude_by_cam(x, cam_2d, keep_high=True, frac=0.25):
    '''
    Occlude either the top-salient region (keep_high=False) or
    the low-saliency region (keep_high=True) based on a quantile threshold.
    '''
    x2 = x.clone()
    cam_np = cam_2d.cpu().numpy()
    thresh = np.quantile(cam_np, 1 - frac)  # top frac
    mask = (cam_np >= thresh)

    if keep_high:
        # Keep high-saliency; occlude the rest
        occ = ~mask
    else:
        # Occlude high-saliency
        occ = mask

    # Expand to channels
    occ_t = torch.from_numpy(occ).to(x.device)
    occ_t = occ_t.unsqueeze(0).unsqueeze(0)  # (1,1,H,W)
    occ_t = occ_t.expand(-1, 3, -1, -1)      # (1,3,H,W)

    # Zero out occluded pixels (in normalized space)
    x2[occ_t] = 0.0
    return x2

def softmax_prob(logits, idx):
    probs = F.softmax(logits, dim=1)
    return float(probs[0, idx].item())

# Baseline prediction
with torch.no_grad():
    base_logits = model(x)
base_prob = softmax_prob(base_logits, pred_class)

# Occlude high-saliency region
x_occ_high = occlude_by_cam(x, cam[0], keep_high=False, frac=0.25)
with torch.no_grad():
    logits_occ_high = model(x_occ_high)
prob_occ_high = softmax_prob(logits_occ_high, pred_class)

# Occlude low-saliency region (keep only high-saliency)
x_keep_high = occlude_by_cam(x, cam[0], keep_high=True, frac=0.25)
with torch.no_grad():
    logits_keep_high = model(x_keep_high)
prob_keep_high = softmax_prob(logits_keep_high, pred_class)

print(f"Predicted class index: {pred_class}")
print(f"Baseline prob (pred class):            {base_prob:.4f}")
print(f"Prob after occluding HIGH-saliency:    {prob_occ_high:.4f}")
print(f"Prob after keeping only HIGH-saliency: {prob_keep_high:.4f}")


### Visualize the occluded images (optional)

These images are in **normalized tensor space**, so they won’t look perfectly natural, but they’re useful for inspection.


In [None]:
# Visualize occluded versions (de-normalize for display)
def denorm(t):
    mean = torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1)
    t = t * std + mean
    t = torch.clamp(t, 0, 1)
    return t

def show_tensor_image(t, title):
    t = denorm(t).squeeze(0).permute(1,2,0).cpu().numpy()
    plt.figure(figsize=(5,5))
    plt.imshow(t)
    plt.title(title)
    plt.axis("off")
    plt.show()

show_tensor_image(x, "Input (de-normalized)")
show_tensor_image(x_occ_high, "Occlude HIGH-saliency region")
show_tensor_image(x_keep_high, "Keep only HIGH-saliency region")


## Interpretation notes (for your lecture)

- If occluding the “important” region **doesn’t** reduce the model’s confidence much,
  the saliency may be misleading or the model may be relying on broad context.
- If the highlighted region is mostly **background**, the model may not be learning the object.

**Key takeaway:** interpretability tools help you spot *how a model might fail*, especially under domain shift.
