# Brain Data Classification using Graph-Based Models

## 1. Imports, Configuration, Seeding & Device Setup

This cell handles all necessary imports, sets up global configuration parameters, initializes random seeds for reproducibility, and selects the appropriate compute device (CPU or GPU).

In [1]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch_geometric.nn import GCNConv, ChebConv, GATConv, global_mean_pool, global_add_pool, global_max_pool
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data, Dataset
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import RobustScaler
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, f1_score
from tqdm import tqdm
import json
import optuna
import copy
import os
from collections import Counter
import inspect

# --- Configuration ---
SEED = 42
N_SPLITS = 5      # K for K-Fold Cross-Validation (used for final evaluation)
EPOCHS = 150      # Max epochs for training (Early Stopping applies)
PATIENCE = 20     # Patience for Early Stopping
N_HPO_TRIALS = 50 # Number of trials per Optuna study
N_TOP_MODELS = 5  # Number of best configurations to include in the final ensemble
N_INTERNAL_CV_SPLITS = 3 # Number of folds for internal CV within HPO objective

# Flags to control script execution
RUN_HPO = False            # True for running Hyperparameter Optimization
RUN_FINAL_TRAINING = True # True for trainning final models using best HPO params

# Feature Set Selection
# Options: 'basic' (5 wave bands), 'with_ratios' (5 bands + 2 ratios), 'full_domain' (5 bands + ratios + region/hemi OHE)
FEATURE_SET = 'with_ratios'

# Adjacency types to generate and potentially test
ADJACENCY_TYPES = ["KNN", "Threshold", "AnatomicalRegion", "CustomBrainNetwork"]

# --- Seed for Reproducibility ---
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    # Setting deterministic increases reproducibility
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# --- Device Selection ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


  from .autonotebook import tqdm as notebook_tqdm


## 2. Electrode Mapping Utilities

Defines functions to map EEG channel indices to standard 10-10 labels, group channels by brain region and hemisphere.

In [2]:
def get_electrode_mapping_info():
    """
    Provides mappings: 64ch index -> 10-10 label, 10-10 label -> region,
    and generates region/hemisphere group lists needed for custom adjacency.
    Uses a combination of Code A and Code B mapping logic.
    Returns:
        electrode_to_1010 (dict): Map from 1-64 index to 10-10 label string.
        electrode_to_region (dict): Map from 0-63 index to simplified region string.
        region_groups (dict): Dict mapping region names to lists of 0-63 indices.
        hemisphere_groups (dict): Dict mapping 'left', 'right', 'midline' to 0-63 indices.
    """
    # Map electrode numbers (1-64) to 10-10 positions
    electrode_to_1010 = {
        1: 'F10', 2: 'AF4', 3: 'F2', 4: 'FCz', 5: 'FP2', 6: 'Fz', 7: 'FC1', 8: 'AFz',
        9: 'F1', 10: 'FP1', 11: 'AF3', 12: 'F3', 13: 'F5', 14: 'FC5', 15: 'FC3',
        16: 'C1', 17: 'F9', 18: 'F7', 19: 'FT7', 20: 'C3', 21: 'CP1', 22: 'C5',
        23: 'T9', 24: 'T7', 25: 'TP7', 26: 'CP5', 27: 'P5', 28: 'P3', 29: 'TP9',
        30: 'P7', 31: 'P1', 32: 'P9', 33: 'PO3', 34: 'Pz', 35: 'O1', 36: 'POz',
        37: 'Oz', 38: 'PO4', 39: 'O2', 40: 'P2', 41: 'CP2', 42: 'P4', 43: 'P10',
        44: 'P8', 45: 'P6', 46: 'CP6', 47: 'TP10', 48: 'TP8', 49: 'C6', 50: 'C4',
        51: 'C2', 52: 'T8', 53: 'FC4', 54: 'FC2', 55: 'T10', 56: 'FC6', 57: 'N/A',
        58: 'F8', 59: 'F6', 60: 'F4', 61: 'N/A', 62: 'N/A', 63: 'N/A', 64: 'N/A'
    }

    # Define brain regions based on 10-10 system
    frontal_labels = ['Fp1', 'Fp2', 'AF3', 'AF4', 'AF7', 'AF8', 'Fz', 'F1', 'F2', 'F3', 'F4', 'F5', 'F6', 'F7', 'F8', 'F9', 'F10', 'AFz', 'FPz']
    central_labels = ['Cz', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'FCz', 'FC1', 'FC2', 'FC3', 'FC4', 'FC5', 'FC6']
    temporal_labels = ['T7', 'T8', 'T9', 'T10', 'FT7', 'FT8', 'FT9', 'FT10', 'TP7', 'TP8', 'TP9', 'TP10'] # Grouped T/FT/TP
    parietal_labels = ['Pz', 'P1', 'P2', 'P3', 'P4', 'P5', 'P6', 'P7', 'P8', 'P9', 'P10', 'CPz', 'CP1', 'CP2', 'CP3', 'CP4', 'CP5', 'CP6']
    occipital_labels = ['Oz', 'O1', 'O2', 'POz', 'PO3', 'PO4', 'PO7', 'PO8']

    # Create mapping from electrode indices (0-63) to simplified brain regions
    electrode_to_region = {}
    regions_simplified = ['frontal', 'central', 'temporal', 'parietal', 'occipital']
    region_label_map = {
        'frontal': frontal_labels, 'central': central_labels, 'temporal': temporal_labels,
        'parietal': parietal_labels, 'occipital': occipital_labels
    }

    for i in range(1, 65):
        electrode_label = electrode_to_1010.get(i, 'N/A')
        assigned_region = 'unknown'
        if electrode_label != 'N/A':
            # Find which simplified region contains this label (case-insensitive check might be needed)
            for region_name, labels in region_label_map.items():
                # Check variations like Fz vs FZ
                if electrode_label.upper() in [l.upper() for l in labels]:
                    assigned_region = region_name
                    break
        electrode_to_region[i-1] = assigned_region # Use 0-based index

    # Group electrodes by region (0-based indices)
    region_groups = {region_name: [] for region_name in regions_simplified + ['unknown']}
    for idx, region_name in electrode_to_region.items():
        if region_name in region_groups:
             region_groups[region_name].append(idx)
        else: # Should not happen with the init above, but safety check
             region_groups['unknown'].append(idx)

    # Hemisphere grouping (0-based indices)
    left_labels = [label for label in electrode_to_1010.values()
                   if label != 'N/A' and (label.endswith(('1', '3', '5', '7', '9')) or label.startswith(('Fp1', 'AF3', 'AF7', 'F1','F3','F5','F7','F9','FC1','FC3','FC5','FT7','FT9', 'C1','C3','C5','CP1','CP3','CP5','TP7','TP9','P1','P3','P5','P7','P9','PO3','PO7','O1')))] # More explicit L/R
    right_labels = [label for label in electrode_to_1010.values()
                    if label != 'N/A' and (label.endswith(('2', '4', '6', '8', '0')) or label.startswith(('Fp2','AF4','AF8','F2','F4','F6','F8','F10','FC2','FC4','FC6','FT8','FT10','C2','C4','C6','CP2','CP4','CP6','TP8','TP10','P2','P4','P6','P8','P10','PO4','PO8','O2')))] # 0 for F10/T10 etc.
    midline_labels = [label for label in electrode_to_1010.values() if label != 'N/A' and label.endswith('z')]

    left_electrodes = [i-1 for i in range(1, 65) if electrode_to_1010.get(i, 'N/A') in left_labels]
    right_electrodes = [i-1 for i in range(1, 65) if electrode_to_1010.get(i, 'N/A') in right_labels]
    midline_electrodes = [i-1 for i in range(1, 65) if electrode_to_1010.get(i, 'N/A') in midline_labels]

    hemisphere_groups = {
        'left': left_electrodes,
        'right': right_electrodes,
        'midline': midline_electrodes
    }


    return electrode_to_1010, electrode_to_region, region_groups, hemisphere_groups

## 3. Data Loading & Feature Engineering

Loads the EEG data and channel coordinates. Preprocesses features using RobustScaler and adds engineered features (e.g., band ratios) based on the `FEATURE_SET` configuration.

In [3]:
def load_and_preprocess_data(feature_set_config):
    """
    Load EEG data and channel coordinates.
    Preprocess basic wave features using RobustScaler.

    Args:
        feature_set_config (str): 'basic', 'with_ratios', 'full_domain'

    Returns:
        X_node (np.ndarray): Node features (n_samples, n_nodes, n_features)
        y (np.ndarray): Target labels (n_samples,)
        channel_coords (pd.DataFrame): DataFrame with channel coordinates.
        electrode_mapping_info (dict): Contains dictionaries from get_electrode_mapping_info().
    """
    print(f"Loading data and preparing feature set: '{feature_set_config}'")
    eeg_data = pd.read_csv('EEG data - Sheet1.csv', index_col=0)
    channel_coords = pd.read_csv('3d_channel_coord.csv')

    # Get electrode mapping info needed for feature engineering and custom graph
    electrode_to_1010, electrode_to_region, region_groups, hemisphere_groups = get_electrode_mapping_info()
    electrode_mapping_info = {
        "electrode_to_1010": electrode_to_1010,
        "electrode_to_region": electrode_to_region,
        "region_groups": region_groups,
        "hemisphere_groups": hemisphere_groups
    }


    y = eeg_data['target'].values
    wave_types = ['alpha', 'beta', 'delta', 'theta', 'gamma']
    n_channels = 64
    n_samples = len(eeg_data)

    # --- 1. Extract and Scale Basic Wave Features ---
    X_by_wave_raw = {}
    X_by_wave_scaled = {}

    for wave in wave_types:
        wave_features_raw = np.zeros((n_samples, n_channels))
        for ch in range(1, n_channels + 1):
            col_name = f"{wave}{ch}"
            if col_name in eeg_data.columns:
                wave_features_raw[:, ch-1] = eeg_data[col_name].values
            else:
                 # Handle missing columns if necessary (e.g., fill with zeros or mean)
                 print(f"Warning: Column {col_name} not found in EEG data.")
        X_by_wave_raw[wave] = wave_features_raw

        # Apply RobustScaler per wave type (scales features across samples and channels for that wave)
        scaler = RobustScaler()
        X_by_wave_scaled[wave] = scaler.fit_transform(X_by_wave_raw[wave])

    # Basic node features: 5 scaled frequency bands
    X_node_basic = np.zeros((n_samples, n_channels, len(wave_types)))
    for i, wave in enumerate(wave_types):
        X_node_basic[:, :, i] = X_by_wave_scaled[wave]

    # --- 2. Adding Engineered Features ---
    features_to_concat = [X_node_basic]

    if feature_set_config == 'basic':
        print("Using basic features (5 wave bands).")

    elif feature_set_config == 'with_ratios' or feature_set_config == 'full_domain':
        print("Calculating ratio features...")
        ratio_features = np.zeros((n_samples, n_channels, 2))
        # Calculate ratios using SCALED data for better numerical stability
        epsilon = 1e-10
        # Ensure denominators are safe
        theta_scaled = X_by_wave_scaled['theta']
        ratio_features[:, :, 0] = X_by_wave_scaled['alpha'] / (theta_scaled + epsilon)
        ratio_features[:, :, 1] = X_by_wave_scaled['beta'] / (theta_scaled + epsilon)
        features_to_concat.append(ratio_features)
        print(f"Added 2 ratio features.")

    if feature_set_config == 'full_domain':
        print("Calculating region and hemisphere one-hot features...")
        # Region features (one-hot encoding based on electrode_to_region map)
        regions_list = sorted([r for r in region_groups.keys() if r != 'unknown']) # Exclude unknown?
        n_regions = len(regions_list)
        region_map = {name: i for i, name in enumerate(regions_list)}
        region_ohe_features = np.zeros((n_samples, n_channels, n_regions))
        for sample_idx in range(n_samples):
             for ch_idx in range(n_channels):
                 region_name = electrode_to_region.get(ch_idx, 'unknown')
                 if region_name in region_map:
                     region_ohe_features[sample_idx, ch_idx, region_map[region_name]] = 1
        features_to_concat.append(region_ohe_features)
        print(f"Added {n_regions} region OHE features.")

        # Hemisphere features (one-hot encoding based on hemisphere_groups map)
        hemispheres_list = ['left', 'right', 'midline']
        n_hemispheres = len(hemispheres_list)
        hemisphere_ohe_features = np.zeros((n_samples, n_channels, n_hemispheres))
        for sample_idx in range(n_samples):
             for ch_idx in range(n_channels):
                 if ch_idx in hemisphere_groups['left']:
                     hemisphere_ohe_features[sample_idx, ch_idx, 0] = 1
                 elif ch_idx in hemisphere_groups['right']:
                     hemisphere_ohe_features[sample_idx, ch_idx, 1] = 1
                 elif ch_idx in hemisphere_groups['midline']:
                     hemisphere_ohe_features[sample_idx, ch_idx, 2] = 1
        features_to_concat.append(hemisphere_ohe_features)
        print(f"Added {n_hemispheres} hemisphere OHE features.")

    # --- 3. Combine Features ---
    X_node = np.concatenate(features_to_concat, axis=2)
    print(f"Final node feature shape: {X_node.shape}") # (n_samples, n_nodes, n_features)

    # Basic check for NaN/Inf values
    if np.isnan(X_node).any() or np.isinf(X_node).any():
        print("Warning: NaN or Inf values detected in final node features! Check scaling and ratios.")
        X_node = np.nan_to_num(X_node, nan=0.0, posinf=0.0, neginf=0.0)

    return X_node, y, channel_coords, electrode_mapping_info

## 4. Graph Construction Utilities

Defines functions to create adjacency matrices based on different strategies:
- K-Nearest Neighbors (KNN)
- Distance Threshold
- Anatomical Regions (simplified)
- Custom Brain Network (based on functional connectivity priors)

In [4]:
def create_custom_brain_adjacency(electrode_mapping_info, binary=True):
    """
    Create adjacency matrix based on brain regions and functional networks.
    Refined to be binary by default, reflecting connection presence.

    Args:
        electrode_mapping_info (dict): Contains region_groups, hemisphere_groups, electrode_to_1010.
        binary (bool): binary is always True, connections are 1/0. I removed possibility edge weights.

    Returns:
        adj_matrix (np.ndarray): The adjacency matrix.
    """
    region_groups = electrode_mapping_info['region_groups']
    hemisphere_groups = electrode_mapping_info['hemisphere_groups']
    # Ensure electrode_to_1010 maps 1-based index to label
    electrode_to_1010 = electrode_mapping_info['electrode_to_1010']
    # Create reverse map: label -> 0-based index for easier lookup
    label_to_index0 = {v: k-1 for k, v in electrode_to_1010.items() if v != 'N/A'}


    n_nodes = 64
    adj_matrix = np.zeros((n_nodes, n_nodes))

    # Use binary weights (1.0 for connection)
    intra_region_w = 1.0
    fronto_parietal_w = 1.0
    # Identifying TP electrodes:
    tp_labels = ['TP7','TP8','TP9','TP10']
    # Use the label_to_index0 map
    tp_indices_0based = {label_to_index0.get(label) for label in tp_labels if label in label_to_index0}
    # Remove None if a label wasn't found (shouldn't happen)
    tp_indices_0based.discard(None)

    occip_temporal_w = 1.0
    midline_w = 1.0
    homologous_w = 1.0

    # 1. Connect electrodes within the same brain region (excluding 'unknown')
    for region, electrodes in region_groups.items():
        if region == 'unknown': continue
        valid_electrodes = [e for e in electrodes if e < n_nodes] # Ensure index is within bounds
        for i in valid_electrodes:
            for j in valid_electrodes:
                if i != j:
                    adj_matrix[i, j] = max(adj_matrix[i, j], intra_region_w) # Use max to avoid overwriting stronger links

    # 2. Connect electrodes in functional networks

    # Frontoparietal network
    frontal_idx = [e for e in region_groups.get('frontal', []) if e < n_nodes]
    parietal_idx = [e for e in region_groups.get('parietal', []) if e < n_nodes]
    for i in frontal_idx:
        for j in parietal_idx:
             adj_matrix[i, j] = max(adj_matrix[i, j], fronto_parietal_w)
             adj_matrix[j, i] = max(adj_matrix[j, i], fronto_parietal_w)

    # Occipital-Temporal connections (simplified: connect all occipital to TP)
    occipital_idx = [e for e in region_groups.get('occipital', []) if e < n_nodes]
    temporal_idx = [e for e in region_groups.get('temporal', []) if e < n_nodes]
    # Connecting occipital to temporo-parietal (TP*) sites
    for i in occipital_idx:
        for j in temporal_idx:
            if j in tp_indices_0based: # Connect only to TP electrodes
                 adj_matrix[i, j] = max(adj_matrix[i, j], occip_temporal_w)
                 adj_matrix[j, i] = max(adj_matrix[j, i], occip_temporal_w)

    # Midline network (connect all midline electrodes)
    midline_idx = [e for e in hemisphere_groups.get('midline', []) if e < n_nodes]
    for i in midline_idx:
        for j in midline_idx:
            if i != j:
                adj_matrix[i, j] = max(adj_matrix[i, j], midline_w)


    # 3. Connect homologous areas across hemispheres
    left_idx = [e for e in hemisphere_groups.get('left', []) if e < n_nodes]
    for i in left_idx:
        label_i = electrode_to_1010.get(i + 1, '') # Get 1-based label
        if not label_i or label_i == 'N/A': continue

        # Find corresponding right label (handling numbers)
        label_j = ''
        # Handle T9/P9 -> T10/P10 specifically
        if label_i in ['T9', 'P9', 'TP9']: # Add other potential 9s if needed
            label_j = label_i[:-1] + '10'
        elif label_i in ['F9']: # F9 doesn't have F10, maybe connects to F8?
             label_j = 'F8'
        elif label_i[-1].isdigit(): # Ends in 1, 3, 5, 7
            num = int(label_i[-1])
            if num % 2 != 0: # Is odd
                label_j = label_i[:-1] + str(num + 1)

        if label_j:
            # Find the 0-based index j for label_j using the reverse map
            found_j = label_to_index0.get(label_j, -1)

            if found_j != -1 and found_j < n_nodes:
                 adj_matrix[i, found_j] = max(adj_matrix[i, found_j], homologous_w)
                 adj_matrix[found_j, i] = max(adj_matrix[found_j, i], homologous_w)


    # Ensure no self-connections
    np.fill_diagonal(adj_matrix, 0)
    return adj_matrix


def create_adjacency_matrices(channel_coords, electrode_mapping_info, k_knn=8, dist_threshold=0.2):
    """
    Create multiple adjacency matrices based on coordinates and anatomical info.

    Args:
        channel_coords (pd.DataFrame): DataFrame with 'X', 'Y', 'Z' columns.
        electrode_mapping_info (dict): Output from get_electrode_mapping_info().
        k_knn (int): Number of neighbors for KNN graph.
        dist_threshold (float): Normalized distance threshold for Threshold graph.

    Returns:
        adj_matrices (dict): Dictionary mapping adjacency type name to np.ndarray matrix.
    """
    # Ensure coords are aligned with 0-63 indexing if channel_coords has labels
    # Assuming channel_coords is already ordered 0-63 or doesn't have conflicting labels
    coords = channel_coords[['X', 'Y', 'Z']].values
    n_nodes = len(coords)
    if n_nodes != 64:
        print(f"Warning: Expected 64 channels based on mapping, but found {n_nodes} coordinates. Using {n_nodes}.")

    # --- Calculate pairwise Euclidean distances (still needed for KNN and Threshold) ---
    dist_matrix = np.zeros((n_nodes, n_nodes))
    for i in range(n_nodes):
        for j in range(i + 1, n_nodes):
            dist = np.sqrt(np.sum((coords[i] - coords[j])**2))
            dist_matrix[i, j] = dist_matrix[j, i] = dist

    # --- Normalize distances (needed for Threshold) ---
    max_dist = np.max(dist_matrix)
    if max_dist > 0:
        norm_dist_matrix = dist_matrix / max_dist
    else:
        norm_dist_matrix = dist_matrix # Avoid division by zero

    adj_matrices = {}

    # 1. KNN adjacency
    knn_adj = np.zeros((n_nodes, n_nodes))
    for i in range(n_nodes):
        # Handle cases where k_knn >= n_nodes
        k_actual = min(k_knn, n_nodes - 1)
        if k_actual > 0:
            neighbors = np.argsort(dist_matrix[i])[1:k_actual+1] # Exclude self
            knn_adj[i, neighbors] = 1
            knn_adj[neighbors, i] = 1 # Make symmetric
    adj_matrices["KNN"] = knn_adj

    # 2. Distance threshold adjacency
    threshold_adj = (norm_dist_matrix < dist_threshold).astype(np.float32)
    np.fill_diagonal(threshold_adj, 0)
    adj_matrices["Threshold"] = threshold_adj

    # 3. Anatomical Region-based adjacency (Connect nodes within the same *simplified* region)
    region_adj = np.zeros((n_nodes, n_nodes))
    region_groups = electrode_mapping_info['region_groups']
    for region, electrodes in region_groups.items():
        if region == 'unknown': continue
        valid_electrodes = [e for e in electrodes if e < n_nodes]
        for i in valid_electrodes:
            for j in valid_electrodes:
                 if i != j:
                     region_adj[i, j] = 1 # Binary connection if in same region
    adj_matrices["AnatomicalRegion"] = region_adj

    # 4. Custom Brain Network Adjacency
    custom_adj = create_custom_brain_adjacency(electrode_mapping_info, binary=True)
    adj_matrices["CustomBrainNetwork"] = custom_adj

    return adj_matrices # Return only the adjacency matrices

## 5. PyTorch Geometric Dataset

Defines a custom PyTorch Geometric `Dataset` class (`EEGGraphDataset`) to wrap the EEG data samples. Includes a helper function (`construct_graph_dataset`) to create instances of this dataset, handling the conversion of adjacency matrices to `edge_index` format.

In [5]:
#--------------------------------
# 4. PYTORCH GEOMETRIC DATASET 
#--------------------------------

class EEGGraphDataset(Dataset):
    """PyTorch Geometric Dataset for EEG graph samples (no edge attributes)."""
    def __init__(self, X_node, y, edge_index):
        super().__init__()
        self.X_node = torch.tensor(X_node, dtype=torch.float)
        self.y = torch.tensor(y, dtype=torch.long)
        self.edge_index = edge_index # Should be a single tensor shared across graphs

    def len(self):
        return len(self.y)

    def get(self, idx):
        x = self.X_node[idx]
        target = self.y[idx]

        # Create Data object for the sample (no edge_attr)
        data = Data(x=x, edge_index=self.edge_index, y=target)
        return data

def construct_graph_dataset(X_node, y, adj_matrix):
    """
    Construct PyTorch Geometric Dataset instance (no edge weights).

    Args:
        X_node (np.ndarray): Node features (n_samples, n_nodes, n_features).
        y (np.ndarray): Target labels (n_samples,).
        adj_matrix (np.ndarray): Adjacency matrix for the graph structure.

    Returns:
        EEGGraphDataset: The constructed dataset object.
    """
    edge_index = torch.tensor(np.array(np.where(adj_matrix > 0)), dtype=torch.long)

    pyg_dataset = EEGGraphDataset(X_node, y, edge_index)
    return pyg_dataset

## 6. Model Architectures

Defines the Graph Neural Network architectures to be tested:
- `SimplifiedGNN`: Basic GCN-based model.
- `SpectralGNN`: ChebConv-based model (spectral graph convolution).
- `SimpleGAT`: Graph Attention Network (GAT)-based model.
Includes options for batch normalization and different global pooling strategies.

In [6]:
#-------------------------------------------------------------------------------
# 5. MODEL ARCHITECTURES (Optional BatchNorm, Flexible Pooling, No Edge Weights)
#--------------------------------------------------------------------------------

def get_pooling_layer(pooling_type):
    """Returns the global pooling function based on type string."""
    if pooling_type == 'mean':
        return global_mean_pool
    elif pooling_type == 'add':
        return global_add_pool
    elif pooling_type == 'max':
        return global_max_pool
    else:
        raise ValueError(f"Unsupported pooling type: {pooling_type}")

class SimplifiedGNN(torch.nn.Module):
    """ GCN-based model with optional BatchNorm. No edge weights. """
    def __init__(self, in_channels, hidden_channels=32, out_channels=2, dropout_rate=0.5, pooling_type='mean', use_batch_norm=False):
        super().__init__()
        self.use_batch_norm = use_batch_norm
        self.conv1 = GCNConv(in_channels, hidden_channels)
        if self.use_batch_norm:
             self.bn1 = nn.BatchNorm1d(hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        if self.use_batch_norm:
            self.bn2 = nn.BatchNorm1d(hidden_channels)
        self.dropout = nn.Dropout(dropout_rate)
        self.pool = get_pooling_layer(pooling_type)
        self.out = nn.Linear(hidden_channels, out_channels)

    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index)
        if self.use_batch_norm:
             x = self.bn1(x)
        x = F.relu(x)
        x = self.dropout(x)

        x = self.conv2(x, edge_index)
        if self.use_batch_norm:
             x = self.bn2(x)
        x = F.relu(x)
        x = self.dropout(x) # Dropout after second conv+activation

        x = self.pool(x, batch)
        x = self.out(x)
        return x

class SpectralGNN(torch.nn.Module):
    """ ChebConv-based model"""
    def __init__(self, in_channels, hidden_channels=32, out_channels=2, K=3, dropout_rate=0.5, pooling_type='mean', use_batch_norm=False):
        super().__init__()
        self.use_batch_norm = use_batch_norm
        self.conv1 = ChebConv(in_channels, hidden_channels, K=K)
        if self.use_batch_norm:
             self.bn1 = nn.BatchNorm1d(hidden_channels)
        self.conv2 = ChebConv(hidden_channels, hidden_channels, K=K)
        if self.use_batch_norm:
            self.bn2 = nn.BatchNorm1d(hidden_channels)
        self.dropout = nn.Dropout(dropout_rate)
        self.pool = get_pooling_layer(pooling_type)
        self.out = nn.Linear(hidden_channels, out_channels)

    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index, batch=batch) # Pass batch if needed by ChebConv version
        if self.use_batch_norm:
             x = self.bn1(x)
        x = F.relu(x)
        x = self.dropout(x)

        x = self.conv2(x, edge_index, batch=batch) # Pass batch if needed by ChebConv version
        if self.use_batch_norm:
             x = self.bn2(x)
        x = F.relu(x)
        x = self.dropout(x)

        x = self.pool(x, batch)
        x = self.out(x)
        return x

class SimpleGAT(torch.nn.Module):
    """ GAT-based model with optional BatchNorm. No edge weights. """
    def __init__(self, in_channels, hidden_channels=32, out_channels=2, heads=4, dropout_rate=0.5, pooling_type='mean', use_batch_norm=False):
        super().__init__()
        self.use_batch_norm = use_batch_norm
        # Note: GAT output features are hidden_channels * heads for the first layer if concat=True
        self.conv1 = GATConv(in_channels, hidden_channels, heads=heads, dropout=dropout_rate) # Dropout within GATConv
        if self.use_batch_norm:
             # BN usually applied *after* convolution, *before* non-linearity
             self.bn1 = nn.BatchNorm1d(hidden_channels * heads)

        # Output layer (using 1 head and concat=False)
        # Input channels for conv2 is hidden_channels * heads
        self.conv2 = GATConv(hidden_channels * heads, hidden_channels, heads=1, concat=False, dropout=dropout_rate)
        if self.use_batch_norm:
            self.bn2 = nn.BatchNorm1d(hidden_channels) # Output is hidden_channels

        self.dropout = nn.Dropout(dropout_rate) # Separate dropout layer
        self.pool = get_pooling_layer(pooling_type)
        self.out = nn.Linear(hidden_channels, out_channels) # Output depends on conv2's output size

    def forward(self, x, edge_index, batch):
        # GATConv call without edge_attr
        x = self.conv1(x, edge_index)
        if self.use_batch_norm:
             x = self.bn1(x)
        x = F.relu(x)
        # Dropout can be applied here (after activation) if not relying solely on GAT's internal dropout

        # GATConv call
        x = self.conv2(x, edge_index)
        if self.use_batch_norm:
             x = self.bn2(x)
        x = F.relu(x)

        x = self.pool(x, batch)
        x = self.out(x)
        return x

MODEL_CLASSES = {
    "SimplifiedGNN": SimplifiedGNN,
    "SpectralGNN": SpectralGNN,
    "SimpleGAT": SimpleGAT
}

## 7. Training & Evaluation Core Function

Defines the `train_evaluate_single_model` function, which encapsulates the logic for training one epoch, evaluating on a validation set, implementing early stopping based on validation accuracy, and optionally interacting with Optuna for pruning during hyperparameter search.

In [7]:
#---------------------------------------
# 6. TRAINING & EVALUATION CORE FUNCTION 
#----------------------------------------

def train_evaluate_single_model(model, train_loader, val_loader, optimizer, criterion,
                                device, epochs=100, patience=15, trial=None):
    """Trains and validates a single model instance with early stopping (no edge weights)."""
    model.to(device)
    best_val_acc = 0.0
    best_model_state = None
    epochs_no_improve = 0
    train_losses, val_losses, val_accs = [], [], []

    # --- Determine number of samples for loss normalization ---
    # Handle Subset datasets correctly for length
    # Access the underlying dataset if it's a Subset
    train_dataset_obj = train_loader.dataset.dataset if isinstance(train_loader.dataset, torch.utils.data.Subset) else train_loader.dataset
    val_dataset_obj = val_loader.dataset.dataset if isinstance(val_loader.dataset, torch.utils.data.Subset) else val_loader.dataset
    # Use the indices if it's a Subset
    len_train_dataset = len(train_loader.dataset.indices) if isinstance(train_loader.dataset, torch.utils.data.Subset) else len(train_loader.dataset)
    len_val_dataset = len(val_loader.dataset.indices) if isinstance(val_loader.dataset, torch.utils.data.Subset) else len(val_loader.dataset)

    for epoch in range(1, epochs + 1):
        # --- Training ---
        model.train()
        total_train_loss = 0.0
        for data in train_loader:
            data = data.to(device)
            optimizer.zero_grad()
            # Model call without edge_attr
            out = model(data.x, data.edge_index, data.batch)
            # Ensure target shape matches output shape for CrossEntropyLoss
            # Output shape: [batch_size, num_classes], Target shape: [batch_size]
            loss = criterion(out, data.y)
            loss.backward()
            # Gradient Clipping (can help stability)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            total_train_loss += loss.item() * data.num_graphs # Use num_graphs in batch
        # Normalize loss by total number of samples in the dataset
        avg_train_loss = total_train_loss / len_train_dataset if len_train_dataset > 0 else 0
        train_losses.append(avg_train_loss)


        # --- Validation ---
        model.eval()
        total_val_loss = 0.0
        correct = 0
        with torch.no_grad():
            for data in val_loader:
                data = data.to(device)
                # Model call without edge_attr
                out = model(data.x, data.edge_index, data.batch)
                loss = criterion(out, data.y)
                total_val_loss += loss.item() * data.num_graphs
                pred = out.argmax(dim=1)
                correct += int((pred == data.y).sum())

        avg_val_loss = total_val_loss / len_val_dataset if len_val_dataset > 0 else 0
        val_acc = correct / len_val_dataset if len_val_dataset > 0 else 0
        val_losses.append(avg_val_loss)
        val_accs.append(val_acc)

        # --- Early Stopping & Best Model Saving ---
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            # Save model state dict on CPU to avoid GPU memory issues if many models are trained
            best_model_state = copy.deepcopy({k: v.cpu() for k, v in model.state_dict().items()})
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1

        if epoch % 25 == 0: # Print progress periodically
            print(f'  Epoch: {epoch:03d}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, Val Acc: {val_acc:.4f}')

        # --- Optuna Pruning ---
        if trial:
            trial.report(val_acc, epoch)
            if trial.should_prune():
                print(f"  Trial pruned at epoch {epoch}.")
                raise optuna.TrialPruned()

        if epochs_no_improve >= patience:
            print(f"  Early stopping triggered after epoch {epoch}. Best Val Acc: {best_val_acc:.4f}")
            break

    # Load the best model state before returning
    if best_model_state:
        # Ensure the model is on the correct device before loading state
        model.to(device)
        model.load_state_dict({k: v.to(device) for k, v in best_model_state.items()})
    else:
        print("Warning: No best model state found (possibly due to immediate pruning or no improvement). Using last state.")
        # Ensure model is still on the correct device
        model.to(device)

    return best_val_acc, model # Return best validation accuracy achieved and the model with best weights

## 8. Hyperparameter Optimization (Optuna)

Defines the `objective` function for Optuna, which performs internal cross-validation for a given set of hyperparameters to get a robust performance estimate. Also defines `run_hpo_studies` to manage the optimization process across different model architectures and adjacency matrix types, saving the best parameters found.

In [8]:
#-------------------------------------------------------------------------------
# 7. OPTUNA HYPERPARAMETER OPTIMIZATION
#-------------------------------------------------------------------------------

def objective(trial, model_name, adj_type, X_node, y, adj_matrices, device):
    """
    Optuna objective function using internal k-fold CV.
    Includes use_batch_norm hyperparameter. No edge weights.
    """
    # --- Hyperparameter Sampling ---
    lr = trial.suggest_float("lr", 1e-5, 1e-2, log=True)
    weight_decay = trial.suggest_float("weight_decay", 1e-6, 1e-3, log=True)
    dropout_rate = trial.suggest_float("dropout_rate", 0.2, 0.7)
    hidden_channels = trial.suggest_categorical("hidden_channels", [16, 32, 64])
    pooling_type = trial.suggest_categorical("pooling_type", ['mean', 'add', 'max'])
    batch_size = trial.suggest_categorical("batch_size", [4, 8, 16])
    optimizer_name = trial.suggest_categorical("optimizer", ["Adam", "AdamW"])
    use_batch_norm = trial.suggest_categorical("use_batch_norm", [True, False])

    # --- Model Specific Hyperparameters ---
    model_init_params = {
        # in_channels determined dynamically later
        "hidden_channels": hidden_channels,
        "dropout_rate": dropout_rate,
        "pooling_type": pooling_type,
        "use_batch_norm": use_batch_norm, # Pass BN choice
        "out_channels": 2
    }
    if model_name == "SpectralGNN":
        model_init_params["K"] = trial.suggest_int("K", 2, 5)
    if model_name == "SimpleGAT":
        model_init_params["heads"] = trial.suggest_categorical("heads", [2, 4, 8])

    # --- Internal Cross-Validation Loop ---
    kf_internal = StratifiedKFold(n_splits=N_INTERNAL_CV_SPLITS, shuffle=True, random_state=SEED)
    internal_fold_accuracies = []
    adj_matrix = adj_matrices[adj_type]

    # Construct dataset for this adjacency type (once outside internal loop)
    full_dataset = construct_graph_dataset(X_node, y, adj_matrix)
    # Dynamically determine in_channels from the actual data
    in_channels = full_dataset.num_node_features
    model_init_params["in_channels"] = in_channels # Add to params dict

    for fold_num, (train_idx, val_idx) in enumerate(kf_internal.split(np.zeros(len(y)), y)): # Use dummy X for split indices

        train_dataset = torch.utils.data.Subset(full_dataset, train_idx)
        val_dataset = torch.utils.data.Subset(full_dataset, val_idx)
        # Use suggested batch_size
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True) # drop_last might help with BN issues on small last batches
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

        # --- Instantiate Model & Optimizer *INSIDE* internal loop ---
        ModelClass = MODEL_CLASSES[model_name]
        # Filter params to only those accepted by the model's __init__
        valid_params = {k: v for k, v in model_init_params.items() if k in inspect.signature(ModelClass.__init__).parameters}
        try:
            model = ModelClass(**valid_params)
        except TypeError as e:
             print(f"ERROR Instantiating {model_name} with params {valid_params}: {e}")
             return 0.0 # Penalize trial heavily

        model.to(device)

        if optimizer_name == "Adam":
            optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
        elif optimizer_name == "AdamW":
            optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
        criterion = nn.CrossEntropyLoss()

        # --- Train and Evaluate on this internal fold ---
        try:
            # Pass trial=None to prevent pruning within the internal CV loop itself
            # Pruning decision will be based on the average score reported later
            # Use global EPOCHS and PATIENCE
            best_fold_val_acc, _ = train_evaluate_single_model(
                model, train_loader, val_loader, optimizer, criterion, device,
                epochs=EPOCHS, patience=PATIENCE, trial=None # Pass None for internal folds
            )
            internal_fold_accuracies.append(best_fold_val_acc)

        # Catch pruning signal if it somehow propagates (shouldn't tho)
        except optuna.TrialPruned:
             print(f"    Trial {trial.number} - Pruning detected unexpectedly in internal fold {fold_num+1}.")
             # Re-raise or handle as error for the trial
             internal_fold_accuracies.append(0.0) # Penalize score
             break # Stop internal CV for this trial
        # Catch other errors during training/evaluation
        except Exception as e:
            print(f"    Trial {trial.number} - ERROR during internal fold {fold_num+1}: {e}")
            import traceback
            traceback.print_exc() # Print stack trace for debugging
            internal_fold_accuracies.append(0.0) # Penalize errors

    # --- Calculate Average Score ---
    if not internal_fold_accuracies: # Handle case where all internal folds failed
        avg_val_acc = 0.0
    else:
        avg_val_acc = np.mean(internal_fold_accuracies)

    # --- Report Average Score to Optuna for Pruning ---
    # Report the average value at a single step (e.g., step=1)
    # Use the main trial object passed to the objective function here
    trial.report(avg_val_acc, 1)
    if trial.should_prune():
        raise optuna.TrialPruned()

    return avg_val_acc # Optuna maximizes this average value


def run_hpo_studies(X_node, y, adj_matrices, device):
    """Runs Optuna studies for all model/adjacency combinations (no distance matrix)."""
    all_best_params = {}
    study_results = {}

    # Create a directory for HPO results if it doesn't exist
    hpo_dir = "hpo_results"
    os.makedirs(hpo_dir, exist_ok=True)

    # Use global ADJACENCY_TYPES and FEATURE_SET
    for model_name in MODEL_CLASSES.keys():
        for adj_type in ADJACENCY_TYPES:
            if adj_type not in adj_matrices:
                print(f"Skipping HPO for adjacency type '{adj_type}' as it wasn't generated.")
                continue

            study_name = f"{model_name}_{adj_type}_feat_{FEATURE_SET}" # Include feature set in name
            print(f"\n--- Running HPO for: {study_name} ---")

            # Define the objective function with fixed args for this study (no norm_dist_matrix)
            objective_func = lambda trial: objective(trial, model_name, adj_type, X_node, y, adj_matrices, device)

            # Create and run the Optuna study
            # Use global SEED
            study = optuna.create_study(direction="maximize", study_name=study_name,
                                       pruner=optuna.pruners.MedianPruner(n_warmup_steps=5, n_min_trials=3),
                                       sampler=optuna.samplers.TPESampler(seed=SEED))
            try:
                 # Use global N_HPO_TRIALS
                 study.optimize(objective_func, n_trials=N_HPO_TRIALS, timeout=None)
                 best_params = study.best_trial.params
                 best_value = study.best_value
            except Exception as e:
                 print(f"ERROR during HPO study {study_name}: {e}")
                 import traceback
                 traceback.print_exc()
                 best_params = {}
                 best_value = 0.0

            # Store results
            all_best_params[study_name] = best_params
            study_results[study_name] = {"best_value": best_value, "best_params": best_params}
            print(f"Best Validation Accuracy (Avg Internal CV) for {study_name}: {best_value:.4f}")
            print(f"Best Params: {best_params}")

            # Save study results incrementally
            results_filename = os.path.join(hpo_dir, f'hpo_results_{study_name}.json')
            try:
                with open(results_filename, 'w') as f:
                    json.dump(study_results[study_name], f, indent=4)
            except Exception as e:
                print(f"Error saving HPO results for {study_name} to {results_filename}: {e}")


    # Save all best parameters together
    all_params_filename = os.path.join(hpo_dir, f'best_hpo_params_all_feat_{FEATURE_SET}.json')
    try:
        with open(all_params_filename, 'w') as f:
            json.dump(all_best_params, f, indent=4)
        print(f"\n--- HPO Complete. Best parameters saved to '{all_params_filename}' ---")
    except Exception as e:
        print(f"Error saving combined HPO parameters to {all_params_filename}: {e}")

    return all_best_params

## 9. Final Model Training & Ensemble Evaluation

Defines functions for the final stage:
- `train_final_models_cv`: Loads the best hyperparameters found by Optuna, selects the top N configurations, and trains each using full K-fold cross-validation on the entire dataset. Stores trained models and out-of-fold (OOF) predictions.
- `evaluate_ensemble_cv`: Takes the OOF predictions from the trained models, performs majority voting to create an ensemble prediction, and evaluates the ensemble's performance using metrics like accuracy, F1-score, classification report, and confusion matrix.

In [9]:
#----------------------------------------------
# 8. FINAL MODEL TRAINING WITH CROSS-VALIDATION
#---------------------------------------------

def train_final_models_cv(best_hpo_params_path, X_node, y, adj_matrices, device):
    """
    Trains the final models using best HPO params and full CV.
    Ranks configurations based on their saved robust objective score from HPO.
    """
    # Use global N_SPLITS and SEED
    kf = StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=SEED)
    trained_models = {} # Store models for ensemble {config_name: [model_fold1_state_dict, model_fold2_state_dict,...]}
    oof_predictions = {} # Store out-of-fold predictions {config_name: np.array([...])}
    oof_indices = {}     # Store indices corresponding to OOF predictions {config_name: np.array([...])}
    cv_scores = {}       # Store CV scores per config {config_name: [acc_fold1, ...]}

    # --- Load and Rank Configurations ---
    try:
        with open(best_hpo_params_path, 'r') as f:
            best_hpo_params_all = json.load(f)
    except FileNotFoundError:
        print(f"Error: HPO parameters file not found at '{best_hpo_params_path}'. Run HPO first.")
        return None, None, None, None
    except json.JSONDecodeError as e:
        print(f"Error decoding JSON from HPO parameters file '{best_hpo_params_path}': {e}")
        return None, None, None, None

    # Extract HPO scores (best_value) to rank configurations
    config_scores = {}
    hpo_results_dir = os.path.dirname(best_hpo_params_path)
    if not hpo_results_dir:
        hpo_results_dir = "."

    for config_name in best_hpo_params_all.keys():
        individual_result_path = os.path.join(hpo_results_dir, f'hpo_results_{config_name}.json')
        try:
             with open(individual_result_path, 'r') as f:
                 result_data = json.load(f)
                 config_scores[config_name] = result_data.get('best_value', 0.0) # Default to 0 if score missing
        except FileNotFoundError:
             print(f"Warning: Individual HPO result file not found for {config_name} at '{individual_result_path}'. Cannot rank accurately. Assigning score 0.")
             config_scores[config_name] = 0.0
        except json.JSONDecodeError as e:
             print(f"Error decoding JSON from individual HPO result file '{individual_result_path}': {e}. Assigning score 0.")
             config_scores[config_name] = 0.0
        except Exception as e:
             print(f"Error loading score for {config_name}: {e}. Assigning score 0.")
             config_scores[config_name] = 0.0


    # Sort configurations by their validation score from HPO
    sorted_configs = sorted(config_scores.items(), key=lambda item: item[1], reverse=True)

    # Filter out configs with score 0 (likely errors during HPO) before selecting top N
    valid_sorted_configs = [ (name, score) for name, score in sorted_configs if score > 0 ]

    if not valid_sorted_configs:
         print("Error: No valid HPO configurations found with score > 0. Cannot proceed with final training.")
         return None, None, None, None

    # Use global N_TOP_MODELS
    num_to_select = min(N_TOP_MODELS, len(valid_sorted_configs))
    top_configs = dict(valid_sorted_configs[:num_to_select])
    print(f"\n--- Selected Top {num_to_select} Configurations for Final Training (Ranked by HPO Score) ---")
    for name, score in top_configs.items():
        print(f"- {name} (HPO Score: {score:.4f})")

    # Determine in_channels dynamically from the provided X_node
    in_channels = X_node.shape[2]
    # Use global FEATURE_SET
    print(f"Using input channels: {in_channels} (based on FEATURE_SET='{FEATURE_SET}')")

    # --- Train each top configuration across all CV folds ---
    for config_name in top_configs.keys():
        print(f"\n--- Training Final Model for: {config_name} ---")
        # --- Improved Parsing Logic ---
        try:
            # Split the config name based on the feature set marker '_feat_'
            base_name, feature_suffix = config_name.split('_feat_', 1)

            # Split the base name (ModelName_AdjType) at the first underscore
            parts = base_name.split('_')
            if len(parts) < 2:
                raise ValueError("Base name does not contain ModelName_AdjType")
            model_name = parts[0]
            adj_type = "_".join(parts[1:]) # Handle adj_types with underscores if any


        except ValueError as e:
            # Fallback if splitting fails
            print(f"ERROR: Could not parse config name '{config_name}' using expected format 'ModelName_AdjType_feat_FeatureSet'. Error: {e}. Skipping.")
            # Assign invalid values to ensure skipping in subsequent checks
            model_name = "PARSE_ERROR"
            adj_type = "PARSE_ERROR"

        if model_name not in MODEL_CLASSES:
             print(f"ERROR: Model name '{model_name}' from config '{config_name}' not found in MODEL_CLASSES. Skipping.")
             continue
        if adj_type not in adj_matrices:
             print(f"ERROR: Adjacency type '{adj_type}' from config '{config_name}' not found in generated matrices. Skipping.")
             continue

        params = best_hpo_params_all[config_name] # Get the parameters for the current top config

        fold_models_states = []
        fold_oof_preds = []
        fold_oof_indices = []
        fold_scores = []

        adj_matrix = adj_matrices[adj_type]
        # Construct dataset for this fold using the appropriate adjacency matrix
        full_dataset = construct_graph_dataset(X_node, y, adj_matrix)
        # Double check in_channels matches dataset
        if full_dataset.num_node_features != in_channels:
             print(f"ERROR: Mismatch between expected in_channels ({in_channels}) and dataset features ({full_dataset.num_node_features}) for {config_name}. Skipping.")
             continue


        for fold, (train_idx, test_idx) in enumerate(kf.split(np.zeros(len(y)), y)): # Use dummy X for split
            print(f"  Fold {fold+1}/{N_SPLITS}")

            train_dataset = torch.utils.data.Subset(full_dataset, train_idx)
            test_dataset = torch.utils.data.Subset(full_dataset, test_idx) # Test set for this fold

            batch_size = params.get('batch_size', 8) # Default batch size if not in HPO params

            train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True)
            test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

            # Instantiate Model
            ModelClass = MODEL_CLASSES[model_name]
            # Prepare parameters for model instantiation
            model_init_params = {
                "in_channels": in_channels, # Use dynamically determined value
                "out_channels": 2 # Assuming binary classification
            }
            # Add HPO params relevant to the model's __init__
            init_signature = inspect.signature(ModelClass.__init__)
            for k, v in params.items():
                 if k in init_signature.parameters:
                     model_init_params[k] = v

             # Filter again to be sure
            valid_params = {k: v for k, v in model_init_params.items() if k in init_signature.parameters}

            try:
                model = ModelClass(**valid_params)
            except TypeError as e:
                print(f"  ERROR: Could not instantiate model {model_name} for fold {fold+1}. Config: {valid_params}. Error: {e}")
                fold_scores.append(0.0) # Record failure
                continue # Skip to next fold

            model.to(device)

            # Instantiate Optimizer
            optimizer_name = params.get('optimizer', 'Adam')
            lr = params.get('lr', 0.001)
            weight_decay = params.get('weight_decay', 0.0)

            if optimizer_name == "Adam":
                optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
            elif optimizer_name == "AdamW":
                optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
            else: # Default fallback
                optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

            criterion = nn.CrossEntropyLoss()

            # Train model on this fold's train/val (using test_loader as val here for early stopping)
            # Use global EPOCHS, PATIENCE
            val_acc, trained_fold_model = train_evaluate_single_model(
                model, train_loader, test_loader, optimizer, criterion, device,
                epochs=EPOCHS, patience=PATIENCE, trial=None # No trial object for final train
            )

            # --- Evaluate on Test set for this fold (for OOF predictions) ---
            trained_fold_model.eval()
            preds = []
            with torch.no_grad():
                for data in test_loader:
                    data = data.to(device)
                    # Model call without edge_attr
                    out = trained_fold_model(data.x, data.edge_index, data.batch)
                    preds.extend(out.argmax(dim=1).cpu().numpy())

            # Use y[test_idx] for calculating fold accuracy
            fold_test_acc = accuracy_score(y[test_idx], preds)
            fold_scores.append(fold_test_acc) # Use test accuracy for the fold score
            fold_oof_preds.extend(preds)
            fold_oof_indices.extend(test_idx) # Store original indices

            # Store model state dict (on CPU to save GPU memory)
            fold_models_states.append({k: v.cpu() for k, v in trained_fold_model.state_dict().items()})
            print(f"  Fold {fold+1} Test Acc: {fold_test_acc:.4f}")

        # Store results for this configuration
        trained_models[config_name] = fold_models_states
        # Ensure OOF preds and indices are numpy arrays
        oof_predictions[config_name] = np.array(fold_oof_preds)
        oof_indices[config_name] = np.array(fold_oof_indices)
        cv_scores[config_name] = fold_scores

        if fold_scores:
             print(f"  {config_name} Avg CV Accuracy: {np.mean(fold_scores):.4f} +/- {np.std(fold_scores):.4f}")
        else:
             print(f"  {config_name} No successful folds completed.")


    return trained_models, oof_predictions, oof_indices, cv_scores

#-------------------------------------------------------------------------------
# 9. ENSEMBLE PREDICTION & EVALUATION (Unchanged)
#-------------------------------------------------------------------------------

def evaluate_ensemble_cv(oof_predictions, oof_indices, y_true):
    """Evaluates the ensemble based on out-of-fold predictions."""
    if not oof_predictions:
         print("No OOF predictions available to evaluate ensemble.")
         return 0.0, None, None, None # Return None for metrics

    all_config_preds_sorted = []
    target_indices_sorted = None
    valid_configs = list(oof_predictions.keys()) # Start with all configs that have predictions

    # Aggregate predictions from all configs, ensuring correct order based on indices
    for config_name in list(valid_configs): 
        preds = oof_predictions[config_name]
        indices = oof_indices[config_name]


        # Sort predictions based on original sample index
        sort_order = np.argsort(indices)
        sorted_preds = preds[sort_order]
        all_config_preds_sorted.append(sorted_preds)

        if target_indices_sorted is None:
            target_indices_sorted = indices[sort_order]
            if not np.all(target_indices_sorted == np.arange(len(y_true))):
                 print("Warning: OOF indices do not cover all samples uniquely. Ensemble evaluation might be inaccurate.")


    if not all_config_preds_sorted:
         print("No valid OOF predictions found after checks. Cannot evaluate ensemble.")
         return 0.0, None, None, None

    # Perform majority voting
    # Shape: (n_models_in_ensemble, n_samples)
    votes = np.array(all_config_preds_sorted)
    ensemble_preds = []
    for i in range(votes.shape[1]): # Iterate through samples
        sample_votes = votes[:, i]
        # Find the most frequent prediction (majority vote)
        # If there's a tie, np.argmax(np.bincount()) typically returns the smallest index among the tied values.
        counts = np.bincount(sample_votes.astype(int))
        final_pred = np.argmax(counts)
        ensemble_preds.append(final_pred)
    ensemble_preds = np.array(ensemble_preds)

    # --- Calculate Final Metrics ---
    # Ensure y_true corresponds to the sorted order if needed (it should if target_indices_sorted is correct)
    accuracy = accuracy_score(y_true, ensemble_preds)
    f1_macro = f1_score(y_true, ensemble_preds, average='macro', zero_division=0)
    report = classification_report(y_true, ensemble_preds, zero_division=0)
    matrix = confusion_matrix(y_true, ensemble_preds)

    print("\n--- Final Ensemble Evaluation (Cross-Validated) ---")
    print(f"Accuracy: {accuracy:.4f}")
    print(f"F1 Score (Macro): {f1_macro:.4f}")
    print("\nClassification Report:")
    print(report)
    print("\nConfusion Matrix:")
    print(matrix)

    return accuracy, f1_macro, report, matrix

## 10. Main Execution Block

This cell orchestrates the entire workflow:
1. Loads the data and creates features using `load_and_preprocess_data`.
2. Creates the necessary adjacency matrices using `create_adjacency_matrices`.
3. Defines the `main` function which, based on the `RUN_HPO` and `RUN_FINAL_TRAINING` flags:
   - Optionally runs hyperparameter optimization (`run_hpo_studies`).
   - Optionally trains the top N models found during HPO using cross-validation (`train_final_models_cv`).
   - Optionally evaluates the performance of the ensemble model (`evaluate_ensemble_cv`).
4. Executes the `main` function within the standard `if __name__ == "__main__":` block.

In [10]:
#------------------------------
# 10. MAIN EXECUTION SCRIPT
#------------------------------

# --- Phase 0: Load Data & Prepare Features/Graphs ---
print("--- Phase 0: Loading Data & Preparing Features/Graphs ---")
# Use global FEATURE_SET
X_node, y, channel_coords, electrode_mapping_info = load_and_preprocess_data(FEATURE_SET)
print(f"Data shapes: X_node={X_node.shape}, y={y.shape}")

print("Creating adjacency matrices...")
adj_matrices = create_adjacency_matrices(channel_coords, electrode_mapping_info)
print(f"Adjacency matrices created: {list(adj_matrices.keys())}")
print("\n" + "="*60 + "\n")

def main():
    """Main execution function using pre-loaded data and graphs."""

    # --- Phase 1: Hyperparameter Optimization ---
    # Use global FEATURE_SET
    best_hpo_params_path = os.path.join("hpo_results", f'best_hpo_params_all_feat_{FEATURE_SET}.json')
    # Use global RUN_HPO flag
    if RUN_HPO:
        print("\n--- Phase 1: Running Hyperparameter Optimization ---")
        # HPO results (parameters and scores) are saved by run_hpo_studies
        run_hpo_studies(X_node, y, adj_matrices, device)
    else:
        print(f"\n--- Phase 1: Skipping HPO. Will load parameters from '{best_hpo_params_path}' if needed. ---")


    # --- Phase 2: Final Model Training & Evaluation ---
    # Use global RUN_FINAL_TRAINING flag
    if RUN_FINAL_TRAINING:
        print(f"\n--- Phase 2: Training Final Models with CV and Evaluating Ensemble ---")
        if not os.path.exists(best_hpo_params_path):
             print(f"ERROR: HPO results file '{best_hpo_params_path}' not found. Set RUN_HPO=True to first run HPO")
             return

        # Pass required variables explicitly
        trained_models, oof_preds, oof_indices, cv_scores = train_final_models_cv(
            best_hpo_params_path, X_node, y, adj_matrices, device
        )

        if trained_models and oof_preds: # Check if training produced results
            # Evaluate the ensemble based on OOF predictions
            # Pass required variables explicitly
            final_accuracy, f1, report, matrix = evaluate_ensemble_cv(oof_preds, oof_indices, y)
            print("\n--- Analysis Complete ---")
        else:
            print("\n--- Final model training failed or produced no results. Skipping ensemble evaluation. ---")

    else:
        print("\n--- Phase 2: Skipping final model training and evaluation. ---")


if __name__ == "__main__":
    main()

--- Phase 0: Loading Data & Preparing Features/Graphs ---
Loading data and preparing feature set: 'with_ratios'
Calculating ratio features...
Added 2 ratio features.
Final node feature shape: (40, 64, 7)
Data shapes: X_node=(40, 64, 7), y=(40,)
Creating adjacency matrices...
Adjacency matrices created: ['KNN', 'Threshold', 'AnatomicalRegion', 'CustomBrainNetwork']



--- Phase 1: Skipping HPO. Will load parameters from 'hpo_results\best_hpo_params_all_feat_with_ratios.json' if needed. ---

--- Phase 2: Training Final Models with CV and Evaluating Ensemble ---

--- Selected Top 5 Configurations for Final Training (Ranked by HPO Score) ---
- SimpleGAT_Threshold_feat_with_ratios (HPO Score: 0.8022)
- SimpleGAT_AnatomicalRegion_feat_with_ratios (HPO Score: 0.7985)
- SpectralGNN_Threshold_feat_with_ratios (HPO Score: 0.7766)
- SpectralGNN_AnatomicalRegion_feat_with_ratios (HPO Score: 0.7747)
- SimpleGAT_CustomBrainNetwork_feat_with_ratios (HPO Score: 0.7747)
Using input channels: 7 (based o



  Epoch: 025, Train Loss: 10.8765, Val Loss: 3.8408, Val Acc: 0.5000
  Early stopping triggered after epoch 39. Best Val Acc: 0.7500
  Fold 1 Test Acc: 0.7500
  Fold 2/5




  Early stopping triggered after epoch 22. Best Val Acc: 0.7500
  Fold 2 Test Acc: 0.7500
  Fold 3/5




  Epoch: 025, Train Loss: 0.6968, Val Loss: 9.2631, Val Acc: 0.5000
  Early stopping triggered after epoch 38. Best Val Acc: 0.7500
  Fold 3 Test Acc: 0.7500
  Fold 4/5




  Early stopping triggered after epoch 23. Best Val Acc: 0.5000
  Fold 4 Test Acc: 0.5000
  Fold 5/5




  Early stopping triggered after epoch 23. Best Val Acc: 0.6250
  Fold 5 Test Acc: 0.6250
  SimpleGAT_Threshold_feat_with_ratios Avg CV Accuracy: 0.6750 +/- 0.1000

--- Training Final Model for: SimpleGAT_AnatomicalRegion_feat_with_ratios ---
  Fold 1/5
  Early stopping triggered after epoch 23. Best Val Acc: 0.7500
  Fold 1 Test Acc: 0.7500
  Fold 2/5
  Early stopping triggered after epoch 21. Best Val Acc: 0.5000
  Fold 2 Test Acc: 0.5000
  Fold 3/5
  Epoch: 025, Train Loss: 0.5594, Val Loss: 4.2127, Val Acc: 0.5000
  Early stopping triggered after epoch 27. Best Val Acc: 0.7500
  Fold 3 Test Acc: 0.7500
  Fold 4/5
  Epoch: 025, Train Loss: 0.5251, Val Loss: 1.7952, Val Acc: 0.2500
  Early stopping triggered after epoch 28. Best Val Acc: 0.6250
  Fold 4 Test Acc: 0.6250
  Fold 5/5
  Early stopping triggered after epoch 21. Best Val Acc: 0.7500
  Fold 5 Test Acc: 0.7500
  SimpleGAT_AnatomicalRegion_feat_with_ratios Avg CV Accuracy: 0.6750 +/- 0.1000

--- Training Final Model for: Spec



  Early stopping triggered after epoch 24. Best Val Acc: 0.3750
  Fold 1 Test Acc: 0.3750
  Fold 2/5




  Epoch: 025, Train Loss: 0.4120, Val Loss: 0.8215, Val Acc: 0.3750
  Early stopping triggered after epoch 27. Best Val Acc: 0.8750
  Fold 2 Test Acc: 0.8750
  Fold 3/5




  Early stopping triggered after epoch 23. Best Val Acc: 0.7500
  Fold 3 Test Acc: 0.7500
  Fold 4/5




  Epoch: 025, Train Loss: 0.3637, Val Loss: 0.9772, Val Acc: 0.1250
  Early stopping triggered after epoch 34. Best Val Acc: 0.5000
  Fold 4 Test Acc: 0.5000
  Fold 5/5




  Early stopping triggered after epoch 21. Best Val Acc: 0.7500
  Fold 5 Test Acc: 0.7500
  SpectralGNN_Threshold_feat_with_ratios Avg CV Accuracy: 0.6500 +/- 0.1837

--- Training Final Model for: SpectralGNN_AnatomicalRegion_feat_with_ratios ---
  Fold 1/5
  Early stopping triggered after epoch 24. Best Val Acc: 0.6250
  Fold 1 Test Acc: 0.6250
  Fold 2/5
  Epoch: 025, Train Loss: 29.9838, Val Loss: 25.2213, Val Acc: 0.6250
  Early stopping triggered after epoch 32. Best Val Acc: 0.8750
  Fold 2 Test Acc: 0.8750
  Fold 3/5
  Epoch: 025, Train Loss: 35.2835, Val Loss: 13.2348, Val Acc: 0.6250
  Early stopping triggered after epoch 41. Best Val Acc: 0.7500
  Fold 3 Test Acc: 0.7500
  Fold 4/5
  Early stopping triggered after epoch 21. Best Val Acc: 0.7500
  Fold 4 Test Acc: 0.7500
  Fold 5/5
  Early stopping triggered after epoch 23. Best Val Acc: 0.7500
  Fold 5 Test Acc: 0.7500
  SpectralGNN_AnatomicalRegion_feat_with_ratios Avg CV Accuracy: 0.7500 +/- 0.0791

--- Training Final Model