# Initializations and Imports

In [1]:
import matplotlib.pyplot as plt
import matplotlib as mpl
! pip install nibabel
import nibabel as nib
import numpy as np
import os, glob
import psutil
import gc
import pandas as pd
import zipfile
import os, math, random, gc, pickle
import seaborn as sns
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.metrics import (
    f1_score, accuracy_score, classification_report, confusion_matrix
)
from sklearn.utils.class_weight import compute_class_weight
import os, math, random, gc, pickle
import torch, torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
from sklearn.preprocessing import StandardScaler, label_binarize
from sklearn.decomposition import PCA
from sklearn.metrics import (
    f1_score, accuracy_score, classification_report, confusion_matrix,
    roc_curve, auc, precision_recall_curve, average_precision_score,
    brier_score_loss
)
from sklearn.utils.class_weight import compute_class_weight
from sklearn.calibration import calibration_curve
! pip install nilearn
import nilearn
from sklearn.model_selection import StratifiedKFold
from torch.utils.data import DataLoader, TensorDataset
from collections import Counter
from io import BytesIO


mpl.rcParams.update({
    "font.family": "Liberation Sans",   # or "Liberation Sans"
    "pdf.fonttype": 42,             # Embed TrueType fonts
    "ps.fonttype": 42,              # Prevent outline fonts
    "text.usetex": False,
    "font.size": 12,
    "axes.labelsize": 14,
    "axes.titlesize": 14,
    "savefig.dpi": 300
})

Collecting nilearn
  Downloading nilearn-0.12.1-py3-none-any.whl.metadata (9.9 kB)
Downloading nilearn-0.12.1-py3-none-any.whl (12.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.7/12.7 MB[0m [31m103.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: nilearn
Successfully installed nilearn-0.12.1


In [2]:
# changes directory from colab to gdrive
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [3]:
# Choose your base folder (local runtime or Drive)
BASE_DIR = "/content/drive/MyDrive/Capstone-Project/figures"

# Create it if it doesn't exist
os.makedirs(BASE_DIR, exist_ok=True)

print(f"Figures will be saved to: {BASE_DIR}")

Figures will be saved to: /content/drive/MyDrive/Capstone-Project/figures


# Import Tissue Volume CSV

In [4]:
in_roots    = ["/content/drive/MyDrive/Capstone-Project/derivatives/pp_preproc_anat"]
in_patterns = ["**/*_preproc.nii.gz", "**/*_preproc.nii"]

seg_out_dir = "/content/drive/MyDrive/Capstone-Project/derivatives/pp_seg"

sup_dir = "/content/drive/MyDrive/Capstone-Project/derivatives/pp_supervised"

subjects_csv = os.path.join(sup_dir, "subjects_with_string_labels.csv")

# Save volumes CSV
df_tissue=pd.read_csv(subjects_csv)
print("Saved:", subjects_csv)

df = df_tissue

Saved: /content/drive/MyDrive/Capstone-Project/derivatives/pp_supervised/subjects_with_string_labels.csv


# Deep Learning Modeling

## MultiLayer Perceptrons

In [5]:
# ============================================================
#  Publication-ready MLP (PCA → MLP) nested CV with figures & tables
#  - leakage-safe folds and PCA
#  - saves figures as PDF+SVG and tables as PDF+CSV
# ============================================================

# =======================
# Config
# =======================
SEED = 42
K_OUTER = 5
EPOCHS = 200
PATIENCE = 18
BATCH_TRAIN = 64
BATCH_VAL = 128
LR_MAX = 3e-3            # peak LR for OneCycle
CLIP_NORM = 5.0
MIXUP_ALPHA = 0.2        # 0 disables
FOCAL_GAMMA = 1.5
LABEL_SMOOTH = 0.10
SWA_START_EPOCH = 140
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparam grid
PCA_GRID = [32, 64, 128]   # per-fold clamped to rank
H_GRID = [128, 256]
DROP_GRID = [0.2, 0.35]
WD_GRID = [1e-4, 5e-4]

# Output dirs
OUT_DIR = "/content/drive/MyDrive/Capstone-Project/derivatives/pp_supervised_dl"
FIG_DIR = os.path.join(OUT_DIR, "figs")
TAB_DIR = os.path.join(OUT_DIR, "tables")
CKPT_DIR = os.path.join(OUT_DIR, "checkpoints")
for d in (OUT_DIR, FIG_DIR, TAB_DIR, CKPT_DIR):
    os.makedirs(d, exist_ok=True)

# =======================
# Style (journal)
# =======================
def set_pub_style():
    mpl.rcParams.update({
        "figure.dpi": 300, "savefig.dpi": 300,
        "pdf.fonttype": 42, "ps.fonttype": 42,  # editable text
        "font.size": 9,
        "axes.labelsize": 9, "axes.titlesize": 10,
        "xtick.labelsize": 8, "ytick.labelsize": 8, "legend.fontsize": 8,
        "axes.linewidth": 0.8, "xtick.major.width": 0.8, "ytick.major.width": 0.8,
        "lines.linewidth": 1.2,
        "axes.spines.top": False, "axes.spines.right": False,
        "figure.autolayout": True,
    })
    sns.set_palette("colorblind")

set_pub_style()

# =======================
# Table helpers (PDF + CSV)
# =======================
def save_df_as_pdf(df: pd.DataFrame, path_pdf: str, title: str = None,
                   max_w=7.5, row_ht=0.28, col_w=1.6):
    df = df.copy()
    df.columns = [str(c) for c in df.columns]
    # round float cols
    for c in df.columns:
        if pd.api.types.is_float_dtype(df[c]):
            df[c] = df[c].map(lambda z: f"{z:.3f}" if pd.notnull(z) else "")
    n_rows, n_cols = df.shape
    fig_w = min(max_w, max(2.5, col_w * max(1, n_cols)))
    fig_h = max(1.5, 0.6 + row_ht * (n_rows + 1) + (0.4 if title else 0))
    fig, ax = plt.subplots(figsize=(fig_w, fig_h))
    ax.axis("off")
    if title:
        ax.set_title(title, loc="left", fontsize=10, pad=6)
    tbl = ax.table(cellText=df.values, colLabels=df.columns, loc="upper left", cellLoc="left")
    tbl.auto_set_font_size(False)
    tbl.set_fontsize(8)
    # zebra stripes
    for (r, c), cell in tbl.get_celld().items():
        cell.set_linewidth(0.4)
        if r == 0:
            cell.set_facecolor("#f2f2f2"); cell.set_fontsize(8.5)
        else:
            if r % 2 == 0:
                cell.set_facecolor("#fafafa")
    for c in range(n_cols):
        tbl.auto_set_column_width(col=list(range(n_cols)))
    fig.tight_layout()
    fig.savefig(path_pdf)
    plt.close(fig)

def report_to_df(y_true, y_pred, class_names):
    rep = classification_report(y_true, y_pred, labels=class_names, output_dict=True, zero_division=0)
    rows = []
    for lab in class_names:
        m = rep.get(lab, {})
        rows.append({"class": lab,
                     "precision": m.get("precision", np.nan),
                     "recall":    m.get("recall", np.nan),
                     "f1-score":  m.get("f1-score", np.nan),
                     "support":   int(m.get("support", 0))})
    rows.append({"class": "macro avg",    **{k:v for k,v in rep["macro avg"].items() if k!="support"}})
    rows.append({"class": "weighted avg", **{k:v for k,v in rep["weighted avg"].items() if k!="support"}})
    return pd.DataFrame(rows)

# =======================
# Reproducibility
# =======================
def set_seed(s=SEED):
    random.seed(s); np.random.seed(s); torch.manual_seed(s)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(s)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

set_seed()

# =======================
# Data (flexible 3- or 4-class mapping)
# =======================
df = df_tissue  # your dataframe

features = ["GM_frac", "WM_frac", "CSF_frac"]
label_col = "Cluster_Label"

assert set(features).issubset(df.columns), "Some features missing from df"
assert label_col in df.columns, "Label column missing from df"

# Detect which labels exist
uniq_labels = sorted(df[label_col].dropna().unique().tolist())

three_class_order = ["CN", "MCI", "AD"]
four_class_order  = ["CN", "MCI", "MCI/AD", "AD"]

if "MCI/AD" in uniq_labels:
    base_order = four_class_order   # CN, MCI, MCI/AD, AD
else:
    base_order = three_class_order  # CN, MCI, AD

# Keep only those actually present
LABELS = [lab for lab in base_order if lab in uniq_labels]
lab2id = {lab: i for i, lab in enumerate(LABELS)}
N_CLASS = len(LABELS)

print("Detected raw labels:", uniq_labels)
print("Using LABELS order:", LABELS)
print("lab2id mapping:", lab2id)
print("N_CLASS:", N_CLASS)

# Map labels and drop unmapped, just in case
y_all = df[label_col].map(lab2id)
mask_valid = y_all.notna()
if not mask_valid.all():
    print(f"Dropping {(~mask_valid).sum()} rows with unmapped labels.")
df_valid = df.loc[mask_valid].reset_index(drop=True)

X_all = df_valid[features].to_numpy(dtype=np.float32)
y_all = y_all[mask_valid].to_numpy(dtype=np.int64)

print("Final X_all shape:", X_all.shape)
print("Final y_all shape:", y_all.shape)

# =======================
# Helpers
# =======================
def to_torch(x, y=None):
    xt = torch.from_numpy(x.astype(np.float32))
    yt = None if y is None else torch.from_numpy(y.astype(np.int64))
    return xt, yt

def class_weights(y):
    """
    Balanced class weights, robust to folds that are missing some classes.
    Returns a torch tensor of shape [N_CLASS].
    """
    present = np.unique(y)
    cw = compute_class_weight(class_weight="balanced", classes=present, y=y)
    w = np.ones(N_CLASS, dtype=np.float32)
    w[present] = cw.astype(np.float32)
    return torch.tensor(w, dtype=torch.float32, device=DEVICE)

def mixup(x, y, alpha=MIXUP_ALPHA):
    if alpha <= 0:
        return x, y, torch.ones(len(x), device=x.device)
    lam = np.random.beta(alpha, alpha)
    idx = torch.randperm(x.size(0), device=x.device)
    x_mix = lam * x + (1 - lam) * x[idx]
    y2 = y[idx]
    return x_mix, y2, torch.full((x.size(0),), lam, device=x.device)

# =======================
# Model
# =======================
class FocalCrossEntropy(nn.Module):
    def __init__(self, nclass, gamma=FOCAL_GAMMA, label_smooth=LABEL_SMOOTH, weights=None):
        super().__init__()
        self.gamma = gamma
        self.eps = label_smooth
        self.n = nclass
        self.register_buffer("w", weights if weights is not None else torch.ones(nclass))
    def forward(self, logits, y):
        logp = torch.log_softmax(logits, dim=1)
        with torch.no_grad():
            tgt = torch.full_like(logp, self.eps/(self.n-1))
            tgt.scatter_(1, y.unsqueeze(1), 1 - self.eps)
        p = torch.exp(logp)
        focal = (1 - p) ** self.gamma
        loss = -(tgt * focal * logp) * self.w
        return loss.sum(dim=1).mean()

class MLP(nn.Module):
    def __init__(self, d, h=256, drop=0.3, nclass=4):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(d),
            nn.Linear(d, h), nn.GELU(),
            nn.Dropout(drop),
            nn.Linear(h, h//2), nn.GELU(),
            nn.Dropout(drop),
            nn.Linear(h//2, nclass)
        )
    def forward(self, x):
        return self.net(x)

@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    yp, yt = [], []
    for xb, yb in loader:
        xb = xb.to(DEVICE)
        logits = model(xb)
        yp.append(torch.softmax(logits, dim=1).cpu().numpy())
        yt.append(yb.numpy())
    P = np.vstack(yp)
    Y = np.concatenate(yt)
    yhat = P.argmax(1)
    return (accuracy_score(Y, yhat),
            f1_score(Y, yhat, average="macro"),
            Y, yhat, P)

# =======================
# Visual helpers (figures)
# =======================
def save_confusions(y_true, y_pred, labels, prefix, title_suffix=""):
    cm = confusion_matrix(y_true, y_pred, labels=range(len(labels)))
    cm_norm = cm / cm.sum(axis=1, keepdims=True).clip(min=1)

    # Counts
    fig, ax = plt.subplots(figsize=(3.25, 2.8))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
                xticklabels=labels, yticklabels=labels, cbar=False, ax=ax)
    ax.set_xlabel("Predicted"); ax.set_ylabel("True")
    ax.set_title(f"Confusion (counts){title_suffix}")
    fig.tight_layout()
    fig.savefig(os.path.join(FIG_DIR, f"{prefix}_cm_counts.pdf"))
    fig.savefig(os.path.join(FIG_DIR, f"{prefix}_cm_counts.svg"))
    plt.close(fig)

    # Row-normalized
    fig, ax = plt.subplots(figsize=(3.25, 2.8))
    sns.heatmap(cm_norm, annot=True, fmt=".2f", cmap="Blues",
                xticklabels=labels, yticklabels=labels, cbar=False, vmin=0, vmax=1, ax=ax)
    ax.set_xlabel("Predicted"); ax.set_ylabel("True")
    ax.set_title(f"Confusion (row-normalized){title_suffix}")
    fig.tight_layout()
    fig.savefig(os.path.join(FIG_DIR, f"{prefix}_cm_norm.pdf"))
    fig.savefig(os.path.join(FIG_DIR, f"{prefix}_cm_norm.svg"))
    plt.close(fig)

    # Also export as tables
    cm_df = pd.DataFrame(cm, index=labels, columns=labels).reset_index().rename(columns={"index":"True\\Pred"})
    cmn_df = pd.DataFrame(np.round(cm_norm,3), index=labels, columns=labels).reset_index().rename(columns={"index":"True\\Pred"})
    cm_df.to_csv(os.path.join(TAB_DIR, f"{prefix}_cm_counts.csv"), index=False)
    cmn_df.to_csv(os.path.join(TAB_DIR, f"{prefix}_cm_norm.csv"), index=False)
    save_df_as_pdf(cm_df,  os.path.join(TAB_DIR, f"{prefix}_cm_counts.pdf"),     title=f"Confusion (counts){title_suffix}")
    save_df_as_pdf(cmn_df, os.path.join(TAB_DIR, f"{prefix}_cm_norm.pdf"),       title=f"Confusion (row-normalized){title_suffix}")

def save_fold_metric_bars(accs, f1s, title_prefix="Nested CV (MLP)", fname="summary"):
    accs = np.array(accs); f1s = np.array(f1s)
    means = np.array([accs.mean(), f1s.mean()])
    stds  = np.array([accs.std(),  f1s.std()])
    cis   = 1.96 * stds / np.sqrt(len(accs))
    labels = ["Accuracy", "Macro-F1"]

    fig, ax = plt.subplots(figsize=(3.25, 2.5))
    idx = np.arange(len(labels))
    ax.bar(idx, means, yerr=cis, capsize=3, linewidth=0.8, edgecolor="black")
    ax.set_xticks(idx, labels)
    ax.set_ylim(0, 1)
    for i, m in enumerate(means):
        ax.text(i, min(0.98, m+0.02), f"{m:.2f}", ha="center", va="bottom", fontsize=8)
    ax.set_title(title_prefix)
    fig.tight_layout()
    fig.savefig(os.path.join(FIG_DIR, f"{fname}.pdf"))
    fig.savefig(os.path.join(FIG_DIR, f"{fname}.svg"))
    plt.close(fig)

def save_pca_scatter(Xtr_s, ytr, title="Train PCA (first 2 PCs)", fname="pca_scatter"):
    # purely EDA inside each fold train only (no leakage)
    ncomp = min(2, Xtr_s.shape[1])
    if ncomp < 2:
        return
    pca = PCA(n_components=2, random_state=SEED).fit(Xtr_s)
    Z = pca.transform(Xtr_s)
    fig, ax = plt.subplots(figsize=(3.25, 2.8))
    for c, lab in enumerate(LABELS):
        if c not in np.unique(ytr):
            continue
        m = (ytr == c)
        ax.scatter(Z[m,0], Z[m,1], s=12, alpha=0.8, label=lab)
    ax.set_xlabel("PC1"); ax.set_ylabel("PC2")
    ax.set_title(title); ax.legend(frameon=False, ncol=2, fontsize=7)
    fig.tight_layout()
    fig.savefig(os.path.join(FIG_DIR, f"{fname}.pdf"))
    fig.savefig(os.path.join(FIG_DIR, f"{fname}.svg"))
    plt.close(fig)

# =======================
# Train one fold (with PCA)
# =======================
def train_one_fold(Ztr, ytr, Zva, yva, h=256, drop=0.3, wd=1e-4, epochs=EPOCHS):
    Xtr_t, ytr_t = to_torch(Ztr, ytr)
    Xva_t, yva_t = to_torch(Zva, yva)
    tr_loader = DataLoader(TensorDataset(Xtr_t, ytr_t), batch_size=BATCH_TRAIN, shuffle=True,
                           pin_memory=(DEVICE.type=="cuda"))
    va_loader = DataLoader(TensorDataset(Xva_t, yva_t), batch_size=BATCH_VAL, shuffle=False,
                           pin_memory=(DEVICE.type=="cuda"))

    model = MLP(Ztr.shape[1], h=h, drop=drop, nclass=N_CLASS).to(DEVICE)
    weights = class_weights(ytr)
    crit = FocalCrossEntropy(nclass=N_CLASS, gamma=FOCAL_GAMMA,
                             label_smooth=LABEL_SMOOTH, weights=weights)
    opt = torch.optim.AdamW(model.parameters(), lr=LR_MAX, weight_decay=wd)
    steps_per_epoch = max(1, math.ceil(len(tr_loader)))
    sched = torch.optim.lr_scheduler.OneCycleLR(
        opt, max_lr=LR_MAX, epochs=epochs, steps_per_epoch=steps_per_epoch,
        pct_start=0.15, anneal_strategy="cos"
    )
    swa_model = torch.optim.swa_utils.AveragedModel(model)
    swa_sched = torch.optim.swa_utils.SWALR(opt, swa_lr=LR_MAX*0.2)

    best_f1, best_state, stale = -1.0, None, 0
    for ep in range(1, epochs+1):
        model.train()
        for xb, yb in tr_loader:
            xb, yb = xb.to(DEVICE), yb.to(DEVICE)
            xb_m, yb2, lam = mixup(xb, yb, alpha=MIXUP_ALPHA)
            opt.zero_grad(set_to_none=True)
            logits = model(xb_m)
            loss = lam * crit(logits, yb) + (1 - lam) * crit(logits, yb2)
            loss.mean().backward()
            nn.utils.clip_grad_norm_(model.parameters(), CLIP_NORM)
            opt.step()
            sched.step()

        if ep >= SWA_START_EPOCH:
            swa_model.update_parameters(model)
            swa_sched.step()

        acc_tr, f1_tr, _, _, _ = evaluate(model, tr_loader)
        acc_va, f1_va, Yv_tmp, yhat_v_tmp, _ = evaluate(model, va_loader)
        print(f"[ep {ep:03d}] tr acc={acc_tr:.3f} f1={f1_tr:.3f} | va acc={acc_va:.3f} f1={f1_va:.3f}")

        if f1_va > best_f1 + 1e-4:
            best_f1, stale = f1_va, 0
            best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
        else:
            stale += 1
            if stale >= PATIENCE:
                print(f"Early stop at epoch {ep}")
                break

    if EPOCHS >= SWA_START_EPOCH:
        torch.optim.swa_utils.update_bn(tr_loader, swa_model, device=DEVICE)
        acc_va_swa, f1_va_swa, *_ = evaluate(swa_model, va_loader)
        if f1_va_swa >= best_f1:
            best_state = {k: v.detach().cpu().clone() for k, v in swa_model.state_dict().items()}
            best_f1 = f1_va_swa

    model.load_state_dict({k: v.to(DEVICE) for k, v in best_state.items()})
    return model, best_f1

# =======================
# Nested CV (outer) with inner tuning PCA + MLP
# =======================
outer = StratifiedKFold(n_splits=K_OUTER, shuffle=True, random_state=SEED)

fold_reports = []
best_cfgs = []
perfold_reports = []

for fold, (tr_idx, va_idx) in enumerate(outer.split(X_all, y_all), 1):
    Xtr_raw, Xva_raw = X_all[tr_idx], X_all[va_idx]
    ytr, yva = y_all[tr_idx], y_all[va_idx]

    # PCA scatter (train-only)
    scaler_tr = StandardScaler().fit(Xtr_raw)
    Xtr_s_eda = scaler_tr.transform(Xtr_raw)
    save_pca_scatter(Xtr_s_eda, ytr, title=f"Fold {fold} Train PCA (PC1/PC2)",
                     fname=f"fold{fold}_pca_scatter")

    best_fold_f1 = -1.0
    best_cfg = None
    best_artifacts = None

    for pca_dim in PCA_GRID:
        scaler = StandardScaler()
        Xtr_s = scaler.fit_transform(Xtr_raw)
        Xva_s = scaler.transform(Xva_raw)

        ncomp = min(pca_dim, Xtr_s.shape[1], max(1, Xtr_s.shape[0] - 1))
        pca = PCA(n_components=ncomp, random_state=SEED)
        Ztr = pca.fit_transform(Xtr_s)
        Zva = pca.transform(Xva_s)

        for h in H_GRID:
            for drop in DROP_GRID:
                for wd in WD_GRID:
                    print(f"\n[FOLD {fold}] Try PCA={ncomp}, h={h}, drop={drop}, wd={wd}")
                    model, f1_va = train_one_fold(
                        Ztr, ytr, Zva, yva,
                        h=h, drop=drop, wd=wd, epochs=EPOCHS
                    )
                    print(f"[FOLD {fold}] Result macro-F1={f1_va:.3f}")
                    if f1_va > best_fold_f1:
                        best_fold_f1 = f1_va
                        best_cfg = {"pca": ncomp, "h": h, "drop": drop, "wd": wd}
                        best_artifacts = {
                            "scaler": scaler,
                            "pca": pca,
                            "model_state": {k: v.cpu().clone() for k, v in model.state_dict().items()},
                        }

    # Evaluate best config cleanly
    scaler = best_artifacts["scaler"]
    pca = best_artifacts["pca"]
    Ztr_best = pca.transform(scaler.transform(Xtr_raw))
    Zva_best = pca.transform(scaler.transform(Xva_raw))

    model_best = MLP(Ztr_best.shape[1], h=best_cfg["h"],
                     drop=best_cfg["drop"], nclass=N_CLASS).to(DEVICE)
    model_best.load_state_dict({k: v.to(DEVICE) for k, v in best_artifacts["model_state"].items()})

    va_loader = DataLoader(
        TensorDataset(*to_torch(Zva_best, yva)),
        batch_size=BATCH_VAL,
        shuffle=False,
        pin_memory=(DEVICE.type=="cuda"),
    )
    acc, f1m, Yv, yhat, P = evaluate(model_best, va_loader)

    print(f"\n[FOLD {fold}] BEST CFG: {best_cfg} | val acc={acc:.3f} macro-F1={f1m:.3f}")
    print(pd.Series(yva).value_counts(), "train fold label counts:", pd.Series(ytr).value_counts())
    print(classification_report(Yv, yhat, target_names=LABELS, digits=3))

    # Save confusion + report
    save_confusions(Yv, yhat, LABELS, prefix=f"fold{fold}",
                    title_suffix=f" (Fold {fold})")
    y_true_str = [LABELS[i] for i in Yv]
    y_pred_str = [LABELS[i] for i in yhat]
    rep_df = report_to_df(y_true_str, y_pred_str, LABELS)
    rep_df.to_csv(os.path.join(TAB_DIR, f"fold{fold}_report.csv"), index=False)
    save_df_as_pdf(rep_df, os.path.join(TAB_DIR, f"fold{fold}_report.pdf"),
                   title=f"Fold {fold} Classification Report")

    fold_reports.append({"fold": fold, "acc": acc, "f1": f1m})
    best_cfgs.append(best_cfg)
    perfold_reports.append({"fold": fold, "cfg": best_cfg, "acc": acc, "f1": f1m})
    gc.collect()

# =======================
# Summary across folds
# =======================
accs = [r["acc"] for r in fold_reports]
f1s  = [r["f1"] for r in fold_reports]
print("\n==== MLP Nested CV Summary ====")
print(f"Accuracy  mean±sd: {np.mean(accs):.3f} ± {np.std(accs):.3f}")
print(f"Macro-F1 mean±sd: {np.mean(f1s):.3f} ± {np.std(f1s):.3f}")
print("Best configs per fold:", best_cfgs)

save_fold_metric_bars(accs, f1s, title_prefix="Nested CV (MLP)",
                      fname="mlp_nestedcv_summary")

summary_tbl = pd.DataFrame(fold_reports)
summary_tbl.to_csv(os.path.join(TAB_DIR, "nestedcv_fold_metrics.csv"), index=False)
save_df_as_pdf(summary_tbl, os.path.join(TAB_DIR, "nestedcv_fold_metrics.pdf"),
               title="Fold-wise Metrics (Nested CV)")

cfg_tbl = pd.DataFrame([{**{"fold": i+1}, **cfg} for i, cfg in enumerate(best_cfgs)])
cfg_tbl.to_csv(os.path.join(TAB_DIR, "nestedcv_best_configs.csv"), index=False)
save_df_as_pdf(cfg_tbl, os.path.join(TAB_DIR, "nestedcv_best_configs.pdf"),
               title="Best Hyperparameters per Fold")

# =======================
# Final refit on ALL data with most frequent config
# =======================
cfg_counts = Counter([tuple(sorted(c.items())) for c in best_cfgs])
top_cfg_tuple, _ = cfg_counts.most_common(1)[0]
final_cfg = dict(top_cfg_tuple)
print("\n[FINAL] Using config:", final_cfg)

scaler_all = StandardScaler()
X_all_s = scaler_all.fit_transform(X_all)
ncomp_final = min(final_cfg["pca"], X_all_s.shape[1], max(1, X_all_s.shape[0]-1))
pca_all = PCA(n_components=ncomp_final, random_state=SEED)
Z_all = pca_all.fit_transform(X_all_s)

y_all_arr = y_all.copy()
model_final, _ = train_one_fold(
    Z_all, y_all_arr, Z_all, y_all_arr,
    h=final_cfg["h"], drop=final_cfg["drop"], wd=final_cfg["wd"],
    epochs=max(EPOCHS//2, 80)
)

final_state = {k: v.cpu() for k, v in model_final.state_dict().items()}
artifacts = {
    "state_dict": final_state,
    "scaler_mean": scaler_all.mean_.astype(np.float32),
    "scaler_scale": scaler_all.scale_.astype(np.float32),
    "pca_components": pca_all.components_.astype(np.float32),
    "pca_mean": pca_all.mean_.astype(np.float32),
    "pca_ncomp": ncomp_final,
    "labels": LABELS,
    "config": final_cfg,
}
ckpt_path = os.path.join(CKPT_DIR, "mlp_final.pkl")
with open(ckpt_path, "wb") as f:
    pickle.dump(artifacts, f)
print("Saved final MLP to", ckpt_path)

run_sum = pd.DataFrame({
    "metric": ["Accuracy mean", "Accuracy sd", "Macro-F1 mean", "Macro-F1 sd"],
    "value": [np.mean(accs), np.std(accs), np.mean(f1s), np.std(f1s)],
})
run_sum.to_csv(os.path.join(TAB_DIR, "nestedcv_summary.csv"), index=False)
save_df_as_pdf(run_sum, os.path.join(TAB_DIR, "nestedcv_summary.pdf"),
               title="Nested CV Summary (MLP)")

print("\nSaved figures to:", FIG_DIR)
print("Saved tables  to:", TAB_DIR)


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
[ep 003] tr acc=0.818 f1=0.822 | va acc=0.760 f1=0.763
[ep 004] tr acc=0.838 f1=0.841 | va acc=0.760 f1=0.768
[ep 005] tr acc=0.859 f1=0.863 | va acc=0.760 f1=0.768
[ep 006] tr acc=0.838 f1=0.841 | va acc=0.760 f1=0.768
[ep 007] tr acc=0.828 f1=0.828 | va acc=0.720 f1=0.720
[ep 008] tr acc=0.818 f1=0.819 | va acc=0.640 f1=0.633
[ep 009] tr acc=0.828 f1=0.829 | va acc=0.640 f1=0.633
[ep 010] tr acc=0.828 f1=0.829 | va acc=0.640 f1=0.633
[ep 011] tr acc=0.838 f1=0.838 | va acc=0.640 f1=0.633
[ep 012] tr acc=0.828 f1=0.829 | va acc=0.640 f1=0.633
[ep 013] tr acc=0.859 f1=0.857 | va acc=0.720 f1=0.720
[ep 014] tr acc=0.859 f1=0.857 | va acc=0.840 f1=0.849
[ep 015] tr acc=0.919 f1=0.919 | va acc=0.840 f1=0.849
[ep 016] tr acc=0.919 f1=0.916 | va acc=0.840 f1=0.849
[ep 017] tr acc=0.909 f1=0.905 | va acc=0.800 f1=0.806
[ep 018] tr acc=0.909 f1=0.906 | va acc=0.800 f1=0.806
[ep 019] tr acc=0.899 f1=0.895 | va acc=0.760 f1=0.763


# DL Summary pipeline -- All Aggregate measures

In [6]:
# ============================================================
#  Summary-only publication figures (PCA→MLP nested CV)
#  - Collects OUT-OF-FOLD predictions
#  - Saves ONLY aggregate/summary PDFs + tables (no per-fold figs)
#  Output root:
#    /content/drive/MyDrive/Capstone-Project/derivatives/pp_supervised_dl_summary/
# ============================================================

# =======================
# Config
# =======================
SEED = 42
K_OUTER = 5
EPOCHS = 200
PATIENCE = 18
BATCH_TRAIN = 64
BATCH_VAL = 128
LR_MAX = 3e-3
CLIP_NORM = 5.0
MIXUP_ALPHA = 0.2
FOCAL_GAMMA = 1.5
LABEL_SMOOTH = 0.10
SWA_START_EPOCH = 140
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparam grid
PCA_GRID = [32, 64, 128]
H_GRID   = [128, 256]
DROP_GRID= [0.2, 0.35]
WD_GRID  = [1e-4, 5e-4]

# Output dirs (NEW summary folder)
OUT_DIR = "/content/drive/MyDrive/Capstone-Project/derivatives/pp_supervised_dl_summary"
FIG_DIR = os.path.join(OUT_DIR, "figs")
TAB_DIR = os.path.join(OUT_DIR, "tables")
CKPT_DIR= os.path.join(OUT_DIR, "checkpoints")
for d in (OUT_DIR, FIG_DIR, TAB_DIR, CKPT_DIR):
    os.makedirs(d, exist_ok=True)

# =======================
# Style (journal)
# =======================
def set_pub_style():
    mpl.rcParams.update({
        "figure.dpi": 300, "savefig.dpi": 300,
        "pdf.fonttype": 42, "ps.fonttype": 42,
        "font.size": 9,
        "axes.labelsize": 9, "axes.titlesize": 10,
        "xtick.labelsize": 8, "ytick.labelsize": 8, "legend.fontsize": 8,
        "axes.linewidth": 0.8, "xtick.major.width": 0.8, "ytick.major.width": 0.8,
        "lines.linewidth": 1.2,
        "axes.spines.top": False, "axes.spines.right": False,
        "figure.autolayout": True,
    })
    sns.set_palette("colorblind")

set_pub_style()

# =======================
# Table helpers (PDF + CSV)
# =======================
def save_df_as_pdf(df: pd.DataFrame, path_pdf: str, title: str = None,
                   max_w=7.5, row_ht=0.28, col_w=1.6):
    df = df.copy()
    df.columns = [str(c) for c in df.columns]
    for c in df.columns:
        if pd.api.types.is_float_dtype(df[c]):
            df[c] = df[c].map(lambda z: f"{z:.3f}" if pd.notnull(z) else "")
    n_rows, n_cols = df.shape
    fig_w = min(max_w, max(2.5, col_w * max(1, n_cols)))
    fig_h = max(1.5, 0.6 + row_ht * (n_rows + 1) + (0.4 if title else 0))
    fig, ax = plt.subplots(figsize=(fig_w, fig_h))
    ax.axis("off")
    if title:
        ax.set_title(title, loc="left", fontsize=10, pad=6)
    tbl = ax.table(cellText=df.values, colLabels=df.columns, loc="upper left", cellLoc="left")
    tbl.auto_set_font_size(False)
    tbl.set_fontsize(8)
    for (r, c), cell in tbl.get_celld().items():
        cell.set_linewidth(0.4)
        if r == 0:
            cell.set_facecolor("#f2f2f2"); cell.set_fontsize(8.5)
        else:
            if r % 2 == 0:
                cell.set_facecolor("#fafafa")
    for c in range(n_cols):
        tbl.auto_set_column_width(col=list(range(n_cols)))
    fig.tight_layout()
    fig.savefig(path_pdf)
    plt.close(fig)

def report_to_df(y_true, y_pred, class_names):
    rep = classification_report(y_true, y_pred, labels=class_names, output_dict=True, zero_division=0)
    rows = []
    for lab in class_names:
        m = rep.get(lab, {})
        rows.append({"class": lab,
                     "precision": m.get("precision", np.nan),
                     "recall":    m.get("recall", np.nan),
                     "f1-score":  m.get("f1-score", np.nan),
                     "support":   int(m.get("support", 0))})
    rows.append({"class": "macro avg",    **{k:v for k,v in rep["macro avg"].items() if k!="support"}})
    rows.append({"class": "weighted avg", **{k:v for k,v in rep["weighted avg"].items() if k!="support"}})
    return pd.DataFrame(rows)

# =======================
# Repro / utils
# =======================
def set_seed(s=SEED):
    random.seed(s); np.random.seed(s); torch.manual_seed(s)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(s)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

set_seed()

def to_torch(x, y=None):
    xt = torch.from_numpy(x.astype(np.float32))
    yt = None if y is None else torch.from_numpy(y.astype(np.int64))
    return xt, yt

def class_weights(y, n_class):
    """
    Balanced class weights, robust to folds that are missing some classes.
    Returns a torch tensor of shape [n_class].
    """
    present = np.unique(y)
    cw = compute_class_weight(class_weight="balanced", classes=present, y=y)
    w = np.ones(n_class, dtype=np.float32)
    w[present] = cw.astype(np.float32)
    return torch.tensor(w, dtype=torch.float32, device=DEVICE)

def mixup(x, y, alpha=MIXUP_ALPHA):
    if alpha <= 0:
        return x, y, torch.ones(len(x), device=x.device)
    lam = np.random.beta(alpha, alpha)
    idx = torch.randperm(x.size(0), device=x.device)
    x_mix = lam * x + (1 - lam) * x[idx]
    y2 = y[idx]
    return x_mix, y2, torch.full((x.size(0),), lam, device=x.device)

# =======================
# Model
# =======================
class FocalCrossEntropy(nn.Module):
    def __init__(self, nclass, gamma=FOCAL_GAMMA, label_smooth=LABEL_SMOOTH, weights=None):
        super().__init__()
        self.gamma = gamma
        self.eps = label_smooth
        self.n = nclass
        self.register_buffer("w", weights if weights is not None else torch.ones(nclass))
    def forward(self, logits, y):
        logp = torch.log_softmax(logits, dim=1)
        with torch.no_grad():
            tgt = torch.full_like(logp, self.eps/(self.n-1))
            tgt.scatter_(1, y.unsqueeze(1), 1 - self.eps)
        p = torch.exp(logp)
        focal = (1 - p) ** self.gamma
        loss = -(tgt * focal * logp) * self.w
        return loss.sum(dim=1).mean()

class MLP(nn.Module):
    def __init__(self, d, h=256, drop=0.3, nclass=4):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(d),
            nn.Linear(d, h), nn.GELU(),
            nn.Dropout(drop),
            nn.Linear(h, h//2), nn.GELU(),
            nn.Dropout(drop),
            nn.Linear(h//2, nclass)
        )
    def forward(self, x):
        return self.net(x)

@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    yp, yt = [], []
    for xb, yb in loader:
        xb = xb.to(DEVICE)
        logits = model(xb)
        yp.append(torch.softmax(logits, dim=1).cpu().numpy())
        yt.append(yb.numpy())
    P = np.vstack(yp)
    Y = np.concatenate(yt)
    yhat = P.argmax(1)
    return (accuracy_score(Y, yhat),
            f1_score(Y, yhat, average="macro"),
            Y, yhat, P)

def train_one_fold(Ztr, ytr, Zva, yva, n_class, h=256, drop=0.3, wd=1e-4, epochs=EPOCHS):
    Xtr_t, ytr_t = to_torch(Ztr, ytr)
    Xva_t, yva_t = to_torch(Zva, yva)
    tr_loader = DataLoader(TensorDataset(Xtr_t, ytr_t), batch_size=BATCH_TRAIN, shuffle=True,
                           pin_memory=(DEVICE.type=="cuda"))
    va_loader = DataLoader(TensorDataset(Xva_t, yva_t), batch_size=BATCH_VAL, shuffle=False,
                           pin_memory=(DEVICE.type=="cuda"))

    model = MLP(Ztr.shape[1], h=h, drop=drop, nclass=n_class).to(DEVICE)
    weights = class_weights(ytr, n_class=n_class)
    crit = FocalCrossEntropy(nclass=n_class, gamma=FOCAL_GAMMA,
                             label_smooth=LABEL_SMOOTH, weights=weights)
    opt = torch.optim.AdamW(model.parameters(), lr=LR_MAX, weight_decay=wd)
    steps_per_epoch = max(1, math.ceil(len(tr_loader)))
    sched = torch.optim.lr_scheduler.OneCycleLR(
        opt, max_lr=LR_MAX, epochs=epochs, steps_per_epoch=steps_per_epoch,
        pct_start=0.15, anneal_strategy="cos"
    )
    swa_model = torch.optim.swa_utils.AveragedModel(model)
    swa_sched = torch.optim.swa_utils.SWALR(opt, swa_lr=LR_MAX*0.2)

    best_f1, best_state, stale = -1.0, None, 0
    for ep in range(1, epochs+1):
        model.train()
        for xb, yb in tr_loader:
            xb, yb = xb.to(DEVICE), yb.to(DEVICE)
            xb_m, yb2, lam = mixup(xb, yb, alpha=MIXUP_ALPHA)
            opt.zero_grad(set_to_none=True)
            logits = model(xb_m)
            loss = lam * crit(logits, yb) + (1 - lam) * crit(logits, yb2)
            loss.mean().backward()
            nn.utils.clip_grad_norm_(model.parameters(), CLIP_NORM)
            opt.step()
            sched.step()

        if ep >= SWA_START_EPOCH:
            swa_model.update_parameters(model)
            swa_sched.step()

        acc_va, f1_va, *_ = evaluate(model, va_loader)
        if f1_va > best_f1 + 1e-4:
            best_f1, stale = f1_va, 0
            best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
        else:
            stale += 1
            if stale >= PATIENCE:
                break

    if EPOCHS >= SWA_START_EPOCH:
        torch.optim.swa_utils.update_bn(tr_loader, swa_model, device=DEVICE)
        acc_va_swa, f1_va_swa, *_ = evaluate(swa_model, va_loader)
        if f1_va_swa >= best_f1:
            best_state = {k: v.detach().cpu().clone() for k, v in swa_model.state_dict().items()}
            best_f1 = f1_va_swa

    model.load_state_dict({k: v.to(DEVICE) for k, v in best_state.items()})
    return model, best_f1

# =======================
# Data (flexible 3- or 4-class mapping)
# =======================
df = df_tissue  # your dataframe

features = ["GM_frac", "WM_frac", "CSF_frac"]
label_col = "Cluster_Label"

assert set(features).issubset(df.columns)
assert label_col in df.columns

uniq_labels = sorted(df[label_col].dropna().unique().tolist())
three_class_order = ["CN", "MCI", "AD"]
four_class_order  = ["CN", "MCI", "MCI/AD", "AD"]

if "MCI/AD" in uniq_labels:
    base_order = four_class_order
else:
    base_order = three_class_order

LABELS = [lab for lab in base_order if lab in uniq_labels]
lab2id = {lab: i for i, lab in enumerate(LABELS)}
N_CLASS = len(LABELS)

print("Detected raw labels:", uniq_labels)
print("Using LABELS order:", LABELS)
print("lab2id mapping:", lab2id)
print("N_CLASS:", N_CLASS)

y_all = df[label_col].map(lab2id)
mask_valid = y_all.notna()
if not mask_valid.all():
    print(f"Dropping {(~mask_valid).sum()} rows with unmapped labels.")
df_valid = df.loc[mask_valid].reset_index(drop=True)

X_all = df_valid[features].to_numpy(dtype=np.float32)
y_all = y_all[mask_valid].to_numpy(dtype=np.int64)

N, D = X_all.shape
C = N_CLASS

print("Final X_all shape:", X_all.shape)
print("Final y_all shape:", y_all.shape)

# =======================
# Nested CV with OUT-OF-FOLD predictions (summary only)
# =======================
set_seed()
outer = StratifiedKFold(n_splits=K_OUTER, shuffle=True, random_state=SEED)

oof_pred = np.zeros((N, C), dtype=np.float32)
oof_y    = np.zeros(N, dtype=np.int64)

fold_reports = []
best_cfgs = []

for fold, (tr_idx, va_idx) in enumerate(outer.split(X_all, y_all), 1):
    Xtr_raw, Xva_raw = X_all[tr_idx], X_all[va_idx]
    ytr, yva = y_all[tr_idx], y_all[va_idx]

    best_fold_f1 = -1.0
    best_cfg = None
    best_artifacts = None

    for pca_dim in PCA_GRID:
        scaler = StandardScaler()
        Xtr_s = scaler.fit_transform(Xtr_raw)
        Xva_s = scaler.transform(Xva_raw)

        ncomp = min(pca_dim, Xtr_s.shape[1], max(1, Xtr_s.shape[0] - 1))
        pca = PCA(n_components=ncomp, random_state=SEED)
        Ztr = pca.fit_transform(Xtr_s)
        Zva = pca.transform(Xva_s)

        for h in H_GRID:
            for drop in DROP_GRID:
                for wd in WD_GRID:
                    model, f1_va = train_one_fold(
                        Ztr, ytr, Zva, yva,
                        n_class=C, h=h, drop=drop, wd=wd, epochs=EPOCHS
                    )
                    if f1_va > best_fold_f1:
                        best_fold_f1 = f1_va
                        best_cfg = {"pca": ncomp, "h": h, "drop": drop, "wd": wd}
                        best_artifacts = {
                            "scaler": scaler,
                            "pca": pca,
                            "model_state": {k:v.cpu().clone() for k,v in model.state_dict().items()}
                        }

    # Finalize best on this fold and write OOF predictions
    scaler = best_artifacts["scaler"]; pca = best_artifacts["pca"]
    Ztr = pca.transform(scaler.transform(Xtr_raw))
    Zva = pca.transform(scaler.transform(Xva_raw))

    model = MLP(Ztr.shape[1], h=best_cfg["h"], drop=best_cfg["drop"], nclass=C).to(DEVICE)
    model.load_state_dict({k: v.to(DEVICE) for k, v in best_artifacts["model_state"].items()})

    va_loader = DataLoader(
        TensorDataset(*to_torch(Zva, yva)),
        batch_size=BATCH_VAL,
        shuffle=False,
        pin_memory=(DEVICE.type=="cuda"),
    )
    acc, f1m, Yv, yhat, P = evaluate(model, va_loader)

    oof_pred[va_idx] = P
    oof_y[va_idx]    = yva

    fold_reports.append({"fold": fold, "acc": acc, "f1": f1m})
    best_cfgs.append(best_cfg)
    gc.collect()

# =======================
# SUMMARY-ONLY VISUALS (using OOF predictions)
# =======================
# 1) Confusion matrices (counts + row-normalized)
cm = confusion_matrix(oof_y, oof_pred.argmax(1), labels=range(C))
cm_norm = cm / cm.sum(axis=1, keepdims=True).clip(min=1)

fig, ax = plt.subplots(figsize=(3.4, 3.0))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
            xticklabels=LABELS, yticklabels=LABELS, cbar=False, ax=ax)
ax.set_xlabel("Predicted"); ax.set_ylabel("True")
ax.set_title("Confusion (counts)")
fig.tight_layout()
fig.savefig(os.path.join(FIG_DIR, "summary_cm_counts.pdf"))
plt.close(fig)

fig, ax = plt.subplots(figsize=(3.4, 3.0))
sns.heatmap(cm_norm, annot=True, fmt=".2f", cmap="Blues",
            xticklabels=LABELS, yticklabels=LABELS, cbar=False, vmin=0, vmax=1, ax=ax)
ax.set_xlabel("Predicted"); ax.set_ylabel("True")
ax.set_title("Confusion (row-normalized)")
fig.tight_layout()
fig.savefig(os.path.join(FIG_DIR, "summary_cm_norm.pdf"))
plt.close(fig)

# 2) Per-class bars from report
rep_df = report_to_df(
    [LABELS[i] for i in oof_y],
    [LABELS[i] for i in oof_pred.argmax(1)],
    LABELS
)
cls_only = rep_df[rep_df["class"].isin(LABELS)].copy()
melted = cls_only.melt(id_vars="class", value_vars=["precision","recall","f1-score"],
                       var_name="metric", value_name="value")

fig, ax = plt.subplots(figsize=(5.2, 3.0))
sns.barplot(data=melted, x="class", y="value", hue="metric",
            ax=ax, edgecolor="black", linewidth=0.5)
ax.set_ylim(0, 1); ax.set_ylabel("Score"); ax.set_xlabel("")
ax.set_title("Per-class metrics (OOF)")
ax.legend(frameon=False, ncol=3, loc="upper center", bbox_to_anchor=(0.5, 1.25))
fig.tight_layout()
fig.savefig(os.path.join(FIG_DIR, "summary_perclass_bars.pdf"))
plt.close(fig)

# 3) ROC & PR (OvR)
Y_bin = label_binarize(oof_y, classes=list(range(C)))

fig, ax = plt.subplots(figsize=(4.0, 3.2))
auc_rows = []
for c in range(C):
    fpr, tpr, _ = roc_curve(Y_bin[:, c], oof_pred[:, c])
    AUC = auc(fpr, tpr); auc_rows.append(AUC)
    ax.plot(fpr, tpr, label=f"{LABELS[c]} (AUC={AUC:.2f})")
ax.plot([0,1],[0,1], linestyle="--", color="gray", linewidth=0.8)
ax.set_xlabel("FPR"); ax.set_ylabel("TPR")
ax.set_title("ROC (OOF)")
ax.legend(frameon=False, fontsize=7, ncol=2)
fig.tight_layout()
fig.savefig(os.path.join(FIG_DIR, "summary_roc.pdf"))
plt.close(fig)

fig, ax = plt.subplots(figsize=(4.0, 3.2))
ap_rows = []
for c in range(C):
    prec, rec, _ = precision_recall_curve(Y_bin[:, c], oof_pred[:, c])
    AP = average_precision_score(Y_bin[:, c], oof_pred[:, c]); ap_rows.append(AP)
    ax.plot(rec, prec, label=f"{LABELS[c]} (AP={AP:.2f})")
ax.set_xlabel("Recall"); ax.set_ylabel("Precision")
ax.set_title("Precision–Recall (OOF)")
ax.legend(frameon=False, fontsize=7, ncol=2)
fig.tight_layout()
fig.savefig(os.path.join(FIG_DIR, "summary_pr.pdf"))
plt.close(fig)

# 4) Calibration (Brier/ECE 10 bins)
fig, ax = plt.subplots(figsize=(4.0, 3.2))
cal_rows = []
for c in range(C):
    y_bin = (oof_y == c).astype(int)
    prob_c = oof_pred[:, c]
    frac_pos, mean_pred = calibration_curve(y_bin, prob_c, n_bins=10, strategy="uniform")
    ax.plot(mean_pred, frac_pos, marker="o", label=f"{LABELS[c]}")
    brier = brier_score_loss(y_bin, prob_c)
    # ECE
    bins = np.linspace(0,1,11)
    idx = np.digitize(prob_c, bins) - 1
    ece = 0.0; Ntot = len(prob_c)
    for b in range(10):
        m = (idx == b)
        if not m.any(): continue
        conf = prob_c[m].mean()
        acc  = y_bin[m].mean()
        ece += (m.sum()/Ntot) * abs(acc - conf)
    cal_rows.append({"class": LABELS[c], "Brier": brier, "ECE_10bin": ece})

ax.plot([0,1],[0,1], linestyle="--", color="gray", linewidth=0.8)
ax.set_xlabel("Predicted probability"); ax.set_ylabel("Empirical accuracy")
ax.set_title("Reliability (OOF)")
ax.legend(frameon=False, fontsize=7, ncol=2)
fig.tight_layout()
fig.savefig(os.path.join(FIG_DIR, "summary_calibration.pdf"))
plt.close(fig)

cal_df = pd.DataFrame(cal_rows)
cal_df.loc[len(cal_df)] = {
    "class":"macro",
    "Brier": cal_df["Brier"].mean(),
    "ECE_10bin": cal_df["ECE_10bin"].mean()
}
cal_df.to_csv(os.path.join(TAB_DIR, "summary_calibration.csv"), index=False)
save_df_as_pdf(cal_df, os.path.join(TAB_DIR, "summary_calibration.pdf"),
               title="Calibration (Brier/ECE) — OOF")

# 5) Expected probability heatmap
M = np.zeros((C, C), float)
for i in range(C):
    idx = (oof_y == i)
    if idx.any():
        M[i] = oof_pred[idx].mean(axis=0)
fig, ax = plt.subplots(figsize=(3.4, 3.0))
sns.heatmap(M, annot=True, fmt=".2f", cmap="Purples",
            xticklabels=LABELS, yticklabels=LABELS, vmin=0, vmax=1, cbar=False, ax=ax)
ax.set_xlabel("Predicted class prob"); ax.set_ylabel("True class")
ax.set_title("Expected P(class | true) (OOF)")
fig.tight_layout()
fig.savefig(os.path.join(FIG_DIR, "summary_exp_prob_heatmap.pdf"))
plt.close(fig)

# 6) Nested CV summary bars (Accuracy & Macro-F1 mean±95% CI)
accs = [fr["acc"] for fr in fold_reports]
f1s  = [fr["f1"]  for fr in fold_reports]
means = np.array([np.mean(accs), np.mean(f1s)])
stds  = np.array([np.std(accs),  np.std(f1s)])
cis   = 1.96 * stds / np.sqrt(len(accs))
labels2 = ["Accuracy", "Macro-F1"]

fig, ax = plt.subplots(figsize=(3.6, 2.8))
idx = np.arange(len(labels2))
ax.bar(idx, means, yerr=cis, capsize=3, linewidth=0.8, edgecolor="black")
ax.set_xticks(idx, labels2); ax.set_ylim(0, 1)
for i, m in enumerate(means):
    ax.text(i, min(0.98, m+0.02), f"{m:.2f}", ha="center", va="bottom", fontsize=8)
ax.set_title("Nested CV (OOF) — Summary")
fig.tight_layout()
fig.savefig(os.path.join(FIG_DIR, "nestedcv_summary_bars.pdf"))
plt.close(fig)

# =======================
# Summary tables (CSV + PDF)
# =======================
summary_tbl = pd.DataFrame(fold_reports)
summary_tbl.to_csv(os.path.join(TAB_DIR, "nestedcv_fold_metrics.csv"), index=False)
save_df_as_pdf(summary_tbl, os.path.join(TAB_DIR, "nestedcv_fold_metrics.pdf"),
               title="Fold-wise Metrics (Nested CV)")

cfg_tbl = pd.DataFrame([{**{"fold": i+1}, **cfg} for i, cfg in enumerate(best_cfgs)])
cfg_tbl.to_csv(os.path.join(TAB_DIR, "nestedcv_best_configs.csv"), index=False)
save_df_as_pdf(cfg_tbl, os.path.join(TAB_DIR, "nestedcv_best_configs.pdf"),
               title="Best Hyperparameters per Fold")

# AUC/AP table from OOF
auc_df = pd.DataFrame({"class": LABELS, "AUC": auc_rows})
ap_df  = pd.DataFrame({"class": LABELS, "AP":  ap_rows})
auc_df.loc[len(auc_df)] = {"class":"macro","AUC":auc_df["AUC"].mean()}
ap_df.loc [len(ap_df )] = {"class":"macro","AP": ap_df["AP"].mean()}
aucap_sum = (auc_df.merge(ap_df, on="class"))
aucap_sum.to_csv(os.path.join(TAB_DIR, "nestedcv_auc_ap_summary.csv"), index=False)
save_df_as_pdf(aucap_sum, os.path.join(TAB_DIR, "nestedcv_auc_ap_summary.pdf"),
               title="ROC AUC / Average Precision (OOF)")

# Classification report table (OOF)
rep_df.to_csv(os.path.join(TAB_DIR, "summary_classification_report.csv"), index=False)
save_df_as_pdf(rep_df, os.path.join(TAB_DIR, "summary_classification_report.pdf"),
               title="Classification Report (OOF)")

# High-level run summary
run_sum = pd.DataFrame({
    "metric": ["Accuracy mean", "Accuracy sd", "Macro-F1 mean", "Macro-F1 sd"],
    "value": [np.mean(accs), np.std(accs), np.mean(f1s), np.std(f1s)]
})
run_sum.to_csv(os.path.join(TAB_DIR, "nestedcv_summary.csv"), index=False)
save_df_as_pdf(run_sum, os.path.join(TAB_DIR, "nestedcv_summary.pdf"),
               title="Nested CV Summary (OOF)")

print("\nSaved SUMMARY-ONLY figures to:", FIG_DIR)
print("Saved SUMMARY-ONLY tables  to:", TAB_DIR)


Detected raw labels: ['AD', 'CN', 'MCI', 'MCI/AD']
Using LABELS order: ['CN', 'MCI', 'MCI/AD', 'AD']
lab2id mapping: {'CN': 0, 'MCI': 1, 'MCI/AD': 2, 'AD': 3}
N_CLASS: 4
Final X_all shape: (124, 3)
Final y_all shape: (124,)

Saved SUMMARY-ONLY figures to: /content/drive/MyDrive/Capstone-Project/derivatives/pp_supervised_dl_summary/figs
Saved SUMMARY-ONLY tables  to: /content/drive/MyDrive/Capstone-Project/derivatives/pp_supervised_dl_summary/tables


# DL Summary Pipeline
# MultiLayer Perceptrons  Lean → Deep → Hybrid

In [9]:
# ============================================================
#  DL Lean / Deep / Hybrid MLPs on parcellation features
#  - Assumes upstream: X_full (N,D) and y (N,) of string labels
# ============================================================
# ------------------------------------------------------------
# Sanity: need X_full and y from the parcellation pipeline
# ------------------------------------------------------------
assert "X_full" in globals(), "X_full not found – run the parcellation/feature pipeline first."
assert "y" in globals(), "y (string labels) not found – run the parcellation/feature pipeline first."

X_all = X_full.astype(np.float32)
y_str_all = np.asarray(y)

# ============================================================
#  Config
# ============================================================
SEED = 42
K_OUTER = 5          # outer CV folds
EPOCHS = 200
PATIENCE = 18
BATCH_TRAIN = 64
BATCH_VAL = 128
LR_MAX = 3e-3
CLIP_NORM = 5.0
MIXUP_ALPHA = 0.2
FOCAL_GAMMA = 1.5
LABEL_SMOOTH = 0.10
SWA_START_EPOCH = 140

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

ROOT_OUT = "/content/drive/MyDrive/Capstone-Project/derivatives/pp_supervised_dl_summary"
os.makedirs(ROOT_OUT, exist_ok=True)

# ============================================================
#  Style
# ============================================================
def set_pub_style():
    mpl.rcParams.update({
        "figure.dpi": 300, "savefig.dpi": 300,
        "pdf.fonttype": 42, "ps.fonttype": 42,
        "font.size": 9,
        "axes.labelsize": 9, "axes.titlesize": 10,
        "xtick.labelsize": 8, "ytick.labelsize": 8, "legend.fontsize": 8,
        "axes.linewidth": 0.8, "xtick.major.width": 0.8, "ytick.major.width": 0.8,
        "lines.linewidth": 1.2,
        "axes.spines.top": False, "axes.spines.right": False,
        "figure.autolayout": True,
    })
    sns.set_palette("colorblind")

set_pub_style()

# ============================================================
#  Helpers: tables, reproducibility
# ============================================================
def save_df_as_pdf(df: pd.DataFrame, path_pdf: str, title: str = None,
                   max_w=7.5, row_ht=0.28, col_w=1.6):
    df = df.copy()
    df.columns = [str(c) for c in df.columns]
    for c in df.columns:
        if pd.api.types.is_float_dtype(df[c]):
            df[c] = df[c].map(lambda z: f"{z:.3f}" if pd.notnull(z) else "")
    n_rows, n_cols = df.shape
    fig_w = min(max_w, max(2.5, col_w * max(1, n_cols)))
    fig_h = max(1.5, 0.6 + row_ht * (n_rows + 1) + (0.4 if title else 0))
    fig, ax = plt.subplots(figsize=(fig_w, fig_h))
    ax.axis("off")
    if title:
        ax.set_title(title, loc="left", fontsize=10, pad=6)
    tbl = ax.table(cellText=df.values, colLabels=df.columns,
                   loc="upper left", cellLoc="left")
    tbl.auto_set_font_size(False)
    tbl.set_fontsize(8)
    for (r, c), cell in tbl.get_celld().items():
        cell.set_linewidth(0.4)
        if r == 0:
            cell.set_facecolor("#f2f2f2"); cell.set_fontsize(8.5)
        else:
            if r % 2 == 0:
                cell.set_facecolor("#fafafa")
    tbl.auto_set_column_width(col=list(range(n_cols)))
    fig.tight_layout()
    fig.savefig(path_pdf)
    plt.close(fig)

def report_to_df(y_true_str, y_pred_str, class_names):
    rep = classification_report(y_true_str, y_pred_str, labels=class_names,
                                output_dict=True, zero_division=0)
    rows = []
    for lab in class_names:
        m = rep.get(lab, {})
        rows.append({"class": lab,
                     "precision": m.get("precision", np.nan),
                     "recall":    m.get("recall", np.nan),
                     "f1-score":  m.get("f1-score", np.nan),
                     "support":   int(m.get("support", 0))})
    rows.append({"class": "macro avg",
                 **{k: v for k, v in rep["macro avg"].items() if k != "support"}})
    rows.append({"class": "weighted avg",
                 **{k: v for k, v in rep["weighted avg"].items() if k != "support"}})
    return pd.DataFrame(rows)

def set_seed(s=SEED):
    random.seed(s)
    np.random.seed(s)
    torch.manual_seed(s)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(s)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

set_seed()

# ============================================================
#  Label encoding: flexible 3/4-class (and subsets)
# ============================================================
uniq_labels = sorted(pd.unique(y_str_all).tolist())
canonical_order = ["CN", "MCI", "MCI/AD", "AD"]

LABELS = [lab for lab in canonical_order if lab in uniq_labels]
# if you ever have extra labels, append them at the end
for lab in uniq_labels:
    if lab not in LABELS:
        LABELS.append(lab)

LAB2ID = {lab: i for i, lab in enumerate(LABELS)}
C = len(LABELS)

print("Detected raw labels:", uniq_labels)
print("Using LABELS order:", LABELS)
print("lab2id mapping:", LAB2ID)
print("N_CLASS:", C)

y_all = pd.Series(y_str_all).map(LAB2ID).to_numpy(dtype=np.int64)

N, D = X_all.shape
print(f"[data] N,D = ({N}, {D})")

# ============================================================
#  Utilities
# ============================================================
def to_torch(x, y=None):
    xt = torch.from_numpy(x.astype(np.float32))
    yt = None if y is None else torch.from_numpy(y.astype(np.int64))
    return xt, yt

def class_weights(y):
    """
    Balanced class weights, robust to folds that miss some classes.
    """
    present = np.unique(y)
    cw = compute_class_weight(class_weight="balanced", classes=present, y=y)
    w = np.ones(C, dtype=np.float32)
    w[present] = cw.astype(np.float32)
    return torch.tensor(w, dtype=torch.float32, device=DEVICE)

def mixup(x, y, alpha=MIXUP_ALPHA):
    if alpha <= 0:
        return x, y, torch.ones(len(x), device=x.device)
    lam = np.random.beta(alpha, alpha)
    idx = torch.randperm(x.size(0), device=x.device)
    x_mix = lam * x + (1 - lam) * x[idx]
    y2 = y[idx]
    return x_mix, y2, torch.full((x.size(0),), lam, device=x.device)

def _sanitize_state_dict(sd: dict) -> dict:
    clean = {}
    for k, v in sd.items():
        if k == "n_averaged":
            continue
        if k.startswith("module."):
            k = k[len("module."):]
        clean[k] = v
    return clean

# ============================================================
#  Loss & Models (Lean / Deep / Hybrid)
# ============================================================
class FocalCrossEntropy(nn.Module):
    def __init__(self, nclass=C, gamma=FOCAL_GAMMA,
                 label_smooth=LABEL_SMOOTH, weights=None):
        super().__init__()
        self.gamma = gamma
        self.eps = label_smooth
        self.n = nclass
        self.register_buffer("w", weights if weights is not None else torch.ones(nclass))
    def forward(self, logits, y):
        logp = torch.log_softmax(logits, dim=1)
        with torch.no_grad():
            tgt = torch.full_like(logp, self.eps/(self.n-1))
            tgt.scatter_(1, y.unsqueeze(1), 1 - self.eps)
        p = torch.exp(logp)
        focal = (1 - p) ** self.gamma
        loss = -(tgt * focal * logp) * self.w
        return loss.sum(dim=1).mean()

class MLPLean(nn.Module):
    """1 hidden layer, smaller capacity."""
    def __init__(self, d, h=128, drop=0.2, nclass=C):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(d),
            nn.Linear(d, h), nn.GELU(),
            nn.Dropout(drop),
            nn.Linear(h, nclass)
        )
    def forward(self, x): return self.net(x)

class MLPDeep(nn.Module):
    """3 hidden layers, larger capacity."""
    def __init__(self, d, h=256, drop=0.3, nclass=C):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(d),
            nn.Linear(d, h), nn.GELU(), nn.Dropout(drop),
            nn.Linear(h, h), nn.GELU(), nn.Dropout(drop),
            nn.Linear(h, h//2), nn.GELU(), nn.Dropout(drop),
            nn.Linear(h//2, nclass)
        )
    def forward(self, x): return self.net(x)

class MLPHybrid(nn.Module):
    """2 hidden layers, intermediate capacity."""
    def __init__(self, d, h=256, drop=0.3, nclass=C):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(d),
            nn.Linear(d, h), nn.GELU(),
            nn.Dropout(drop),
            nn.Linear(h, h//2), nn.GELU(),
            nn.Dropout(drop),
            nn.Linear(h//2, nclass)
        )
    def forward(self, x): return self.net(x)

@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    yp, yt = [], []
    for xb, yb in loader:
        xb = xb.to(DEVICE)
        logits = model(xb)
        yp.append(torch.softmax(logits, dim=1).cpu().numpy())
        yt.append(yb.numpy())
    P = np.vstack(yp)
    Y = np.concatenate(yt)
    yhat = P.argmax(1)
    return (accuracy_score(Y, yhat),
            f1_score(Y, yhat, average="macro"),
            Y, yhat, P)

def train_one_fold(Ztr, ytr, Zva, yva, model_ctor,
                   h=256, drop=0.3, wd=1e-4, epochs=EPOCHS):
    Xtr_t, ytr_t = to_torch(Ztr, ytr)
    Xva_t, yva_t = to_torch(Zva, yva)
    tr_loader = DataLoader(TensorDataset(Xtr_t, ytr_t),
                           batch_size=BATCH_TRAIN, shuffle=True,
                           pin_memory=(DEVICE.type=="cuda"))
    va_loader = DataLoader(TensorDataset(Xva_t, yva_t),
                           batch_size=BATCH_VAL, shuffle=False,
                           pin_memory=(DEVICE.type=="cuda"))

    model = model_ctor(Ztr.shape[1], h=h, drop=drop, nclass=C).to(DEVICE)
    weights = class_weights(ytr)
    crit = FocalCrossEntropy(nclass=C, gamma=FOCAL_GAMMA,
                             label_smooth=LABEL_SMOOTH, weights=weights)
    opt = torch.optim.AdamW(model.parameters(), lr=LR_MAX, weight_decay=wd)
    steps_per_epoch = max(1, len(tr_loader))
    sched = torch.optim.lr_scheduler.OneCycleLR(
        opt, max_lr=LR_MAX, epochs=epochs, steps_per_epoch=steps_per_epoch,
        pct_start=0.15, anneal_strategy="cos"
    )
    swa_model = torch.optim.swa_utils.AveragedModel(model)
    swa_sched = torch.optim.swa_utils.SWALR(opt, swa_lr=LR_MAX*0.2)

    best_f1, best_state, stale = -1.0, None, 0
    for ep in range(1, epochs+1):
        model.train()
        for xb, yb in tr_loader:
            xb, yb = xb.to(DEVICE), yb.to(DEVICE)
            xb_m, yb2, lam = mixup(xb, yb, alpha=MIXUP_ALPHA)
            opt.zero_grad(set_to_none=True)
            logits = model(xb_m)
            loss = lam * crit(logits, yb) + (1 - lam) * crit(logits, yb2)
            loss.mean().backward()
            nn.utils.clip_grad_norm_(model.parameters(), CLIP_NORM)
            opt.step()
            sched.step()

        # Validation
        acc_va, f1_va, *_ = evaluate(model, va_loader)
        if f1_va > best_f1 + 1e-4:
            best_f1, stale = f1_va, 0
            raw = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
            best_state = _sanitize_state_dict(raw)
        else:
            stale += 1
            if stale >= PATIENCE:
                break

        # SWA tail
        if ep >= SWA_START_EPOCH:
            swa_model.update_parameters(model)
            swa_sched.step()

    # Final SWA check
    if epochs >= SWA_START_EPOCH:
        torch.optim.swa_utils.update_bn(tr_loader, swa_model, device=DEVICE)
        acc_va_swa, f1_va_swa, *_ = evaluate(swa_model, va_loader)
        if f1_va_swa >= best_f1:
            raw = {k: v.detach().cpu().clone() for k, v in swa_model.state_dict().items()}
            best_state = _sanitize_state_dict(raw)
            best_f1 = f1_va_swa

    return best_state, best_f1

# ============================================================
#  Architectures & grids
# ============================================================
ARCHES = {
    "lean": {
        "ctor": MLPLean,
        "PCA_GRID": [16, 32, 64],
        "H_GRID":   [64, 128],
        "DROP_GRID":[0.10, 0.20],
        "WD_GRID":  [1e-4, 5e-4],
    },
    "deep": {
        "ctor": MLPDeep,
        "PCA_GRID": [64, 128, 192],
        "H_GRID":   [256, 384],
        "DROP_GRID":[0.30, 0.40],
        "WD_GRID":  [1e-4, 5e-4],
    },
    "hybrid": {
        "ctor": MLPHybrid,
        "PCA_GRID": [32, 64, 128],
        "H_GRID":   [128, 256],
        "DROP_GRID":[0.20, 0.35],
        "WD_GRID":  [1e-4, 5e-4],
    },
}

# ============================================================
#  Per-architecture runner
# ============================================================
def run_model_summary(arch_name, config, X_all, y_all):
    set_seed()
    out_dir = os.path.join(ROOT_OUT, arch_name)
    fig_dir = os.path.join(out_dir, "figs")
    tab_dir = os.path.join(out_dir, "tables")
    ckpt_dir= os.path.join(out_dir, "checkpoints")
    for d in (out_dir, fig_dir, tab_dir, ckpt_dir):
        os.makedirs(d, exist_ok=True)

    ctor = config["ctor"]
    PCA_GRID = config["PCA_GRID"]
    H_GRID   = config["H_GRID"]
    DROP_GRID= config["DROP_GRID"]
    WD_GRID  = config["WD_GRID"]

    outer = StratifiedKFold(n_splits=K_OUTER, shuffle=True, random_state=SEED)
    oof_pred = np.zeros((N, C), dtype=np.float32)
    oof_y    = np.zeros(N, dtype=np.int64)
    fold_reports, best_cfgs = [], []

    for fold, (tr_idx, va_idx) in enumerate(outer.split(X_all, y_all), 1):
        print(f"\n[{arch_name.upper()}] Outer fold {fold}/{K_OUTER}")
        Xtr_raw, Xva_raw = X_all[tr_idx], X_all[va_idx]
        ytr, yva = y_all[tr_idx], y_all[va_idx]

        best_fold_f1 = -1.0
        best_cfg = None
        best_artifacts = None

        for pca_dim in PCA_GRID:
            scaler = StandardScaler()
            Xtr_s = scaler.fit_transform(Xtr_raw)
            Xva_s = scaler.transform(Xva_raw)

            ncomp = min(pca_dim, Xtr_s.shape[1], max(1, Xtr_s.shape[0] - 1))
            pca = PCA(n_components=ncomp, random_state=SEED)
            Ztr = pca.fit_transform(Xtr_s)
            Zva = pca.transform(Xva_s)

            for h in H_GRID:
                for drop in DROP_GRID:
                    for wd in WD_GRID:
                        best_state, f1_va = train_one_fold(
                            Ztr, ytr, Zva, yva,
                            model_ctor=ctor, h=h, drop=drop, wd=wd, epochs=EPOCHS
                        )
                        if f1_va > best_fold_f1:
                            best_fold_f1 = f1_va
                            best_cfg = {"pca": ncomp, "h": h, "drop": drop, "wd": wd}
                            best_artifacts = {
                                "scaler": scaler,
                                "pca": pca,
                                "model_state": {k: v.clone() for k, v in best_state.items()}
                            }

        # finalize OOF for this fold
        scaler = best_artifacts["scaler"]; pca = best_artifacts["pca"]
        Zva = pca.transform(scaler.transform(Xva_raw))
        model = ctor(Zva.shape[1], h=best_cfg["h"], drop=best_cfg["drop"], nclass=C).to(DEVICE)
        model.load_state_dict({k: v.to(DEVICE) for k, v in best_artifacts["model_state"].items()})
        va_loader = DataLoader(TensorDataset(*to_torch(Zva, yva)),
                               batch_size=BATCH_VAL, shuffle=False,
                               pin_memory=(DEVICE.type=="cuda"))
        acc, f1m, Yv, yhat, P = evaluate(model, va_loader)

        oof_pred[va_idx] = P
        oof_y[va_idx] = yva

        fold_reports.append({"fold": fold, "acc": acc, "f1": f1m})
        best_cfgs.append(best_cfg)
        gc.collect()

    # --------------------
    # OOF summary & plots
    # --------------------
    # Confusion matrices
    cm = confusion_matrix(oof_y, oof_pred.argmax(1), labels=np.arange(C))
    cm_norm = cm / cm.sum(axis=1, keepdims=True).clip(min=1)

    fig, ax = plt.subplots(figsize=(3.4, 3.0))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
                xticklabels=LABELS, yticklabels=LABELS, cbar=False, ax=ax)
    ax.set_xlabel("Predicted"); ax.set_ylabel("True"); ax.set_title("Confusion (counts)")
    fig.tight_layout(); fig.savefig(os.path.join(fig_dir, "summary_cm_counts.pdf")); plt.close(fig)

    fig, ax = plt.subplots(figsize=(3.4, 3.0))
    sns.heatmap(cm_norm, annot=True, fmt=".2f", cmap="Blues",
                xticklabels=LABELS, yticklabels=LABELS, cbar=False, vmin=0, vmax=1, ax=ax)
    ax.set_xlabel("Predicted"); ax.set_ylabel("True"); ax.set_title("Confusion (row-normalized)")
    fig.tight_layout(); fig.savefig(os.path.join(fig_dir, "summary_cm_norm.pdf")); plt.close(fig)

    cm_df = pd.DataFrame(cm, index=LABELS, columns=LABELS)
    cmn_df = pd.DataFrame(cm_norm, index=LABELS, columns=LABELS)
    cm_df.to_csv(os.path.join(tab_dir, "summary_cm_counts.csv"))
    cmn_df.to_csv(os.path.join(tab_dir, "summary_cm_norm.csv"))
    save_df_as_pdf(cm_df.reset_index().rename(columns={"index":"True\\Pred"}),
                   os.path.join(tab_dir, "summary_cm_counts.pdf"),
                   title="Confusion (counts)")
    save_df_as_pdf(cmn_df.reset_index().rename(columns={"index":"True\\Pred"}),
                   os.path.join(tab_dir, "summary_cm_norm.pdf"),
                   title="Confusion (row-normalized)")

    # Per-class metrics
    y_true_str = [LABELS[i] for i in oof_y]
    y_pred_str = [LABELS[i] for i in oof_pred.argmax(1)]
    rep_df = report_to_df(y_true_str, y_pred_str, LABELS)
    cls_only = rep_df[rep_df["class"].isin(LABELS)].copy()
    melted = cls_only.melt(id_vars="class",
                           value_vars=["precision","recall","f1-score"],
                           var_name="metric", value_name="value")

    fig, ax = plt.subplots(figsize=(5.2, 3.0))
    sns.barplot(data=melted, x="class", y="value", hue="metric",
                ax=ax, edgecolor="black", linewidth=0.5)
    ax.set_ylim(0, 1); ax.set_ylabel("Score"); ax.set_xlabel("")
    ax.set_title("Per-class metrics (OOF)")
    ax.legend(frameon=False, ncol=3, loc="upper center", bbox_to_anchor=(0.5, 1.25))
    fig.tight_layout(); fig.savefig(os.path.join(fig_dir, "summary_perclass_bars.pdf")); plt.close(fig)

    rep_df.to_csv(os.path.join(tab_dir, "summary_classification_report.csv"), index=False)
    save_df_as_pdf(rep_df, os.path.join(tab_dir, "summary_classification_report.pdf"),
                   title="Classification Report (OOF)")

    # ROC & PR
    Y_bin = label_binarize(oof_y, classes=list(range(C)))
    auc_rows, ap_rows = [], []

    fig, ax = plt.subplots(figsize=(4.0, 3.2))
    for c_i in range(C):
        fpr, tpr, _ = roc_curve(Y_bin[:, c_i], oof_pred[:, c_i])
        AUC = auc(fpr, tpr); auc_rows.append(AUC)
        ax.plot(fpr, tpr, label=f"{LABELS[c_i]} (AUC={AUC:.2f})")
    ax.plot([0,1],[0,1], linestyle="--", color="gray", linewidth=0.8)
    ax.set_xlabel("FPR"); ax.set_ylabel("TPR"); ax.set_title("ROC (OOF)")
    ax.legend(frameon=False, fontsize=7, ncol=2)
    fig.tight_layout(); fig.savefig(os.path.join(fig_dir, "summary_roc.pdf")); plt.close(fig)

    fig, ax = plt.subplots(figsize=(4.0, 3.2))
    for c_i in range(C):
        prec, rec, _ = precision_recall_curve(Y_bin[:, c_i], oof_pred[:, c_i])
        AP = average_precision_score(Y_bin[:, c_i], oof_pred[:, c_i]); ap_rows.append(AP)
        ax.plot(rec, prec, label=f"{LABELS[c_i]} (AP={AP:.2f})")
    ax.set_xlabel("Recall"); ax.set_ylabel("Precision"); ax.set_title("Precision–Recall (OOF)")
    ax.legend(frameon=False, fontsize=7, ncol=2)
    fig.tight_layout(); fig.savefig(os.path.join(fig_dir, "summary_pr.pdf")); plt.close(fig)

    auc_df = pd.DataFrame({"class": LABELS, "AUC": auc_rows})
    ap_df  = pd.DataFrame({"class": LABELS, "AP":  ap_rows})
    auc_df.loc[len(auc_df)] = {"class":"macro","AUC":auc_df["AUC"].mean()}
    ap_df.loc [len(ap_df )] = {"class":"macro","AP": ap_df["AP"].mean()}
    aucap_sum = auc_df.merge(ap_df, on="class")
    aucap_sum.to_csv(os.path.join(tab_dir, "nestedcv_auc_ap_summary.csv"), index=False)
    save_df_as_pdf(aucap_sum, os.path.join(tab_dir, "nestedcv_auc_ap_summary.pdf"),
                   title="ROC AUC / Average Precision (OOF)")

    # Expected probability heatmap (E[P(pred=j | true=i)])
    M = np.zeros((C, C), float)
    for i_c in range(C):
        idx = (oof_y == i_c)
        if idx.any():
            M[i_c] = oof_pred[idx].mean(axis=0)
    fig, ax = plt.subplots(figsize=(3.4, 3.0))
    sns.heatmap(M, annot=True, fmt=".2f", cmap="Purples",
                xticklabels=LABELS, yticklabels=LABELS, vmin=0, vmax=1,
                cbar=False, ax=ax)
    ax.set_xlabel("Predicted class prob"); ax.set_ylabel("True class")
    ax.set_title("Expected P(class | true) (OOF)")
    fig.tight_layout(); fig.savefig(os.path.join(fig_dir, "summary_exp_prob_heatmap.pdf")); plt.close(fig)

    # Fold-level summary bars
    accs = [fr["acc"] for fr in fold_reports]
    f1s  = [fr["f1"]  for fr in fold_reports]
    means = np.array([np.mean(accs), np.mean(f1s)])
    stds  = np.array([np.std(accs),  np.std(f1s)])
    cis   = 1.96 * stds / np.sqrt(len(accs))
    labels2 = ["Accuracy", "Macro-F1"]

    fig, ax = plt.subplots(figsize=(3.6, 2.8))
    idx = np.arange(len(labels2))
    ax.bar(idx, means, yerr=cis, capsize=3, linewidth=0.8, edgecolor="black")
    ax.set_xticks(idx, labels2); ax.set_ylim(0, 1)
    for i, m in enumerate(means):
        ax.text(i, min(0.98, m+0.02), f"{m:.2f}", ha="center", va="bottom", fontsize=8)
    ax.set_title("Nested CV (OOF) — Summary")
    fig.tight_layout(); fig.savefig(os.path.join(fig_dir, "nestedcv_summary_bars.pdf")); plt.close(fig)

    fold_tbl = pd.DataFrame(fold_reports)
    fold_tbl.to_csv(os.path.join(tab_dir, "nestedcv_fold_metrics.csv"), index=False)
    save_df_as_pdf(fold_tbl, os.path.join(tab_dir, "nestedcv_fold_metrics.pdf"),
                   title="Fold-wise Metrics (Nested CV)")

    run_sum = pd.DataFrame({
        "metric": ["Accuracy mean", "Accuracy sd", "Macro-F1 mean", "Macro-F1 sd"],
        "value": [np.mean(accs), np.std(accs), np.mean(f1s), np.std(f1s)]
    })
    run_sum.to_csv(os.path.join(tab_dir, "nestedcv_summary.csv"), index=False)
    save_df_as_pdf(run_sum, os.path.join(tab_dir, "nestedcv_summary.pdf"),
                   title="Nested CV Summary (OOF)")

    return {
        "arch": arch_name,
        "oof_pred": oof_pred,
        "oof_y": oof_y,
        "acc_mean": np.mean(accs),
        "acc_sd": np.std(accs),
        "f1_mean": np.mean(f1s),
        "f1_sd": np.std(f1s),
        "fig_dir": fig_dir,
        "tab_dir": tab_dir,
    }

# ============================================================
#  Run Lean / Deep / Hybrid + comparison table
# ============================================================
all_runs = []
for arch in ("lean", "deep", "hybrid"):
    print(f"\n=== Running {arch.upper()} MLP ===")
    res = run_model_summary(arch, ARCHES[arch], X_all, y_all)
    all_runs.append(res)

cmp_dir = os.path.join(ROOT_OUT, "compare")
os.makedirs(cmp_dir, exist_ok=True)

cmp_df = pd.DataFrame({
    "arch":   [r["arch"] for r in all_runs],
    "acc_mean":[r["acc_mean"] for r in all_runs],
    "acc_sd": [r["acc_sd"] for r in all_runs],
    "f1_mean":[r["f1_mean"] for r in all_runs],
    "f1_sd":  [r["f1_sd"] for r in all_runs],
})
cmp_df["acc_ci95"] = 1.96 * cmp_df["acc_sd"] / np.sqrt(K_OUTER)
cmp_df["f1_ci95"]  = 1.96 * cmp_df["f1_sd"]  / np.sqrt(K_OUTER)
cmp_df.to_csv(os.path.join(cmp_dir, "model_comparison_metrics.csv"), index=False)
save_df_as_pdf(cmp_df, os.path.join(cmp_dir, "model_comparison_metrics.pdf"),
               title="Lean vs Deep vs Hybrid — Nested CV")

# Accuracy comparison bars
order = ["lean","deep","hybrid"]
x = np.arange(len(order))
fig, ax = plt.subplots(figsize=(4.8, 3.2))
acc_means = cmp_df.set_index("arch").loc[order, "acc_mean"].values
acc_cis   = cmp_df.set_index("arch").loc[order, "acc_ci95"].values
ax.bar(x, acc_means, yerr=acc_cis, capsize=3, edgecolor="black")
ax.set_xticks(x, [s.capitalize() for s in order])
ax.set_ylim(0, 1); ax.set_ylabel("Accuracy")
ax.set_title("Accuracy (mean ± 95% CI) — Nested CV")
for i, m in enumerate(acc_means):
    ax.text(i, min(0.98, m+0.02), f"{m:.2f}", ha="center", va="bottom", fontsize=8)
fig.tight_layout(); fig.savefig(os.path.join(cmp_dir, "comparison_accuracy_bars.pdf")); plt.close(fig)

# Macro-F1 comparison bars
fig, ax = plt.subplots(figsize=(4.8, 3.2))
f1_means = cmp_df.set_index("arch").loc[order, "f1_mean"].values
f1_cis   = cmp_df.set_index("arch").loc[order, "f1_ci95"].values
ax.bar(x, f1_means, yerr=f1_cis, capsize=3, edgecolor="black")
ax.set_xticks(x, [s.capitalize() for s in order])
ax.set_ylim(0, 1); ax.set_ylabel("Macro-F1")
ax.set_title("Macro-F1 (mean ± 95% CI) — Nested CV")
for i, m in enumerate(f1_means):
    ax.text(i, min(0.98, m+0.02), f"{m:.2f}", ha="center", va="bottom", fontsize=8)
fig.tight_layout(); fig.savefig(os.path.join(cmp_dir, "comparison_macroF1_bars.pdf")); plt.close(fig)

print("\nSaved per-model outputs under:", ROOT_OUT)
print("Comparison figures/tables under:", cmp_dir)


Detected raw labels: ['AD', 'CN', 'MCI', 'MCI/AD']
Using LABELS order: ['CN', 'MCI', 'MCI/AD', 'AD']
lab2id mapping: {'CN': 0, 'MCI': 1, 'MCI/AD': 2, 'AD': 3}
N_CLASS: 4
[data] N,D = (124, 203)

=== Running LEAN MLP ===

[LEAN] Outer fold 1/5

[LEAN] Outer fold 2/5

[LEAN] Outer fold 3/5

[LEAN] Outer fold 4/5

[LEAN] Outer fold 5/5

=== Running DEEP MLP ===

[DEEP] Outer fold 1/5

[DEEP] Outer fold 2/5

[DEEP] Outer fold 3/5

[DEEP] Outer fold 4/5

[DEEP] Outer fold 5/5

=== Running HYBRID MLP ===

[HYBRID] Outer fold 1/5

[HYBRID] Outer fold 2/5

[HYBRID] Outer fold 3/5

[HYBRID] Outer fold 4/5

[HYBRID] Outer fold 5/5

Saved per-model outputs under: /content/drive/MyDrive/Capstone-Project/derivatives/pp_supervised_dl_summary
Comparison figures/tables under: /content/drive/MyDrive/Capstone-Project/derivatives/pp_supervised_dl_summary/compare
