# Grad-CAM Explainability (Minimal)

This notebook:
- Loads a trained checkpoint
- Runs Grad-CAM on **one** sample image
- Saves `assets/gradcam_example.png`

Note: Confusion matrix and metrics are produced by `scripts/eval.py`.


In [None]:
!pip -q install grad-cam opencv-python matplotlib
print("✅ grad-cam installed.")


In [None]:
import os, glob
import numpy as np
import cv2
import torch
import matplotlib.pyplot as plt

# Grad-CAM library
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image

# Repo imports
# (Assumes you are running in the repo root or have repo cloned in Colab)
from src.models import create_model
from src.dataset import get_transforms


In [None]:
# ===== EDIT THESE =====
MODEL_NAME   = "resnet50"  # resnet50 | densenet121 | mobilenet_v3_small
WEIGHTS_PATH = "./saved_models/resnet50_fold1.pth"

# Pick a validation folder (choose one fold val)
VAL_DIR      = "./data/autism_unified_kfold/fold_1/val"

IMG_SIZE     = 224
NUM_CLASSES  = 2

# Output figure
ASSETS_DIR   = "./assets"
OUT_PATH     = os.path.join(ASSETS_DIR, "gradcam_example.png")
# ======================

print("MODEL_NAME:", MODEL_NAME)
print("WEIGHTS_PATH:", WEIGHTS_PATH)
print("VAL_DIR:", VAL_DIR)


In [None]:
# Find first image under VAL_DIR/*/*.jpg|png|jpeg
exts = ["jpg", "jpeg", "png"]
candidates = []
for e in exts:
    candidates += glob.glob(os.path.join(VAL_DIR, "*", f"*.{e}"))

if not candidates:
    raise RuntimeError(f"No images found under: {VAL_DIR}/<class>/*.(jpg|png|jpeg)")

SAMPLE_PATH = candidates[0]
print("Sample image:", SAMPLE_PATH)


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("device:", device)

model = create_model(MODEL_NAME, NUM_CLASSES).to(device)
state = torch.load(WEIGHTS_PATH, map_location=device)
model.load_state_dict(state)
model.eval()

print("✅ Model loaded.")


In [None]:
def get_target_layer(model_name: str, model):
    name = model_name.lower().strip()
    if name == "resnet50":
        return model.layer4[-1]           # last block
    if name == "densenet121":
        return model.features.denseblock4 # last dense block
    if name == "mobilenet_v3_small":
        return model.features[-1]         # last feature block
    raise ValueError("Unsupported model_name for Grad-CAM: " + model_name)

target_layer = get_target_layer(MODEL_NAME, model)
print("Target layer:", target_layer.__class__.__name__)


In [None]:
# Read image (RGB)
img_bgr = cv2.imread(SAMPLE_PATH)
if img_bgr is None:
    raise RuntimeError("Failed to read image: " + SAMPLE_PATH)

img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)

# Same transforms as test pipeline
_, test_tf = get_transforms(IMG_SIZE)
tensor = test_tf(image=img_rgb)["image"]          # torch tensor CxHxW
input_tensor = tensor.unsqueeze(0).to(device)     # 1xCxHxW

print("input_tensor:", tuple(input_tensor.shape))


In [None]:
with torch.no_grad():
    logits = model(input_tensor)
    pred_class = int(torch.argmax(logits, dim=1).item())

print("Predicted class index:", pred_class)
targets = [ClassifierOutputTarget(pred_class)]


In [None]:
cam = GradCAM(model=model, target_layers=[target_layer])

grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
grayscale_cam = grayscale_cam[0]  # HxW

print("✅ Grad-CAM computed:", grayscale_cam.shape)


In [None]:
os.makedirs(ASSETS_DIR, exist_ok=True)

# show_cam_on_image expects float RGB in [0,1]
img_float = img_rgb.astype(np.float32) / 255.0
cam_image = show_cam_on_image(img_float, grayscale_cam, use_rgb=True)

# Save
cam_bgr = cv2.cvtColor(cam_image, cv2.COLOR_RGB2BGR)
cv2.imwrite(OUT_PATH, cam_bgr)

print("✅ Saved:", OUT_PATH)


In [None]:
plt.figure(figsize=(10,4))
plt.subplot(1,3,1); plt.title("Original"); plt.axis("off"); plt.imshow(img_rgb)
plt.subplot(1,3,2); plt.title("Grad-CAM"); plt.axis("off"); plt.imshow(cam_image)
plt.subplot(1,3,3); plt.title("Heatmap"); plt.axis("off"); plt.imshow(grayscale_cam, cmap="jet")
plt.tight_layout()
plt.show()
