In [1]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torch.utils.tensorboard import SummaryWriter
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import roc_auc_score, accuracy_score, f1_score, confusion_matrix
from pathlib import Path
import logging
import os
import gc
from typing import Tuple, Optional, Dict, List
import shutil
from tqdm import tqdm
import random
import math 

# --- 0. Global Configuration and Constants ---
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(module)s - %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)

# --- Script Configuration ---
PREPROCESSED_DATA_BASE_DIR = Path('./preprocessed_connectomes_for_dl')
BASE_OUTPUT_DIR = Path('./training_results_vae_classifier')
RUN_NAME = "vae_connectome_clf_run_7_quick_wins" 

N_FOLDS = 51
RANDOM_STATE = 42
LABEL_MAPPING = {'CN': 0, 'MCI': 1, 'LMCI': 1, 'EMCI': 1, 'AD': 2}
CLASSIFIER_LABEL_MAPPING = {0: 0, 2: 1} # CN:0, AD:1 for classifier

# VAE Hyperparameters
VAE_LATENT_DIM = 256
VAE_BETA_MAX = 0.8 
VAE_BETA_ANNEAL_EPOCHS = 50 # Linear ramp-up, then fixed
VAE_LR = 1e-4
VAE_WEIGHT_DECAY = 1e-6 
VAE_EPOCHS = 150 
VAE_BATCH_SIZE = 32
VAE_EARLY_STOPPING_PATIENCE = 15 
VAE_ENCODER_DROPOUT_RATE = 0.0 # Dropout removed from VAE encoder

# Classifier Hyperparameters
CLASSIFIER_INPUT_DIM_MULTIPLIER = 1 # Using only mu
CLASSIFIER_HIDDEN_DIMS = [128, 64, 32] 
CLASSIFIER_DROPOUT = 0.2 
CLASSIFIER_LR = 1e-4
CLASSIFIER_EPOCHS = 150 
CLASSIFIER_FINAL_RETRAIN_EPOCHS = 30 
CLASSIFIER_BATCH_SIZE = 32
CLASSIFIER_EARLY_STOPPING_PATIENCE = 10 
# CLASSIFIER_CLASS_WEIGHTS will be calculated dynamically per fold
FOCAL_LOSS_GAMMA = 2.0 

# Scheduler T0 for CosineAnnealingWarmRestarts
VAE_SCHEDULER_T_0 = VAE_EPOCHS // 4 # Approx 4 cycles, can be tuned
CLASSIFIER_SCHEDULER_T_0 = 10 # Short cycle for classifier

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
USE_AMP = torch.cuda.is_available()
NUM_WORKERS = os.cpu_count() // 2 if os.cpu_count() else 1

# --- Reproducibility Utilities ---
def seed_everything(seed: int):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    logger.info(f"Seeded everything with seed {seed}")

def worker_init_fn(worker_id: int):
    worker_seed = RANDOM_STATE + worker_id
    np.random.seed(worker_seed)
    random.seed(worker_seed)

# --- EarlyStopping Class ---
class EarlyStopping:
    def __init__(self, patience: int = 7, verbose: bool = False, delta: float = 0, path: str = 'checkpoint.pt', trace_func=logger.info, mode: str = 'min'):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_metric_min_delta = delta
        self.path = Path(path)
        self.trace_func = trace_func
        self.mode = mode.lower()
        if self.mode not in ['min', 'max']:
            raise ValueError("Mode should be 'min' or 'max'.")
        if self.mode == 'min':
            self.best_score = np.Inf
        else:
            self.best_score = -np.Inf

    def __call__(self, current_metric_val, model):
        score_improved = False
        if self.mode == 'min':
            if current_metric_val < self.best_score - self.val_metric_min_delta:
                self.best_score = current_metric_val
                score_improved = True
        else: # mode == 'max'
            if current_metric_val > self.best_score + self.val_metric_min_delta:
                self.best_score = current_metric_val
                score_improved = True

        if score_improved:
            self.save_checkpoint(current_metric_val, model)
            self.counter = 0
        else:
            self.counter += 1
            if self.verbose:
                self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience} (Best: {self.best_score:.6f}|Current: {current_metric_val:.6f})')
            if self.counter >= self.patience:
                self.early_stop = True

    def save_checkpoint(self, val_metric, model):
        if self.verbose:
            self.trace_func(f'Validation metric improved ({self.best_score:.6f} --> {val_metric:.6f}). Saving model to {self.path} ...')
        torch.save(model.state_dict(), self.path)

# --- Focal Loss Class ---
class FocalLoss(nn.Module):
    def __init__(self, alpha: Optional[torch.Tensor] = None, gamma: float = 2.0, reduction: str = 'mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        ce_loss = nn.functional.cross_entropy(logits, targets, reduction='none', weight=self.alpha)
        pt = torch.exp(-ce_loss)
        focal_loss = (1 - pt)**self.gamma * ce_loss

        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else: 
            return focal_loss

# --- 1. ConnectomeDataset class ---
class ConnectomeDataset(Dataset):
    def __init__(self, pt_file: str, split: str = 'train', subject_ids_to_load: Optional[np.ndarray] = None):
        logger.info(f"Loading data from {pt_file} for {split} split.")
        try:
            data = torch.load(pt_file, map_location='cpu') 
            if split == 'test':
                required_keys = [f'X_{split}', f'y_{split}', f'{split}_subject_ids']
                if not all(key in data for key in required_keys):
                    logger.warning(f"Test split keys not found in {pt_file}. Test set will be empty.")
                    self.X_all = torch.empty(0,4,116,116) # Ensure correct empty tensor shape
                    self.y_all = torch.empty(0)
                    self.sids_all = np.array([])
                else:
                    self.X_all = data[f'X_{split}']
                    self.y_all = data[f'y_{split}']
                    self.sids_all = data[f'{split}_subject_ids']
            else:
                self.X_all = data[f'X_{split}']
                self.y_all = data[f'y_{split}']
                self.sids_all = data[f'{split}_subject_ids']

        except FileNotFoundError:
            logger.error(f"File not found: {pt_file}")
            raise
        except KeyError as e:
            if not (split == 'test' and not all(key in data for key in [f'X_{split}', f'y_{split}', f'{split}_subject_ids'])):
                 logger.error(f"Key error {e} in file {pt_file}. Ensure X_{split}, y_{split}, {split}_subject_ids exist.")
                 raise
            
        if subject_ids_to_load is not None and len(self.sids_all) > 0:
            logger.info(f"Filtering dataset for {len(subject_ids_to_load)} specific subject IDs.")
            indices_to_keep = np.isin(self.sids_all, subject_ids_to_load)
            self.X = self.X_all[indices_to_keep]
            self.y = self.y_all[indices_to_keep]
            self.sids = self.sids_all[indices_to_keep]
            if len(self.X) != len(subject_ids_to_load):
                logger.warning(f"Could not find all requested subject IDs. Found {len(self.X)} out of {len(subject_ids_to_load)}.")
        else:
            self.X = self.X_all
            self.y = self.y_all
            self.sids = self.sids_all
        logger.info(f"Loaded {len(self.X)} samples for {split} split.")

    def __len__(self):
        return self.X.shape[0] if isinstance(self.X, torch.Tensor) and self.X.nelement() > 0 else 0


    def __getitem__(self, idx):
        return self.X[idx], self.y[idx], self.sids[idx]

class LatentFeatureDataset(Dataset):
    def __init__(self, features: torch.Tensor, labels: torch.Tensor, subject_ids: np.ndarray):
        self.features = features
        self.labels = labels
        self.subject_ids = subject_ids

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        return self.features[idx], self.labels[idx], self.subject_ids[idx]

# --- 2. VAE Model (ConvVAE for Connectomes) ---
class ConvVAE(nn.Module):
    def __init__(self, input_channels=4, latent_dim=VAE_LATENT_DIM, dropout_rate=VAE_ENCODER_DROPOUT_RATE):
        super(ConvVAE, self).__init__()
        self.latent_dim = latent_dim

        self.encoder_conv = nn.Sequential(
            nn.Conv2d(input_channels, 32, kernel_size=5, stride=2, padding=2),
            nn.GroupNorm(8, 32), 
            nn.ReLU(),
            nn.Dropout2d(p=dropout_rate) if dropout_rate > 0 else nn.Identity(),
            nn.Conv2d(32, 64, kernel_size=5, stride=2, padding=2),
            nn.GroupNorm(16, 64), 
            nn.ReLU(),
            nn.Dropout2d(p=dropout_rate) if dropout_rate > 0 else nn.Identity(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.GroupNorm(32, 128), 
            nn.ReLU(),
            nn.Dropout2d(p=dropout_rate) if dropout_rate > 0 else nn.Identity(),
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.GroupNorm(64, 256), 
            nn.ReLU(),
            nn.Dropout2d(p=dropout_rate) if dropout_rate > 0 else nn.Identity()
        )
        self.flatten_size = 256 * 8 * 8
        self.fc_mu = nn.Linear(self.flatten_size, latent_dim)
        self.fc_logvar = nn.Linear(self.flatten_size, latent_dim)

        self.decoder_fc = nn.Linear(latent_dim, self.flatten_size)
        self.decoder_conv = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=0),
            nn.GroupNorm(32, 128), 
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=0),
            nn.GroupNorm(16, 64), 
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=5, stride=2, padding=2, output_padding=1),
            nn.GroupNorm(8, 32), 
            nn.ReLU(),
            nn.ConvTranspose2d(32, input_channels, kernel_size=5, stride=2, padding=2, output_padding=1),
        )

    def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        x = self.encoder_conv(x)
        x = x.view(x.size(0), -1)
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar

    def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z: torch.Tensor) -> torch.Tensor:
        x = self.decoder_fc(z)
        x = x.view(x.size(0), 256, 8, 8)
        x_recon = self.decoder_conv(x)
        return x_recon

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        x_recon = self.decode(z)
        return x_recon, mu, logvar, z

# --- 3. VAE Loss Function ---
def vae_loss_function(recon_x: torch.Tensor, x: torch.Tensor, mu: torch.Tensor, logvar: torch.Tensor, beta: float) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    MSE = nn.functional.mse_loss(recon_x, x, reduction='mean') 
    KLD = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()).mean()
    return MSE + beta * KLD, MSE, KLD

# --- 4. Classifier Model (SimpleMLP) ---
class SimpleMLP(nn.Module):
    def __init__(self, input_dim: int, hidden_dims: List[int], output_dim: int, dropout_rate: float):
        super(SimpleMLP, self).__init__()
        layers = []
        prev_dim = input_dim
        for i, h_dim in enumerate(hidden_dims):
            layers.append(torch.nn.utils.weight_norm(nn.Linear(prev_dim, h_dim)))
            layers.append(nn.BatchNorm1d(h_dim))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(dropout_rate))
            prev_dim = h_dim
        layers.append(torch.nn.utils.weight_norm(nn.Linear(prev_dim, output_dim)))
        self.network = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.network(x)

# --- 5. Training and Evaluation Utilities ---
def train_vae_epoch(model: ConvVAE, loader: DataLoader, optimizer: optim.Optimizer, device: torch.device, current_beta: float, epoch_num: int, total_epochs: int):
    model.train()
    total_loss = 0
    total_mse = 0
    total_kld = 0
    scaler = torch.amp.GradScaler(enabled=(USE_AMP and device.type == 'cuda'))
    progress_bar = tqdm(loader, desc=f"VAE Train Epoch {epoch_num+1}/{total_epochs} (β={current_beta:.3f})", leave=False)
    for batch_idx, (data, _, _) in enumerate(progress_bar):
        data = data.to(device)
        optimizer.zero_grad(set_to_none=True)
        with torch.amp.autocast(device_type=device.type, enabled=(USE_AMP and device.type == 'cuda')):
            recon_batch, mu, logvar, _ = model(data)
            loss, loss_mse, loss_kld = vae_loss_function(recon_batch, data, mu, logvar, current_beta)
        if USE_AMP and device.type == 'cuda':
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()
        total_loss += loss.item()
        total_mse += loss_mse.item()
        total_kld += loss_kld.item()
        progress_bar.set_postfix({'loss': loss.item(), 'mse': loss_mse.item(), 'kld': loss_kld.item()})
    avg_loss = total_loss / len(loader)
    avg_mse = total_mse / len(loader)
    avg_kld = total_kld / len(loader)
    return avg_loss, avg_mse, avg_kld

def evaluate_vae_epoch(model: ConvVAE, loader: DataLoader, device: torch.device, current_beta: float, epoch_num: int, total_epochs: int):
    model.eval()
    total_loss = 0
    total_mse = 0
    total_kld = 0
    progress_bar = tqdm(loader, desc=f"VAE Eval Epoch {epoch_num+1}/{total_epochs}", leave=False)
    with torch.no_grad():
        for data, _, _ in progress_bar:
            data = data.to(device)
            with torch.amp.autocast(device_type=device.type, enabled=(USE_AMP and device.type == 'cuda')):
                recon_batch, mu, logvar, _ = model(data)
                loss, loss_mse, loss_kld = vae_loss_function(recon_batch, data, mu, logvar, current_beta)
            total_loss += loss.item()
            total_mse += loss_mse.item()
            total_kld += loss_kld.item()
            progress_bar.set_postfix({'loss': loss.item(), 'mse': loss_mse.item(), 'kld': loss_kld.item()})
    avg_loss = total_loss / len(loader)
    avg_mse = total_mse / len(loader)
    avg_kld = total_kld / len(loader)
    return avg_loss, avg_mse, avg_kld

def extract_latent_features(vae_model: ConvVAE, data_loader: DataLoader, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, np.ndarray]:
    vae_model.eval()
    all_mu = []
    all_logvar = [] # Keep for potential future use, though classifier only uses mu now
    all_labels = []
    all_sids = []
    if len(data_loader.dataset) == 0: 
        logger.warning("extract_latent_features called with an empty DataLoader.")
        return torch.empty(0, VAE_LATENT_DIM), torch.empty(0, VAE_LATENT_DIM), torch.empty(0), np.array([])

    with torch.no_grad():
        for data, labels, sids in tqdm(data_loader, desc="Extracting Latent Features", leave=False):
            data = data.to(device)
            with torch.amp.autocast(device_type=device.type, enabled=(USE_AMP and device.type == 'cuda')):
                mu, logvar = vae_model.encode(data)
            all_mu.append(mu.cpu())
            all_logvar.append(logvar.cpu())
            all_labels.append(labels.cpu())
            all_sids.extend(list(sids)) 
    return torch.cat(all_mu), torch.cat(all_logvar), torch.cat(all_labels), np.array(all_sids)

def train_classifier_epoch(model: SimpleMLP, loader: DataLoader, optimizer: optim.Optimizer, criterion, device: torch.device, epoch_num: int, total_epochs: int, use_scheduler: bool, scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None):
    model.train()
    total_loss = 0
    correct_predictions = 0
    total_samples = 0
    scaler = torch.amp.GradScaler(enabled=(USE_AMP and device.type == 'cuda'))
    progress_bar = tqdm(loader, desc=f"CLF Train Epoch {epoch_num+1}/{total_epochs}", leave=False)
    for features, labels, _ in progress_bar:
        features, labels = features.to(device), labels.to(device)
        optimizer.zero_grad(set_to_none=True)
        with torch.amp.autocast(device_type=device.type, enabled=(USE_AMP and device.type == 'cuda')):
            outputs = model(features)
            loss = criterion(outputs, labels)
        if USE_AMP and device.type == 'cuda':
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()
        
        if use_scheduler and scheduler is not None and not isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau):
             scheduler.step()

        total_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total_samples += labels.size(0)
        correct_predictions += (predicted == labels).sum().item()
        progress_bar.set_postfix({'loss': loss.item(), 'acc': (predicted == labels).sum().item()/labels.size(0)})
    avg_loss = total_loss / len(loader) if len(loader) > 0 else 0
    accuracy = correct_predictions / total_samples if total_samples > 0 else 0
    return avg_loss, accuracy

def evaluate_classifier(model: SimpleMLP, loader: DataLoader, criterion, device: torch.device, desc_prefix: str = "Evaluating") -> Dict:
    model.eval()
    total_loss = 0
    all_labels = []
    all_predictions = []
    all_probabilities = [] 
    if len(loader.dataset) == 0:
        logger.warning(f"{desc_prefix} Classifier called with an empty DataLoader.")
        return {"loss": 0, "accuracy": 0, "auc": 0, "f1": 0, "labels": [], "predictions": []}

    with torch.no_grad():
        for features, labels, _ in tqdm(loader, desc=f"{desc_prefix} Classifier", leave=False):
            features, labels = features.to(device), labels.to(device)
            with torch.amp.autocast(device_type=device.type, enabled=(USE_AMP and device.type == 'cuda')):
                outputs = model(features)
                loss = criterion(outputs, labels)
            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(predicted.cpu().numpy())
            all_probabilities.extend(torch.softmax(outputs, dim=1).cpu().numpy())
    avg_loss = total_loss / len(loader) if len(loader) > 0 else 0
    accuracy = accuracy_score(all_labels, all_predictions) if len(all_labels) > 0 else 0
    auc = 0.0
    if len(np.unique(all_labels)) < 2:
        if len(all_labels) > 0: logger.warning(f"{desc_prefix} AUC calculation skipped: only one class present in evaluation.")
    elif len(all_probabilities) == 0:
         logger.warning(f"{desc_prefix} AUC calculation skipped: no probabilities to evaluate.")
    else:
        try:
            auc = roc_auc_score(all_labels, np.array(all_probabilities)[:, 1])
        except ValueError as e:
            logger.warning(f"{desc_prefix} Could not calculate AUC: {e}.")
            auc = 0.0
    f1 = f1_score(all_labels, all_predictions, average='binary', zero_division=0) if len(all_labels) > 0 else 0
    return {"loss": avg_loss, "accuracy": accuracy, "auc": auc, "f1": f1, "labels": all_labels, "predictions": all_predictions}

# --- 6. Main Script Logic ---
def main():
    seed_everything(RANDOM_STATE)
    logger.info(f"Device: {DEVICE}")
    logger.info(f"Using AMP: {USE_AMP}")
    logger.info(f"Number of workers for DataLoader: {NUM_WORKERS}")

    run_output_dir = BASE_OUTPUT_DIR / RUN_NAME
    run_output_dir.mkdir(parents=True, exist_ok=True)
    logger.info(f"Outputs will be saved to: {run_output_dir}")

    config_summary = {k: v for k, v in globals().items() if k.isupper() and isinstance(v, (int, float, str, list, tuple, bool, dict))}
    config_summary["DEVICE"] = str(DEVICE)
    with open(run_output_dir / "config_summary.txt", 'w') as f:
        for key, value in config_summary.items():
            f.write(f"{key}: {value}\n")
    logger.info(f"Saved run configuration to {run_output_dir / 'config_summary.txt'}")
    logger.warning("Stratification by sex/age should be handled during the preprocessing step that creates the .pt files.")

    overall_fold_val_metrics = []
    overall_fold_test_metrics = []

    for fold_idx in range(N_FOLDS):
        logger.info(f"--- Processing Fold {fold_idx + 1}/{N_FOLDS} ---")
        fold_output_dir = run_output_dir / f"fold_{fold_idx}"
        fold_output_dir.mkdir(parents=True, exist_ok=True)
        tb_writer_fold = SummaryWriter(log_dir=str(fold_output_dir / "tensorboard"))

        logger.info("--- VAE Training Phase ---")
        preprocessed_fold_file = PREPROCESSED_DATA_BASE_DIR / f"fold_{fold_idx}" / f"fold_{fold_idx}_preprocessed_data.pt"
        if not preprocessed_fold_file.exists():
            logger.error(f"Preprocessed data for fold {fold_idx} not found at {preprocessed_fold_file}. Skipping fold.")
            continue

        vae_train_dataset = ConnectomeDataset(str(preprocessed_fold_file), 'train')
        vae_val_dataset = ConnectomeDataset(str(preprocessed_fold_file), 'val')
        _worker_init_fn = worker_init_fn if NUM_WORKERS > 0 else None
        vae_train_loader = DataLoader(vae_train_dataset, batch_size=VAE_BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True, worker_init_fn=_worker_init_fn)
        vae_val_loader = DataLoader(vae_val_dataset, batch_size=VAE_BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True, worker_init_fn=_worker_init_fn)

        vae_model = ConvVAE(input_channels=4, latent_dim=VAE_LATENT_DIM, dropout_rate=VAE_ENCODER_DROPOUT_RATE).to(DEVICE)
        vae_optimizer = optim.AdamW(vae_model.parameters(), lr=VAE_LR, weight_decay=VAE_WEIGHT_DECAY)
        vae_scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(vae_optimizer, T_0=VAE_SCHEDULER_T_0, T_mult=1, eta_min=1e-7, verbose=False)
        
        vae_model_path = fold_output_dir / "best_vae_model.pt"
        vae_early_stopper = EarlyStopping(patience=VAE_EARLY_STOPPING_PATIENCE, verbose=True, path=vae_model_path, mode='min') # Monitor val_mse

        for epoch in range(VAE_EPOCHS):
            # Linear ramp-up for beta then fixed
            if epoch < VAE_BETA_ANNEAL_EPOCHS:
                current_beta = VAE_BETA_MAX * (epoch / VAE_BETA_ANNEAL_EPOCHS)
            else:
                current_beta = VAE_BETA_MAX
            
            train_loss, train_mse, train_kld = train_vae_epoch(vae_model, vae_train_loader, vae_optimizer, DEVICE, current_beta, epoch, VAE_EPOCHS)
            val_loss, val_mse, val_kld = evaluate_vae_epoch(vae_model, vae_val_loader, DEVICE, current_beta, epoch, VAE_EPOCHS)
            logger.info(f"Fold {fold_idx+1} VAE Epoch {epoch+1}/{VAE_EPOCHS} | β: {current_beta:.3f} | Train Loss: {train_loss:.4f} (MSE: {train_mse:.4f}, KLD: {train_kld:.4f}) | Val Loss: {val_loss:.4f} (MSE: {val_mse:.4f}, KLD: {val_kld:.4f})")
            tb_writer_fold.add_scalar("VAE/Train_Loss", train_loss, epoch)
            tb_writer_fold.add_scalar("VAE/Train_MSE", train_mse, epoch)
            tb_writer_fold.add_scalar("VAE/Train_KLD", train_kld, epoch)
            tb_writer_fold.add_scalar("VAE/Val_Loss", val_loss, epoch)
            tb_writer_fold.add_scalar("VAE/Val_MSE", val_mse, epoch) 
            tb_writer_fold.add_scalar("VAE/Val_KLD", val_kld, epoch)
            tb_writer_fold.add_scalar("VAE/Beta", current_beta, epoch)
            tb_writer_fold.add_scalar("VAE/Learning_Rate", vae_optimizer.param_groups[0]['lr'], epoch)
            
            vae_scheduler.step() 
            vae_early_stopper(val_mse, vae_model) 
            if vae_early_stopper.early_stop:
                logger.info("Early stopping VAE training.")
                break
        
        vae_model.load_state_dict(torch.load(vae_model_path, map_location=DEVICE, weights_only=True))
        logger.info(f"Loaded best VAE model from {vae_model_path}")

        logger.info("--- Extracting Latent Features (mu only) for Train, Val, Test ---")
        # Prepare extraction loaders
        extract_train_loader = DataLoader(
            vae_train_dataset,
            batch_size=VAE_BATCH_SIZE * 2,
            shuffle=False,
            num_workers=NUM_WORKERS,
            pin_memory=True,
            worker_init_fn=_worker_init_fn
        )
        extract_val_loader = DataLoader(
            vae_val_dataset,
            batch_size=VAE_BATCH_SIZE * 2,
            shuffle=False,
            num_workers=NUM_WORKERS,
            pin_memory=True,
            worker_init_fn=_worker_init_fn
        )

        train_mu, _, train_orig_labels, train_sids = extract_latent_features(
            vae_model, extract_train_loader, DEVICE
        )
        val_mu,   _, val_orig_labels,   val_sids   = extract_latent_features(
            vae_model, extract_val_loader,   DEVICE
        )

        
        vae_test_dataset = ConnectomeDataset(str(preprocessed_fold_file), 'test')
        test_mu, _, test_orig_labels, test_sids = torch.empty(0, VAE_LATENT_DIM), torch.empty(0, VAE_LATENT_DIM), torch.empty(0), np.array([])
        if len(vae_test_dataset) > 0:
            extract_test_loader = DataLoader(vae_test_dataset, batch_size=VAE_BATCH_SIZE*2, shuffle=False, num_workers=NUM_WORKERS, worker_init_fn=_worker_init_fn)
            test_mu, _, test_orig_labels, test_sids = extract_latent_features(vae_model, extract_test_loader, DEVICE)
        else:
            logger.warning(f"Fold {fold_idx+1}: Test dataset is empty. Skipping test set evaluation for this fold.")

        # Use only mu for classifier features
        train_clf_latent_features = train_mu
        val_clf_latent_features = val_mu
        test_clf_latent_features = test_mu if test_mu.nelement() > 0 else torch.empty(0, VAE_LATENT_DIM)


        logger.info("--- Initial Classifier Training Phase (CN vs AD) on Train, Validate on Val ---")
        train_cn_ad_mask = (train_orig_labels == LABEL_MAPPING['CN']) | (train_orig_labels == LABEL_MAPPING['AD'])
        clf_train_features_filtered = train_clf_latent_features[train_cn_ad_mask]
        clf_train_sids_initial = train_sids[train_cn_ad_mask]
        clf_train_labels_initial = torch.tensor([CLASSIFIER_LABEL_MAPPING[l.item()] for l in train_orig_labels[train_cn_ad_mask]], dtype=torch.long)
        
        val_cn_ad_mask = (val_orig_labels == LABEL_MAPPING['CN']) | (val_orig_labels == LABEL_MAPPING['AD'])
        clf_val_features_filtered = val_clf_latent_features[val_cn_ad_mask]
        clf_val_sids_initial = val_sids[val_cn_ad_mask]
        clf_val_labels_initial = torch.tensor([CLASSIFIER_LABEL_MAPPING[l.item()] for l in val_orig_labels[val_cn_ad_mask]], dtype=torch.long)

        if len(clf_train_features_filtered) == 0 or len(clf_val_features_filtered) == 0:
            logger.warning(f"Fold {fold_idx+1}: Not enough CN/AD samples for initial classifier training. Skipping classifier steps for this fold.")
            overall_fold_test_metrics.append({"accuracy": 0, "auc": 0, "f1": 0}) 
            overall_fold_val_metrics.append({"accuracy": 0, "auc": 0, "f1": 0}) 
            tb_writer_fold.close()
            gc.collect()
            if DEVICE.type == 'cuda': torch.cuda.empty_cache()
            continue 
        
        clf_train_dataset_initial = LatentFeatureDataset(clf_train_features_filtered, clf_train_labels_initial, clf_train_sids_initial)
        clf_val_dataset_initial = LatentFeatureDataset(clf_val_features_filtered, clf_val_labels_initial, clf_val_sids_initial)
        
        # WeightedRandomSampler for initial classifier training
        class_counts = np.bincount(clf_train_labels_initial.numpy())
        if len(class_counts) < 2 : # Handle cases where one class might be missing after filtering
            logger.warning(f"Fold {fold_idx+1}: Only one class present in initial classifier training data. Sampler not used.")
            sampler = None
            # Calculate dynamic class weights for FocalLoss, or use default if issues
            if len(class_counts) == 1 and class_counts[0] == len(clf_train_labels_initial): # Only one class
                 class_weights_tensor_alpha = None # No weights if only one class
                 logger.warning("Only one class for classifier, FocalLoss alpha will be None.")
            else: # Should not happen if bincount < 2 but not 1. Defaulting.
                 class_weights_tensor_alpha = torch.tensor(CLASSIFIER_CLASS_WEIGHTS, dtype=torch.float).to(DEVICE)
                 logger.warning("Issue with class counts for FocalLoss, using default weights.")

        else:
            weights_sampler = 1. / torch.tensor(class_counts, dtype=torch.float)
            sample_weights = weights_sampler[clf_train_labels_initial]
            sampler = WeightedRandomSampler(sample_weights, len(sample_weights), replacement=True)
            
            # Dynamic class weights for FocalLoss alpha
            class_weights_tensor_alpha = (len(clf_train_labels_initial) / (2 * torch.tensor(class_counts, dtype=torch.float))).to(DEVICE)
            logger.info(f"Using DYNAMIC FocalLoss alpha: {class_weights_tensor_alpha.cpu().numpy()}")


        clf_train_loader_initial = DataLoader(clf_train_dataset_initial, batch_size=CLASSIFIER_BATCH_SIZE, sampler=sampler, num_workers=NUM_WORKERS, worker_init_fn=_worker_init_fn) # shuffle=False if sampler is used
        clf_val_loader_initial = DataLoader(clf_val_dataset_initial, batch_size=CLASSIFIER_BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, worker_init_fn=_worker_init_fn)

        initial_classifier_model = SimpleMLP(
            input_dim=VAE_LATENT_DIM * CLASSIFIER_INPUT_DIM_MULTIPLIER,
            hidden_dims=CLASSIFIER_HIDDEN_DIMS, output_dim=2, dropout_rate=CLASSIFIER_DROPOUT
        ).to(DEVICE)
        
        clf_criterion_initial = FocalLoss(alpha=class_weights_tensor_alpha, gamma=FOCAL_LOSS_GAMMA)
        
        clf_optimizer_initial = optim.AdamW(initial_classifier_model.parameters(), lr=CLASSIFIER_LR)
        clf_scheduler_initial = optim.lr_scheduler.CosineAnnealingWarmRestarts(clf_optimizer_initial, T_0=CLASSIFIER_SCHEDULER_T_0, T_mult=1, eta_min=1e-7, verbose=False)
        
        clf_model_path_initial = fold_output_dir / "best_initial_classifier_model.pt"
        clf_early_stopper_initial = EarlyStopping(patience=CLASSIFIER_EARLY_STOPPING_PATIENCE, verbose=True, path=clf_model_path_initial, mode='max')

        for epoch in range(CLASSIFIER_EPOCHS):
            train_clf_loss, train_clf_acc = train_classifier_epoch(initial_classifier_model, clf_train_loader_initial, clf_optimizer_initial, clf_criterion_initial, DEVICE, epoch, CLASSIFIER_EPOCHS, True, clf_scheduler_initial)
            val_metrics = evaluate_classifier(initial_classifier_model, clf_val_loader_initial, clf_criterion_initial, DEVICE, desc_prefix="Initial Val")
            tb_writer_fold.add_scalar("Initial_Classifier/Train_Loss", train_clf_loss, epoch)
            tb_writer_fold.add_scalar("Initial_Classifier/Train_Acc", train_clf_acc, epoch)
            tb_writer_fold.add_scalar("Initial_Classifier/Val_Loss", val_metrics['loss'], epoch)
            tb_writer_fold.add_scalar("Initial_Classifier/Val_Acc", val_metrics['accuracy'], epoch)
            tb_writer_fold.add_scalar("Initial_Classifier/Val_AUC", val_metrics['auc'], epoch)
            tb_writer_fold.add_scalar("Initial_Classifier/Val_F1", val_metrics['f1'], epoch)
            tb_writer_fold.add_scalar("Initial_Classifier/Learning_Rate", clf_optimizer_initial.param_groups[0]['lr'], epoch)

            clf_early_stopper_initial(val_metrics['auc'], initial_classifier_model)
            if clf_early_stopper_initial.early_stop:
                logger.info("Early stopping initial classifier training.")
                break
        
        initial_classifier_model.load_state_dict(torch.load(clf_model_path_initial, map_location=DEVICE, weights_only=True))
        logger.info(f"Loaded best initial classifier model from {clf_model_path_initial}.")
        fold_val_metrics = evaluate_classifier(initial_classifier_model, clf_val_loader_initial, clf_criterion_initial, DEVICE, desc_prefix="Final Val (from initial train)")
        overall_fold_val_metrics.append(fold_val_metrics)
        logger.info(f"Fold {fold_idx+1} Initial Best Classifier Val Metrics: Acc: {fold_val_metrics['accuracy']:.4f}, AUC: {fold_val_metrics['auc']:.4f}, F1: {fold_val_metrics['f1']:.4f}")

        logger.info("--- Classifier Re-training Phase (CN vs AD) on Train+Val ---")
        all_train_val_features = torch.cat((clf_train_features_filtered, clf_val_features_filtered), dim=0)
        all_train_val_labels = torch.cat((clf_train_labels_initial, clf_val_labels_initial), dim=0)
        all_train_val_sids = np.concatenate((clf_train_sids_initial, clf_val_sids_initial))

        clf_train_val_dataset = LatentFeatureDataset(all_train_val_features, all_train_val_labels, all_train_val_sids)
        
        # Re-calculate sampler and weights for combined train+val dataset for FocalLoss alpha
        class_counts_retrain = np.bincount(all_train_val_labels.numpy())
        if len(class_counts_retrain) < 2:
            logger.warning(f"Fold {fold_idx+1}: Only one class present in combined train+val data. Sampler not used for retraining.")
            sampler_retrain = None
            if len(class_counts_retrain) == 1 and class_counts_retrain[0] == len(all_train_val_labels):
                 class_weights_tensor_alpha_retrain = None
                 logger.warning("Only one class for retrain classifier, FocalLoss alpha will be None.")
            else:
                 class_weights_tensor_alpha_retrain = torch.tensor(CLASSIFIER_CLASS_WEIGHTS, dtype=torch.float).to(DEVICE) # Fallback
                 logger.warning("Issue with class counts for retrain FocalLoss, using default weights.")
        else:
            weights_sampler_retrain = 1. / torch.tensor(class_counts_retrain, dtype=torch.float)
            sample_weights_retrain = weights_sampler_retrain[all_train_val_labels]
            sampler_retrain = WeightedRandomSampler(sample_weights_retrain, len(sample_weights_retrain), replacement=True)
            class_weights_tensor_alpha_retrain = (len(all_train_val_labels) / (2 * torch.tensor(class_counts_retrain, dtype=torch.float))).to(DEVICE)
            logger.info(f"Using DYNAMIC FocalLoss alpha for retraining: {class_weights_tensor_alpha_retrain.cpu().numpy()}")


        clf_train_val_loader = DataLoader(clf_train_val_dataset, batch_size=CLASSIFIER_BATCH_SIZE, sampler=sampler_retrain, num_workers=NUM_WORKERS, worker_init_fn=_worker_init_fn)

        final_classifier_model = SimpleMLP(
            input_dim=VAE_LATENT_DIM * CLASSIFIER_INPUT_DIM_MULTIPLIER,
            hidden_dims=CLASSIFIER_HIDDEN_DIMS, output_dim=2, dropout_rate=CLASSIFIER_DROPOUT
        ).to(DEVICE)
        
        clf_criterion_final = FocalLoss(alpha=class_weights_tensor_alpha_retrain, gamma=FOCAL_LOSS_GAMMA)
        
        clf_optimizer_final = optim.AdamW(final_classifier_model.parameters(), lr=CLASSIFIER_LR)
        clf_scheduler_final = optim.lr_scheduler.CosineAnnealingWarmRestarts(clf_optimizer_final, T_0=CLASSIFIER_FINAL_RETRAIN_EPOCHS, T_mult=1, eta_min=1e-7, verbose=False)

        logger.info(f"Re-training classifier on combined train+val data for {CLASSIFIER_FINAL_RETRAIN_EPOCHS} epochs.")
        for epoch in range(CLASSIFIER_FINAL_RETRAIN_EPOCHS):
            train_loss, train_acc = train_classifier_epoch(final_classifier_model, clf_train_val_loader, clf_optimizer_final, clf_criterion_final, DEVICE, epoch, CLASSIFIER_FINAL_RETRAIN_EPOCHS, True, clf_scheduler_final)
            tb_writer_fold.add_scalar("Final_Classifier/Train_Loss", train_loss, epoch)
            tb_writer_fold.add_scalar("Final_Classifier/Train_Acc", train_acc, epoch)
            tb_writer_fold.add_scalar("Final_Classifier/Learning_Rate", clf_optimizer_final.param_groups[0]['lr'], epoch)
            if (epoch + 1) % 5 == 0 or epoch == CLASSIFIER_FINAL_RETRAIN_EPOCHS -1 :
                 logger.info(f"Fold {fold_idx+1} CLF Re-train Epoch {epoch+1}/{CLASSIFIER_FINAL_RETRAIN_EPOCHS} | Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f}")
        
        torch.save(final_classifier_model.state_dict(), fold_output_dir / "final_classifier_model.pt")
        logger.info(f"Saved final classifier model to {fold_output_dir / 'final_classifier_model.pt'}")

        if len(test_clf_latent_features) > 0 and test_clf_latent_features.nelement() > 0 and len(test_orig_labels) > 0:
            logger.info("--- Evaluating Final Classifier on Test Set ---")
            test_cn_ad_mask = (test_orig_labels == LABEL_MAPPING['CN']) | (test_orig_labels == LABEL_MAPPING['AD'])
            clf_test_features_filtered = test_clf_latent_features[test_cn_ad_mask]
            clf_test_sids_filtered = test_sids[test_cn_ad_mask]
            clf_test_labels_filtered = torch.tensor([CLASSIFIER_LABEL_MAPPING[l.item()] for l in test_orig_labels[test_cn_ad_mask] if l.item() in CLASSIFIER_LABEL_MAPPING], dtype=torch.long)

            if len(clf_test_features_filtered) > 0 and len(clf_test_labels_filtered) > 0:
                clf_test_dataset = LatentFeatureDataset(clf_test_features_filtered, clf_test_labels_filtered, clf_test_sids_filtered)
                clf_test_loader = DataLoader(clf_test_dataset, batch_size=CLASSIFIER_BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, worker_init_fn=_worker_init_fn)
                
                test_metrics = evaluate_classifier(final_classifier_model, clf_test_loader, clf_criterion_final, DEVICE, desc_prefix="Test Set")
                logger.info(f"Fold {fold_idx+1} TEST SET METRICS: Acc: {test_metrics['accuracy']:.4f}, AUC: {test_metrics['auc']:.4f}, F1: {test_metrics['f1']:.4f}")
                overall_fold_test_metrics.append(test_metrics)
                
                tb_writer_fold.add_scalar("Test_Set/Accuracy", test_metrics['accuracy'])
                tb_writer_fold.add_scalar("Test_Set/AUC", test_metrics['auc'])
                tb_writer_fold.add_scalar("Test_Set/F1", test_metrics['f1'])
                cm_test = confusion_matrix(test_metrics['labels'], test_metrics['predictions'])
                logger.info(f"Fold {fold_idx+1} Test Set Confusion Matrix (CN vs AD):\n{cm_test}")
            else:
                logger.warning(f"Fold {fold_idx+1}: No CN/AD samples in the test set after filtering. Skipping test evaluation.")
                overall_fold_test_metrics.append({"accuracy": 0, "auc": 0, "f1": 0}) 
        else:
            logger.warning(f"Fold {fold_idx+1}: Test data latent features are empty. Skipping test set evaluation.")
            overall_fold_test_metrics.append({"accuracy": 0, "auc": 0, "f1": 0}) 

        tb_writer_fold.close()
        gc.collect()
        if DEVICE.type == 'cuda': torch.cuda.empty_cache()
    
    if not overall_fold_val_metrics:
        logger.error("No folds completed initial classifier training. Cannot aggregate validation results.")
    else:
        avg_val_acc = np.mean([m['accuracy'] for m in overall_fold_val_metrics])
        avg_val_auc = np.mean([m['auc'] for m in overall_fold_val_metrics])
        avg_val_f1 = np.mean([m['f1'] for m in overall_fold_val_metrics])
        std_val_acc = np.std([m['accuracy'] for m in overall_fold_val_metrics])
        std_val_auc = np.std([m['auc'] for m in overall_fold_val_metrics])
        std_val_f1 = np.std([m['f1'] for m in overall_fold_val_metrics])

        logger.info("--- Overall Cross-Validation Results (Validation Set - Initial Classifier) ---")
        logger.info(f"Average Validation Accuracy: {avg_val_acc:.4f} +/- {std_val_acc:.4f}")
        logger.info(f"Average Validation AUC:      {avg_val_auc:.4f} +/- {std_val_auc:.4f}")
        logger.info(f"Average Validation F1-score: {avg_val_f1:.4f} +/- {std_val_f1:.4f}")

        results_summary_path = run_output_dir / "cross_validation_summary.txt"
        with open(results_summary_path, 'w') as f:
            f.write("--- Overall Cross-Validation Results (Validation Set - Initial Classifier) ---\n")
            f.write(f"Average Validation Accuracy: {avg_val_acc:.4f} +/- {std_val_acc:.4f}\n")
            f.write(f"Average Validation AUC:      {avg_val_auc:.4f} +/- {std_val_auc:.4f}\n")
            f.write(f"Average Validation F1-score: {avg_val_f1:.4f} +/- {std_val_f1:.4f}\n\n")
            f.write("Individual Fold Validation Metrics (Initial Classifier):\n")
            for i, metrics in enumerate(overall_fold_val_metrics):
                f.write(f"Fold {i+1}: Acc={metrics['accuracy']:.4f}, AUC={metrics['auc']:.4f}, F1={metrics['f1']:.4f}\n")
            
            if overall_fold_test_metrics and any(m['auc'] > 0 or m['accuracy'] > 0 for m in overall_fold_test_metrics):
                # Filter out placeholder metrics before calculating mean/std for test
                valid_test_metrics = [m for m in overall_fold_test_metrics if m['auc'] > 0 or m['accuracy'] > 0]
                if valid_test_metrics:
                    avg_test_acc = np.mean([m['accuracy'] for m in valid_test_metrics])
                    avg_test_auc = np.mean([m['auc'] for m in valid_test_metrics])
                    avg_test_f1 = np.mean([m['f1'] for m in valid_test_metrics])
                    std_test_acc = np.std([m['accuracy'] for m in valid_test_metrics])
                    std_test_auc = np.std([m['auc'] for m in valid_test_metrics])
                    std_test_f1 = np.std([m['f1'] for m in valid_test_metrics])
                    
                    f.write("\n--- Overall Cross-Validation Results (Test Set - Final Classifier) ---\n")
                    f.write(f"Average Test Accuracy: {avg_test_acc:.4f} +/- {std_test_acc:.4f}\n")
                    f.write(f"Average Test AUC:      {avg_test_auc:.4f} +/- {std_test_auc:.4f}\n")
                    f.write(f"Average Test F1-score: {avg_test_f1:.4f} +/- {std_test_f1:.4f}\n\n")
                    f.write("Individual Fold Test Metrics (Final Classifier):\n")
                    for i, metrics in enumerate(overall_fold_test_metrics): # Log all, including placeholders
                         f.write(f"Fold {i+1}: Acc={metrics['accuracy']:.4f}, AUC={metrics['auc']:.4f}, F1={metrics['f1']:.4f}\n")
                else:
                    f.write("\n--- No Valid Test Set Results to Report ---\n")
            else:
                f.write("\n--- No Test Set Results to Report ---\n")
        logger.info(f"Saved cross-validation summary to {results_summary_path}")


if __name__ == '__main__':
    try:
        import torch
        logger.info(f"PyTorch version: {torch.__version__}")
    except ImportError:
        logger.critical("PyTorch is not installed. Please install PyTorch to run this script.")
        exit()
    
    if not PREPROCESSED_DATA_BASE_DIR.exists():
        logger.critical(f"Preprocessed data directory not found: {PREPROCESSED_DATA_BASE_DIR}")
        logger.critical("Please run the preprocessing script first.")
        exit()

    found_any_fold_data = any(
        (PREPROCESSED_DATA_BASE_DIR / f"fold_{i}" / f"fold_{i}_preprocessed_data.pt").exists()
        for i in range(N_FOLDS)
    )
    if not found_any_fold_data and N_FOLDS > 0 :
         logger.critical(f"No preprocessed fold data found in subdirectories under {PREPROCESSED_DATA_BASE_DIR} for the configured N_FOLDS={N_FOLDS}")
         logger.critical("Example expected path for fold 0: preprocessed_connectomes_for_dl/fold_0/fold_0_preprocessed_data.pt")
         exit()
    elif N_FOLDS == 0:
        logger.warning("N_FOLDS is set to 0. No training will occur.")
    main()


2025-05-26 03:47:02.722883: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-05-26 03:47:04 - INFO - 3473208690 - PyTorch version: 2.4.0
2025-05-26 03:47:04 - INFO - 3473208690 - Seeded everything with seed 42
2025-05-26 03:47:04 - INFO - 3473208690 - Device: cuda
2025-05-26 03:47:04 - INFO - 3473208690 - Using AMP: True
2025-05-26 03:47:04 - INFO - 3473208690 - Number of workers for DataLoader: 6
2025-05-26 03:47:04 - INFO - 3473208690 - Outputs will be saved to: training_results_vae_classifier/vae_connectome_clf_run_7_quick_wins
2025-05-26 03:47:04 - INFO - 3473208690 - Saved run configuration to training_results_vae_classifier/vae_connectome_clf_run_7_quick_wins/config_summary.txt
2025-05-26 03:47:04 - INFO - 3473208690 - --- Processing Fold 1/51 