# MedMNIST DermaMNIST — KD Evaluation (Pretrained Checkpoints)
This notebook loads your **pretrained** MedMNIST checkpoints for teacher (ResNet-50) and students (ResNet-18, MobileNetV2, EfficientNet-B0), evaluates them, logs a single-point ablation record per student, computes efficiency stats, and produces Grad-CAM & t-SNE visualizations.

**It does not train.**


In [1]:

# --- Paths & configuration ---
import os
from pathlib import Path

PROJECT_ROOT = "."
REPORTS_ROOT = os.path.join(PROJECT_ROOT, "reports")
FIGS_ROOT    = os.path.join(PROJECT_ROOT, "figs", "medmnist")
TABLES_ROOT  = os.path.join(PROJECT_ROOT, "tables", "medmnist")

# User-provided pretrained checkpoints
CKPT_MAP = {
    "resnet50":       "MedMNIST-EdgeAI/models/resnet50_teacher_dermamnist.pth",
    "resnet18":       "MedMNIST-EdgeAI/models/resnet18/resnet18_dermamnist_student.pth",
    "mobilenet_v2":   "MedMNIST-EdgeAI/models/mobilenet_v2/mobilenet_v2_dermamnist_student.pth",
    "efficientnet_b0":"MedMNIST-EdgeAI/models/efficientnet_b0/efficientnet_b0_dermamnist_student.pth",
}

os.makedirs(REPORTS_ROOT, exist_ok=True)
os.makedirs(FIGS_ROOT, exist_ok=True)
os.makedirs(TABLES_ROOT, exist_ok=True)

print("REPORTS_ROOT:", REPORTS_ROOT)
print("FIGS_ROOT   :", FIGS_ROOT)
print("TABLES_ROOT :", TABLES_ROOT)


REPORTS_ROOT: .\reports
FIGS_ROOT   : .\figs\medmnist
TABLES_ROOT : .\tables\medmnist


In [2]:

# --- Dependencies ---
# If running in a fresh env, you may need:
# %pip install -q medmnist torch torchvision scikit-learn pandas matplotlib thop psutil

import math, time, json, random
import numpy as np
import pandas as pd
from pathlib import Path

import torch, torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import torchvision.transforms as T
import torchvision.models as tv

import matplotlib.pyplot as plt
from matplotlib import cm
from PIL import Image

from sklearn.metrics import (classification_report, confusion_matrix,
                             precision_recall_curve, average_precision_score, f1_score)
from sklearn.manifold import TSNE

try:
    from thop import profile as thop_profile
except Exception:
    thop_profile = None

import psutil

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SEED = 123
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
if torch.cuda.is_available(): torch.cuda.manual_seed_all(SEED)
print("DEVICE:", DEVICE)


DEVICE: cuda


In [3]:

# --- Dataset: DermaMNIST ---
from medmnist import DermaMNIST

def class_names_for_medmnist():
    return ["akiec","bcc","bkl","df","mel","nv","vasc"]

IMGSZ = 224
TRANS_TRAIN = T.Compose([
    T.Resize((IMGSZ,IMGSZ)),
    T.ToTensor(),
    T.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
])
TRANS_TEST = T.Compose([
    T.Resize((IMGSZ,IMGSZ)),
    T.ToTensor(),
    T.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
])

trainset = DermaMNIST(split='train', transform=TRANS_TRAIN, download=True)
valset   = DermaMNIST(split='val',   transform=TRANS_TEST,  download=True)
testset  = DermaMNIST(split='test',  transform=TRANS_TEST,  download=True)

C = len(class_names_for_medmnist())
print("Train/Val/Test:", len(trainset), len(valset), len(testset), "Classes:", C)


Train/Val/Test: 7007 1003 2005 Classes: 7


In [4]:

# --- Model builders and checkpoint loader ---
def build_arch(tag: str, num_classes: int):
    tag = tag.lower()
    if tag == "resnet50":
        m = tv.resnet50(weights=None)
        m.fc = nn.Linear(m.fc.in_features, num_classes)
        return m
    if tag == "resnet18":
        m = tv.resnet18(weights=None)
        m.fc = nn.Linear(m.fc.in_features, num_classes)
        return m
    if tag in ("mobilenet_v2","mbv2"):
        m = tv.mobilenet_v2(weights=None)
        m.classifier[-1] = nn.Linear(m.classifier[-1].in_features, num_classes)
        return m
    if tag in ("efficientnet_b0","effb0"):
        m = tv.efficientnet_b0(weights=None)
        m.classifier[-1] = nn.Linear(m.classifier[-1].in_features, num_classes)
        return m
    raise ValueError(f"Unknown arch: {tag}")

import re

def load_checkpoint_flexible(model: nn.Module, ckpt_path: str):
    state = torch.load(ckpt_path, map_location="cpu")
    if isinstance(state, dict) and "state_dict" in state and isinstance(state["state_dict"], dict):
        state = state["state_dict"]
    new_sd = {}
    for k,v in state.items():
        nk = re.sub(r"^(module\.|model\.|student\.|net\.)", "", k)
        new_sd[nk] = v
    missing, unexpected = model.load_state_dict(new_sd, strict=False)
    print(f"[load] {os.path.basename(ckpt_path)}  missing={len(missing)}  unexpected={len(unexpected)}")
    return model

def load_user_model(arch_tag: str, ckpt_map: dict, num_classes: int, device: str):
    assert arch_tag in ckpt_map, f"No checkpoint mapping for {arch_tag}"
    path = ckpt_map[arch_tag]
    if not Path(path).exists():
        raise FileNotFoundError(f"Checkpoint not found: {path}")
    m = build_arch(arch_tag, num_classes).to(device).eval()
    print(f"[Load] {arch_tag} <- {path}")
    return load_checkpoint_flexible(m, path)


In [5]:

# --- Load teacher and students ---
num_classes = 7
teacher = load_user_model("resnet50", CKPT_MAP, num_classes, DEVICE)
student_resnet18 = load_user_model("resnet18", CKPT_MAP, num_classes, DEVICE)
student_mbv2 = load_user_model("mobilenet_v2", CKPT_MAP, num_classes, DEVICE)
student_effb0 = load_user_model("efficientnet_b0", CKPT_MAP, num_classes, DEVICE)

MODELS = {
    "resnet50": teacher,
    "resnet18": student_resnet18,
    "mobilenet_v2": student_mbv2,
    "efficientnet_b0": student_effb0,
}


[Load] resnet50 <- MedMNIST-EdgeAI/models/resnet50_teacher_dermamnist.pth
[load] resnet50_teacher_dermamnist.pth  missing=0  unexpected=0
[Load] resnet18 <- MedMNIST-EdgeAI/models/resnet18/resnet18_dermamnist_student.pth
[load] resnet18_dermamnist_student.pth  missing=0  unexpected=0
[Load] mobilenet_v2 <- MedMNIST-EdgeAI/models/mobilenet_v2/mobilenet_v2_dermamnist_student.pth
[load] mobilenet_v2_dermamnist_student.pth  missing=0  unexpected=0
[Load] efficientnet_b0 <- MedMNIST-EdgeAI/models/efficientnet_b0/efficientnet_b0_dermamnist_student.pth
[load] efficientnet_b0_dermamnist_student.pth  missing=0  unexpected=0


In [6]:

# --- Evaluation helpers ---
def eval_model(model: nn.Module, loader: DataLoader, classes: list):
    y_true, y_prob = [], []
    model.eval()
    with torch.inference_mode():
        for x,y in loader:
            x = x.to(DEVICE); y = y.long().to(DEVICE)
            p = torch.softmax(model(x), dim=1).cpu().numpy()
            y_prob.append(p)
            y_true.append(y.cpu().numpy())
    y_true = np.concatenate(y_true)
    y_prob = np.concatenate(y_prob)
    y_pred = y_prob.argmax(1)
    report = classification_report(y_true, y_pred, target_names=classes, output_dict=True, zero_division=0)
    cm = confusion_matrix(y_true, y_pred, labels=list(range(len(classes))))
    return y_true, y_prob, y_pred, report, cm

def save_perclass_and_plots(tag: str, y_true, y_prob, classes: list, out_tab_dir: str, out_fig_dir: str):
    os.makedirs(out_tab_dir, exist_ok=True); os.makedirs(out_fig_dir, exist_ok=True)
    y_pred = y_prob.argmax(1)
    report = classification_report(y_true, y_pred, target_names=classes, output_dict=True, zero_division=0)
    df = pd.DataFrame(report).transpose()
    df.to_csv(os.path.join(out_tab_dir, f"{tag}_perclass_metrics.csv"))
    try:
        df.to_latex(os.path.join(out_tab_dir, f"{tag}_perclass_metrics.tex"), float_format="%.3f")
    except Exception:
        pass

    # Confusion matrix plot
    cmx = confusion_matrix(y_true, y_pred, labels=list(range(len(classes))))
    fig = plt.figure(figsize=(7,6)); plt.imshow(cmx, interpolation="nearest")
    plt.title(f"Confusion: {tag}"); plt.colorbar()
    plt.xticks(range(len(classes)), classes, rotation=45, ha="right")
    plt.yticks(range(len(classes)), classes)
    for i in range(cmx.shape[0]):
        for j in range(cmx.shape[1]):
            plt.text(j, i, str(cmx[i,j]), ha="center", va="center", fontsize=8)
    plt.tight_layout(); fig.savefig(os.path.join(out_fig_dir, f"{tag}_confmat.png"), dpi=200); plt.close(fig)

    # PR curves
    fig = plt.figure(figsize=(7,6))
    for c in range(len(classes)):
        gt = (y_true==c).astype(int)
        prec, rec, _ = precision_recall_curve(gt, y_prob[:,c])
        ap = average_precision_score(gt, y_prob[:,c])
        plt.plot(rec, prec, label=f"{classes[c]} (AP={ap:.3f})")
    plt.xlabel("Recall"); plt.ylabel("Precision"); plt.legend(); plt.title(f"PR Curves: {tag}")
    fig.savefig(os.path.join(out_fig_dir, f"{tag}_pr_curves.png"), dpi=200); plt.close(fig)


In [7]:

# --- Evaluate and save artifacts for all models ---
classes = class_names_for_medmnist()
test_loader = DataLoader(testset, batch_size=128, shuffle=False, num_workers=2, pin_memory=True)

for tag, model in MODELS.items():
    y_true, y_prob, y_pred, report, cmx = eval_model(model, test_loader, classes)
    save_perclass_and_plots(tag, y_true, y_prob, classes, TABLES_ROOT, FIGS_ROOT)
print("[OK] Metrics, confusions, PR saved to", TABLES_ROOT, "and", FIGS_ROOT)


[OK] Metrics, confusions, PR saved to .\tables\medmnist and .\figs\medmnist


In [8]:

# --- Single-point ablation records (alpha, tau, beta defaults for pretrained students) ---
from sklearn.metrics import f1_score

def ece_score_torch(probs: torch.Tensor, y: torch.Tensor, n_bins=15):
    conf, preds = probs.max(1).values, probs.argmax(1)
    bins = torch.linspace(0,1,n_bins+1, device=probs.device)
    ece = torch.zeros(1, device=probs.device)
    for i in range(n_bins):
        m = (conf >= bins[i]) & (conf < bins[i+1])
        if m.sum() == 0: continue
        acc = (preds[m]==y[m]).float().mean()
        conf_m = conf[m].mean()
        ece += (m.float().mean()) * (conf_m - acc).abs()
    return ece.item()

ABL_DIR = Path(REPORTS_ROOT) / "ablation_medmnist"
ABL_DIR.mkdir(parents=True, exist_ok=True)

defaults = {
    "resnet18":       {"alpha":0.5,"tau":4.0,"beta":0.0},
    "mobilenet_v2":   {"alpha":0.5,"tau":4.0,"beta":0.0},
    "efficientnet_b0":{"alpha":0.5,"tau":4.0,"beta":0.0},
}

for st_tag, hp in defaults.items():
    m = MODELS[st_tag]
    ys=[]; ps=[]
    with torch.inference_mode():
        for x,y in test_loader:
            x = x.to(DEVICE); y = y.long().to(DEVICE)
            p = torch.softmax(m(x), dim=1).cpu()
            ys.append(y.cpu()); ps.append(p)
    y_true = torch.cat(ys).numpy()
    y_prob = torch.cat(ps).numpy()
    y_pred = y_prob.argmax(1)
    macf1 = f1_score(y_true, y_pred, average="macro")
    ece = ece_score_torch(torch.from_numpy(y_prob), torch.from_numpy(y_true))

    outdir = ABL_DIR / f"fixed_{st_tag}_medmnist"
    outdir.mkdir(parents=True, exist_ok=True)
    rec = {"epoch": 0, "macro_f1": float(macf1), "ece": float(ece),
           "alpha": float(hp["alpha"]), "tau": float(hp["tau"]), "beta": float(hp["beta"]),
           "student": st_tag, "dataset": "medmnist"}
    with open(outdir/"metrics.jsonl","w") as f: f.write(json.dumps(rec)+"\n")
print("[OK] Ablation jsonl written under", str(ABL_DIR))


[OK] Ablation jsonl written under reports\ablation_medmnist


In [12]:
# --- Efficiency: Params / FLOPs / Latency (CPU/GPU) / RAM ---
import copy, time, psutil, numpy as np, torch
from types import MethodType

def _params_m(model: torch.nn.Module) -> float:
    return sum(p.numel() for p in model.parameters()) / 1e6

def _strip_thop_artifacts(model: torch.nn.Module):
    """
    Remove THOP-added hooks/attrs so forward passes don't try to touch CPU tensors on CUDA.
    Works on the passed instance (typically a deepcopy used for timing).
    """
    for m in model.modules():
        # restore original forward if THOP wrapped it
        for attr in ("__original_forward__", "origin_forward", "_origin_forward"):
            if hasattr(m, attr):
                try:
                    m.forward = MethodType(getattr(m, attr), m)
                except Exception:
                    pass
        # clear forward/pre hooks
        if hasattr(m, "_forward_hooks") and isinstance(m._forward_hooks, dict):
            m._forward_hooks = {}
        if hasattr(m, "_forward_pre_hooks") and isinstance(m._forward_pre_hooks, dict):
            m._forward_pre_hooks = {}
        if hasattr(m, "_backward_hooks") and isinstance(m._backward_hooks, dict):
            m._backward_hooks = {}
        # remove THOP counters
        for a in ("total_ops", "total_params"):
            if hasattr(m, a):
                try:
                    delattr(m, a)
                except Exception:
                    pass
    return model

def _flops_g_cpu(model: torch.nn.Module, imgsz=224):
    # THOP on a CPU-only deepcopy; never re-use that instance.
    if thop_profile is None:
        return float("nan")
    try:
        m_cpu = copy.deepcopy(model).cpu().eval()
        x_cpu = torch.randn(1, 3, imgsz, imgsz)  # CPU tensor
        macs, _ = thop_profile(m_cpu, inputs=(x_cpu,), verbose=False)
        return macs / 1e9
    except Exception as e:
        print("[thop] warn:", e)
        return float("nan")

@torch.inference_mode()
def _latency_ms(model: torch.nn.Module, device: str, imgsz=224, reps=20, warmup=10):
    # Benchmark on a fresh, cleaned copy
    m = copy.deepcopy(model).eval()
    _strip_thop_artifacts(m)
    m = m.to(device)
    x = torch.randn(1, 3, imgsz, imgsz, device=device)

    # warmup
    if device == "cuda":
        for _ in range(warmup):
            _ = m(x)
        torch.cuda.synchronize()
    else:
        for _ in range(warmup):
            _ = m(x)

    # timed
    times = []
    for _ in range(reps):
        t0 = time.time()
        _ = m(x)
        if device == "cuda":
            torch.cuda.synchronize()
        times.append((time.time() - t0) * 1000.0)
    return float(np.mean(times)), float(np.std(times))

@torch.inference_mode()
def _gpu_peak_mb(model: torch.nn.Module, imgsz=224):
    if not torch.cuda.is_available():
        return float("nan")
    m = copy.deepcopy(model).eval()
    _strip_thop_artifacts(m)
    m = m.to("cuda")
    x = torch.randn(1, 3, imgsz, imgsz, device="cuda")
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    _ = m(x)
    torch.cuda.synchronize()
    peak = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
    torch.cuda.empty_cache()
    return peak

def model_stats(model: torch.nn.Module, imgsz=224, reps=20, warmup=10):
    stats = {}
    stats["params_M"] = _params_m(model)
    stats["FLOPs_G"]  = _flops_g_cpu(model, imgsz=imgsz)

    # latency
    lat_cpu_mean, lat_cpu_std = _latency_ms(model, "cpu", imgsz=imgsz, reps=reps, warmup=warmup)
    stats["latency_ms_cpu_mean"] = lat_cpu_mean
    stats["latency_ms_cpu_std"]  = lat_cpu_std

    if torch.cuda.is_available():
        lat_gpu_mean, lat_gpu_std = _latency_ms(model, "cuda", imgsz=imgsz, reps=reps, warmup=warmup)
        stats["latency_ms_gpu_mean"] = lat_gpu_mean
        stats["latency_ms_gpu_std"]  = lat_gpu_std
        stats["gpu_peak_MB"]         = _gpu_peak_mb(model, imgsz=imgsz)
    else:
        stats["latency_ms_gpu_mean"] = float("nan")
        stats["latency_ms_gpu_std"]  = float("nan")
        stats["gpu_peak_MB"]         = float("nan")

    # process RAM (RSS)
    stats["ram_rss_MB"] = psutil.Process(os.getpid()).memory_info().rss / (1024.0 * 1024.0)
    return stats

# --- collect & save ---
rows = []
for tag, m in MODELS.items():
    s = model_stats(m, imgsz=224, reps=20, warmup=10)
    s["model"] = tag
    rows.append(s)

df_eff = pd.DataFrame(rows)[[
    "model","params_M","FLOPs_G","latency_ms_gpu_mean","latency_ms_cpu_mean",
    "ram_rss_MB","gpu_peak_MB"
]]
df_eff.to_csv(os.path.join(TABLES_ROOT, "efficiency_medmnist.csv"), index=False)
try:
    df_eff.to_latex(os.path.join(TABLES_ROOT, "efficiency_medmnist.tex"), float_format="%.3f", index=False)
except Exception:
    pass
df_eff


Unnamed: 0,model,params_M,FLOPs_G,latency_ms_gpu_mean,latency_ms_cpu_mean,ram_rss_MB,gpu_peak_MB
0,resnet50,23.522375,8.263418,8.54038,101.889646,2087.925781,424.933594
1,resnet18,11.180103,1.823525,2.758622,49.473834,2090.476562,376.927246
2,mobilenet_v2,2.232839,0.326216,2.789652,13.267028,2009.359375,335.947754
3,efficientnet_b0,4.016515,0.413874,5.286169,30.118823,2020.242188,342.546387


In [15]:
# --- Grad-CAM (no OpenCV) — THOP-safe ---
import copy, os
import numpy as np, torch
import torch.nn as nn
import torch.nn.functional as F
from types import MethodType
from PIL import Image
from matplotlib import cm

def _strip_thop_artifacts(model: nn.Module):
    """Remove THOP-added wrappers/counters, but DO NOT replace hook dict objects."""
    for m in model.modules():
        # restore original forward if THOP wrapped it
        for attr in ("__original_forward__", "origin_forward", "_origin_forward"):
            if hasattr(m, attr):
                try:
                    m.forward = MethodType(getattr(m, attr), m)
                except Exception:
                    pass
        # clear hooks IN PLACE (do not assign a new dict)
        if hasattr(m, "_forward_hooks") and m._forward_hooks is not None:
            try: m._forward_hooks.clear()
            except Exception: pass
        if hasattr(m, "_forward_pre_hooks") and m._forward_pre_hooks is not None:
            try: m._forward_pre_hooks.clear()
            except Exception: pass
        if hasattr(m, "_backward_hooks") and m._backward_hooks is not None:
            try: m._backward_hooks.clear()
            except Exception: pass
        # remove THOP counters if present
        for a in ("total_ops", "total_params"):
            if hasattr(m, a):
                try: delattr(m, a)
                except Exception: pass
    return model

def pick_last_conv(model: nn.Module):
    # resnet-style
    if hasattr(model, "layer4"):
        return model.layer4[-1]
    # mobilenet/efficientnet-style
    if hasattr(model, "features") and len(list(model.features.children())) > 0:
        last = None
        for mod in model.features.modules():
            if isinstance(mod, nn.Conv2d):
                last = mod
        if last is not None:
            return last
        return list(model.features.children())[-1]
    # generic fallback
    last = None
    for mod in model.modules():
        if isinstance(mod, nn.Conv2d): last = mod
    if last is None:
        raise RuntimeError("No conv layer found for Grad-CAM")
    return last

def gradcam_map(model: nn.Module, x4d: torch.Tensor, target_layer: nn.Module, class_idx: int):
    feats, grads = {}, {}

    def fwd_hook(m, i, o): feats["z"] = o
    def bwd_hook(m, gi, go): grads["g"] = go[0]

    # register hooks (works because we didn't replace the dict objects)
    h1 = target_layer.register_forward_hook(fwd_hook)
    h2 = target_layer.register_full_backward_hook(bwd_hook)

    # ensure grad tracking is on
    was_training = model.training
    model.eval()
    for p in model.parameters():
        p.requires_grad_(True)

    with torch.enable_grad():
        logits = model(x4d)
        C = logits.shape[1]
        cls = int(max(0, min(class_idx, C - 1)))
        score = logits[:, cls].sum()
        model.zero_grad(set_to_none=True)
        score.backward(retain_graph=False)

    A = feats["z"].detach()          # [N,K,H',W']
    G = grads["g"].detach()          # [N,K,H',W']
    w = G.flatten(2).mean(2)         # [N,K]
    cam = (w[:, :, None, None] * A).sum(1, keepdim=True)  # [N,1,H',W']
    cam = torch.relu(cam)
    cam = cam / (cam.amax(dim=(2, 3), keepdim=True) + 1e-6)

    h1.remove(); h2.remove()
    model.train(was_training)
    return cam  # [N,1,H',W']

def denorm_to_uint8(x3: torch.Tensor):
    mean = torch.tensor([0.485, 0.456, 0.406], device=x3.device)[:, None, None]
    std  = torch.tensor([0.229, 0.224, 0.225], device=x3.device)[:, None, None]
    img = (x3*std + mean).clamp(0, 1).permute(1, 2, 0).contiguous().detach().cpu().numpy()
    return (img*255).astype(np.uint8)

def overlay_heatmap(rgb_u8: np.ndarray, cam_01: np.ndarray, alpha=0.35):
    H, W = rgb_u8.shape[:2]
    cam_t = torch.from_numpy(cam_01)[None, None, ...].float()
    cam_up = F.interpolate(cam_t, size=(H, W), mode="bilinear", align_corners=False)[0, 0].numpy()
    heat = (cm.get_cmap("jet")(cam_up)[..., :3] * 255).astype(np.uint8)
    out = (1 - alpha) * rgb_u8 + alpha * heat
    return np.clip(out, 0, 255).astype(np.uint8)

# --- Build class-balanced index list (graceful if a class has <K examples) ---
K = 8
classes = class_names_for_medmnist()
idxs_per_class = {c: [] for c in range(len(classes))}
for i in range(len(testset)):
    _, y = testset[i]
    y = int(np.array(y).squeeze())
    if 0 <= y < len(classes) and len(idxs_per_class[y]) < K:
        idxs_per_class[y].append(i)
    if all(len(v) >= K for v in idxs_per_class.values()):
        break

# --- Generate panels per model on a cleaned copy (CPU or CUDA) ---
os.makedirs(FIGS_ROOT, exist_ok=True)
for tag, model in MODELS.items():
    m = copy.deepcopy(model).eval()
    _strip_thop_artifacts(m)   # critical: keep hook dict objects, just clear
    m = m.to(DEVICE)
    target_layer = pick_last_conv(m)

    rows = []
    for c in range(len(classes)):
        tiles = []
        for idx in idxs_per_class[c]:
            x, _ = testset[idx]
            x = x.unsqueeze(0).to(DEVICE)
            with torch.enable_grad():
                cam = gradcam_map(m, x, target_layer, c)[0, 0].detach().cpu().numpy()
            rgb = denorm_to_uint8(x[0])
            tiles.append(overlay_heatmap(rgb, cam))
        if len(tiles) == 0:
            continue
        rows.append(np.concatenate(tiles, axis=1))
    if len(rows) == 0:
        print(f"[Grad-CAM] No rows for {tag}; skipping.")
        continue
    maxW = max(r.shape[1] for r in rows)
    rows_pad = []
    for r in rows:
        if r.shape[1] < maxW:
            pad = np.zeros((r.shape[0], maxW - r.shape[1], 3), dtype=r.dtype)
            r = np.concatenate([r, pad], axis=1)
        rows_pad.append(r)
    panel = np.concatenate(rows_pad, axis=0)
    Image.fromarray(panel).save(os.path.join(FIGS_ROOT, f"{tag}_gradcam_panel.png"))
print("[OK] Grad-CAM panels saved to", FIGS_ROOT)


  heat = (cm.get_cmap("jet")(cam_up)[..., :3] * 255).astype(np.uint8)


[OK] Grad-CAM panels saved to .\figs\medmnist


In [17]:
# --- t-SNE (teacher vs student) — THOP-safe, device-clean ---
import copy, os
import numpy as np, torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

def _clean_clone(model: nn.Module, device: str):
    m = copy.deepcopy(model).eval()
    _strip_thop_artifacts(m)      # from your Grad-CAM cell
    return m.to(device)

def penultimate_features(model: nn.Module, x4d: torch.Tensor):
    """
    Works for torchvision resnet*, mobilenet_v2, efficientnet_b0.
    """
    # ResNet family
    if hasattr(model, "layer4") and hasattr(model, "avgpool"):
        x = model.conv1(x4d); x = model.bn1(x); x = model.relu(x); x = model.maxpool(x)
        x = model.layer1(x); x = model.layer2(x); x = model.layer3(x); x = model.layer4(x)
        x = F.adaptive_avg_pool2d(x, 1).flatten(1)
        return x
    # MobileNetV2 / EfficientNet style (features -> GAP)
    if hasattr(model, "features"):
        x = model.features(x4d)
        x = F.adaptive_avg_pool2d(x, 1).flatten(1)
        return x
    # Fallback: last conv then GAP
    last = None
    for mod in model.modules():
        if isinstance(mod, nn.Conv2d): last = mod
    if last is None:
        raise RuntimeError("No conv block found for penultimate features")
    feats = {}
    def hk(_, __, o): feats["z"] = o
    h = last.register_forward_hook(hk)
    _ = model(x4d)
    h.remove()
    x = feats["z"]
    x = F.adaptive_avg_pool2d(x, 1).flatten(1)
    return x

@torch.inference_mode()
def collect_feats(model: nn.Module, loader: torch.utils.data.DataLoader, limit=1500):
    m = _clean_clone(model, DEVICE)   # <-- critical: deep copy + strip THOP + move to DEVICE
    Z, Y, n = [], [], 0
    for x, y in loader:
        x = x.to(DEVICE)
        z = penultimate_features(m, x).cpu().numpy()
        Z.append(z); Y.append(y.numpy()); n += len(y)
        if n >= limit:
            break
    Z = np.vstack(Z)[:limit]; Y = np.concatenate(Y)[:limit]
    return Z, Y

def procrustes_align(A, B):
    # Align A to B (orthogonal)
    A0 = A - A.mean(0); B0 = B - B.mean(0)
    U, _, Vt = np.linalg.svd(A0.T @ B0)
    R = U @ Vt
    return (A0 @ R) + B.mean(0)

# smaller loader to avoid pulling full test at once
small_loader = DataLoader(testset, batch_size=128, shuffle=False, num_workers=2, pin_memory=True)

# Teacher embedding
Zt, Yt = collect_feats(teacher, small_loader, limit=1500)
T2 = TSNE(n_components=2, init="pca", learning_rate="auto", perplexity=30, random_state=0).fit_transform(Zt)

# Students
for st in ["resnet18","mobilenet_v2","efficientnet_b0"]:
    Zs, Ys = collect_feats(MODELS[st], small_loader, limit=1500)
    S2 = TSNE(n_components=2, init="pca", learning_rate="auto", perplexity=30, random_state=0).fit_transform(Zs)
    S2a = procrustes_align(S2, T2)
    fig = plt.figure(figsize=(6,5))
    plt.scatter(T2[:,0], T2[:,1], s=6, alpha=0.35, label="resnet50")
    plt.scatter(S2a[:,0], S2a[:,1], s=6, alpha=0.35, label=st)
    plt.title(f"t-SNE (DermaMNIST) — teacher vs {st}")
    plt.legend()
    fig.savefig(os.path.join(FIGS_ROOT, f"tsne_teacher_vs_{st}.png"), dpi=200, bbox_inches="tight")
    plt.close(fig)
print("[OK] t-SNE plots saved to", FIGS_ROOT)


[OK] t-SNE plots saved to .\figs\medmnist
