In [21]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


## PyTorch Library Download

In [None]:
import torch

def format_pytorch_version(version):
  return version.split('+')[0]

TORCH_version = torch.__version__
TORCH = format_pytorch_version(TORCH_version)

def format_cuda_version(version):
  return 'cu' + version.replace('.', '')

CUDA_version = torch.version.cuda
CUDA = format_cuda_version(CUDA_version)

!pip install torch-scatter -f https://data.pyg.org/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-sparse -f https://data.pyg.org/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-cluster -f https://data.pyg.org/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-spline-conv -f https://data.pyg.org/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-geometric

print(f'▶︎ Successfully installed PyTorch {TORCH} with CUDA {CUDA}')

Looking in links: https://data.pyg.org/whl/torch-2.8.0+cu126.html
Looking in links: https://data.pyg.org/whl/torch-2.8.0+cu126.html
Looking in links: https://data.pyg.org/whl/torch-2.8.0+cu126.html
Looking in links: https://data.pyg.org/whl/torch-2.8.0+cu126.html
▶︎ Successfully installed PyTorch 2.8.0 with CUDA cu126


In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from pathlib import Path

import torch.nn.functional as F
from torch.nn import Linear, Dropout
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.data import Dataset, Data
from torch_geometric.loader import DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, accuracy_score, f1_score
from scipy.stats import ttest_ind

print("▶︎ Finished Environment Setting")

▶︎ Finished Environment Setting


In [24]:
import sys
print(sys.version)

3.12.11 (main, Jun  4 2025, 08:56:18) [GCC 11.4.0]


In [None]:
# Schaefer Atlas
CN_FC_SCHAEFER_DIR = Path('/content/drive/MyDrive/AD_fMRI_GNN/CN_conn_roi_data/schaefer/fc')
CN_ROI_SCHAEFER_DIR = Path('/content/drive/MyDrive/AD_fMRI_GNN/CN_conn_roi_data/schaefer/roi')
AD_FC_SCHAEFER_DIR = Path('/content/drive/MyDrive/AD_fMRI_GNN/AD_conn_roi_data/schaefer/fc')
AD_ROI_SCHAEFER_DIR = Path('/content/drive/MyDrive/AD_fMRI_GNN/AD_conn_roi_data/schaefer/roi')

# AAL Atlas
CN_FC_AAL_DIR = Path('/content/drive/MyDrive/AD_fMRI_GNN/CN_conn_roi_data/aal/fc')
CN_ROI_AAL_DIR = Path('/content/drive/MyDrive/AD_fMRI_GNN/CN_conn_roi_data/aal/roi')
AD_FC_AAL_DIR = Path('/content/drive/MyDrive/AD_fMRI_GNN/AD_conn_roi_data/aal/fc')
AD_ROI_AAL_DIR = Path('/content/drive/MyDrive/AD_fMRI_GNN/AD_conn_roi_data/aal/roi')


# Hyperparameters 
LEARNING_RATE = 0.001
BATCH_SIZE = 32
INNER_NUM_EPOCHS = 30
OUTER_NUM_EPOCHS = 100
DROPOUT_RATE = 0.5

print("▶︎ Finished Drive Mount and Directory Settings.")


▶︎ Finished Drive Mount and Directory Settings.


In [None]:
import os 
from sklearn.neighbors import NearestNeighbors
from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances

class fMRIDataset(Dataset):
    """
    Processing fMRI Dataset for single Atlas (AAL or Schaefer) using KNN graph.
    """

    def __init__(self, root, cn_fc_dir, cn_roi_dir, ad_fc_dir, ad_roi_dir, k_neighbors=10, transform=None, pre_transform=None):
        self.cn_fc_dir = cn_fc_dir
        self.cn_roi_dir = cn_roi_dir
        self.ad_fc_dir = ad_fc_dir
        self.ad_roi_dir = ad_roi_dir
        self.k_neighbors = k_neighbors 

        # Determine atlas_type based on directory names
        if 'schaefer' in str(cn_fc_dir).lower():
            self.atlas_type = 'schaefer'
        elif 'aal' in str(cn_fc_dir).lower():
            self.atlas_type = 'aal'
        else:
            self.atlas_type = 'unknown'
            print(f"Warning: Could not determine atlas type from directory names: {cn_fc_dir}")


        self.ad_subjects = [f.stem.split('_')[0] for f in self.ad_fc_dir.glob('*.csv')]
        self.cn_subjects = [f.stem.split('_')[0] for f in self.cn_fc_dir.glob('*.csv')]
        self.all_subjects = sorted(list(set(self.ad_subjects + self.cn_subjects)))

        # Ensure all subjects have both FC and ROI files, otherwise exclude them.
        valid_subjects = []
        for sub in self.all_subjects:
            expected_fc_path = (self.ad_fc_dir / f"{sub}_task-rest_bold_{self.atlas_type}_connectivity_matrix.csv") if sub in self.ad_subjects else (self.cn_fc_dir / f"{sub}_task-rest_bold_{self.atlas_type}_connectivity_matrix.csv")
            expected_roi_path = (self.ad_roi_dir / f"{sub}_task-rest_bold_{self.atlas_type}_roi_timeseries.csv") if sub in self.ad_subjects else (self.cn_roi_dir / f"{sub}_task-rest_bold_{self.atlas_type}_roi_timeseries.csv")

            fc_found = expected_fc_path.exists() or list(self.ad_fc_dir.glob(f"{sub}*connectivity_matrix.csv")) or list(self.cn_fc_dir.glob(f"{sub}*connectivity_matrix.csv"))
            roi_found = expected_roi_path.exists() or list(self.ad_roi_dir.glob(f"{sub}*roi_timeseries.csv")) or list(self.cn_roi_dir.glob(f"{sub}*roi_timeseries.csv"))


            if fc_found and roi_found:
                valid_subjects.append(sub)
            else:
                if not fc_found and not roi_found:
                    print(f"Warning: FC and ROI data missing for subject {sub}. Skipping.")
                elif not fc_found:
                    print(f"Warning: FC data missing for subject {sub}. Skipping.")
                elif not roi_found:
                     print(f"Warning: ROI data missing for subject {sub}. Skipping.")


        self.all_subjects = valid_subjects
        self.ad_subjects = [sub for sub in self.all_subjects if sub in self.ad_subjects]
        self.cn_subjects = [sub for sub in self.all_subjects if sub in self.cn_subjects]

        print(f"Initialized dataset with {len(self.all_subjects)} valid subjects.")


        super(fMRIDataset, self).__init__(root, transform, pre_transform)

    @property
    def raw_file_names(self):
         # This method is typically for raw files that need processing.
         cn_files = [f.name for f in self.cn_fc_dir.glob('*.csv')] + [f.name for f in self.cn_roi_dir.glob('*.csv')]
         ad_files = [f.name for f in self.ad_fc_dir.glob('*.csv')] + [f.name for f in self.ad_roi_dir.glob('*.csv')]
         return list(set(cn_files + ad_files)) # Return unique file names


    @property
    def processed_file_names(self):
        # Define the names of the processed data files
        return [f'data_{i}.pt' for i in range(len(self.all_subjects))]


    def download(self):
        pass

    def process(self):
        # Process raw data into PyTorch Geometric Data objects
        print(f"Processing {len(self.all_subjects)} subjects...")
        for idx, subject_id in enumerate(self.all_subjects):
            print(f"Processing subject {subject_id} ({idx + 1}/{len(self.all_subjects)})...")

            # Determine file paths based on subject group and stored atlas type
            if subject_id in self.ad_subjects:
                label = 1
                fc_dir, roi_dir = self.ad_fc_dir, self.ad_roi_dir
            else:
                label = 0
                fc_dir, roi_dir = self.cn_fc_dir, self.cn_roi_dir

            fc_path_candidates = list(fc_dir.glob(f"{subject_id}*connectivity_matrix.csv"))
            roi_path_candidates = list(roi_dir.glob(f"{subject_id}*roi_timeseries.csv"))

            fc_path = fc_path_candidates[0] if fc_path_candidates else None
            roi_path = roi_path_candidates[0] if roi_path_candidates else None


            if fc_path is None or roi_path is None:
                 print(f"Error during processing: Data for subject {subject_id} not found at expected paths. Skipping.")
                 continue


            try:
                fc_matrix = pd.read_csv(fc_path, index_col=0).values
                roi_data = pd.read_csv(roi_path).values
                print(f"  Successfully loaded data. FC shape: {fc_matrix.shape}, ROI shape: {roi_data.shape}")
            except Exception as e:
                print(f"Error reading files for subject {subject_id}: {e}. Skipping.")
                continue

            # Robust ROI Data Transposition Logic 
            initial_roi_shape = roi_data.shape
            expected_nodes = fc_matrix.shape[0]

            if initial_roi_shape[0] == expected_nodes:
                pass
            elif initial_roi_shape[1] == expected_nodes:
                roi_data = roi_data.T
                print(f"  Transposed ROI data for subject {subject_id}. Initial shape: {initial_roi_shape}, New shape: {roi_data.shape}")
            else:
                print(f"Warning: ROI data shape mismatch for subject {subject_id}. Expected one dimension to match number of nodes ({expected_nodes}), got shape {initial_roi_shape}. Skipping.")
                continue

            if roi_data.shape[0] != expected_nodes:
                 print(f"Error: ROI data first dimension still does not match expected nodes after transposition logic for subject {subject_id}. Expected {expected_nodes}, got {roi_data.shape[0]}. Skipping.")
                 continue


            # Define Node Features 
            # Use the rows of the FC matrix as node features
            x = torch.tensor(fc_matrix, dtype=torch.float)
            num_nodes = x.shape[0]
            print(f"  Node features (x) shape: {x.shape}")

            # Implement KNN Graph Construction and Edge Weighting 
            edge_index = []
            edge_attr = []

            # Nearest Neighbours with Eusclidean distance
            nn = NearestNeighbors(n_neighbors=self.k_neighbors + 1, metric='euclidean') # +1 to exclude self
            nn.fit(x.cpu().numpy()) # Fit on CPU numpy array

            # Find neighbors and distances
            distances, indices = nn.kneighbors(x.cpu().numpy())

            for i in range(num_nodes):
                # Connect node i to its k_neighbors (excluding itself)
                for j in range(1, self.k_neighbors + 1):
                    neighbor_index = indices[i, j]
                    distance = distances[i, j]

                    # Add directed edge 
                    edge_index.append([i, neighbor_index])

                    weight = 1.0 / (distance + 1e-8)
                    edge_attr.append([weight])


            # Convert lists to tensors
            if not edge_index: # Handle case with no edges (e.g., k=0 or issues)
                edge_index = torch.empty((2, 0), dtype=torch.long)
                edge_attr = torch.empty((0, 1), dtype=torch.float)
            else:
                edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
                edge_attr = torch.tensor(edge_attr, dtype=torch.float)

            # Ensure edge_index has correct shape (2, num_edges)
            if edge_index.ndim != 2 or edge_index.shape[0] != 2:
                 print(f"Error creating edge_index for subject {subject_id}. Skipping.")
                 continue


            # PyTorch Geometric Data object
            data = Data(x=x,
                        edge_index=edge_index,
                        edge_attr=edge_attr, # Add edge attributes
                        y=torch.tensor([label], dtype=torch.long))

            # Save the processed data object
            save_path = os.path.join(self.processed_dir, f'data_{idx}.pt')
            try:
                torch.save(data, save_path)
            except Exception as e:
                print(f"Error saving processed data for subject {subject_id} to {save_path}: {e}. Skipping.")
                continue


        print("Finished processing subjects.")


    def get(self, idx):
        data = torch.load(os.path.join(self.processed_dir, f'data_{idx}.pt'), weights_only=False)
        return data

    def len(self) -> int:
        """Returns the number of data objects stored in the dataset."""
        return len(self.all_subjects) 


    @property
    def num_node_features(self):
        if len(self.processed_file_names) > 0:
             try:
                 first_processed_file = os.path.join(self.processed_dir, self.processed_file_names[0])
                 if os.path.exists(first_processed_file):
                     data = torch.load(first_processed_file, weights_only=False)

                     return data.x.shape[1] 
                 else:
                     print(f"Warning: First processed file not found at {first_processed_file}. Cannot determine num_node_features.")
                     return 0
             except Exception as e:
                 print(f"Error loading first processed file to determine num_node_features: {e}")
                 return 0
        return 0

    @property
    def num_classes(self):
        return 2

In [None]:
import torch.nn.functional as F
from torch.nn import Linear, Dropout
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.data import Data 
import torch

class GCN(torch.nn.Module):
    def __init__(self, num_node_features, num_classes, hidden_channels=64):
        super(GCN, self).__init__()

        self.conv1 = GCNConv(num_node_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.lin1 = Linear(hidden_channels, num_classes)
        self.dropout = Dropout(p=DROPOUT_RATE)

    def forward(self, data):
        # Data object now contains edge_attr
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch

        edge_weight = edge_attr.squeeze() if edge_attr is not None else None

        x = self.conv1(x, edge_index, edge_weight) # Pass edge_weight
        x = F.relu(x)
        x = self.dropout(x)
        x = self.conv2(x, edge_index, edge_weight) # Pass edge_weight 
        x = F.relu(x)

        # Global pooling remains the same
        x = global_mean_pool(x, batch)
        x = self.dropout(x)
        x = self.lin1(x)
        return F.log_softmax(x, dim=-1)

print("▶︎ Finished Defining Dataset Class and GCN Model (Modified for weighted edges).")

▶︎ Finished Defining Dataset Class and GCN Model (Modified for weighted edges).


In [None]:
from tqdm import tqdm 
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, recall_score 

def train_model(model, loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    for data in tqdm(loader, desc="Training"):
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data)
        loss = criterion(out, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * data.num_graphs
    return total_loss / len(loader.dataset)

def test_model(model, loader, device):
    model.eval()
    all_preds = []
    all_labels = []
    all_probs = []

    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            out = model(data)
            pred = out.argmax(dim=1)

            all_preds.extend(pred.cpu().numpy())
            all_labels.extend(data.y.cpu().numpy())
            all_probs.extend(torch.exp(out).cpu().numpy())
    
    # Calculate metrics
    accuracy = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds, average='weighted')
    
    # Calculate AUC
    auc_score = 0 
    if len(np.unique(all_labels)) > 1 and np.array(all_probs).shape[1] == 2:
        try:
            probs_positive_class = np.array(all_probs)[:, 1]
            auc_score = roc_auc_score(all_labels, probs_positive_class)
        except Exception as e:
            print(f"Warning: Could not calculate AUC: {e}")
            auc_score = 0

    # Calculate Recall for the positive class (AD, label 1)
    recall = 0 
    if len(all_labels) > 0 and 1 in all_labels:
        try:
            recall = recall_score(all_labels, all_preds, pos_label=1)
        except Exception as e:
            print(f"Warning: Could not calculate Recall: {e}")
            recall = 0


    return accuracy, f1, auc_score, recall, all_preds, all_labels, np.array(all_probs)

def plot_confusion_matrix(cm, class_names, title):
    """
    Visualising Confusion Matrix
    """
    df_cm = pd.DataFrame(cm, index=class_names, columns=class_names)
    plt.figure(figsize=(6, 5))
    sns.heatmap(df_cm, annot=True, fmt="d", cmap='Blues')
    plt.title(title)
    plt.ylabel('Actual')
    plt.xlabel('Predicted')
    plt.show()

print("▶︎ Finished Defining Training and Testing Function.")

▶︎ Finished Defining Training and Testing Function.


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🖥️ Current Device: {device}")


🖥️ Current Device: cuda


In [None]:
# Model Training 
def run_training(atlas_name, dataset, num_epochs=100):
    print("-" * 50)
    print(f"▶︎ Start Training {atlas_name} Atlas Model ...")

    # Data Splitting
    train_idx, test_idx = train_test_split(
        np.arange(len(dataset)),
        test_size=0.2,
        random_state=42,
        stratify=[data.y.item() for data in dataset]
    )
    train_dataset = dataset[train_idx]
    test_dataset = dataset[test_idx]
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

    # Initialise Model ->  Optimiser, Loss function
    model = GCN(
        num_node_features=dataset.num_node_features,
        num_classes=dataset.num_classes,
        hidden_channels=64
    ).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    criterion = torch.nn.CrossEntropyLoss()

    # Training the model
    for epoch in range(1, num_epochs + 1):
        loss = train_model(model, train_loader, criterion, optimizer, device)
        if epoch % 10 == 0:
            print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')

    # Evaluating the model
    test_acc, test_f1, test_auc, test_recall, _, _, _ = test_model(model, test_loader, device) # Capture all returned metrics
    print(f"{atlas_name} Model Test Accuracy: {test_acc:.4f}")
    print(f"{atlas_name} Model Test F1-Score: {test_f1:.4f}")
    print(f"{atlas_name} Model Test AUC: {test_auc:.4f}")
    print(f"{atlas_name} Model Test Recall (AD): {test_recall:.4f}")


    # Saving the trained model
    model_save_path = f'./{atlas_name}_gcn_model.pth'
    torch.save(model.state_dict(), model_save_path)
    print(f"{atlas_name} model is save in the path: '{model_save_path}'.")
    print("-" * 50)

    return model, test_loader

#Recall을 기준으로 앙상블

In [31]:
# Hyperparameter candidates for Random Search
params = {
    'learning_rate': [0.01, 0.005, 0.001],
    'dropout_rate': [0.3, 0.5, 0.7],
    'hidden_channels': [32, 64, 128],
    'k_neighbors': [5, 10, 15]
}

print("▶︎ Defined hyperparameters:")
print(params)

▶︎ Defined hyperparameters:
{'learning_rate': [0.01, 0.005, 0.001], 'dropout_rate': [0.3, 0.5, 0.7], 'hidden_channels': [32, 64, 128], 'k_neighbors': [5, 10, 15]}


## AAL Model Training Using Hyperparameter Tuning with Nested CV
- Nested Cross-validation for the single model
- do Random Search within each outer fold

In [None]:

from sklearn.model_selection import StratifiedKFold, train_test_split
import gc
import torch
import numpy as np
import pandas as pd
from pathlib import Path
from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay, accuracy_score, f1_score, roc_curve, auc, roc_auc_score, recall_score
import matplotlib.pyplot as plt
import seaborn as sns
import itertools
import random 
print("\n" + "="*80)
print(" Starting AAL Model Nested Cross-Validation with KNN Graph and Random Search Tuning")
print("="*80 + "\n")

# Cross-Validation & Hyperparameter Setup 
OUTER_K_FOLDS = 10
INNER_K_FOLDS = 5
N_RANDOM = 20 


# Generate Random Hyperparameter Combinations 
all_combinations = list(itertools.product(*params.values()))
num_to_sample = min(N_RANDOM, len(all_combinations))
random_combinations_tuples = random.sample(all_combinations, k=num_to_sample)
# Convert tuples to dictionaries
random_combinations = [dict(zip(params.keys(), combo)) for combo in random_combinations_tuples]
print(f"Total possible hyperparameter combinations: {len(all_combinations)}")
print(f"Testing {len(random_combinations)} random combinations in each outer fold.")


# Cross-Validation Setup 
skf_outer = StratifiedKFold(n_splits=OUTER_K_FOLDS, shuffle=True, random_state=42)
skf_inner = StratifiedKFold(n_splits=INNER_K_FOLDS, shuffle=True, random_state=84) # Inner CV uses a different seed

# Lists to store results for each Outer Fold
aal_outer_fold_history = []
all_aal_preds = []
all_aal_labels = []
all_aal_probs = []

# Load Full Dataset Info for subject list and labels 
full_dataset_aal_info = fMRIDataset(
    root='./data/aal_full_cv_info_tuning',
    cn_fc_dir=CN_FC_AAL_DIR, cn_roi_dir=CN_ROI_AAL_DIR,
    ad_fc_dir=AD_FC_AAL_DIR, ad_roi_dir=AD_ROI_AAL_DIR,
    k_neighbors=10
)
dataset_labels = [data.y.item() for data in full_dataset_aal_info]
full_subject_list = full_dataset_aal_info.all_subjects


# Start Outer Cross-Validation Loop 
for outer_fold, (train_idx_outer, test_idx_outer) in enumerate(skf_outer.split(np.arange(len(full_subject_list)), dataset_labels)):
    print(f"\n=============== OUTER FOLD {outer_fold+1}/{OUTER_K_FOLDS} ================")
    print("\n--- Starting Inner CV Loop for Hyperparameter Tuning ---")

    inner_tuning_results = [] # store the performance of each hyperparameter combination

    for i, current_hyperparams in enumerate(random_combinations):
        print(f"\n  Testing Hyperparams {i+1}/{len(random_combinations)}: {current_hyperparams}")

        # --- Inner Cross-Validation ---
        inner_fold_recalls = [] # recall for each inner fold
        current_k_neighbors = current_hyperparams['k_neighbors']

        dataset_aal_inner_cv = fMRIDataset(
             root=f'./data/aal_outer_{outer_fold+1}_inner_cv_k{current_k_neighbors}',
             cn_fc_dir=CN_FC_AAL_DIR, cn_roi_dir=CN_ROI_AAL_DIR,
             ad_fc_dir=AD_FC_AAL_DIR, ad_roi_dir=AD_ROI_AAL_DIR,
             k_neighbors=current_k_neighbors
        )

        # labels for the outer training set for inner stratification
        outer_train_labels = [dataset_labels[i] for i in train_idx_outer]

        for inner_fold, (train_idx_inner, val_idx_inner) in enumerate(skf_inner.split(train_idx_outer, outer_train_labels)):
    
            actual_train_idx_inner = [train_idx_outer[i] for i in train_idx_inner]
            actual_val_idx_inner = [train_idx_outer[i] for i in val_idx_inner]

            # Subsets and DataLoaders for the inner fold
            train_dataset_inner = torch.utils.data.Subset(dataset_aal_inner_cv, actual_train_idx_inner)
            val_dataset_inner = torch.utils.data.Subset(dataset_aal_inner_cv, actual_val_idx_inner)
            train_loader_inner = DataLoader(train_dataset_inner, batch_size=BATCH_SIZE, shuffle=True)
            val_loader_inner = DataLoader(val_dataset_inner, batch_size=BATCH_SIZE, shuffle=False)

            # Initialise and Train Model for Inner Fold
            model_inner = GCN(
                num_node_features=dataset_aal_inner_cv.num_node_features,
                num_classes=dataset_aal_inner_cv.num_classes,
                hidden_channels=current_hyperparams['hidden_channels']
            ).to(device)
            optimizer_inner = torch.optim.Adam(model_inner.parameters(), lr=current_hyperparams['learning_rate'])
            criterion_inner = torch.nn.CrossEntropyLoss()


            DROPOUT_RATE = current_hyperparams['dropout_rate']

            for epoch_inner in range(1, INNER_NUM_EPOCHS + 1):
                train_model(model_inner, train_loader_inner, criterion_inner, optimizer_inner, device)

            # Evaluate on inner validation set and store recall
            _, _, _, val_recall, _, _, _ = test_model(model_inner, val_loader_inner, device)
            inner_fold_recalls.append(val_recall)

            # Memory Cleanup for Inner Fold 
            del model_inner, optimizer_inner, criterion_inner
            del train_dataset_inner, val_dataset_inner, train_loader_inner, val_loader_inner
            gc.collect()
            torch.cuda.empty_cache()

        # Calculate average recall across inner folds for the current hyperparameter set
        avg_inner_recall = np.mean(inner_fold_recalls)
        print(f"  --> Avg. Inner CV Recall for these hyperparams: {avg_inner_recall:.4f}")
        inner_tuning_results.append({
            'hyperparams': current_hyperparams,
            'avg_recall': avg_inner_recall
        })
        # Cleanup dataset for this k_neighbors
        del dataset_aal_inner_cv
        gc.collect()


    # Find the best hyperparameters based on the highest average inner CV recall
    best_performing_set = max(inner_tuning_results, key=lambda x: x['avg_recall'])
    best_hyperparams = best_performing_set['hyperparams']
    best_avg_recall = best_performing_set['avg_recall']

    print("\n--- Inner CV Complete. Best Hyperparameters for Outer Fold "
          f"{outer_fold+1}: {best_hyperparams} (Avg. Inner Recall: {best_avg_recall:.4f}) ---")

    # Train Final Model on Outer Training Set with Best Hyperparameters 
    print(f"\nFold {outer_fold+1}: Training Final AAL Model with Best Hyperparameters...")

    # Extract best hyperparams 
    best_k_neighbors = best_hyperparams['k_neighbors']
    best_hidden_channels = best_hyperparams['hidden_channels']
    best_dropout_rate = best_hyperparams['dropout_rate']
    best_learning_rate = best_hyperparams['learning_rate']

    # datasets and dataloaders 
    dataset_aal_outer_train = fMRIDataset(
        root=f'./data/aal_outer_fold_{outer_fold+1}_final_train_k{best_k_neighbors}',
        cn_fc_dir=CN_FC_AAL_DIR, cn_roi_dir=CN_ROI_AAL_DIR,
        ad_fc_dir=AD_FC_AAL_DIR, ad_roi_dir=AD_ROI_AAL_DIR,
        k_neighbors=best_k_neighbors
    )
    train_dataset_outer = torch.utils.data.Subset(dataset_aal_outer_train, train_idx_outer)
    train_loader_outer = DataLoader(train_dataset_outer, batch_size=BATCH_SIZE, shuffle=True)

    dataset_aal_outer_test = fMRIDataset(
        root=f'./data/aal_outer_fold_{outer_fold+1}_final_test_k{best_k_neighbors}',
        cn_fc_dir=CN_FC_AAL_DIR, cn_roi_dir=CN_ROI_AAL_DIR,
        ad_fc_dir=AD_FC_AAL_DIR, ad_roi_dir=AD_ROI_AAL_DIR,
        k_neighbors=best_k_neighbors
    )
    test_dataset_outer = torch.utils.data.Subset(dataset_aal_outer_test, test_idx_outer)
    test_loader_outer = DataLoader(test_dataset_outer, batch_size=BATCH_SIZE, shuffle=False)

    # Initialise Final Model with Best Hyperparameters
    model_final = GCN(
        num_node_features=dataset_aal_outer_train.num_node_features,
        num_classes=dataset_aal_outer_train.num_classes,
        hidden_channels=best_hidden_channels
    ).to(device)
    optimizer_final = torch.optim.Adam(model_final.parameters(), lr=best_learning_rate)
    criterion_final = torch.nn.CrossEntropyLoss()

    DROPOUT_RATE = best_dropout_rate

    # Train the final model on the full outer training set
    for epoch_outer in range(1, OUTER_NUM_EPOCHS + 1):
        loss_outer = train_model(model_final, train_loader_outer, criterion_final, optimizer_final, device)
        if epoch_outer % 10 == 0:
             print(f'  Outer Epoch: {epoch_outer:03d}, Train Loss: {loss_outer:.4f}')

    # Evaluate Final Model on Outer Test Set 
    test_accuracy, test_f1, test_auc, test_recall, preds_aal_test, labels_aal_test, probs_aal_test = test_model(model_final, test_loader_outer, device)

    # Store Outer Fold Results
    aal_outer_fold_history.append({
        'fold': outer_fold + 1,
        'best_hyperparams': best_hyperparams,
        'test_accuracy': test_accuracy,
        'test_f1': test_f1,
        'test_auc': test_auc,
        'test_recall': test_recall
    })
    all_aal_preds.extend(preds_aal_test)
    all_aal_labels.extend(labels_aal_test)
    all_aal_probs.extend(probs_aal_test)

    print(f"Outer Fold {outer_fold+1} Test Accuracy: {test_accuracy:.4f}, AUC: {test_auc:.4f}, Recall (AD): {test_recall:.4f}")

    # Memory Cleanup for Outer Loop 
    del model_final, optimizer_final, criterion_final
    del train_dataset_outer, test_dataset_outer, train_loader_outer, test_loader_outer
    del dataset_aal_outer_train, dataset_aal_outer_test
    gc.collect()
    torch.cuda.empty_cache()

print("\n=============== Outer Cross-Validation Complete ===============")

# Final Result Analysis
if aal_outer_fold_history:
    avg_acc = np.mean([f['test_accuracy'] for f in aal_outer_fold_history])
    avg_auc = np.mean([f['test_auc'] for f in aal_outer_fold_history])
    avg_recall = np.mean([f['test_recall'] for f in aal_outer_fold_history])
    avg_f1 = np.mean([f['test_f1'] for f in aal_outer_fold_history])

    print(f"\n{OUTER_K_FOLDS}-Fold Nested Cross-validation Result (AAL with KNN Graph and Random Search Tuning):")
    print(f"  - Average Test Accuracy: {avg_acc:.4f}")
    print(f"  - Average Test AUC: {avg_auc:.4f}")
    print(f"  - Average Test Recall (AD): {avg_recall:.4f}")
    print(f"  - Average Test F1-Score: {avg_f1:.4f}")

    # Overall Confusion Matrix
    all_aal_labels_np = np.array(all_aal_labels)
    all_aal_preds_np = np.array(all_aal_preds)
    cm = confusion_matrix(all_aal_labels_np, all_aal_preds_np)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=['CN', 'AD'], yticklabels=['CN', 'AD'])
    plt.title(f'{OUTER_K_FOLDS}-Fold Nested CV Confusion Matrix (AAL Tuned Model)')
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.show()

    # Overall Classification Report
    print("\nOverall Classification Report:")
    print(classification_report(all_aal_labels_np, all_aal_preds_np, target_names=['CN', 'AD'], digits=4))

    # Overall ROC Curve
    all_aal_probs_np = np.array(all_aal_probs)
    if all_aal_probs_np.ndim > 1 and all_aal_probs_np.shape[1] > 1:
         probs_positive_class = all_aal_probs_np[:, 1]
         if len(all_aal_labels_np) == len(probs_positive_class):
             fpr, tpr, thresholds = roc_curve(all_aal_labels_np, probs_positive_class)
             roc_auc_overall = auc(fpr, tpr)

             plt.figure(figsize=(8, 8))
             plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (overall area = {roc_auc_overall:.4f})')
             plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
             plt.xlim([0.0, 1.0])
             plt.ylim([0.0, 1.05])
             plt.xlabel('False Positive Rate')
             plt.ylabel('True Positive Rate')
             plt.title('Overall Receiver Operating Characteristic (ROC) Curve for AAL Tuned Model')
             plt.legend(loc="lower right")
             plt.show()
         else:
             print(f"Error: Length of overall labels ({len(all_aal_labels_np)}) does not match length of overall probabilities ({len(probs_positive_class)}). Cannot plot overall ROC curve.")
    else:
         print("Warning: Unexpected shape for overall AAL probabilities. Cannot plot overall ROC curve.")

else:
    print("Couldn't execute nested cross-validation with hyperparameter tuning.")

print("\n" + "="*80)
print(" AAL Model Nested Cross-Validation with Hyperparameter Tuning Complete")
print("="*80 + "\n")

## Schaefer Model Training Using Hyperparameter Tuning with Nested CV
- Nested Cross-validation for the single model
- do Random Search within each outer fold

In [None]:
from sklearn.model_selection import StratifiedKFold, train_test_split
import gc
import torch
import numpy as np
import pandas as pd
from pathlib import Path
from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay, accuracy_score, f1_score, roc_curve, auc, roc_auc_score, recall_score
import matplotlib.pyplot as plt
import seaborn as sns
import itertools
import random 

print("\n" + "="*80)
print(" Starting Schaefer Model Nested Cross-Validation with KNN Graph and Random Search Tuning")
print("="*80 + "\n")

# Cross-Validation & Hyperparameter Setup 
OUTER_K_FOLDS = 10
INNER_K_FOLDS = 5
N_RANDOM = 20 


# Random Hyperparameter Combinations 

all_combinations = list(itertools.product(*params.values()))
num_to_sample = min(N_RANDOM, len(all_combinations))
random_combinations_tuples = random.sample(all_combinations, k=num_to_sample)
# Tuple to dictionary conversion
random_combinations = [dict(zip(params.keys(), combo)) for combo in random_combinations_tuples]
print(f"Total possible hyperparameter combinations: {len(all_combinations)}")
print(f"Testing {len(random_combinations)} random combinations in each outer fold.")


# Cross-Validation Setup 
skf_outer = StratifiedKFold(n_splits=OUTER_K_FOLDS, shuffle=True, random_state=42)
skf_inner = StratifiedKFold(n_splits=INNER_K_FOLDS, shuffle=True, random_state=84) 

# Lists to store results for each Outer Fold
schaefer_outer_fold_history = []
all_schaefer_preds = []
all_schaefer_labels = []
all_schaefer_probs = []

# Load Full Dataset Info for subject list and labels 
full_dataset_schaefer_info = fMRIDataset(
    root='./data/schaefer_full_cv_info_tuning',
    cn_fc_dir=CN_FC_SCHAEFER_DIR, cn_roi_dir=CN_ROI_SCHAEFER_DIR,
    ad_fc_dir=AD_FC_SCHAEFER_DIR, ad_roi_dir=AD_ROI_SCHAEFER_DIR,
    k_neighbors=10
)
dataset_labels = [data.y.item() for data in full_dataset_schaefer_info]
full_subject_list = full_dataset_schaefer_info.all_subjects


# Outer Cross-Validation Loop 
for outer_fold, (train_idx_outer, test_idx_outer) in enumerate(skf_outer.split(np.arange(len(full_subject_list)), dataset_labels)):
    print(f"\n=============== OUTER FOLD {outer_fold+1}/{OUTER_K_FOLDS} ================")
    print("\n--- Starting Inner CV Loop for Hyperparameter Tuning ---")

    inner_tuning_results = []  #hyperparmeter tuning results

    # Inner Cross-Validation for Hyperparameter Tuning
    for i, current_hyperparams in enumerate(random_combinations):
        print(f"\n  Testing Hyperparams {i+1}/{len(random_combinations)}: {current_hyperparams}")

        inner_fold_recalls = [] # Store recall for each inner fold
        current_k_neighbors = current_hyperparams['k_neighbors']
        
        dataset_schaefer_inner_cv = fMRIDataset(
             root=f'./data/schaefer_outer_{outer_fold+1}_inner_cv_k{current_k_neighbors}',
             cn_fc_dir=CN_FC_SCHAEFER_DIR, cn_roi_dir=CN_ROI_SCHAEFER_DIR,
             ad_fc_dir=AD_FC_SCHAEFER_DIR, ad_roi_dir=AD_ROI_SCHAEFER_DIR,
             k_neighbors=current_k_neighbors
        )

        outer_train_labels = [dataset_labels[i] for i in train_idx_outer]

        for inner_fold, (train_idx_inner, val_idx_inner) in enumerate(skf_inner.split(train_idx_outer, outer_train_labels)):
            # The indices from skf_inner are relative to train_idx_outer, so we map them back
            actual_train_idx_inner = [train_idx_outer[i] for i in train_idx_inner]
            actual_val_idx_inner = [train_idx_outer[i] for i in val_idx_inner]

            # Subsets and DataLoaders for the inner fold
            train_dataset_inner = torch.utils.data.Subset(dataset_schaefer_inner_cv, actual_train_idx_inner)
            val_dataset_inner = torch.utils.data.Subset(dataset_schaefer_inner_cv, actual_val_idx_inner)
            train_loader_inner = DataLoader(train_dataset_inner, batch_size=BATCH_SIZE, shuffle=True)
            val_loader_inner = DataLoader(val_dataset_inner, batch_size=BATCH_SIZE, shuffle=False)

            # Initialise and Train Model for Inner Fold
            model_inner = GCN(
                num_node_features=dataset_schaefer_inner_cv.num_node_features,
                num_classes=dataset_schaefer_inner_cv.num_classes,
                hidden_channels=current_hyperparams['hidden_channels']
            ).to(device)
            optimizer_inner = torch.optim.Adam(model_inner.parameters(), lr=current_hyperparams['learning_rate'])
            criterion_inner = torch.nn.CrossEntropyLoss()

            DROPOUT_RATE = current_hyperparams['dropout_rate']

            for epoch_inner in range(1, INNER_NUM_EPOCHS + 1):
                train_model(model_inner, train_loader_inner, criterion_inner, optimizer_inner, device)

            # Evaluate on inner validation set and store recall
            _, _, _, val_recall, _, _, _ = test_model(model_inner, val_loader_inner, device)
            inner_fold_recalls.append(val_recall)

            # Memory Cleanup for Inner Fold 
            del model_inner, optimizer_inner, criterion_inner
            del train_dataset_inner, val_dataset_inner, train_loader_inner, val_loader_inner
            gc.collect()
            torch.cuda.empty_cache()

        # Calculate average recall across inner folds 
        avg_inner_recall = np.mean(inner_fold_recalls)
        print(f"  --> Avg. Inner CV Recall for these hyperparams: {avg_inner_recall:.4f}")
        inner_tuning_results.append({
            'hyperparams': current_hyperparams,
            'avg_recall': avg_inner_recall
        })
        # Cleanup dataset for this k_neighbors
        del dataset_schaefer_inner_cv
        gc.collect()


    # Find the best hyperparameters based on the highest average inner CV recall
    best_performing_set = max(inner_tuning_results, key=lambda x: x['avg_recall'])
    best_hyperparams = best_performing_set['hyperparams']
    best_avg_recall = best_performing_set['avg_recall']

    print("\n--- Inner CV Complete. Best Hyperparameters for Outer Fold "
          f"{outer_fold+1}: {best_hyperparams} (Avg. Inner Recall: {best_avg_recall:.4f}) ---")

    # Train Final Model on Outer Training Set with Best Hyperparameters 
    print(f"\nFold {outer_fold+1}: Training Final Schaefer Model with Best Hyperparameters...")

    # Extract best hyperparams 
    best_k_neighbors = best_hyperparams['k_neighbors']
    best_hidden_channels = best_hyperparams['hidden_channels']
    best_dropout_rate = best_hyperparams['dropout_rate']
    best_learning_rate = best_hyperparams['learning_rate']

    # Create datasets and dataloaders for the outer loop
    dataset_schaefer_outer_train = fMRIDataset(
        root=f'./data/schaefer_outer_fold_{outer_fold+1}_final_train_k{best_k_neighbors}',
        cn_fc_dir=CN_FC_SCHAEFER_DIR, cn_roi_dir=CN_ROI_SCHAEFER_DIR,
        ad_fc_dir=AD_FC_SCHAEFER_DIR, ad_roi_dir=AD_ROI_SCHAEFER_DIR,
        k_neighbors=best_k_neighbors
    )
    train_dataset_outer = torch.utils.data.Subset(dataset_schaefer_outer_train, train_idx_outer)
    train_loader_outer = DataLoader(train_dataset_outer, batch_size=BATCH_SIZE, shuffle=True)

    dataset_schaefer_outer_test = fMRIDataset(
        root=f'./data/schaefer_outer_fold_{outer_fold+1}_final_test_k{best_k_neighbors}',
        cn_fc_dir=CN_FC_SCHAEFER_DIR, cn_roi_dir=CN_ROI_SCHAEFER_DIR,
        ad_fc_dir=AD_FC_SCHAEFER_DIR, ad_roi_dir=AD_ROI_SCHAEFER_DIR,
        k_neighbors=best_k_neighbors
    )
    test_dataset_outer = torch.utils.data.Subset(dataset_schaefer_outer_test, test_idx_outer)
    test_loader_outer = DataLoader(test_dataset_outer, batch_size=BATCH_SIZE, shuffle=False)

    # Initialise Final Model with Best Hyperparameters
    model_final = GCN(
        num_node_features=dataset_schaefer_outer_train.num_node_features,
        num_classes=dataset_schaefer_outer_train.num_classes,
        hidden_channels=best_hidden_channels
    ).to(device)
    optimizer_final = torch.optim.Adam(model_final.parameters(), lr=best_learning_rate)
    criterion_final = torch.nn.CrossEntropyLoss()

    # Set the best dropout rate
    DROPOUT_RATE = best_dropout_rate

    # Train the final model on the full outer training set
    for epoch_outer in range(1, OUTER_NUM_EPOCHS + 1):
        loss_outer = train_model(model_final, train_loader_outer, criterion_final, optimizer_final, device)
        if epoch_outer % 10 == 0:
             print(f'  Outer Epoch: {epoch_outer:03d}, Train Loss: {loss_outer:.4f}')

    # Evaluate the Perfomance of Final Model on Outer Test Set 
    test_accuracy, test_f1, test_auc, test_recall, preds_schaefer_test, labels_schaefer_test, probs_schaefer_test = test_model(model_final, test_loader_outer, device)

    # Store Outer Fold Results
    schaefer_outer_fold_history.append({
        'fold': outer_fold + 1,
        'best_hyperparams': best_hyperparams,
        'test_accuracy': test_accuracy,
        'test_f1': test_f1,
        'test_auc': test_auc,
        'test_recall': test_recall
    })
    all_schaefer_preds.extend(preds_schaefer_test)
    all_schaefer_labels.extend(labels_schaefer_test)
    all_schaefer_probs.extend(probs_schaefer_test)

    print(f"Outer Fold {outer_fold+1} Test Accuracy: {test_accuracy:.4f}, AUC: {test_auc:.4f}, Recall (AD): {test_recall:.4f}")

    #  Memory Cleanup for Outer Loop
    del model_final, optimizer_final, criterion_final
    del train_dataset_outer, test_dataset_outer, train_loader_outer, test_loader_outer
    del dataset_schaefer_outer_train, dataset_schaefer_outer_test
    gc.collect()
    torch.cuda.empty_cache()

print("\n=============== Outer Cross-Validation Complete ===============")

# Final Result Analysis 
if schaefer_outer_fold_history:
    avg_acc = np.mean([f['test_accuracy'] for f in schaefer_outer_fold_history])
    avg_auc = np.mean([f['test_auc'] for f in schaefer_outer_fold_history])
    avg_recall = np.mean([f['test_recall'] for f in schaefer_outer_fold_history])
    avg_f1 = np.mean([f['test_f1'] for f in schaefer_outer_fold_history])


    print(f"\n{OUTER_K_FOLDS}-Fold Nested Cross-validation Result (SCHAEFER with KNN Graph and Random Search Tuning):")
    print(f"  - Average Test Accuracy: {avg_acc:.4f}")
    print(f"  - Average Test AUC: {avg_auc:.4f}")
    print(f"  - Average Test Recall (AD): {avg_recall:.4f}")
    print(f"  - Average Test F1-Score: {avg_f1:.4f}")

    # Overall Confusion Matrix
    all_schaefer_labels_np = np.array(all_schaefer_labels)
    all_schaefer_preds_np = np.array(all_schaefer_preds)
    cm = confusion_matrix(all_schaefer_labels_np, all_schaefer_preds_np)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=['CN', 'AD'], yticklabels=['CN', 'AD'])
    plt.title(f'{OUTER_K_FOLDS}-Fold Nested CV Confusion Matrix (Schaefer Tuned Model)')
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.show()

    # Overall Classification Report
    print("\nOverall Classification Report:")
    print(classification_report(all_schaefer_labels_np, all_schaefer_preds_np, target_names=['CN', 'AD'], digits=4))

    # Overall ROC Curve
    all_schaefer_probs_np = np.array(all_schaefer_probs)
    if all_schaefer_probs_np.ndim > 1 and all_schaefer_probs_np.shape[1] > 1:
         probs_positive_class = all_schaefer_probs_np[:, 1]
         if len(all_schaefer_labels_np) == len(probs_positive_class):
             fpr, tpr, thresholds = roc_curve(all_schaefer_labels_np, probs_positive_class)
             roc_auc_overall = auc(fpr, tpr)

             plt.figure(figsize=(8, 8))
             plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (overall area = {roc_auc_overall:.4f})')
             plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
             plt.xlim([0.0, 1.0])
             plt.ylim([0.0, 1.05])
             plt.xlabel('False Positive Rate')
             plt.ylabel('True Positive Rate')
             plt.title('Overall Receiver Operating Characteristic (ROC) Curve for Schaefer Tuned Model')
             plt.legend(loc="lower right")
             plt.show()
         else:
             print(f"Error: Length of overall labels ({len(all_schaefer_labels_np)}) does not match length of overall probabilities ({len(probs_positive_class)}). Cannot plot overall ROC curve.")
    else:
         print("Warning: Unexpected shape for overall Schaefer probabilities. Cannot plot overall ROC curve.")

else:
    print("Couldn't execute nested cross-validation with hyperparameter tuning.")

print("\n" + "="*80)
print(" Schaefer Model Nested Cross-Validation with Hyperparameter Tuning Complete")
print("="*80 + "\n")

## Hard Voting + Tie Based on Recall Weighted Sum Ensemble (with Hyperparameter Tuning)
- Performing Hard Voting and Weighted Sum Ensemble 
- AAL and Schaefer model for predictions

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, classification_report, roc_curve, auc

print("\n" + "="*80)
print(" Starting Hard Voting + Tie Based on Recall Weighted Sum Ensemble (using Tuned Models)")
print("="*80 + "\n")


# Load AAL and Schaefer results from the hyperparameter tuning CV

try:
    # AAL Tuning Results 
    aal_preds = np.array(all_aal_preds)
    true_labels = np.array(all_aal_labels) 
    aal_probs = np.array(all_aal_probs)
    print("✅ Tuned AAL results variables found.")
except NameError:
    print("--- !!! ERROR !!! ---")
    print("Tuned AAL results variables (`all_aal_preds`, `all_aal_labels`, `all_aal_probs`) not found.")
    print("Please ensure you have run the AAL CV with Hyperparameter Tuning cell successfully.")
    raise NameError("Tuned AAL results not found. Cannot perform ensemble.")


try:
    # Schaefer Tuning Results
    schaefer_preds = np.array(all_schaefer_preds)
    schaefer_probs = np.array(all_schaefer_probs)
    print("✅ Tuned Schaefer results variables found.")
except NameError:
    print("--- !!! ERROR !!! ---")
    print("Tuned Schaefer results variables (`all_schaefer_preds`, `all_schaefer_labels`, `all_schaefer_probs`) not found.")
    print("Please ensure you have run the Schaefer CV with Hyperparameter Tuning cell successfully.")
    raise NameError("Tuned Schaefer results not found. Cannot perform ensemble.")

# Evaluate the outer fold history to calculate average Recall for each model
try:
    # Compute average Recall for AAL and Schaefer models
    aal_avg_recall = np.mean([f['test_recall'] for f in aal_outer_fold_history])
    schaefer_avg_recall = np.mean([f['test_recall'] for f in schaefer_outer_fold_history])

    # Total Recall sum for weighted average calculation
    total_recall_sum = aal_avg_recall + schaefer_avg_recall

    # Weight Calculation : w_j = Recall_j / sum(Recall_i)
    if total_recall_sum > 0:
        aal_weight = aal_avg_recall / total_recall_sum
        schaefer_weight = schaefer_avg_recall / total_recall_sum
    else:
        print("Warning: Total average Recall sum is zero. Using equal weights (0.5) for tie-breaking.")
        aal_weight = 0.5
        schaefer_weight = 0.5

    print(f"Using Average Recall-based weights for tie-breaking: AAL={aal_weight:.4f}, Schaefer={schaefer_weight:.4f}")

except NameError:
    print("Warning: Could not access outer_fold_history to get average Recalls for weighted sum calculation.")
    print("Using equal weights (0.5) for tie-breaking.")
    aal_weight = 0.5
    schaefer_weight = 0.5
except Exception as e:
    print(f"Warning: Error calculating Average Recall-based weights: {e}")
    print("Using equal weights (0.5) for tie-breaking.")
    aal_weight = 0.5
    schaefer_weight = 0.5


# Ensemble Prediction (Hard Voting with Recall Weighted Sum for Ties) 

def ensemble_predict_hard_weighted_tie(aal_preds, aal_probs, schaefer_preds, schaefer_probs, aal_weight, schaefer_weight):
    """
    Hard Voting 기반 앙상블을 수행하며, 동점일 경우 Recall 가중합으로 최종 예측을 결정합니다.
    """
    num_samples = len(aal_preds)
    ensemble_preds = np.zeros(num_samples, dtype=int)

    for i in range(num_samples):
        pred1 = aal_preds[i]
        pred2 = schaefer_preds[i]

        if pred1 == pred2:
            # Hard Voting  
            ensemble_preds[i] = pred1
        else:
            # Tie -> use weighted sum of recall-based probabilities
            weighted_probs = (aal_probs[i] * aal_weight) + (schaefer_probs[i] * schaefer_weight)
            ensemble_preds[i] = np.argmax(weighted_probs)

    return ensemble_preds


print("\n--- Performing Ensemble Prediction ---")
ensemble_predictions = ensemble_predict_hard_weighted_tie(aal_preds, aal_probs, schaefer_preds, schaefer_probs, aal_weight, schaefer_weight)


# Ensemble Model Performance Evaluation
print("\n--- Evaluating Ensemble Performance ---")

ensemble_accuracy = accuracy_score(true_labels, ensemble_predictions)
ensemble_f1 = f1_score(true_labels, ensemble_predictions, average='weighted')
ensemble_cm = confusion_matrix(true_labels, ensemble_predictions)
ensemble_classification_report_str = classification_report(true_labels, ensemble_predictions, target_names=['CN', 'AD'], digits=4)
ensemble_classification_report_dict = classification_report(true_labels, ensemble_predictions, target_names=['CN', 'AD'], output_dict=True)

# Print Evaluation Results
print(f"Ensemble Accuracy: {ensemble_accuracy:.4f}")
print(f"Ensemble F1-Score (weighted): {ensemble_f1:.4f}")
print(f"Ensemble Recall (AD): {ensemble_classification_report_dict['AD']['recall']:.4f}")

print("\nEnsemble Classification Report:")
print(ensemble_classification_report_str)



# Ensemble Model Results Visualization

# Confusion Matrix
plt.figure(figsize=(8, 6))
sns.heatmap(ensemble_cm, annot=True, fmt='d', cmap='Blues', xticklabels=['CN', 'AD'], yticklabels=['CN', 'AD'])
plt.title('Hard Voting + Tie Based on Recall Weighted Sum Confusion Matrix (AAL + Schaefer Tuned Models)')
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.show()

# ROC Curve
weighted_ensemble_probs_for_roc = (aal_probs * aal_weight) + (schaefer_probs * schaefer_weight)
probs_positive_class_ensemble_for_roc = weighted_ensemble_probs_for_roc[:, 1]


if len(true_labels) == len(probs_positive_class_ensemble_for_roc):
    fpr, tpr, thresholds = roc_curve(true_labels, probs_positive_class_ensemble_for_roc)
    roc_auc_ensemble = auc(fpr, tpr)

    plt.figure(figsize=(8, 8))
    plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc_ensemble:.4f})')
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Hard Voting + Tie Based on Recall Weighted Sum ROC Curve (Tuned Models)')
    plt.legend(loc="lower right")
    plt.show()

    print(f"Ensemble AUC: {roc_auc_ensemble:.4f}")
else:
     print(f"Error: Length of true labels ({len(true_labels)}) does not match length of ensemble probabilities ({len(probs_positive_class_ensemble_for_roc)}). Cannot plot ROC curve.")

print("\n--- Ensemble Analysis Complete ---")