In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
!ls /content/drive

MyDrive


In [4]:
!find /content/drive -name "*.pth"

/content/drive/MyDrive/oa_checkpoints/best_densenet121_res320.pth


In [6]:
CKPT_PATH = "/content/drive/MyDrive/oa_checkpoints/best_densenet121_res320.pth"

In [7]:
import os
print("Exists:", os.path.isfile(CKPT_PATH), CKPT_PATH)

Exists: True /content/drive/MyDrive/oa_checkpoints/best_densenet121_res320.pth


In [8]:
#Set dataset paths + create output folder
DATASET_NAME = "knee-osteoarthritis-dataset-with-severity"
DATA_DIR  = f"/content/{DATASET_NAME}"
TRAIN_DIR = f"{DATA_DIR}/train"
VAL_DIR   = f"{DATA_DIR}/val"
TEST_DIR  = f"{DATA_DIR}/test"

# If dataset folder missing, unzip again from Drive:
ZIP_NAME = "knee_oa_dataset.zip"  # must match your zip name in Drive
if not os.path.isdir(DATA_DIR):
    print("Dataset not found. Unzipping from Drive...")
    zip_path = f"/content/drive/MyDrive/{ZIP_NAME}"
    !unzip -q "{zip_path}" -d /content/
print("Dataset ready ✅")

OUT_DIR = "/content/drive/MyDrive/oa_checkpoints/gradcam_figures"
os.makedirs(OUT_DIR, exist_ok=True)

print("TEST_DIR:", TEST_DIR)
print("OUT_DIR:", OUT_DIR)

Dataset not found. Unzipping from Drive...
Dataset ready ✅
TEST_DIR: /content/knee-osteoarthritis-dataset-with-severity/test
OUT_DIR: /content/drive/MyDrive/oa_checkpoints/gradcam_figures


In [9]:
# Model + loaders + Grad-CAM implementation (DenseNet)
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score

import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# Use the SAME resolution as the checkpoint training
IMAGE_SIZE = 320
BATCH_SIZE = 16

test_tfms = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

test_ds = datasets.ImageFolder(TEST_DIR, transform=test_tfms)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

NUM_CLASSES = len(test_ds.classes)
print("Classes:", test_ds.classes, "NUM_CLASSES:", NUM_CLASSES)

def build_densenet121(num_classes):
    m = models.densenet121(weights=None)  # weights not needed because we load our checkpoint
    in_features = m.classifier.in_features
    m.classifier = nn.Linear(in_features, num_classes)
    return m

model = build_densenet121(NUM_CLASSES).to(device)
model.load_state_dict(torch.load(CKPT_PATH, map_location=device))
model.eval()
print("Loaded model ✅")

# --- utilities for visualization ---
IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
IMAGENET_STD  = np.array([0.229, 0.224, 0.225], dtype=np.float32)

def denorm_img(t_chw):
    x = t_chw.detach().cpu().numpy().transpose(1,2,0)
    x = (x * IMAGENET_STD) + IMAGENET_MEAN
    return np.clip(x, 0, 1)

def overlay_heatmap_on_image(img_hwc, heatmap_hw, alpha=0.35):
    heatmap = plt.get_cmap("jet")(heatmap_hw)[..., :3]
    overlay = (1 - alpha) * img_hwc + alpha * heatmap
    return np.clip(overlay, 0, 1)

class GradCAM:
    def __init__(self, model, target_module):
        self.model = model
        self.target_module = target_module
        self.activations = None
        self.gradients = None

        self.fwd_hook = target_module.register_forward_hook(self._forward_hook)
        self.bwd_hook = target_module.register_full_backward_hook(self._backward_hook)

    def _forward_hook(self, module, inp, out):
        self.activations = out

    def _backward_hook(self, module, grad_in, grad_out):
        self.gradients = grad_out[0]

    def remove(self):
        self.fwd_hook.remove()
        self.bwd_hook.remove()

    @staticmethod
    def _normalize(cam):
        cam = cam - cam.min()
        cam = cam / (cam.max() + 1e-8)
        return cam

    def __call__(self, x, class_idx=None):
        """
        x: [1,3,H,W]
        returns: cam [H,W] 0..1, pred_idx, probs numpy
        """
        self.model.zero_grad(set_to_none=True)
        logits = self.model(x)
        probs = F.softmax(logits, dim=1)
        pred_idx = int(torch.argmax(probs, dim=1).item())

        if class_idx is None:
            class_idx = pred_idx

        score = logits[0, class_idx]
        self.model.zero_grad(set_to_none=True)
        score.backward(retain_graph=False)

        grads = self.gradients              # [1,C,h,w]
        acts  = self.activations            # [1,C,h,w]
        weights = grads.mean(dim=(2,3), keepdim=True)  # [1,C,1,1]
        cam = (weights * acts).sum(dim=1, keepdim=True)
        cam = F.relu(cam)

        cam = F.interpolate(cam, size=x.shape[-2:], mode="bilinear", align_corners=False)
        cam = cam[0,0].detach().cpu().numpy()
        cam = self._normalize(cam)

        return cam, pred_idx, probs[0].detach().cpu().numpy()

# Good DenseNet target layer
target_layer = model.features.denseblock4
gradcam = GradCAM(model, target_layer)
print("Grad-CAM ready ✅")

Device: cuda
Classes: ['0', '1', '2', '3', '4'] NUM_CLASSES: 5
Loaded model ✅
Grad-CAM ready ✅


In [10]:
# Pick correct/wrong samples, save overlays + summary grid
import os
from math import ceil
from sklearn.metrics import f1_score, classification_report, confusion_matrix

# 1) Run predictions over test set (cache images)
all_preds, all_targets, cached_imgs = [], [], []

with torch.no_grad():
    for imgs, targets in test_loader:
        out = model(imgs.to(device))
        preds = out.argmax(dim=1).detach().cpu().numpy().tolist()
        all_preds.extend(preds)
        all_targets.extend(targets.numpy().tolist())
        cached_imgs.extend([t.detach().cpu() for t in imgs])  # normalized tensors

all_preds = np.array(all_preds)
all_targets = np.array(all_targets)

test_acc = (all_preds == all_targets).mean()
test_macro_f1 = f1_score(all_targets, all_preds, average="macro")
print("Checkpoint TEST acc:", test_acc)
print("Checkpoint TEST macro-F1:", test_macro_f1)

# 2) Choose examples (6 correct + 6 wrong)
correct_idx = np.where(all_preds == all_targets)[0]
wrong_idx   = np.where(all_preds != all_targets)[0]

np.random.seed(42)
N_CORRECT = 6
N_WRONG = 6

picked_correct = np.random.choice(correct_idx, size=min(N_CORRECT, len(correct_idx)), replace=False)
picked_wrong   = np.random.choice(wrong_idx,   size=min(N_WRONG,   len(wrong_idx)),   replace=False)

picked = [("correct", int(i)) for i in picked_correct] + [("wrong", int(i)) for i in picked_wrong]
print("Picked:", len(picked), "examples")

def save_cam_example(tag, i, use_true_class=False):
    x = cached_imgs[i].unsqueeze(0).to(device)
    true_y = int(all_targets[i])
    pred_y = int(all_preds[i])

    class_for_cam = true_y if use_true_class else pred_y
    cam, _, _ = gradcam(x, class_idx=class_for_cam)

    img = denorm_img(cached_imgs[i])
    overlay = overlay_heatmap_on_image(img, cam, alpha=0.35)

    plt.figure(figsize=(9,3))

    plt.subplot(1,3,1)
    plt.imshow(img); plt.axis("off")
    plt.title(f"Original\nTrue={true_y}, Pred={pred_y}")

    plt.subplot(1,3,2)
    plt.imshow(cam, cmap="jet"); plt.axis("off")
    plt.title("CAM (true)" if use_true_class else "CAM (pred)")

    plt.subplot(1,3,3)
    plt.imshow(overlay); plt.axis("off")
    plt.title("Overlay")

    out_name = f"{tag}_idx{i}_T{true_y}_P{pred_y}_{'truecam' if use_true_class else 'predcam'}.png"
    out_path = os.path.join(OUT_DIR, out_name)
    plt.tight_layout()
    plt.savefig(out_path, dpi=220)
    plt.close()
    return out_path

saved = []
for tag, i in picked:
    saved.append(save_cam_example(tag, i, use_true_class=False))
    if tag == "wrong":
        saved.append(save_cam_example(tag, i, use_true_class=True))

print("Saved images:", len(saved))

# 3) Summary grid (pred-cam overlays only)
grid_items = []
for tag, i in picked:
    x = cached_imgs[i].unsqueeze(0).to(device)
    true_y = int(all_targets[i]); pred_y = int(all_preds[i])
    cam, _, _ = gradcam(x, class_idx=pred_y)
    img = denorm_img(cached_imgs[i])
    overlay = overlay_heatmap_on_image(img, cam, alpha=0.35)
    grid_items.append((tag, i, true_y, pred_y, overlay))

cols = 4
rows = ceil(len(grid_items) / cols)
plt.figure(figsize=(4*cols, 4*rows))
for k, (tag, i, t, p, overlay) in enumerate(grid_items, start=1):
    plt.subplot(rows, cols, k)
    plt.imshow(overlay)
    plt.axis("off")
    plt.title(f"{tag} | T={t} P={p}")
plt.tight_layout()

grid_path = os.path.join(OUT_DIR, "gradcam_summary_grid.png")
plt.savefig(grid_path, dpi=240)
plt.close()

print("Saved summary grid ✅:", grid_path)
print("Open folder:", OUT_DIR)


Checkpoint TEST acc: 0.6183574879227053
Checkpoint TEST macro-F1: 0.6623441983437284
Picked: 12 examples
Saved images: 18
Saved summary grid ✅: /content/drive/MyDrive/oa_checkpoints/gradcam_figures/gradcam_summary_grid.png
Open folder: /content/drive/MyDrive/oa_checkpoints/gradcam_figures
