In [1]:

from data.datasets import ADNIDataset
from torch.utils.data import DataLoader
from utils.eval import load_model, create_slice_plot

In [2]:
from pytorch_grad_cam import GradCAM, EigenCAM, GradCAMPlusPlus, LayerCAM,ScoreCAM, ShapleyCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget

import plotly.graph_objects as go
import numpy as np
from IPython.display import HTML
import plotly.io as pio

import numpy as np
from sklearn.metrics import classification_report, confusion_matrix
import pandas as pd

import matplotlib.pyplot as plt

In [3]:
def compute_gradcam(model, sample, target_layer_name, target_class=1):
    
    model.eval()
    out = model(sample)
    
    target_layer = None
    for name, module in model.named_modules():
        if name == target_layer_name:
            target_layer = module
            break
    
    with ShapleyCAM(model=model, target_layers=[ target_layer ]) as cam:
       gcam_overlay = cam(sample, targets=[ ClassifierOutputTarget(target_class)])
       
    return gcam_overlay.squeeze()

#### Analyze 

In [4]:
path = '/project/aereditato/abhat/adni-mri-classification/experiments/mresnet18_20250821_190324_CN_vs_AD/'
fold = 4

In [5]:
model, cfg = load_model(path, 1)

FileNotFoundError: /project/aereditato/abhat/adni-mri-classification/experiments/mresnet18_20250821_190324_CN_vs_AD/model.py

In [3]:
import torch

# Load test dataset
test_dataset = ADNIDataset(cfg['data']['test_csv'])
loader = DataLoader(test_dataset,1,shuffle=False)

model.eval()

samples = []
labels = []
preds = []
for imgs,lbls in loader:
    samples.append(imgs)
    labels.append(lbls)
    pred = torch.round(torch.sigmoid(model(imgs))).squeeze(1)
    preds.append(pred.detach().numpy())

report = pd.DataFrame(classification_report(labels, preds, output_dict=True)).T
cm = pd.DataFrame(confusion_matrix(labels, preds))

print(report)
print(cm)

NameError: name 'ADNIDataset' is not defined

In [None]:
# ...existing code...
idx = 35

model.eval()
import torch
import torch.nn.functional as F

with torch.no_grad():
    out = model(samples[idx])                     # (1, num_classes)
    pred = torch.round(torch.sigmoid(out)).squeeze(1)   # predicted class as int

    fmap = model.feature_maps.squeeze() # get feature map from model                   

    print(f"Feature map shape: {fmap.shape}")
    
    # make sure fmap has a batch dim: (1, C, d, h, w) or (1, 1, d, h, w)
    if fmap.dim() == 4:       # (C, d, h, w) or (d, h, w)
        fmap = fmap.unsqueeze(0)
    if fmap.dim() == 3:       # (d, h, w)
        fmap = fmap.unsqueeze(0).unsqueeze(0)

    # select channel corresponding to predicted class when possible
    vis = pred
    if fmap.size(1) > 1 and vis < fmap.size(1):
        sel = fmap[:, vis:vis+1]   # (1,1,d',h',w')
    else:
        sel = fmap[:, :1]           # fallback to first channel

    target_size = samples[idx].shape[-3:]   # (D,H,W) e.g. (79,95,79)
    up = F.interpolate(sel.float(), size=target_size, mode='trilinear', align_corners=False)
    feature_map_upsampled = up.squeeze().cpu()   # (D,H,W)

print(f"True label: {labels[idx][0].item()}, Predicted: {pred}")

gradcam_vol = feature_map_upsampled
figs = [
    create_slice_plot(samples[idx], axis=0, gradcam_volume=gradcam_vol),
    create_slice_plot(samples[idx], axis=1, gradcam_volume=gradcam_vol),
    create_slice_plot(samples[idx], axis=2, gradcam_volume=gradcam_vol),
]

html_figs = [pio.to_html(fig, full_html=False, include_plotlyjs='cdn') for fig in figs]
HTML(f"<div style='display:flex;gap:0px;'>{''.join(html_figs)}</div>")


Feature map shape: torch.Size([16, 19, 23, 19])


TypeError: only integer tensors of a single element can be converted to an index

In [None]:
idx = 94

print(f"True label: {labels[idx]}")
print(f"Predicted label: {pred[idx]}")

gradcam_vol = compute_gradcam(model, samples[idx], 'b4', pred[idx])
figs = [
    create_slice_plot(samples[idx], axis=0, gradcam_volume=gradcam_vol),
    create_slice_plot(samples[idx], axis=1, gradcam_volume=gradcam_vol),
    create_slice_plot(samples[idx], axis=2, gradcam_volume=gradcam_vol),
]

html_figs = [pio.to_html(fig, full_html=False, include_plotlyjs='cdn') for fig in figs]
HTML(f"<div style='display:flex;gap:0px;'>{''.join(html_figs)}</div>")


In [None]:
# Initialize average Grad-CAM
avg_cn_cam = 0
avg_mci_cam = 0
avg_ad_cam = 0

# Initialize average volumes
avg_cn_vol = 0
avg_mci_vol = 0
avg_ad_vol = 0

# Predictions
pred = []

# Compute Grad-CAM for each sample and save to file
for i, sample in enumerate(samples):
    
    print("Processing sample", i+1, "/", len(samples), end='\r')
    
    cam_overlay = compute_gradcam(model, sample, 'b4', labels[i])
    
    # compile prediction
    pred.append(model(sample).max(1, keepdim=True)[1].detach().numpy())
    
    if labels[i] == 0:
        avg_cn_cam += cam_overlay
        avg_cn_vol += samples[i]
    elif labels[i] == 1:
        avg_mci_cam += cam_overlay
        avg_mci_vol += samples[i]
    elif labels[i] == 2:
        avg_ad_cam += cam_overlay
        avg_ad_vol += samples[i]
    
# Labels counts
num_cn = sum([1 for lbl in labels if lbl == 0])
num_mci = sum([1 for lbl in labels if lbl == 1])
num_ad = sum([1 for lbl in labels if lbl == 2])

# Average Grad-CAMs
if num_cn > 0:
    avg_cn_cam /= num_cn
    avg_cn_vol /= num_cn
    
if num_mci > 0:
    avg_mci_cam /= num_mci
    avg_mci_vol /= num_mci
    
if num_ad > 0:
    avg_ad_cam /= num_ad
    avg_ad_vol /= num_ad

In [None]:
vol = avg_ad_vol
cam_vol = avg_ad_cam

figs = [
    create_slice_plot(vol, axis=0, gradcam_volume=cam_vol),
    create_slice_plot(vol, axis=1, gradcam_volume=cam_vol),
    create_slice_plot(vol, axis=2, gradcam_volume=cam_vol),
]

html_figs = [pio.to_html(fig, full_html=False, include_plotlyjs='cdn') for fig in figs]
HTML(f"<div style='display:flex;gap:0px;'>{''.join(html_figs)}</div>")