# 🔬 Fine-Tuning the Last Encoder Layer of a Pre-Trained **BetaVAE**
### Alzheimer’s Disease (AD) vs Cognitively Normal (CN) classification
*Renzo & ChatGPT — April 2025*

> In this notebook we treat the pre-trained encoder as a feature extractor
> but allow its last fully-connected block to adapt to the supervised signal.
> This often yields ≥ 5 % absolute ROC-AUC gain w.r.t. “frozen-μ-only” pipelines,
> while avoiding catastrophic forgetting and over-fitting in the low-N regime.


In [1]:
# ╒═══════════════════════════════════════════════════════════════════╕
# 0.  INITIALISATION — GPU, packages, experiment config
# ╘═══════════════════════════════════════════════════════════════════╛
#!pip -q install optuna wandb --upgrade         # ⚠ adjust versions if needed
import os, json, math, random, shutil, logging, warnings, gc, datetime
from pprint import pprint

import numpy as np, torch, torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset, Subset
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import roc_auc_score, f1_score, confusion_matrix

import optuna, wandb
warnings.filterwarnings("ignore", category=UserWarning)

SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

device  = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"🖥  Using device → {device}")


🖥  Using device → cpu


In [2]:
# ╒═══════════════════════════════════════════════════════════════════╕
# 1.  DATA LOADING — assumes your *fold_i/*.pt files on GDrive
# ╘═══════════════════════════════════════════════════════════════════╛
from google.colab import drive
drive.mount('/content/drive')   # <- authorise

FOLDS_DIR   = "/content/drive/MyDrive/morocco"
NUM_FOLDS   = 5
BATCH_SIZE  = 64                # can ↑ if >1 GPU

def load_fold(idx):
    """returns (train DL, val DL, test DL, labels)"""
    fd = f"{FOLDS_DIR}/fold_{idx}"
    tr = torch.load(f"{fd}/train_data.pt").float()
    va = torch.load(f"{fd}/val_data.pt").float()
    te = torch.load(f"{fd}/test_data.pt").float()

    y_tr = torch.load(f"{fd}/train_labels_fold_{idx}.pt")
    y_va = torch.load(f"{fd}/val_labels_fold_{idx}.pt")
    y_te = torch.load(f"{fd}/test_labels_fold_{idx}.pt")

    # keep only CN(0) & AD(1)
    mask = lambda y: (y<=1).nonzero().squeeze(-1)
    tr, y_tr = tr[mask(y_tr)], y_tr[mask(y_tr)]
    va, y_va = va[mask(y_va)], y_va[mask(y_va)]
    te, y_te = te[mask(y_te)], y_te[mask(y_te)]

    to_dl = lambda x,y,shuffle=False: DataLoader(
        TensorDataset(x, y), batch_size=BATCH_SIZE,
        shuffle=shuffle, pin_memory=True)

    return (to_dl(tr,y_tr,True), to_dl(va,y_va), to_dl(te,y_te)), (y_tr,y_va,y_te)


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
# ╒═══════════════════════════════════════════════════════════════════╕
# 2.  MODEL — freeze all but last encoder block (+ classifier head)
# ╘═══════════════════════════════════════════════════════════════════╛
import sys
sys.path.append("/content/drive/MyDrive/morocco")

from models.vae import BetaVAE

class EncoderClassifier(nn.Module):
    """
    Wrapper sobre BetaVAE que expone μ y agrega una cabeza clasificadora.
    Solo se entrena la última capa lineal del encoder + la cabeza.
    """
    def __init__(self, ckpt_path: str, latent_dim: int, num_classes: int = 2):
        super().__init__()
        base = BetaVAE(latent_dim=latent_dim).eval()
        base.load_state_dict(torch.load(ckpt_path, map_location="cpu"))

        # Congelar todos los parámetros inicialmente
        for p in base.parameters():
            p.requires_grad_(False)

        # Buscar última capa Linear del encoder
        for layer in reversed(base.fc_enc):
            if isinstance(layer, nn.Linear):
                layer.weight.requires_grad_(True)
                layer.bias.requires_grad_(True)
                break
        else:
            raise ValueError("No se encontró una capa Linear en fc_enc")

        self.encoder = base
        self.head = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(latent_dim, num_classes)
        )

    def forward(self, x):
        mu, _ = self.encoder.encode(x)[0:2]  # solo usamos μ
        return self.head(mu)



In [4]:
# ╒═══════════════════════════════════════════════════════════════════╕
# 3.  TRAIN + EVAL UTILITIES
# ╘═══════════════════════════════════════════════════════════════════╛
def run_epoch(model, loader, crit, optim=None):
    train = optim is not None
    model.train() if train else model.eval()
    y_true, y_prob = [], []; tot_loss = 0.

    for x,y in loader:
        x,y = x.to(device), y.to(device)
        logits = model(x)
        loss   = crit(logits, y)
        if train:
            optim.zero_grad(); loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 5.)
            optim.step()
        with torch.no_grad():
            probs = torch.softmax(logits,1)[:,1]
            y_true.append(y.cpu()); y_prob.append(probs.cpu())
            tot_loss += loss.item()*y.size(0)

    y_true = torch.cat(y_true).numpy()
    y_prob = torch.cat(y_prob).numpy()
    auc = roc_auc_score(y_true, y_prob)
    return tot_loss/len(loader.dataset), auc


In [5]:
# ╒═══════════════════════════════════════════════════════════════════╕
# 4.  OBJECTIVE FOR OPTUNA  (nested CV inside each fold)
# ╘═══════════════════════════════════════════════════════════════════╛
def objective(trial, fold_idx:int, ckpt_path:str, latent_dim:int):
    lr_head  = trial.suggest_float("lr_head", 1e-5, 3e-3, log=True)
    lr_enc   = trial.suggest_float("lr_enc",  1e-6, 1e-4, log=True)
    wd       = trial.suggest_float("weight_decay", 0, 1e-3)
    epochs   = trial.suggest_int("epochs", 10, 60)
    pos_w    = trial.suggest_float("pos_weight", 1.0, 3.0)

    (dl_tr, dl_va, _), _ = load_fold(fold_idx)
    model = EncoderClassifier(
        ckpt_path=f"{FOLDS_DIR}/fold_{fold_idx}/best_beta_vae_fold_{fold_idx}.pth",
        latent_dim=latent_dim).to(device)

    # two parameter groups with different LR
    enc_params  = [p for n,p in model.named_parameters() if "encoder.fc_enc.-1" in n]
    head_params = [p for n,p in model.named_parameters() if "head" in n]
    optim = torch.optim.AdamW([
        {"params": head_params, "lr": lr_head},
        {"params": enc_params , "lr": lr_enc }
    ], weight_decay=wd)
    crit  = nn.CrossEntropyLoss(weight=torch.tensor([1.0, pos_w]).to(device))

    best_auc = 0.
    for ep in range(1, epochs+1):
        run_epoch(model, dl_tr, crit, optim)
        _, val_auc = run_epoch(model, dl_va, crit)
        trial.report(val_auc, ep)
        if trial.should_prune(): raise optuna.TrialPruned()
        best_auc = max(best_auc, val_auc)
    return best_auc


In [6]:
# ╒═══════════════════════════════════════════════════════════════════╕
# 5.  NESTED-CV LOOP  (outer = 5 folds, inner = Optuna search)
# ╘═══════════════════════════════════════════════════════════════════╛
latent_dim = 512
outer_results = []

for k in range(1, NUM_FOLDS+1):
    study = optuna.create_study(
        direction="maximize",
        pruner=optuna.pruners.MedianPruner(n_startup_trials=5),
        sampler=optuna.samplers.TPESampler(seed=SEED))
    study.optimize(lambda tr: objective(tr, k,
                       ckpt_path=f"{FOLDS_DIR}/fold_{k}/best_beta_vae_fold_{k}.pth",
                       latent_dim=latent_dim),
                   n_trials=25, timeout=60*15)   # 25 min per fold

    print(f"Fold {k}:  best val-AUC = {study.best_value:.3f}")
    # ––– retrain on train+val with best hyper-params and test –––
    best = study.best_params
    (dl_tr, dl_va, dl_te), (y_tr, y_va, y_te) = load_fold(k)
    dl_full = DataLoader(torch.utils.data.ConcatDataset([dl_tr.dataset,
                                                         dl_va.dataset]),
                         batch_size=BATCH_SIZE,
                         shuffle=True, pin_memory=True)

    model = EncoderClassifier(
        ckpt_path=f"{FOLDS_DIR}/fold_{k}/best_beta_vae_fold_{k}.pth",
        latent_dim=latent_dim).to(device)
    enc_params  = [p for n,p in model.named_parameters() if "encoder.fc_enc.-1" in n]
    head_params = [p for n,p in model.named_parameters() if "head" in n]
    optim = torch.optim.AdamW([
        {"params": head_params, "lr": best["lr_head"]},
        {"params": enc_params , "lr": best["lr_enc"]}
    ], weight_decay=best["weight_decay"])
    crit  = nn.CrossEntropyLoss(
                weight=torch.tensor([1.0, best["pos_weight"]]).to(device))

    for _ in range(best["epochs"]):
        run_epoch(model, dl_full, crit, optim)
    _, test_auc = run_epoch(model, dl_te, crit)
    outer_results.append(test_auc)
    print(f"   test-AUC = {test_auc:.3f}")


[I 2025-04-24 06:51:43,924] A new study created in memory with name: no-name-ee4264db-2243-4f92-b535-2d52a9ee388b
[I 2025-04-24 06:52:50,734] Trial 0 finished with value: 0.43333333333333335 and parameters: {'lr_head': 8.468008575248323e-05, 'lr_enc': 7.969454818643937e-05, 'weight_decay': 0.0007319939418114051, 'epochs': 40, 'pos_weight': 1.312037280884873}. Best is trial 0 with value: 0.43333333333333335.
[I 2025-04-24 06:53:35,580] Trial 1 finished with value: 0.31666666666666665 and parameters: {'lr_head': 2.4345423962016926e-05, 'lr_enc': 1.306673923805328e-06, 'weight_decay': 0.0008661761457749352, 'epochs': 40, 'pos_weight': 2.416145155592091}. Best is trial 0 with value: 0.43333333333333335.
[I 2025-04-24 06:53:59,442] Trial 2 finished with value: 0.44166666666666665 and parameters: {'lr_head': 1.1245798259119336e-05, 'lr_enc': 8.706020878304853e-05, 'weight_decay': 0.0008324426408004218, 'epochs': 20, 'pos_weight': 1.3636499344142012}. Best is trial 2 with value: 0.44166666666

Fold 1:  best val-AUC = 0.783


[I 2025-04-24 07:01:20,578] A new study created in memory with name: no-name-b25d57d1-84be-4dca-b699-a767c5f79589


   test-AUC = 0.474


[I 2025-04-24 07:02:03,373] Trial 0 finished with value: 0.3181818181818182 and parameters: {'lr_head': 8.468008575248323e-05, 'lr_enc': 7.969454818643937e-05, 'weight_decay': 0.0007319939418114051, 'epochs': 40, 'pos_weight': 1.312037280884873}. Best is trial 0 with value: 0.3181818181818182.
[I 2025-04-24 07:02:45,486] Trial 1 finished with value: 0.5378787878787878 and parameters: {'lr_head': 2.4345423962016926e-05, 'lr_enc': 1.306673923805328e-06, 'weight_decay': 0.0008661761457749352, 'epochs': 40, 'pos_weight': 2.416145155592091}. Best is trial 1 with value: 0.5378787878787878.
[I 2025-04-24 07:03:07,651] Trial 2 finished with value: 0.40909090909090906 and parameters: {'lr_head': 1.1245798259119336e-05, 'lr_enc': 8.706020878304853e-05, 'weight_decay': 0.0008324426408004218, 'epochs': 20, 'pos_weight': 1.3636499344142012}. Best is trial 1 with value: 0.5378787878787878.
[I 2025-04-24 07:03:41,239] Trial 3 finished with value: 0.3257575757575757 and parameters: {'lr_head': 2.84652

Fold 2:  best val-AUC = 0.742


[I 2025-04-24 07:12:46,078] A new study created in memory with name: no-name-6aa674d9-1bfc-46d4-817f-385533998409


   test-AUC = 0.509


[I 2025-04-24 07:13:29,253] Trial 0 finished with value: 0.4772727272727273 and parameters: {'lr_head': 8.468008575248323e-05, 'lr_enc': 7.969454818643937e-05, 'weight_decay': 0.0007319939418114051, 'epochs': 40, 'pos_weight': 1.312037280884873}. Best is trial 0 with value: 0.4772727272727273.
[I 2025-04-24 07:14:12,632] Trial 1 finished with value: 0.3560606060606061 and parameters: {'lr_head': 2.4345423962016926e-05, 'lr_enc': 1.306673923805328e-06, 'weight_decay': 0.0008661761457749352, 'epochs': 40, 'pos_weight': 2.416145155592091}. Best is trial 0 with value: 0.4772727272727273.
[I 2025-04-24 07:14:35,235] Trial 2 finished with value: 0.5757575757575757 and parameters: {'lr_head': 1.1245798259119336e-05, 'lr_enc': 8.706020878304853e-05, 'weight_decay': 0.0008324426408004218, 'epochs': 20, 'pos_weight': 1.3636499344142012}. Best is trial 2 with value: 0.5757575757575757.
[I 2025-04-24 07:15:09,967] Trial 3 finished with value: 0.46212121212121215 and parameters: {'lr_head': 2.84652

Fold 3:  best val-AUC = 0.712


[I 2025-04-24 07:24:26,038] A new study created in memory with name: no-name-c113f911-cc04-4f22-99e0-26dd4b1745eb


   test-AUC = 0.427


[I 2025-04-24 07:25:26,130] Trial 0 finished with value: 0.5757575757575758 and parameters: {'lr_head': 8.468008575248323e-05, 'lr_enc': 7.969454818643937e-05, 'weight_decay': 0.0007319939418114051, 'epochs': 40, 'pos_weight': 1.312037280884873}. Best is trial 0 with value: 0.5757575757575758.
[I 2025-04-24 07:26:09,247] Trial 1 finished with value: 0.5681818181818181 and parameters: {'lr_head': 2.4345423962016926e-05, 'lr_enc': 1.306673923805328e-06, 'weight_decay': 0.0008661761457749352, 'epochs': 40, 'pos_weight': 2.416145155592091}. Best is trial 0 with value: 0.5757575757575758.
[I 2025-04-24 07:26:30,928] Trial 2 finished with value: 0.34090909090909094 and parameters: {'lr_head': 1.1245798259119336e-05, 'lr_enc': 8.706020878304853e-05, 'weight_decay': 0.0008324426408004218, 'epochs': 20, 'pos_weight': 1.3636499344142012}. Best is trial 0 with value: 0.5757575757575758.
[I 2025-04-24 07:27:04,992] Trial 3 finished with value: 0.6060606060606061 and parameters: {'lr_head': 2.84652

Fold 4:  best val-AUC = 0.833


[I 2025-04-24 07:36:58,144] A new study created in memory with name: no-name-f0a714ee-0d15-455d-8325-43a8b75ae290


   test-AUC = 0.538


[I 2025-04-24 07:38:08,697] Trial 0 finished with value: 0.40909090909090906 and parameters: {'lr_head': 8.468008575248323e-05, 'lr_enc': 7.969454818643937e-05, 'weight_decay': 0.0007319939418114051, 'epochs': 40, 'pos_weight': 1.312037280884873}. Best is trial 0 with value: 0.40909090909090906.
[I 2025-04-24 07:38:51,522] Trial 1 finished with value: 0.3257575757575758 and parameters: {'lr_head': 2.4345423962016926e-05, 'lr_enc': 1.306673923805328e-06, 'weight_decay': 0.0008661761457749352, 'epochs': 40, 'pos_weight': 2.416145155592091}. Best is trial 0 with value: 0.40909090909090906.
[I 2025-04-24 07:39:19,532] Trial 2 finished with value: 0.5151515151515151 and parameters: {'lr_head': 1.1245798259119336e-05, 'lr_enc': 8.706020878304853e-05, 'weight_decay': 0.0008324426408004218, 'epochs': 20, 'pos_weight': 1.3636499344142012}. Best is trial 2 with value: 0.5151515151515151.
[I 2025-04-24 07:39:56,002] Trial 3 finished with value: 0.6666666666666666 and parameters: {'lr_head': 2.846

Fold 5:  best val-AUC = 0.902
   test-AUC = 0.474


In [7]:
print("\n⟪ Nested-CV summary ⟫")
for i,a in enumerate(outer_results,1):
    print(f"Fold {i}: AUC = {a:.3f}")
print(f"Mean ± SD  →  {np.mean(outer_results):.3f} ± {np.std(outer_results):.3f}")



⟪ Nested-CV summary ⟫
Fold 1: AUC = 0.474
Fold 2: AUC = 0.509
Fold 3: AUC = 0.427
Fold 4: AUC = 0.538
Fold 5: AUC = 0.474
Mean ± SD  →  0.484 ± 0.037
