# MedMNIST (DermaMNIST) — End-to-End KD Pipeline Notebook
**Teacher**: ResNet-50  
**Students**: ResNet-18, MobileNetV2, EfficientNet-B0  
This notebook:
- fetches DermaMNIST
- trains teacher & students (with KD + optional Attention Transfer)
- evaluates per-class metrics, confusions, PR curves
- runs α/τ/β ablations, Grad-CAM, t-SNE
- measures params, FLOPs, latency CPU/GPU, RAM peak
- writes artifacts to `figs/`, `tables/`, and `reports/`
- logs to TensorBoard
It will **reuse existing checkpoints** under your `models/` tree when available.


In [1]:
import os
# --- Notebook config|
PROJECT_ROOT = "."  # set to your repo root if running elsewhere
DATA_ROOT    = os.path.join(PROJECT_ROOT, "MedMNIST-EdgeAI", "data")  # optional; medmnist downloads will pick a cache too
REPORTS_ROOT = os.path.join(PROJECT_ROOT, "reports")
FIGS_ROOT    = os.path.join(PROJECT_ROOT, "figs")
TABLES_ROOT  = os.path.join(PROJECT_ROOT, "tables")
MODELS_ROOT  = os.path.join(PROJECT_ROOT, "models")

os.makedirs(REPORTS_ROOT, exist_ok=True)
os.makedirs(FIGS_ROOT, exist_ok=True)
os.makedirs(TABLES_ROOT, exist_ok=True)

print("Using:")
print("  PROJECT_ROOT:", PROJECT_ROOT)
print("  REPORTS_ROOT:", REPORTS_ROOT)
print("  FIGS_ROOT:", FIGS_ROOT)
print("  TABLES_ROOT:", TABLES_ROOT)
print("  MODELS_ROOT:", MODELS_ROOT)


Using:
  PROJECT_ROOT: .
  REPORTS_ROOT: .\reports
  FIGS_ROOT: .\figs
  TABLES_ROOT: .\tables
  MODELS_ROOT: .\models


In [None]:

# If running in a fresh env, uncomment the next lines.
# %pip install -q medmnist torch torchvision torchaudio scikit-learn pandas matplotlib tensorboard thop psutil
# %pip install -q timm  # optional if using timm models


In [2]:

import os, json, time, math, random, psutil
from pathlib import Path
import numpy as np
import pandas as pd
import torch, torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.transforms as T
import torchvision.models as tv
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report, confusion_matrix, precision_recall_curve, average_precision_score, f1_score
from sklearn.manifold import TSNE
from torch.utils.tensorboard import SummaryWriter

try:
    from thop import profile as thop_profile
except Exception:
    thop_profile = None

SEED = 123
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
if torch.cuda.is_available(): torch.cuda.manual_seed_all(SEED)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("DEVICE =", DEVICE)


DEVICE = cuda


In [3]:

from medmnist import DermaMNIST

def class_names_for(dataset):
    if dataset == "medmnist":
        return ["akiec","bcc","bkl","df","mel","nv","vasc"]
    raise ValueError

IMGSZ = 224
TRANS_TRAIN = T.Compose([
    T.Resize((IMGSZ,IMGSZ)),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
])
TRANS_TEST = T.Compose([
    T.Resize((IMGSZ,IMGSZ)),
    T.ToTensor(),
    T.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
])

trainset = DermaMNIST(split='train', transform=TRANS_TRAIN, download=True)
valset   = DermaMNIST(split='val',   transform=TRANS_TEST,  download=True)
testset  = DermaMNIST(split='test',  transform=TRANS_TEST,  download=True)

C = len(class_names_for("medmnist"))
print("Train/Val/Test:", len(trainset), len(valset), len(testset), "Classes:", C)


Train/Val/Test: 7007 1003 2005 Classes: 7


In [4]:

def build_arch(tag: str, num_classes: int):
    tag = tag.lower()
    if tag == "resnet50":
        m = tv.resnet50(weights=None)
        m.fc = nn.Linear(m.fc.in_features, num_classes)
        return m
    if tag == "resnet18":
        m = tv.resnet18(weights=None)
        m.fc = nn.Linear(m.fc.in_features, num_classes)
        return m
    if tag in ("mbv2","mobilenet_v2"):
        m = tv.mobilenet_v2(weights=None)
        m.classifier[-1] = nn.Linear(m.classifier[-1].in_features, num_classes)
        return m
    if tag in ("effb0","efficientnet_b0"):
        m = tv.efficientnet_b0(weights=None)
        m.classifier[-1] = nn.Linear(m.classifier[-1].in_features, num_classes)
        return m
    raise ValueError(tag)

def _load_state_smart(model, ckpt_path):
    ckpt = torch.load(ckpt_path, map_location="cpu")
    sd = ckpt.get("state_dict", ckpt if isinstance(ckpt, dict) else ckpt)
    new_sd = {}
    for k,v in sd.items():
        k2 = k
        for pref in ["model.","module.","student.","net."]:
            if k2.startswith(pref): k2 = k2[len(pref):]
        new_sd[k2] = v
    miss, unexp = model.load_state_dict(new_sd, strict=False)
    if miss:  print("[load_state] missing keys:", len(miss))
    if unexp: print("[load_state] unexpected keys:", len(unexp))
    return model

def discover_student_ckpt(dataset, tag):
    root = Path(MODELS_ROOT)/"students"
    tmap = {"resnet18":"resnet18","mbv2":"mobilenetv2","effb0":"efficientnetb0"}
    tagnorm = tmap[tag]
    pats = [
        root/f"distilled_{tagnorm}_{dataset}/ckpt-best.pth",
        root/f"*{dataset}*kdat*/ckpt-best.pth",
        root/"**/ckpt-best.pth",
        root/"**/ckpt-last.pth"
    ]
    for p in pats:
        hits = list(Path(root).glob(str(p.relative_to(root))))
        if hits: return hits[0].as_posix()
    return None

def discover_teacher_ckpt(dataset):
    root = Path(MODELS_ROOT)/"teachers"
    pats = [
        root/f"{dataset}_resnet50*/ckpt-best.pth",
        root/"**/ckpt-best.pth",
        root/"**/ckpt-last.pth"
    ]
    for p in pats:
        hits = list(Path(root).glob(str(p.relative_to(root))))
        if hits: return hits[0].as_posix()
    return None


In [5]:

def ece_score_torch(probs, y, n_bins=15):
    conf, preds = probs.max(1).values, probs.argmax(1)
    bins = torch.linspace(0,1,n_bins+1, device=probs.device)
    ece = torch.zeros(1, device=probs.device)
    for i in range(n_bins):
        m = (conf >= bins[i]) & (conf < bins[i+1])
        if m.sum() == 0: continue
        acc = (preds[m]==y[m]).float().mean()
        conf_m = conf[m].mean()
        ece += (m.float().mean()) * (conf_m - acc).abs()
    return ece.item()

def kd_loss(student_logits, teacher_logits, y, alpha, tau):
    ce = F.cross_entropy(student_logits, y)
    if alpha == 0.0: return ce
    p = F.log_softmax(student_logits / tau, dim=1)
    q = F.softmax(teacher_logits / tau, dim=1)
    kl = F.kl_div(p, q, reduction="batchmean") * (tau*tau)
    return (1 - alpha) * ce + alpha * kl

def at_loss(student_feat, teacher_feat):
    def attn(x):
        a = x.pow(2).mean(dim=1, keepdim=True)
        a = a / (a.norm(p=2, dim=(2,3), keepdim=True) + 1e-8)
        return a
    return F.mse_loss(attn(student_feat), attn(teacher_feat))

def pick_hook_layer(m):
    if hasattr(m, "layer4"): return m.layer4[-1]
    if hasattr(m, "features"): return list(m.features.children())[-1]
    for mod in m.modules():
        if isinstance(mod, nn.Conv2d): last = mod
    return last

def fit_supervised_teacher(dataset="medmnist", epochs=8, batch=64, lr=3e-4):
    classes = class_names_for("medmnist")
    C = len(classes)
    dl_tr = DataLoader(trainset, batch_size=batch, shuffle=True, num_workers=2, pin_memory=True)
    dl_va = DataLoader(valset, batch_size=batch, shuffle=False, num_workers=2, pin_memory=True)

    m = build_arch("resnet50", C).to(DEVICE)
    opt = torch.optim.AdamW(m.parameters(), lr=lr, weight_decay=1e-4)

    outdir = Path(MODELS_ROOT)/"teachers"/f"medmnist_resnet50_sup"
    outdir.mkdir(parents=True, exist_ok=True)
    writer = SummaryWriter(log_dir=str(outdir/"tb"))
    best_f1 = -1.0; best_path = None

    for ep in range(1, epochs+1):
        m.train()
        for x,y in dl_tr:
            x = x.to(DEVICE); y = y.long().to(DEVICE)
            logits = m(x)
            loss = F.cross_entropy(logits, y)
            opt.zero_grad(set_to_none=True); loss.backward(); opt.step()

        m.eval(); ys=[]; ps=[]
        with torch.inference_mode():
            for x,y in dl_va:
                x = x.to(DEVICE); y = y.long().to(DEVICE)
                prob = torch.softmax(m(x), dim=1).cpu()
                ys.append(y.cpu()); ps.append(prob)
        y_true = torch.cat(ys).numpy()
        y_prob = torch.cat(ps).numpy()
        y_pred = y_prob.argmax(1)
        macf1 = f1_score(y_true, y_pred, average="macro")
        ece = ece_score_torch(torch.from_numpy(y_prob), torch.from_numpy(y_true))
        writer.add_scalar("val/macro_f1", macf1, ep)
        writer.add_scalar("val/ece", ece, ep)

        if macf1 > best_f1:
            best_f1 = macf1
            best_path = outdir/"ckpt-best.pth"
            torch.save(m.state_dict(), best_path)
        torch.save(m.state_dict(), outdir/"ckpt-last.pth")

    writer.flush(); writer.close()
    with open(outdir/"metrics.json","w") as f:
        json.dump({"best_macro_f1":float(best_f1)}, f)
    return outdir.as_posix(), best_f1

def fit_student_kd(student_tag="resnet18", alpha=0.5, tau=4.0, beta=0.0, epochs=6, batch=64, lr=3e-4):
    C = len(class_names_for("medmnist"))
    # teacher
    t_ckpt = discover_teacher_ckpt("medmnist")
    if t_ckpt is None:
        print("[train] No medmnist teacher ckpt found. Training a supervised teacher quickly...")
        t_dir, _ = fit_supervised_teacher(epochs=6, batch=batch)
        t_ckpt = os.path.join(t_dir, "ckpt-best.pth")
    t = build_arch("resnet50", C).to(DEVICE)
    t = _load_state_smart(t, t_ckpt).eval()

    s = build_arch(student_tag, C).to(DEVICE)
    opt = torch.optim.AdamW(s.parameters(), lr=lr, weight_decay=1e-4)

    # feature hooks if beta>0
    if beta > 0:
        tl = pick_hook_layer(t); sl = pick_hook_layer(s)
        feat_t, feat_s = {}, {}
        tl.register_forward_hook(lambda m,i,o: feat_t.setdefault("z", o))
        sl.register_forward_hook(lambda m,i,o: feat_s.setdefault("z", o))

    dl_tr = DataLoader(trainset, batch_size=batch, shuffle=True, num_workers=2, pin_memory=True)
    dl_va = DataLoader(valset, batch_size=batch, shuffle=False, num_workers=2, pin_memory=True)

    tag = f"KD_a{alpha}_t{tau}_b{beta}_{student_tag}_medmnist"
    outdir = Path(REPORTS_ROOT)/"ablation_medmnist"/tag
    outdir.mkdir(parents=True, exist_ok=True)
    writer = SummaryWriter(log_dir=str(outdir/"tb"))
    best_f1 = -1.0

    for ep in range(1, epochs+1):
        s.train()
        for x,y in dl_tr:
            x = x.to(DEVICE); y = y.long().to(DEVICE)
            with torch.no_grad():
                t_logits = t(x)
            s_logits = s(x)
            loss = kd_loss(s_logits, t_logits, y, alpha, tau)
            if beta > 0:
                loss = loss + beta * at_loss(feat_s["z"], feat_t["z"])
                feat_t.clear(); feat_s.clear()
            opt.zero_grad(set_to_none=True); loss.backward(); opt.step()

        # val
        s.eval(); ys=[]; ps=[]
        with torch.inference_mode():
            for x,y in dl_va:
                x = x.to(DEVICE); y = y.long().to(DEVICE)
                prob = torch.softmax(s(x), dim=1).cpu()
                ys.append(y.cpu()); ps.append(prob)
        y_true = torch.cat(ys).numpy()
        y_prob = torch.cat(ps).numpy()
        y_pred = y_prob.argmax(1)
        macf1 = f1_score(y_true, y_pred, average="macro")
        ece = ece_score_torch(torch.from_numpy(y_prob), torch.from_numpy(y_true))
        writer.add_scalar("val/macro_f1", macf1, ep)
        writer.add_scalar("val/ece", ece, ep)

        rec = {"epoch": ep, "macro_f1": float(macf1), "ece": float(ece),
               "alpha": float(alpha), "tau": float(tau), "beta": float(beta),
               "student": student_tag, "dataset": "medmnist"}
        with open(outdir/"metrics.jsonl","a") as f: f.write(json.dumps(rec)+"\n")

        if macf1 > best_f1:
            best_f1 = macf1
            torch.save(s.state_dict(), outdir/"model_best.pth")
        torch.save(s.state_dict(), outdir/"model_last.pth")

    writer.flush(); writer.close()
    return outdir.as_posix(), best_f1


In [6]:

def eval_and_save(model, loader, classes, out_tab_dir, out_fig_dir, tag):
    os.makedirs(out_tab_dir, exist_ok=True)
    os.makedirs(out_fig_dir, exist_ok=True)
    ys=[]; ps=[]
    with torch.inference_mode():
        for x,y in loader:
            x = x.to(DEVICE); y = y.long().to(DEVICE)
            p = torch.softmax(model(x), dim=1).cpu()
            ys.append(y.cpu()); ps.append(p)
    y_true = torch.cat(ys).numpy()
    y_prob = torch.cat(ps).numpy()
    y_pred = y_prob.argmax(1)

    report = classification_report(y_true, y_pred, target_names=classes, output_dict=True, zero_division=0)
    df = pd.DataFrame(report).transpose()
    df.to_csv(os.path.join(out_tab_dir, f"{tag}_perclass_metrics.csv"))
    try:
        df.to_latex(os.path.join(out_tab_dir, f"{tag}_perclass_metrics.tex"), float_format="%.3f")
    except Exception:
        pass

    cm = confusion_matrix(y_true, y_pred, labels=list(range(len(classes))))
    fig = plt.figure(figsize=(7,6)); plt.imshow(cm, interpolation='nearest'); plt.title(f"Confusion: {tag}")
    plt.colorbar(); plt.xticks(range(len(classes)), classes, rotation=45, ha="right"); plt.yticks(range(len(classes)), classes)
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]): plt.text(j,i,str(cm[i,j]),ha="center",va="center")
    plt.tight_layout(); fig.savefig(os.path.join(out_fig_dir, f"{tag}_confmat.png"), dpi=200); plt.close(fig)

    # PR curves
    fig = plt.figure(figsize=(7,6))
    for c in range(len(classes)):
        prec, rec, _ = precision_recall_curve((y_true==c).astype(int), y_prob[:,c])
        ap = average_precision_score((y_true==c).astype(int), y_prob[:,c])
        plt.plot(rec, prec, label=f"{classes[c]} (AP={ap:.3f})")
    plt.xlabel("Recall"); plt.ylabel("Precision"); plt.legend(); plt.title(f"PR Curves: {tag}")
    fig.savefig(os.path.join(out_fig_dir, f"{tag}_pr_curves.png"), dpi=200); plt.close(fig)

    return y_true, y_prob


In [7]:

def model_stats(model, imgsz=IMGSZ, device=DEVICE, reps=20, warmup=10):
    x = torch.randn(1,3,imgsz,imgsz).to(device)
    # Params
    params_m = sum(p.numel() for p in model.parameters())/1e6

    # FLOPs if thop available
    flops_g = float("nan")
    if thop_profile is not None:
        try:
            macs, _ = thop_profile(model, inputs=(x,), verbose=False)
            flops_g = macs/1e9
        except Exception:
            pass

    # Latency CPU/GPU
    def bench_on(d):
        xm = x.to(d); model.to(d).eval()
        times=[]
        if d=="cuda":
            for _ in range(warmup):
                with torch.inference_mode():
                    _ = model(xm); torch.cuda.synchronize()
            torch.cuda.empty_cache()
        else:
            for _ in range(warmup):
                with torch.inference_mode():
                    _ = model(xm)
        for _ in range(reps):
            t0=time.time()
            with torch.inference_mode():
                _ = model(xm)
                if d=="cuda": torch.cuda.synchronize()
            times.append((time.time()-t0)*1000.0)
        return np.mean(times), np.std(times)

    lat_gpu, std_gpu = (float("nan"), float("nan"))
    if torch.cuda.is_available():
        lat_gpu, std_gpu = bench_on("cuda")
    lat_cpu, std_cpu = bench_on("cpu")

    # RAM peak (process RSS); GPU mem peak if CUDA
    rss_mb = psutil.Process(os.getpid()).memory_info().rss/1024/1024
    gpu_peak_mb = float("nan")
    if torch.cuda.is_available():
        torch.cuda.reset_peak_memory_stats()
        with torch.inference_mode():
            _ = model(x.to("cuda"))
            torch.cuda.synchronize()
        gpu_peak_mb = torch.cuda.max_memory_allocated()/1024/1024

    return {
        "params_M": params_m, "FLOPs_G": flops_g,
        "latency_ms_gpu_mean": lat_gpu, "latency_ms_gpu_std": std_gpu,
        "latency_ms_cpu_mean": lat_cpu, "latency_ms_cpu_std": std_cpu,
        "ram_rss_MB": rss_mb, "gpu_peak_MB": gpu_peak_mb
    }


In [8]:

def pick_last_layer(model: nn.Module):
    if hasattr(model, "layer4"): return model.layer4[-1]
    if hasattr(model, "features"): return list(model.features.children())[-1]
    for mod in model.modules():
        if isinstance(mod, nn.Conv2d): last = mod
    return last

def gradcam(model, x4d, target_layer, class_idx):
    model.eval()
    feats, grads = {}, {}
    def fwd_hook(m,i,o): feats["z"]=o
    def bwd_hook(m,gi,go): grads["g"]=go[0]
    h1 = target_layer.register_forward_hook(fwd_hook)
    h2 = target_layer.register_full_backward_hook(bwd_hook)
    logits = model(x4d); score = logits[0,class_idx]
    model.zero_grad(set_to_none=True); score.backward(retain_graph=True)
    A = feats["z"].detach(); G = grads["g"].detach()
    w = G.flatten(2).mean(2)  # [1,K]
    cam = (w[:,:,None,None]*A).sum(1)
    cam = torch.relu(cam)[0].cpu().numpy()
    cam = (cam - cam.min())/(cam.max()+1e-6)
    h1.remove(); h2.remove()
    return cam

def tsne_features(model, loader, k=1500):
    model.eval()
    X=[]; yall=[]
    with torch.inference_mode():
        for i,(x,y) in enumerate(loader):
            x = x.to(DEVICE)
            # penultimate
            if hasattr(model, "layer4"):
                z = model.conv1(x); z = model.bn1(z); z = model.relu(z); z = model.maxpool(z)
                z = model.layer1(z); z = model.layer2(z); z = model.layer3(z); z = model.layer4(z)
                z = F.adaptive_avg_pool2d(z, 1).flatten(1)
            else:
                z = model.features(x)
                z = F.adaptive_avg_pool2d(z, 1).flatten(1)
            X.append(z.cpu().numpy()); yall.append(y.numpy())
            if len(np.concatenate(yall)) >= k: break
    X = np.vstack(X); y = np.concatenate(yall)[:k]
    X = X[:k]
    ts = TSNE(n_components=2, perplexity=30, init="pca", learning_rate="auto", random_state=0).fit_transform(X)
    return ts, y


## Train / Load Teacher (ResNet-50) for DermaMNIST

In [None]:

# === Load teacher & students from user-provided checkpoints (no training) ===
import torch, os
from pathlib import Path
import re
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
num_classes = 7  # DermaMNIST

# Exact file mapping provided by user
ckpt_map = {
    "resnet50":      "MedMNIST-EdgeAI/models/resnet50_teacher_dermamnist.pth",
    "resnet18":      "MedMNIST-EdgeAI/models/resnet18/resnet18_dermamnist_student.pth",
    "mobilenet_v2":  "MedMNIST-EdgeAI/models/mobilenet_v2/mobilenet_v2_dermamnist_student.pth",
    "efficientnet_b0":"MedMNIST-EdgeAI/models/efficientnet_b0/efficientnet_b0_dermamnist_student.pth",
}

# Reuse the in-notebook arch builder
# (build_arch was defined earlier in the notebook)
def _load_state_flexible(model, path):
    state = torch.load(path, map_location="cpu")
    if isinstance(state, dict) and "state_dict" in state and isinstance(state["state_dict"], dict):
        state = state["state_dict"]
    new_sd = {}
    for k, v in state.items():
        nk = re.sub(r"^(module\.|model\.|student\.|net\.)", "", k)
        new_sd[nk] = v
    missing, unexpected = model.load_state_dict(new_sd, strict=False)
    print(f"[state] {os.path.basename(path)} missing={len(missing)} unexpected={len(unexpected)}")
    return model

def load_user_model(arch_tag):
    assert arch_tag in ckpt_map, f"No ckpt mapping for {arch_tag}"
    path = ckpt_map[arch_tag]
    if not Path(path).exists():
        raise FileNotFoundError(f"Checkpoint not found: {path}")
    m = build_arch(arch_tag, num_classes).to(DEVICE).eval()
    print(f"[Load] {arch_tag} <- {path}")
    return _load_state_flexible(m, path)

teacher = load_user_model("resnet50")
student_resnet18 = load_user_model("resnet18")
student_mbv2 = load_user_model("mobilenet_v2")
student_effb0 = load_user_model("efficientnet_b0")

# expose a dictionary for later cells
models_for_eval = {
    "resnet50": teacher,
    "resnet18": student_resnet18,
    "mobilenet_v2": student_mbv2,
    "efficientnet_b0": student_effb0,
}


[Load] resnet50 <- MedMNIST-EdgeAI/models/resnet50_teacher_dermamnist.pth


NameError: name 're' is not defined

## Evaluate Teacher and Save Artifacts

In [None]:

classes = class_names_for("medmnist")
test_loader = DataLoader(testset, batch_size=128, shuffle=False, num_workers=2, pin_memory=True)
tab_dir = os.path.join(TABLES_ROOT, "medmnist"); fig_dir = os.path.join(FIGS_ROOT, "medmnist")
y_true_T, y_prob_T = eval_and_save(teacher, test_loader, classes, tab_dir, fig_dir, tag="resnet50")
print("Teacher eval saved.")


## Train / Load Students with KD (quick sweep)

In [None]:

# === Register ablation records from provided checkpoints (single-point runs) ===
# We write one jsonl per 'run' so downstream ablation aggregator can read.
import json, os
from pathlib import Path
from sklearn.metrics import f1_score

REPORTS_ROOT = REPORTS_ROOT if 'REPORTS_ROOT' in globals() else "./reports"
ABL_DIR = Path(REPORTS_ROOT) / "ablation_medmnist"

classes = class_names_for("medmnist")
test_loader = DataLoader(testset, batch_size=128, shuffle=False, num_workers=2, pin_memory=True)

def ece_score_torch(probs, y, n_bins=15):
    conf, preds = probs.max(1).values, probs.argmax(1)
    bins = torch.linspace(0,1,n_bins+1, device=probs.device)
    ece = torch.zeros(1, device=probs.device)
    for i in range(n_bins):
        m = (conf >= bins[i]) & (conf < bins[i+1])
        if m.sum() == 0: continue
        acc = (preds[m]==y[m]).float().mean()
        conf_m = conf[m].mean()
        ece += (m.float().mean()) * (conf_m - acc).abs()
    return ece.item()

# Evaluate each provided student vs test set, write a single ablation row:
grid_defaults = {
    "resnet18":      {"alpha":0.5,"tau":4.0,"beta":0.0},
    "mobilenet_v2":  {"alpha":0.5,"tau":4.0,"beta":0.0},
    "efficientnet_b0":{"alpha":0.5,"tau":4.0,"beta":0.0},
}

best_dirs = {}  # to keep interface stable for later cells
for st_tag, hp in grid_defaults.items():
    m = models_for_eval[st_tag]
    # eval
    ys=[]; ps=[]
    with torch.inference_mode():
        for x,y in test_loader:
            x = x.to(DEVICE); y = y.long().to(DEVICE)
            p = torch.softmax(m(x), dim=1).cpu()
            ys.append(y.cpu()); ps.append(p)
    y_true = torch.cat(ys).numpy()
    y_prob = torch.cat(ps).numpy()
    y_pred = y_prob.argmax(1)
    macf1 = f1_score(y_true, y_pred, average="macro")
    ece = ece_score_torch(torch.from_numpy(y_prob), torch.from_numpy(y_true))
    # write a jsonl
    outdir = ABL_DIR / f"fixed_{st_tag}_medmnist"
    outdir.mkdir(parents=True, exist_ok=True)
    rec = {"epoch": 0, "macro_f1": float(macf1), "ece": float(ece),
           "alpha": float(hp["alpha"]), "tau": float(hp["tau"]), "beta": float(hp["beta"]),
           "student": st_tag, "dataset": "medmnist"}
    with open(outdir/"metrics.jsonl","w") as f: f.write(json.dumps(rec)+"\n")
    best_dirs[(st_tag,hp["alpha"],hp["tau"],hp["beta"])] = outdir.as_posix()

print("[OK] Wrote single-point ablation rows for provided student checkpoints.")


## Evaluate Students and Save Artifacts

In [None]:

# Evaluate already-loaded models and save artifacts
classes = class_names_for("medmnist")
test_loader = DataLoader(testset, batch_size=128, shuffle=False, num_workers=2, pin_memory=True)
for tag, m in models_for_eval.items():
    tab_dir = os.path.join(TABLES_ROOT, "medmnist")
    fig_dir = os.path.join(FIGS_ROOT, "medmnist")
    eval_and_save(m, test_loader, classes, tab_dir, fig_dir, tag=tag)
print("Student/Teacher eval saved from provided checkpoints.")


## Efficiency: Params / FLOPs / Latency / RAM

In [None]:

rows = []
for tag, m in models_for_eval.items():
    stat = model_stats(m)
    stat["model"] = tag
    rows.append(stat)

df_eff = pd.DataFrame(rows)[["model","params_M","FLOPs_G","latency_ms_gpu_mean","latency_ms_cpu_mean","ram_rss_MB","gpu_peak_MB"]]
df_eff.to_csv(os.path.join(TABLES_ROOT, "efficiency_medmnist.csv"), index=False)
try:
    df_eff.to_latex(os.path.join(TABLES_ROOT, "efficiency_medmnist.tex"), float_format="%.3f", index=False)
except Exception:
    pass
df_eff


## Grad-CAM Panels

In [None]:

def overlay(rgb_u8, cam, alpha=0.35):
    import cv2
    H,W = rgb_u8.shape[:2]
    cam_up = cv2.resize(cam, (W,H), interpolation=cv2.INTER_LINEAR)
    heat = cv2.applyColorMap(np.uint8(255*cam_up), cv2.COLORMAP_JET)
    heat = cv2.cvtColor(heat, cv2.COLOR_BGR2RGB)
    out = (1-alpha)*rgb_u8 + alpha*heat
    return np.clip(out,0,255).astype(np.uint8)

# take K examples per class for panel
K = 10
idxs_per_class = {c:[] for c in range(len(classes))}
for i in range(len(testset)):
    _, y = testset[i]
    y = int(np.array(y).squeeze())
    if 0 <= y < len(classes) and len(idxs_per_class[y]) < K:
        idxs_per_class[y].append(i)
    if all(len(v)>=K for v in idxs_per_class.values()): break

import cv2
for tag, m in models_for_eval.items():
    tl = pick_last_layer(m)
    rows=[]
    for c in range(len(classes)):
        tiles=[]
        for i in idxs_per_class[c]:
            x, y = testset[i]
            x = x.unsqueeze(0).to(DEVICE)
            cam = gradcam(m, x, tl, c)
            # de-normalize for viz
            mean = torch.tensor([0.485,0.456,0.406])[:,None,None]
            std  = torch.tensor([0.229,0.224,0.225])[:,None,None]
            viz = (x[0].cpu()*std + mean).clamp(0,1).permute(1,2,0).numpy()
            viz = (viz*255).astype(np.uint8)
            tiles.append(overlay(viz, cam))
        rows.append(np.concatenate(tiles, axis=1))
    panel = np.concatenate(rows, axis=0)
    outp = os.path.join(FIGS_ROOT, "medmnist", f"{tag}_gradcam_panel.png")
    os.makedirs(os.path.dirname(outp), exist_ok=True)
    cv2.imwrite(outp, cv2.cvtColor(panel, cv2.COLOR_RGB2BGR))
outp


## t-SNE: Teacher vs Students

In [None]:

def orthogonal_procrustes(A, B, scale=True):
    A0 = A - A.mean(0, keepdims=True)
    B0 = B - B.mean(0, keepdims=True)
    if scale:
        sA = np.sqrt((A0**2).sum()); sB = np.sqrt((B0**2).sum())
        A0 = A0/(sA+1e-12); B0 = B0/(sB+1e-12)
    U, _, Vt = np.linalg.svd(A0.T @ B0)
    R = U @ Vt
    Ahat = A0 @ R
    if scale: Ahat *= (sB/(sA+1e-12))
    Ahat += B.mean(0, keepdims=True)
    return Ahat

loader_small = DataLoader(testset, batch_size=64, shuffle=False, num_workers=2, pin_memory=True)
T2, y = tsne_features(teacher, loader_small, k=1500)

for st in students:
    # use loaded student instance for efficiency table if available
    m = models_for_eval.get(st, None)
    if m is None:
        # fallback: load best from sweep dirs
        cand = sorted([d for k,d in best_dirs.items() if k[0]==st], reverse=True)
        if not cand: continue
        ck = os.path.join(cand[0], "model_best.pth")
        m = build_arch(st, len(classes)).to(DEVICE); m = _load_state_smart(m, ck).eval()
    S2, _ = tsne_features(m, loader_small, k=1500)
    S2a = orthogonal_procrustes(S2, T2, scale=True)

    fig=plt.figure(figsize=(6,5))
    plt.scatter(T2[:,0], T2[:,1], s=6, alpha=0.35, label="resnet50")
    plt.scatter(S2a[:,0], S2a[:,1], s=6, alpha=0.35, label=st)
    plt.title(f"t-SNE MedMNIST — resnet50 vs {st}")
    plt.legend()
    outp = os.path.join(FIGS_ROOT, "medmnist", f"tsne_resnet50_vs_{st}.png")
    plt.savefig(outp, dpi=200, bbox_inches="tight"); plt.close(fig)
outp


## Aggregate Ablations (α, τ, β)

In [None]:

abl_dir = Path(REPORTS_ROOT)/"ablation_medmnist"
rows=[]
for fp in abl_dir.rglob("metrics.jsonl"):
    with open(fp,"r") as f:
        for line in f:
            if not line.strip(): continue
            rows.append(json.loads(line))
df = pd.DataFrame(rows)
if df.empty:
    print("[warn] No ablation rows found in", abl_dir)
else:
    df_best = df.sort_values("macro_f1").drop_duplicates(subset=["student","alpha","tau","beta"], keep="last")
    out_tab_dir = os.path.join(TABLES_ROOT, "medmnist")
    out_fig_dir = os.path.join(FIGS_ROOT, "medmnist")
    os.makedirs(out_tab_dir, exist_ok=True); os.makedirs(out_fig_dir, exist_ok=True)
    df_best.to_csv(os.path.join(out_tab_dir,"ablation_grid.csv"), index=False)
    try:
        df_best.to_latex(os.path.join(out_tab_dir,"ablation_grid.tex"), float_format="%.3f", index=False)
    except Exception:
        pass
    for stu, g in df_best.groupby("student"):
        g2 = g.sort_values("macro_f1").drop_duplicates(subset=["alpha","tau"], keep="last")
        piv = g2.pivot_table(index="alpha", columns="tau", values="macro_f1", aggfunc="max")
        fig=plt.figure(figsize=(6,5)); plt.imshow(piv.values, aspect="auto")
        plt.title(f"MedMNIST — {stu} (Macro-F1)"); plt.xlabel("tau"); plt.ylabel("alpha")
        plt.xticks(range(len(piv.columns)), [str(c) for c in piv.columns])
        plt.yticks(range(len(piv.index)), [str(c) for c in piv.index])
        plt.colorbar(label="Macro-F1")
        plt.savefig(os.path.join(out_fig_dir, f"ablation_heatmap_alpha_tau_{stu}.png"), dpi=200, bbox_inches="tight"); plt.close(fig)
    print("Ablation tables/figures saved.")
df.head()
