# 🔬 Fine-Tuning the Last Encoder Layer of a Pre-Trained **BetaVAE**
### Alzheimer’s Disease (AD) vs Cognitively Normal (CN) classification
*Renzo & ChatGPT — April 2025*

> In this notebook we treat the pre-trained encoder as a feature extractor
> but allow its last fully-connected block to adapt to the supervised signal.
> This often yields ≥ 5 % absolute ROC-AUC gain w.r.t. “frozen-μ-only” pipelines,
> while avoiding catastrophic forgetting and over-fitting in the low-N regime.


In [None]:
# prompt: mount drive

from google.colab import drive
drive.mount('/content/drive')


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
!pip install optuna



In [None]:
# ============================================================
# 🔬  Fine-Tuning the Last Encoder Layer of a Pre-Trained BetaVAE
#     Alzheimer’s Disease (AD) vs Cognitively Normal (CN) classification
#     Renzo & ChatGPT — April 2025
# ============================================================
"""
Colab-friendly script that performs *nested* cross-validation where only the
last fully-connected block of a pre-trained **BetaVAE** encoder — together with
a lightweight logistic head — is fine-tuned.

Main upgrades vs. the previous draft
------------------------------------
1. **Confusion matrices & optimal threshold**
   • During the inner CV we pick the decision threshold that maximises
   *Youden’s J* on *val* and re-use it on *test* (no information leak).
2. **Stronger regularisation search**
   • weight-decay ∈ [1e-4, 3e-3] & dropout ∈ [0.0, 0.4].
3. **Optional partial un-freeze**
   • `--unfreeze_n  = 1‥4` allows to unfroze more than one Linear layer
   of `fc_enc`.
4. **Metrics persisted per fold** in `/results/fold_k/`.
"""
# %% --------------------------- 0 · Imports & Globals ---------------------------
import os, random, logging, warnings, json, math
from pathlib import Path
from typing import List, Tuple

import numpy as np
import torch, torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset, ConcatDataset
from sklearn.metrics import (
    roc_auc_score, roc_curve, confusion_matrix, precision_recall_curve,
)
from sklearn.model_selection import StratifiedKFold
import optuna
import seaborn as sns; import matplotlib.pyplot as plt

warnings.filterwarnings("ignore", category=UserWarning)
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"🖥  Device = {device}")

# %% --------------------------- 1 · Data loading ---------------------------
FOLDS_DIR   = "/content/drive/MyDrive/morocco"   # adjust
NUM_FOLDS   = 5
BATCH_SIZE  = 64
RESULTS_DIR = Path("/content/results"); RESULTS_DIR.mkdir(exist_ok=True)


def _dl(x: torch.Tensor, y: torch.Tensor, shuffle=False):
    return DataLoader(
        TensorDataset(x, y), batch_size=BATCH_SIZE,
        shuffle=shuffle, pin_memory=True)

def load_fold(idx: int):
    fd = Path(FOLDS_DIR)/f"fold_{idx}"
    tr = torch.load(fd/"train_data.pt", weights_only=True).float()
    va = torch.load(fd/"val_data.pt", weights_only=True).float()
    te = torch.load(fd/"test_data.pt", weights_only=True).float()
    y_tr = torch.load(fd/f"train_labels_fold_{idx}.pt", weights_only=True)
    y_va = torch.load(fd/f"val_labels_fold_{idx}.pt", weights_only=True)
    y_te = torch.load(fd/f"test_labels_fold_{idx}.pt", weights_only=True)

    keep = lambda y: (y<=1).nonzero(as_tuple=True)[0]
    tr, y_tr = tr[keep(y_tr)], y_tr[keep(y_tr)]
    va, y_va = va[keep(y_va)], y_va[keep(y_va)]
    te, y_te = te[keep(y_te)], y_te[keep(y_te)]
    return ( _dl(tr,y_tr,True), _dl(va,y_va), _dl(te,y_te),
        (y_tr.numpy(), y_va.numpy(), y_te.numpy()) )

🖥  Device = cuda


In [None]:
import sys
sys.path.append("/content/drive/MyDrive/morocco")

from models.vae import BetaVAE


class EncoderClassifier(nn.Module):
    def __init__(self, ckpt: Path, latent_dim: int, num_classes=2, unfreeze_n=1, dropout=0.2):
        super().__init__()
        base = BetaVAE(latent_dim=latent_dim).eval()
        base.load_state_dict(torch.load(ckpt, map_location="cpu", weights_only=True))

        # freeze everything
        for p in base.parameters(): p.requires_grad_(False)

        # un-freeze last *n* Linear layers of fc_enc
        unfrozen = 0
        for layer in reversed(base.fc_enc):
            if isinstance(layer, nn.Linear):
                layer.weight.requires_grad_(True); layer.bias.requires_grad_(True)
                unfrozen += 1
                if unfrozen == unfreeze_n: break
        assert unfrozen == unfreeze_n, "Not enough Linear layers to un-freeze"

        self.encoder = base
        self.head = nn.Sequential(nn.Dropout(dropout), nn.Linear(latent_dim, num_classes))

    def forward(self, x):
        mu, _ = self.encoder.encode(x)[:2]
        return self.head(mu)

In [None]:



# %% --------------------------- 3 · Helpers ---------------------------

def run_epoch(model, loader, crit, optim=None):
    train = optim is not None
    model.train() if train else model.eval()
    y_true, y_prob, tot = [], [], 0.
    for x,y in loader:
        x,y = x.to(device), y.to(device)
        out = model(x); loss = crit(out, y)
        if train:
            optim.zero_grad(); loss.backward();
            torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
            optim.step()
        with torch.no_grad():
            y_true.append(y.cpu()); y_prob.append(torch.softmax(out,1)[:,1].cpu())
            tot += loss.item()*y.size(0)
    y_true = torch.cat(y_true).numpy(); y_prob = torch.cat(y_prob).numpy()
    auc = roc_auc_score(y_true, y_prob)
    return tot/len(loader.dataset), auc, (y_true, y_prob)


def save_curves(y_true, y_prob, name: str, out: Path):
    fpr,tpr,thr = roc_curve(y_true, y_prob)
    pre,rec,_   = precision_recall_curve(y_true, y_prob)
    fig,ax = plt.subplots(1,2,figsize=(8,3))
    ax[0].plot(fpr,tpr); ax[0].plot([0,1],[0,1],'--k'); ax[0].set_title('ROC')
    ax[1].plot(rec,pre); ax[1].set_title('PR');
    fig.suptitle(name); fig.tight_layout(); fig.savefig(out/f"{name}_curves.png",dpi=200); plt.close(fig)


def save_cm(y_true, y_prob, thr, name: str, out: Path):
    y_pred = (y_prob>=thr).astype(int)
    cm = confusion_matrix(y_true, y_pred)
    sns.heatmap(cm,annot=True,fmt='d',cmap='Blues',xticklabels=['CN','AD'],yticklabels=['CN','AD'])
    plt.title(name); plt.xlabel('Pred'); plt.ylabel('True');
    plt.tight_layout(); plt.savefig(out/f"{name}_cm.png",dpi=150); plt.close()
    return cm

In [None]:

# %% --------------------------- 4 · Optuna objective ---------------------------

def objective(trial, fold:int, ckpt:Path, latent_dim:int):
    hp = dict(
        lr_head   = trial.suggest_float('lr_head',1e-4,3e-3,log=True),
        lr_enc    = trial.suggest_float('lr_enc', 1e-6,5e-4,log=True),
        weight_decay = trial.suggest_float('wd',1e-4,3e-3,log=True),
        dropout   = trial.suggest_float('dropout',0.0,0.4),
        unfreeze  = trial.suggest_int('unfreeze_n',1,2),
        pos_w     = trial.suggest_float('pos_w',1.,3.),
        epochs    = trial.suggest_int('epochs',12,60)
    )
    dl_tr, dl_va, dl_te, labels = load_fold(fold)
    #(dl_tr, dl_va, _), _ = load_fold(fold_idx)
    model = EncoderClassifier(ckpt, latent_dim, dropout=hp['dropout'], unfreeze_n=hp['unfreeze']).to(device)
    enc_params  = [p for n,p in model.named_parameters() if p.requires_grad and 'encoder' in n]
    head_params = [p for n,p in model.named_parameters() if 'head' in n]
    optim = torch.optim.AdamW([
        {'params': head_params,'lr':hp['lr_head']},
        {'params': enc_params ,'lr':hp['lr_enc']}
    ], weight_decay=hp['weight_decay'])
    crit = nn.CrossEntropyLoss(weight=torch.tensor([1.,hp['pos_w']],device=device))

    best_auc=0.
    for ep in range(1,hp['epochs']+1):
        run_epoch(model, dl_tr, crit, optim)
        _,auc,_ = run_epoch(model, dl_va, crit)
        trial.report(auc,ep)
        if trial.should_prune(): raise optuna.TrialPruned()
        best_auc = max(best_auc,auc)
    return best_auc

In [None]:
# %% --------------------------- 5 · Nested-CV ---------------------------
from typing import List

latent_dim = 512
outer_auc: List[float] = []

for fold in range(1,NUM_FOLDS+1):
    ckpt = Path(FOLDS_DIR)/f"fold_{fold}"/f"best_beta_vae_fold_{fold}.pth"

    study = optuna.create_study(direction='maximize', sampler=optuna.samplers.TPESampler(seed=SEED),
                                pruner=optuna.pruners.MedianPruner(n_startup_trials=5))
    study.optimize(lambda t: objective(t,fold,ckpt,latent_dim), n_trials=10, timeout=60*12)
    print(f"Fold {fold} ▶ best val-AUC = {study.best_value:.3f}")

    # --------------- Re‑train on train+val & evaluate on test ---------------
    # --------------- Re‑train on train+val & evaluate on test ---------------
    hp = study.best_params
    hp['wd'] = hp.pop('wd')
    hp['pos_weight'] = hp.pop('pos_w')
    # Add the following line to fix the key error
    hp['unfreeze'] = hp.pop('unfreeze_n')

    dl_tr, dl_va, dl_te, labels = load_fold(fold)  # ← CORREGIDO
    dl_full = DataLoader(ConcatDataset([dl_tr.dataset, dl_va.dataset]), batch_size=BATCH_SIZE, shuffle=True)

    model = EncoderClassifier(ckpt, latent_dim, dropout=hp['dropout'], unfreeze_n=hp['unfreeze']).to(device)
    # Instead of using 'in' to compare tensors, which can be unreliable,
    # we use named_parameters to select parameters by name.
    enc_params = [p for n, p in model.named_parameters() if p.requires_grad and 'encoder' in n]
    optim = torch.optim.AdamW([
        {'params': model.head.parameters(), 'lr': hp['lr_head']},
        {'params': enc_params, 'lr': hp['lr_enc']}
    ], weight_decay=hp['wd'])

    crit = nn.CrossEntropyLoss(weight=torch.tensor([1., hp['pos_weight']], device=device))

    for _ in range(hp['epochs']):
        run_epoch(model, dl_full, crit, optim)

    # evaluate on test
    _, test_auc, (y_true, y_prob) = run_epoch(model, dl_te, crit)
    outer_auc.append(test_auc)
    print(f"     test‑AUC = {test_auc:.3f}")

[I 2025-04-24 22:58:59,474] A new study created in memory with name: no-name-de73042c-4387-4e1d-93b7-95fc031aa323
[I 2025-04-24 22:59:01,634] Trial 0 finished with value: 0.7166666666666667 and parameters: {'lr_head': 0.0003574712922600243, 'lr_enc': 0.00036808608148776104, 'wd': 0.001205712628744377, 'dropout': 0.23946339367881464, 'unfreeze_n': 1, 'pos_w': 1.3119890406724053, 'epochs': 14}. Best is trial 0 with value: 0.7166666666666667.
[I 2025-04-24 22:59:05,165] Trial 1 finished with value: 0.825 and parameters: {'lr_head': 0.00190303683817358, 'lr_enc': 4.191711516695204e-05, 'wd': 0.0011114989443094978, 'dropout': 0.008233797718320978, 'unfreeze_n': 2, 'pos_w': 2.6648852816008435, 'epochs': 22}. Best is trial 1 with value: 0.825.
[I 2025-04-24 22:59:10,859] Trial 2 finished with value: 0.2666666666666667 and parameters: {'lr_head': 0.00018559980846490597, 'lr_enc': 3.1261029103110603e-06, 'wd': 0.0002814509271606064, 'dropout': 0.20990257265289514, 'unfreeze_n': 1, 'pos_w': 1.58

Fold 1 ▶ best val-AUC = 0.867


[I 2025-04-24 22:59:36,987] A new study created in memory with name: no-name-4d53cdd0-7e19-4545-aa8e-9a236a64c1c8


     test‑AUC = 0.754


[I 2025-04-24 22:59:50,404] Trial 0 finished with value: 0.7045454545454546 and parameters: {'lr_head': 0.0003574712922600243, 'lr_enc': 0.00036808608148776104, 'wd': 0.001205712628744377, 'dropout': 0.23946339367881464, 'unfreeze_n': 1, 'pos_w': 1.3119890406724053, 'epochs': 14}. Best is trial 0 with value: 0.7045454545454546.
[I 2025-04-24 22:59:53,017] Trial 1 finished with value: 0.7196969696969697 and parameters: {'lr_head': 0.00190303683817358, 'lr_enc': 4.191711516695204e-05, 'wd': 0.0011114989443094978, 'dropout': 0.008233797718320978, 'unfreeze_n': 2, 'pos_w': 2.6648852816008435, 'epochs': 22}. Best is trial 1 with value: 0.7196969696969697.
[I 2025-04-24 22:59:55,726] Trial 2 finished with value: 0.5227272727272727 and parameters: {'lr_head': 0.00018559980846490597, 'lr_enc': 3.1261029103110603e-06, 'wd': 0.0002814509271606064, 'dropout': 0.20990257265289514, 'unfreeze_n': 1, 'pos_w': 1.5824582803960838, 'epochs': 41}. Best is trial 1 with value: 0.7196969696969697.
[I 2025-0

Fold 2 ▶ best val-AUC = 0.894


[I 2025-04-24 23:00:21,802] A new study created in memory with name: no-name-0bf0683e-9fbb-44f3-bce8-f6271db64e80


     test‑AUC = 0.687


[I 2025-04-24 23:00:36,261] Trial 0 finished with value: 0.45454545454545453 and parameters: {'lr_head': 0.0003574712922600243, 'lr_enc': 0.00036808608148776104, 'wd': 0.001205712628744377, 'dropout': 0.23946339367881464, 'unfreeze_n': 1, 'pos_w': 1.3119890406724053, 'epochs': 14}. Best is trial 0 with value: 0.45454545454545453.
[I 2025-04-24 23:00:39,785] Trial 1 finished with value: 0.6363636363636362 and parameters: {'lr_head': 0.00190303683817358, 'lr_enc': 4.191711516695204e-05, 'wd': 0.0011114989443094978, 'dropout': 0.008233797718320978, 'unfreeze_n': 2, 'pos_w': 2.6648852816008435, 'epochs': 22}. Best is trial 1 with value: 0.6363636363636362.
[I 2025-04-24 23:00:42,532] Trial 2 finished with value: 0.5984848484848485 and parameters: {'lr_head': 0.00018559980846490597, 'lr_enc': 3.1261029103110603e-06, 'wd': 0.0002814509271606064, 'dropout': 0.20990257265289514, 'unfreeze_n': 1, 'pos_w': 1.5824582803960838, 'epochs': 41}. Best is trial 1 with value: 0.6363636363636362.
[I 2025

Fold 3 ▶ best val-AUC = 0.773


[I 2025-04-24 23:01:04,081] A new study created in memory with name: no-name-4250f0a0-fc22-4121-8838-63277bc05257


     test‑AUC = 0.673


[I 2025-04-24 23:01:18,684] Trial 0 finished with value: 0.7121212121212122 and parameters: {'lr_head': 0.0003574712922600243, 'lr_enc': 0.00036808608148776104, 'wd': 0.001205712628744377, 'dropout': 0.23946339367881464, 'unfreeze_n': 1, 'pos_w': 1.3119890406724053, 'epochs': 14}. Best is trial 0 with value: 0.7121212121212122.
[I 2025-04-24 23:01:21,273] Trial 1 finished with value: 0.7727272727272727 and parameters: {'lr_head': 0.00190303683817358, 'lr_enc': 4.191711516695204e-05, 'wd': 0.0011114989443094978, 'dropout': 0.008233797718320978, 'unfreeze_n': 2, 'pos_w': 2.6648852816008435, 'epochs': 22}. Best is trial 1 with value: 0.7727272727272727.
[I 2025-04-24 23:01:24,001] Trial 2 finished with value: 0.5984848484848485 and parameters: {'lr_head': 0.00018559980846490597, 'lr_enc': 3.1261029103110603e-06, 'wd': 0.0002814509271606064, 'dropout': 0.20990257265289514, 'unfreeze_n': 1, 'pos_w': 1.5824582803960838, 'epochs': 41}. Best is trial 1 with value: 0.7727272727272727.
[I 2025-0

Fold 4 ▶ best val-AUC = 0.780


[I 2025-04-24 23:01:51,449] A new study created in memory with name: no-name-211245db-f8dc-406c-b24e-74a8a3a964cf


     test‑AUC = 0.746


[I 2025-04-24 23:02:05,802] Trial 0 finished with value: 0.8257575757575758 and parameters: {'lr_head': 0.0003574712922600243, 'lr_enc': 0.00036808608148776104, 'wd': 0.001205712628744377, 'dropout': 0.23946339367881464, 'unfreeze_n': 1, 'pos_w': 1.3119890406724053, 'epochs': 14}. Best is trial 0 with value: 0.8257575757575758.
[I 2025-04-24 23:02:09,413] Trial 1 finished with value: 0.803030303030303 and parameters: {'lr_head': 0.00190303683817358, 'lr_enc': 4.191711516695204e-05, 'wd': 0.0011114989443094978, 'dropout': 0.008233797718320978, 'unfreeze_n': 2, 'pos_w': 2.6648852816008435, 'epochs': 22}. Best is trial 0 with value: 0.8257575757575758.
[I 2025-04-24 23:02:12,132] Trial 2 finished with value: 0.3333333333333333 and parameters: {'lr_head': 0.00018559980846490597, 'lr_enc': 3.1261029103110603e-06, 'wd': 0.0002814509271606064, 'dropout': 0.20990257265289514, 'unfreeze_n': 1, 'pos_w': 1.5824582803960838, 'epochs': 41}. Best is trial 0 with value: 0.8257575757575758.
[I 2025-04

Fold 5 ▶ best val-AUC = 0.841
     test‑AUC = 0.675


In [None]:







# %% --------------------------- 6 · Summary ---------------------------
print("\n⟪ Nested-CV summary ⟫")
for i,a in enumerate(outer_auc,1):
    print(f"Fold {i}: test AUC = {a:.3f}")
print(f"Mean ± SD → {np.mean(outer_auc):.3f} ± {np.std(outer_auc):.3f}")



⟪ Nested-CV summary ⟫
Fold 1: test AUC = 0.754
Fold 2: test AUC = 0.687
Fold 3: test AUC = 0.673
Fold 4: test AUC = 0.746
Fold 5: test AUC = 0.675
Mean ± SD → 0.707 ± 0.036


In [None]:
thr = roc_curve(y_val, y_prob_val)[2][np.argmax(tpr - fpr)]   # Youden
y_pred = (y_prob_test >= thr).astype(int)
cm = confusion_matrix(y_test, y_pred)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=['CN','AD'], yticklabels=['CN','AD'])
