In [1]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [5]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Final Training Script for Beta-VAE + Classifier (AD vs CN)
Using best hyperparameters found during sweep.

Author: Renzo (adapted from sweep script) - 2025-05-03
Requires:  ▸ pip install torch numpy scikit-learn lightgbm
"""
import os
import random
import time
import warnings
import math
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, roc_auc_score, confusion_matrix, classification_report
from sklearn.model_selection import train_test_split # Optional: for classifier tuning
import lightgbm as lgb
import joblib # For saving sklearn models

# ────────────────────────── Best Hyperparameters & Config ──────────────────────
BEST_HP = {
    'lr': 2.169848131729575e-05,
    'batch_size': 128,
    'beta': 50,
    'beta_ramp_epochs': 200,
    'latent_dim': 512,
    'k_filters': 48,
    'epochs': 4000 # User specified final training epochs
}

PROJECT_DIR = '/content/drive/MyDrive/GrandMeanNorm' # Make sure this path is correct
FOLD = 1
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

# Early stopping settings for final training
# Increase patience significantly or disable by setting patience >= epochs
EARLY_STOP_PATIENCE = 100
# Output directory for saved models
OUTPUT_DIR = os.path.join(PROJECT_DIR, f'final_model_fold_{FOLD}')
os.makedirs(OUTPUT_DIR, exist_ok=True)

# ──────────────────────────── Data Utilities ───────────────────────────────────
def load_fold(fold=1):
    """Loads preprocessed data for a specific fold."""
    fdir = os.path.join(PROJECT_DIR, f'fold_{fold}')
    try:
        load = lambda n: torch.load(os.path.join(fdir, n), map_location='cpu', weights_only=False)
        trX, vaX, teX = load('train_z.pt'), load('val_z.pt'), load('test_z.pt')
        trY_str, vaY_str, teY_str = load('train_labels.pt'), load('val_labels.pt'), load('test_labels.pt')
    except FileNotFoundError as e:
        print(f"Error loading data for fold {fold}: {e}")
        print(f"Please ensure the files train_z.pt, val_z.pt, test_z.pt, train_labels.pt, etc., exist in {fdir}")
        raise
    except Exception as e:
        print(f"An unexpected error occurred during data loading: {e}")
        raise

    # Binarize labels (AD=1, CN=0)
    def binarize(Y_str):
        # Ensure Y_str is a list or similar iterable of strings
        if not isinstance(Y_str, (list, tuple)):
             print(f"Warning: Labels are not in expected list/tuple format. Type: {type(Y_str)}. Attempting conversion.")
             # Add specific handling if needed, e.g., if it's a tensor or numpy array
             if hasattr(Y_str, 'tolist'): Y_str = Y_str.tolist()
             else: Y_str = list(Y_str) # Basic attempt

        return torch.tensor([1 if isinstance(s, str) and 'AD_' in s else 0 for s in Y_str], dtype=torch.long)

    trY, vaY, teY = binarize(trY_str), binarize(vaY_str), binarize(teY_str)

    print(f"Data loaded for Fold {fold}:")
    print(f"  Train X: {trX.shape}, Train Y: {trY.shape}")
    print(f"  Val   X: {vaX.shape}, Val   Y: {vaY.shape}")
    print(f"  Test  X: {teX.shape}, Test  Y: {teY.shape}")
    print(f"  Class distribution in Train Y: {torch.bincount(trY)}")
    print(f"  Class distribution in Val Y:   {torch.bincount(vaY)}")
    print(f"  Class distribution in Test Y:  {torch.bincount(teY)}")


    return trX, vaX, teX, trY, vaY, teY

# ─────────────────────────────── Beta-VAE Model ────────────────────────────────
# Identical to the sweep script
class Encoder(nn.Module):
    def __init__(self, c_in, latent, k=32):
        super().__init__()
        # Calculate expected flattened size dynamically if possible, otherwise use known value
        # Example dynamic calculation (assumes input H, W are known and consistent)
        # self.final_h = ... # Calculate based on conv layers
        # self.final_w = ...
        # linear_input_size = 4 * k * self.final_h * self.final_w
        #linear_input_size = 4 * k * 19 * 19 # Hardcoded based on original script - check if input size is always 160x160x1 -> 19x19 after convs
        #print(f"Encoder: Linear input size calculated as {linear_input_size}")

        self.conv = nn.Sequential(
            nn.Conv2d(c_in, k, kernel_size=4, stride=2, padding=1), # Adjusted padding for potentially different output size
            nn.ReLU(),
            nn.Conv2d(k, 2 * k, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(2 * k, 4 * k, kernel_size=3, stride=2, padding=1), # Adjusted padding
            nn.ReLU()
            # Add AdaptiveAvgPool2d if feature map size varies
            # nn.AdaptiveAvgPool2d((expected_h, expected_w))
        )
        # Dummy input to infer shape after convolutions
        with torch.no_grad():
            dummy = torch.zeros(1, c_in, 166, 166)  # Shape of your real input
            out = self.conv(dummy)
            flat_dim = out.view(1, -1).shape[1]
            print(f"[DEBUG] Flattened dim after convs: {flat_dim}")

        self.flat = nn.Flatten()
        self.mu = nn.Linear(flat_dim, latent)
        self.logvar = nn.Linear(flat_dim, latent)
        self.flat_dim = flat_dim
        self.reshape_dims = out.shape[1:]  # (C, H, W)

    def forward(self, x):
        h = self.conv(x)
        # print("Encoder conv output shape:", h.shape) # Debug print
        h_flat = self.flat(h)
        # print("Encoder flattened shape:", h_flat.shape) # Debug print
        return self.mu(h_flat), self.logvar(h_flat)


class Decoder(nn.Module):
    def __init__(self, latent, c_out, k=32, flat_dim=69312, reshape_dims=(192, 19, 19)):
        super().__init__()
        self.reshape_dims = reshape_dims
        self.fc = nn.Sequential(
            nn.Linear(latent, flat_dim),
            nn.ReLU()
        )
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(reshape_dims[0], 2 * k, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(2 * k, k, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(k, c_out, kernel_size=4, stride=2, padding=1),
        )


    def forward(self, z):
        h = self.fc(z)
        # print("Decoder fc output shape:", h.shape) # Debug print
        h_reshaped = h.view(-1, *self.reshape_dims)
        # print("Decoder reshaped shape:", h_reshaped.shape) # Debug print
        recon = self.deconv(h_reshaped)
        # print("Decoder deconv output shape:", recon.shape) # Debug print
        return recon


class BetaVAE(nn.Module):
    def __init__(self, c_in, latent, k=32):
        super().__init__()
        self.enc = Encoder(c_in, latent, k)
        self.dec = Decoder(latent, c_in, k, self.enc.flat_dim, self.enc.reshape_dims)


        #self.dec = Decoder(latent, c_in, k)

    def reparam(self, mu, logv):
        std = torch.exp(0.5 * logv)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        mu, logv = self.enc(x)
        z = self.reparam(mu, logv)
        return self.dec(z), mu, logv

def loss_vae(x_hat, x, mu, logv, beta, current_epoch, beta_ramp_epochs):
    """Calculates Beta-VAE loss with beta ramp-up."""
    # Ensure x_hat and x have the same shape
    if x_hat.shape != x.shape:
         # Attempt to resize x_hat to match x using interpolation
         x_hat = F.interpolate(x_hat, size=x.shape[2:], mode='bilinear', align_corners=False)
         # Or pad/crop if size difference is small and consistent
         # Or raise an error if the mismatch indicates a deeper problem
         # raise ValueError(f"Shape mismatch: x_hat {x_hat.shape}, x {x.shape}")

    # Reconstruction Loss (MSE per pixel, summed over pixels, averaged over batch)
    recon = F.mse_loss(x_hat, x, reduction='sum') / x.size(0)
    # KLD Loss (averaged over batch)
    kld = -0.5 * torch.sum(1 + logv - mu.pow(2) - logv.exp(), dim=1) # Sum over latent dims
    kld = torch.mean(kld) # Average over batch dimension

    # Beta Ramp-up
    beta_effective = min(beta, beta * (current_epoch / beta_ramp_epochs))

    total_loss = recon + beta_effective * kld
    return total_loss, recon, kld, beta_effective


# ────────────────────────── Latent Space Encoding ─────────────────────────────
def encode_latents(encoder, tensor, use_sigma, bs=128, device=DEVICE):
    """Encodes data into latent representations using the trained encoder."""
    encoder.eval() # Set encoder to evaluation mode
    dataset = TensorDataset(tensor)
    loader = DataLoader(dataset, batch_size=bs, shuffle=False, pin_memory=True)
    Z_list = []
    with torch.no_grad():
        for (x,) in loader:
            x = x.to(device)
            mu, logv = encoder(x)
            if use_sigma:
                sigma = torch.exp(0.5 * logv)
                z = torch.cat([mu, sigma], dim=1)
            else:
                z = mu
            Z_list.append(z.cpu().numpy())
    return np.vstack(Z_list)

# ────────────────────────── Classifier Evaluation ─────────────────────────────
def evaluate_classifiers(encoder, trX, trY, teX, teY, device=DEVICE, seed=SEED):
    """Trains and evaluates multiple classifiers on latent features."""
    print("\n--- Evaluating Classifiers on Latent Features ---")

    results = {}

    for use_sigma, suffix in [(False, "_mu"), (True, "_mu_sigma")]:
        print(f"\nEncoding latent features (use_sigma={use_sigma})...")
        Z_train = encode_latents(encoder, trX, use_sigma, device=device)
        Z_test = encode_latents(encoder, teX, use_sigma, device=device)
        y_train = trY.numpy()
        y_test = teY.numpy()

        print(f"Train features shape: {Z_train.shape}, Test features shape: {Z_test.shape}")
        print(f"Feature type: {'Mu + Sigma' if use_sigma else 'Mu only'}")

        classifiers = {
            "Logistic Regression": LogisticRegression(max_iter=5000, class_weight='balanced', random_state=seed, C=1.0, solver='liblinear'), # Increased max_iter, specify solver
            "SVM (RBF Kernel)": SVC(probability=True, class_weight='balanced', kernel='rbf', random_state=seed, C=1.0, gamma='scale'),
            "SVM (Linear Kernel)": SVC(probability=True, class_weight='balanced', kernel='linear', random_state=seed, C=1.0),
            "Random Forest": RandomForestClassifier(class_weight='balanced', n_estimators=200, random_state=seed, n_jobs=-1, max_depth=10, min_samples_leaf=5), # Added some regularization
            "LightGBM": lgb.LGBMClassifier(class_weight='balanced', random_state=seed, n_estimators=200, learning_rate=0.05, num_leaves=31) # Added common params
        }

        results[suffix] = {}

        for name, clf in classifiers.items():
            print(f"\nTraining {name}...")
            start_time = time.time()
            clf.fit(Z_train, y_train)
            train_time = time.time() - start_time
            print(f"Training finished in {train_time:.2f} seconds.")

            print(f"Evaluating {name}...")
            y_pred_proba = clf.predict_proba(Z_test)[:, 1]
            y_pred = clf.predict(Z_test) # Get class predictions

            acc = accuracy_score(y_test, y_pred)
            try:
                auc = roc_auc_score(y_test, y_pred_proba)
            except ValueError as e:
                 print(f"Warning: Could not calculate AUC for {name}. Error: {e}. Setting AUC to 0.0")
                 auc = 0.0 # Handle cases where AUC is undefined (e.g., only one class in y_test or y_pred_proba)


            print(f"Results for {name}{suffix}:")
            print(f"  Accuracy: {acc:.4f}")
            print(f"  AUC:      {auc:.4f}")
            print("\nConfusion Matrix:")
            print(confusion_matrix(y_test, y_pred))
            print("\nClassification Report:")
            print(classification_report(y_test, y_pred, target_names=['CN', 'AD']))


            results[suffix][name] = {'ACC': acc, 'AUC': auc, 'model': clf} # Store trained model too

            # Save the trained classifier
            clf_filename = os.path.join(OUTPUT_DIR, f"{name.replace(' ', '_').replace('(', '').replace(')', '')}{suffix}.joblib")
            try:
                joblib.dump(clf, clf_filename)
                print(f"Saved trained {name}{suffix} classifier to {clf_filename}")
            except Exception as e:
                 print(f"Error saving classifier {name}{suffix}: {e}")


    print("\n--- Classifier Evaluation Summary ---")
    for suffix, classifiers_results in results.items():
        print(f"\nFeature Type: {'Mu + Sigma' if suffix == '_mu_sigma' else 'Mu only'}")
        for name, metrics in classifiers_results.items():
            print(f"  {name}: Accuracy={metrics['ACC']:.4f}, AUC={metrics['AUC']:.4f}")

    return results


# ────────────────────────── VAE Training Function ─────────────────────────────
def train_vae(hp):
    """Trains the Beta-VAE model."""
    print("--- Starting VAE Training ---")
    print(f"Using device: {DEVICE}")
    print("Hyperparameters:")
    for key, val in hp.items():
        print(f"  {key}: {val}")

    # 1. Load Data
    trX, vaX, teX, trY, vaY, teY = load_fold(FOLD)
    C, H, W = trX.shape[1:] # Get channels, height, width from data
    print(f"Input data shape C,H,W: ({C}, {H}, {W})")


    # 2. Initialize Model, Optimizer
    vae = BetaVAE(C, hp['latent_dim'], hp['k_filters']).to(DEVICE)
    optimizer = optim.AdamW(vae.parameters(), lr=hp['lr'])
    scaler = torch.cuda.amp.GradScaler(enabled=(DEVICE.type == 'cuda')) # Enable AMP only for CUDA

    # 3. Create DataLoader (Train on Train+Val as in original sweep script)
    # Consider training only on trX if you want vaX purely for validation/early stopping
    allX = torch.cat([trX, vaX])
    allY = torch.cat([trY, vaY]) # Combine labels if needed for other tasks later
    train_dataset = TensorDataset(allX) # VAE is unsupervised
    train_loader = DataLoader(train_dataset, batch_size=hp['batch_size'], shuffle=True,
                              pin_memory=True, num_workers=os.cpu_count() // 2 if os.cpu_count() > 1 else 1) # Use available cores

    # Separate validation loader for early stopping based on reconstruction
    val_dataset = TensorDataset(vaX)
    val_loader = DataLoader(val_dataset, batch_size=hp['batch_size'], shuffle=False)


    print(f"Training VAE for {hp['epochs']} epochs...")
    best_val_loss = math.inf
    epochs_no_improve = 0
    best_state_dict = None # Store the best model state

    start_train_time = time.time()

    # 4. Training Loop
    for epoch in range(1, hp['epochs'] + 1):
        epoch_start_time = time.time()
        vae.train()
        total_loss, total_recon, total_kld = 0, 0, 0

        for (x_batch,) in train_loader:
            x_batch = x_batch.to(DEVICE)
            optimizer.zero_grad()

            with torch.cuda.amp.autocast(enabled=(DEVICE.type == 'cuda')):
                x_hat, mu, logv = vae(x_batch)
                loss, recon, kld, beta_eff = loss_vae(x_hat, x_batch, mu, logv,
                                                      hp['beta'], epoch, hp['beta_ramp_epochs'])

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            total_loss += loss.item()
            total_recon += recon.item()
            total_kld += kld.item()

        avg_loss = total_loss / len(train_loader)
        avg_recon = total_recon / len(train_loader)
        avg_kld = total_kld / len(train_loader)

        # 5. Validation Step (for Early Stopping)
        vae.eval()
        val_loss, val_recon, val_kld = 0, 0, 0
        with torch.no_grad():
            for (x_val_batch,) in val_loader:
                 x_val_batch = x_val_batch.to(DEVICE)
                 with torch.cuda.amp.autocast(enabled=(DEVICE.type == 'cuda')):
                     x_val_hat, mu_val, logv_val = vae(x_val_batch)
                     v_loss, v_recon, v_kld, _ = loss_vae(x_val_hat, x_val_batch, mu_val, logv_val,
                                                          hp['beta'], epoch, hp['beta_ramp_epochs']) # Use fixed beta for validation comparison
                 val_loss += v_loss.item()
                 val_recon += v_recon.item()
                 val_kld += v_kld.item()


        avg_val_loss = val_loss / len(val_loader)
        avg_val_recon = val_recon / len(val_loader)
        avg_val_kld = val_kld / len(val_loader)

        epoch_time = time.time() - epoch_start_time

        print(f"Epoch {epoch}/{hp['epochs']} [{epoch_time:.2f}s] - "
              f"Train Loss: {avg_loss:.2f} (Recon: {avg_recon:.2f}, KLD: {avg_kld:.2f}, Beta: {beta_eff:.2f}) | "
              f"Val Loss: {avg_val_loss:.2f} (Recon: {avg_val_recon:.2f}, KLD: {avg_val_kld:.2f})")

        # 6. Early Stopping Check & Save Best Model
        if avg_val_loss < best_val_loss - 1e-4: # Added tolerance
            best_val_loss = avg_val_loss
            epochs_no_improve = 0
            # Save the state dict of the best model found so far
            best_state_dict = {k: v.clone().cpu() for k, v in vae.state_dict().items()} # Save to CPU
            print(f"  ** New best validation loss: {best_val_loss:.4f}. Saving model state. **")
        else:
            epochs_no_improve += 1
            print(f"  Validation loss did not improve for {epochs_no_improve} epoch(s).")

        if epochs_no_improve >= EARLY_STOP_PATIENCE:
            print(f"\nEarly stopping triggered after {epoch} epochs due to no improvement in validation loss for {EARLY_STOP_PATIENCE} epochs.")
            break

    total_train_time = time.time() - start_train_time
    print(f"\n--- VAE Training Finished ---")
    print(f"Total training time: {total_train_time:.2f} seconds")

    # 7. Load Best Model State
    if best_state_dict:
        print(f"Loading best model state found at epoch {epoch - epochs_no_improve} with validation loss: {best_val_loss:.4f}")
        vae.load_state_dict(best_state_dict)
        vae = vae.to(DEVICE) # Ensure model is on the correct device after loading
    else:
        print("Warning: No best model state was saved (perhaps training stopped early or validation loss never improved). Using the final model state.")


    # 8. Save the trained VAE Encoder
    encoder_save_path = os.path.join(OUTPUT_DIR, 'vae_encoder_final.pt')
    try:
        torch.save(vae.enc.state_dict(), encoder_save_path)
        print(f"Saved trained VAE encoder state dictionary to {encoder_save_path}")
    except Exception as e:
        print(f"Error saving VAE encoder: {e}")


    # 9. Return necessary components for classification
    return vae.enc, trX, trY, teX, teY


# ────────────────────────────── Main Execution ────────────────────────────────
if __name__ == "__main__":
    warnings.filterwarnings('ignore', category=FutureWarning) # Suppress some sklearn/numpy warnings
    warnings.filterwarnings('ignore', category=UserWarning) # Suppress other common warnings

    # Train the VAE
    final_encoder, train_X, train_Y, test_X, test_Y = train_vae(BEST_HP)

    # Ensure the encoder is on the correct device for evaluation
    final_encoder = final_encoder.to(DEVICE)

    # Evaluate classifiers on the latent space
    classification_results = evaluate_classifiers(final_encoder, train_X, train_Y, test_X, test_Y, device=DEVICE, seed=SEED)

    print("\n--- Final Run Finished ---")



[1;30;43mSe truncaron las últimas líneas 5000 del resultado de transmisión.[0m
  Validation loss did not improve for 28 epoch(s).
Epoch 1765/4000 [0.72s] - Train Loss: 75306.75 (Recon: 69980.39, KLD: 106.53, Beta: 50.00) | Val Loss: 74861.84 (Recon: 69478.05, KLD: 107.68)
  Validation loss did not improve for 29 epoch(s).
Epoch 1766/4000 [0.76s] - Train Loss: 75367.89 (Recon: 69959.64, KLD: 108.16, Beta: 50.00) | Val Loss: 75047.98 (Recon: 69577.55, KLD: 109.41)
  Validation loss did not improve for 30 epoch(s).
Epoch 1767/4000 [0.71s] - Train Loss: 75299.74 (Recon: 69799.59, KLD: 110.00, Beta: 50.00) | Val Loss: 75178.19 (Recon: 69782.30, KLD: 107.92)
  Validation loss did not improve for 31 epoch(s).
Epoch 1768/4000 [0.71s] - Train Loss: 75325.19 (Recon: 69913.61, KLD: 108.23, Beta: 50.00) | Val Loss: 75001.81 (Recon: 69580.44, KLD: 108.43)
  Validation loss did not improve for 32 epoch(s).
Epoch 1769/4000 [0.62s] - Train Loss: 75470.27 (Recon: 70019.67, KLD: 109.01, Beta: 50.00) |