In [1]:
%load_ext autoreload
%autoreload 1
from pathlib import Path

import pandas as pd
import seaborn as sns
import torch
from torch_geometric.loader import DataLoader as GeoDataLoader
from torch.utils.data import Subset, WeightedRandomSampler
# from torch.utils.data import DataLoader
from src.utils.seeder import seed_everything

# set seaborn theme
sns.set_theme()

# create useful constants
RANDOM_SEED = 42
IS_SCITAS = True # set to True if running on SCITAS cluster
LOCAL_DATA_ROOT = Path("./data")
DATA_ROOT = Path("/home/ogut/data") if IS_SCITAS else LOCAL_DATA_ROOT
CHECKPOINT_ROOT = Path("./.checkpoints")
SUBMISSION_ROOT = Path("./.submissions")

# create directories if they do not exist
CHECKPOINT_ROOT.mkdir(parents=True, exist_ok=True)
SUBMISSION_ROOT.mkdir(parents=True, exist_ok=True)

# set dataset root
seed_everything(RANDOM_SEED)

# setup torch device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Using device: cuda


In [None]:
# import subprocess

# # execute feature extraction script
# process = None
# try:
#     process = subprocess.Popen(["python3", "scripts/feature_extractor.py"])
#     process.wait()
# except KeyboardInterrupt:
#     print("Process interrupted, terminating...")
#     if process:
#         process.terminate()
#         process.wait()
# except Exception as e:
#     print(f"Error occurred: {e}")
#     if process:
#         process.terminate()
#         process.wait()

In [3]:
# spacial distance matrix between sensors
spatial_distance_file = LOCAL_DATA_ROOT / "distances_3d.csv"

# training data
train_dir = DATA_ROOT / "train"
train_dir_metadata = train_dir / "segments.parquet"
train_dataset_correlation_dir = LOCAL_DATA_ROOT / "graph_dataset_correlation_train"
train_dataset_spatial_dir = LOCAL_DATA_ROOT / "graph_dataset_spatial_train"

# test data
test_dir = DATA_ROOT / "test"
test_dir_metadata = test_dir / "segments.parquet"
test_dataset_correlation_dir = LOCAL_DATA_ROOT / "graph_dataset_correlation_test"
test_dataset_spatial_dir = LOCAL_DATA_ROOT / "graph_dataset_spatial_test"

# additional features
extracted_features_dir = LOCAL_DATA_ROOT / "extracted_features"
embeddings_dir =  LOCAL_DATA_ROOT / "embeddings"

In [4]:
from src.utils.index import ensure_eeg_multiindex 

# Load clips from datasets
clips_tr = pd.read_parquet(train_dir_metadata)
clips_tr = ensure_eeg_multiindex(clips_tr)
clips_tr['id'] = clips_tr.index.map(lambda x: '_'.join(str(i) for i in x))
assert clips_tr.id.nunique() == len(clips_tr), "There are duplicate IDs"
clips_tr = clips_tr[~clips_tr.label.isna()].reset_index()

# Load clips from datasets
clips_te = pd.read_parquet(test_dir_metadata)
clips_te = ensure_eeg_multiindex(clips_te)
clips_te['id'] = clips_te.index.map(lambda x: '_'.join(str(i) for i in x))
assert clips_te.id.nunique() == len(clips_te), "There are duplicate IDs"
clips_te = clips_te.reset_index()

# sort in order to maintain the same submission order
clips_te = clips_te.sort_values(by="id")

## Create + load spatial graph datasets

In [None]:
%aimport
from src.data.dataset_graph import GraphEEGDataset

# dataset settings
low_bandpass_frequency = 0.5
high_bandpass_frequency = 50

# additional settings
oversampling_power = 1.0

# load training dataset
dataset_spatial_tr = GraphEEGDataset(
    root=train_dataset_spatial_dir,
    clips=clips_tr,
    signal_folder=train_dir,
    extracted_features_dir=extracted_features_dir,
    use_selected_features=False,
    embeddings_dir=embeddings_dir,
    use_embeddings=False,
    edge_strategy="spatial",
    spatial_distance_file=(
        spatial_distance_file
    ),
    top_k=None,
    force_reprocess=False,
    bandpass_frequencies=(
        low_bandpass_frequency,
        high_bandpass_frequency,
    ),
    segment_length=3000,
    apply_filtering=True,
    apply_rereferencing=True,
    apply_normalization=True,
    sampling_rate=250,
    # extract graph features
    extract_graph_features=True,
    graph_feature_types=None # collect all graph features
)

# Check the length of the dataset
print(f"Length of train_dataset: {len(dataset_spatial_tr)}")
print(f' Eliminated IDs: {dataset_spatial_tr.ids_to_eliminate}')
clips_spatial_tr = clips_tr[~clips_tr.index.isin(dataset_spatial_tr.ids_to_eliminate)]

2025-06-08 17:19:00 - INFO - Initializing GraphEEGDataset...
2025-06-08 17:19:00 - INFO - Dataset parameters:
2025-06-08 17:19:00 - INFO -   - Root directory: data/graph_dataset_spatial_train
2025-06-08 17:19:00 - INFO -   - Edge strategy: spatial
2025-06-08 17:19:00 - INFO -   - Top-k neighbors: None
2025-06-08 17:19:00 - INFO -   - Correlation threshold: 0.7
2025-06-08 17:19:00 - INFO -   - Force reprocess: False
2025-06-08 17:19:00 - INFO -   - Bandpass frequencies: (0.5, 50)
2025-06-08 17:19:00 - INFO -   - Segment length: 3000
2025-06-08 17:19:00 - INFO -   - Apply filtering: True
2025-06-08 17:19:00 - INFO -   - Apply rereferencing: True
2025-06-08 17:19:00 - INFO -   - Apply normalization: True
2025-06-08 17:19:00 - INFO - Dataset parameters:
2025-06-08 17:19:00 - INFO -   - Root directory: data/graph_dataset_spatial_train
2025-06-08 17:19:00 - INFO -   - Edge strategy: spatial
2025-06-08 17:19:00 - INFO -   - Top-k neighbors: None
2025-06-08 17:19:00 - INFO -   - Correlation th

Modules to reload:


Modules to skip:

Length of train_dataset: 12993
 Eliminated IDs: []


In [None]:
%aimport
from src.data.dataset_graph import GraphEEGDataset

# load training dataset
dataset_corr_te = GraphEEGDataset(
    root=test_dataset_correlation_dir,
    clips=clips_te,
    signal_folder=test_dir,
    extracted_features_dir=extracted_features_dir,
    use_selected_features=False,
    embeddings_dir=embeddings_dir,
    use_embeddings=False,
    edge_strategy="spatial",
    spatial_distance_file=(
        spatial_distance_file
    ),
    top_k=None,
    force_reprocess=False,
    bandpass_frequencies=(
        low_bandpass_frequency,
        high_bandpass_frequency,
    ),
    segment_length=3000,
    apply_filtering=True,
    apply_rereferencing=True,
    apply_normalization=True,
    sampling_rate=250,
    # extract graph features
    is_test=True, # NOTE: needed to let the dataset know that is okay to now have labels!
    extract_graph_features=True,
    graph_feature_types=None # collect all graph features
)

# Check the length of the dataset
print(f"Length of test_dataset: {len(dataset_corr_te)}")
print(f' Eliminated IDs:{dataset_corr_te.ids_to_eliminate}')
clips_spatial_te = clips_te[~clips_te.index.isin(dataset_corr_te.ids_to_eliminate)].reset_index(drop=True)

2025-06-08 17:19:00 - INFO - Initializing GraphEEGDataset...
2025-06-08 17:19:00 - INFO - Dataset parameters:
2025-06-08 17:19:00 - INFO -   - Root directory: data/graph_dataset_correlation_test
2025-06-08 17:19:00 - INFO -   - Edge strategy: spatial
2025-06-08 17:19:00 - INFO -   - Top-k neighbors: None
2025-06-08 17:19:00 - INFO -   - Correlation threshold: 0.7
2025-06-08 17:19:00 - INFO -   - Force reprocess: False
2025-06-08 17:19:00 - INFO -   - Bandpass frequencies: (0.5, 50)
2025-06-08 17:19:00 - INFO -   - Segment length: 3000
2025-06-08 17:19:00 - INFO -   - Apply filtering: True
2025-06-08 17:19:00 - INFO -   - Apply rereferencing: True
2025-06-08 17:19:00 - INFO -   - Apply normalization: True
2025-06-08 17:19:00 - INFO -   - Sampling rate: 250
2025-06-08 17:19:00 - INFO -   - Test mode: True
2025-06-08 17:19:00 - INFO -   - Extract graph features: True
2025-06-08 17:19:00 - INFO - Initializing graph feature extractor...
2025-06-08 17:19:00,319 - src.utils.graph_features - I

Modules to reload:


Modules to skip:

Length of test_dataset: 3612
 Eliminated IDs:[]


## Create + load correlation-based graph datasets

In [5]:
%aimport
from src.data.dataset_graph import GraphEEGDataset

# dataset settings
top_k = 5
low_bandpass_frequency = 0.5
high_bandpass_frequency = 50

# additional settings
oversampling_power = 1.0

# load training dataset
dataset_corr_tr = GraphEEGDataset(
    root=train_dataset_correlation_dir,
    clips=clips_tr,
    signal_folder=train_dir,
    extracted_features_dir=extracted_features_dir,
    use_selected_features=False,
    embeddings_dir=embeddings_dir,
    use_embeddings=False,
    edge_strategy="correlation",
    spatial_distance_file=None,
    top_k=top_k,
    force_reprocess=False,
    bandpass_frequencies=(
        low_bandpass_frequency,
        high_bandpass_frequency,
    ),
    segment_length=3000,
    apply_filtering=True,
    apply_rereferencing=True,
    apply_normalization=True,
    sampling_rate=250,
    # extract graph features
    extract_graph_features=True,
    graph_feature_types=None # collect all graph features
)

# Check the length of the dataset
print(f"Length of train_dataset: {len(dataset_corr_tr)}")
print(f' Eliminated IDs: {dataset_corr_tr.ids_to_eliminate}')
clips_corr_tr = clips_tr[~clips_tr.index.isin(dataset_corr_tr.ids_to_eliminate)]

2025-06-08 17:19:00 - INFO - Initializing GraphEEGDataset...
2025-06-08 17:19:00 - INFO - Dataset parameters:
2025-06-08 17:19:00 - INFO -   - Root directory: data/graph_dataset_correlation_train
2025-06-08 17:19:00 - INFO -   - Edge strategy: correlation
2025-06-08 17:19:00 - INFO -   - Top-k neighbors: 5
2025-06-08 17:19:00 - INFO -   - Correlation threshold: 0.7
2025-06-08 17:19:00 - INFO -   - Force reprocess: False
2025-06-08 17:19:00 - INFO -   - Bandpass frequencies: (0.5, 50)
2025-06-08 17:19:00 - INFO -   - Segment length: 3000
2025-06-08 17:19:00 - INFO -   - Apply filtering: True
2025-06-08 17:19:00 - INFO -   - Apply rereferencing: True
2025-06-08 17:19:00 - INFO -   - Apply normalization: True
2025-06-08 17:19:00 - INFO -   - Sampling rate: 250
2025-06-08 17:19:00 - INFO -   - Test mode: False
2025-06-08 17:19:00 - INFO -   - Extract graph features: True
2025-06-08 17:19:00 - INFO - Initializing graph feature extractor...
2025-06-08 17:19:00,400 - src.utils.graph_features 

Modules to reload:


Modules to skip:

Length of train_dataset: 12986
 Eliminated IDs: []


In [None]:
%aimport
from src.data.dataset_graph import GraphEEGDataset

# load training dataset
dataset_corr_te = GraphEEGDataset(
    root=test_dataset_correlation_dir,
    clips=clips_te,
    signal_folder=test_dir,
    extracted_features_dir=extracted_features_dir,
    use_selected_features=False,
    embeddings_dir=embeddings_dir,
    use_embeddings=False,
    edge_strategy="correlation",
    spatial_distance_file=None,
    top_k=top_k,
    force_reprocess=False,
    bandpass_frequencies=(
        low_bandpass_frequency,
        high_bandpass_frequency,
    ),
    segment_length=3000,
    apply_filtering=True,
    apply_rereferencing=True,
    apply_normalization=True,
    sampling_rate=250,
    # extract graph features
    is_test=True, # NOTE: needed to let the dataset know that is okay to now have labels!
    extract_graph_features=True,
    graph_feature_types=None # collect all graph features
)

# Check the length of the dataset
print(f"Length of test_dataset: {len(dataset_corr_te)}")
print(f' Eliminated IDs:{dataset_corr_te.ids_to_eliminate}')
clips_corr_te = clips_te[~clips_te.index.isin(dataset_corr_te.ids_to_eliminate)].reset_index(drop=True)

2025-06-08 17:19:00 - INFO - Initializing GraphEEGDataset...
2025-06-08 17:19:00 - INFO - Dataset parameters:
2025-06-08 17:19:00 - INFO -   - Root directory: data/graph_dataset_correlation_test
2025-06-08 17:19:00 - INFO -   - Edge strategy: correlation
2025-06-08 17:19:00 - INFO -   - Top-k neighbors: 5
2025-06-08 17:19:00 - INFO -   - Correlation threshold: 0.7
2025-06-08 17:19:00 - INFO -   - Force reprocess: False
2025-06-08 17:19:00 - INFO -   - Bandpass frequencies: (0.5, 50)
2025-06-08 17:19:00 - INFO -   - Segment length: 3000
2025-06-08 17:19:00 - INFO -   - Apply filtering: True
2025-06-08 17:19:00 - INFO -   - Apply rereferencing: True
2025-06-08 17:19:00 - INFO -   - Apply normalization: True
2025-06-08 17:19:00 - INFO -   - Sampling rate: 250
2025-06-08 17:19:00 - INFO -   - Test mode: True
2025-06-08 17:19:00 - INFO -   - Extract graph features: True
2025-06-08 17:19:00 - INFO - Initializing graph feature extractor...
2025-06-08 17:19:00,457 - src.utils.graph_features - 

Modules to reload:


Modules to skip:

Length of test_dataset: 3612
 Eliminated IDs:[]


In [None]:
# remove the original clips from memory
del clips_tr, clips_te

In [None]:
import numpy as np
from src.utils.general_funcs import labels_stats

# Split settings
TRAIN_RATIO = 0.8
oversampling_power = 1.0
BATCH_SIZE = 64

print("=== SPATIAL DATASET SPLITTING ===")
# Get total samples and split sizes for spatial dataset
total_samples_spatial = len(dataset_spatial_tr)
train_size_spatial = int(TRAIN_RATIO * total_samples_spatial)
val_size_spatial = total_samples_spatial - train_size_spatial

print(f"Spatial dataset - Total: {total_samples_spatial}, Train: {train_size_spatial}, Val: {val_size_spatial}")

# Get labels for spatial dataset split
y_spatial = clips_spatial_tr["label"].values

# Create initial train/val split using random permutation
indices_spatial = torch.randperm(total_samples_spatial)
train_indices_spatial = indices_spatial[:train_size_spatial].numpy()
val_indices_spatial = indices_spatial[train_size_spatial:].numpy()

print('Spatial dataset labels distribution before split:')
labels_stats(y_spatial, train_indices_spatial, val_indices_spatial)

# Create train and val datasets for spatial
train_dataset_spatial = Subset(dataset_spatial_tr, train_indices_spatial)
val_dataset_spatial = Subset(dataset_spatial_tr, val_indices_spatial)

# Compute sample weights for oversampling - spatial
train_labels_spatial = [clips_spatial_tr.iloc[i]["label"] for i in train_indices_spatial]
class_counts_spatial = np.bincount(train_labels_spatial)
class_weights_spatial = (1. / class_counts_spatial) ** oversampling_power
sample_weights_spatial = [class_weights_spatial[label] for label in train_labels_spatial]

# Define sampler for spatial
sampler_spatial = WeightedRandomSampler(sample_weights_spatial, num_samples=len(sample_weights_spatial), replacement=True)

print(f"\nSpatial dataset - Class weights: {class_weights_spatial}")
print(f"Spatial dataset - Class distribution in train: {np.bincount(train_labels_spatial)}")

print("\n=== CORRELATION DATASET SPLITTING ===")

# Get total samples and split sizes for correlation dataset
total_samples_corr = len(dataset_corr_tr)
train_size_corr = int(TRAIN_RATIO * total_samples_corr)
val_size_corr = total_samples_corr - train_size_corr

print(f"Correlation dataset - Total: {total_samples_corr}, Train: {train_size_corr}, Val: {val_size_corr}")

# Get labels for correlation dataset split (should be same as spatial, but let's be explicit)
y_corr = clips_corr_tr["label"].values

# Create initial train/val split using random permutation
indices_corr = torch.randperm(total_samples_corr)
train_indices_corr = indices_corr[:train_size_corr].numpy()
val_indices_corr = indices_corr[train_size_corr:].numpy()

print('Correlation dataset labels distribution before split:')
labels_stats(y_corr, train_indices_corr, val_indices_corr)

# Create train and val datasets for correlation
train_dataset_corr = Subset(dataset_corr_tr, train_indices_corr)
val_dataset_corr = Subset(dataset_corr_tr, val_indices_corr)

# Compute sample weights for oversampling - correlation
train_labels_corr = [clips_corr_tr.iloc[i]["label"] for i in train_indices_corr]
class_counts_corr = np.bincount(train_labels_corr)
class_weights_corr = (1. / class_counts_corr) ** oversampling_power
sample_weights_corr = [class_weights_corr[label] for label in train_labels_corr]

# Define sampler for correlation
sampler_corr = WeightedRandomSampler(sample_weights_corr, num_samples=len(sample_weights_corr), replacement=True)

print(f"\nCorrelation dataset - Class weights: {class_weights_corr}")
print(f"Correlation dataset - Class distribution in train: {np.bincount(train_labels_corr)}")

print("\n=== SUMMARY ===")
print(f"Spatial: {len(train_dataset_spatial)} train, {len(val_dataset_spatial)} val")
print(f"Correlation: {len(train_dataset_corr)} train, {len(val_dataset_corr)} val")

# Create GeoDataLoaders for spatial dataset
train_loader_spatial = GeoDataLoader(
    train_dataset_spatial,
    batch_size=BATCH_SIZE,
    sampler=sampler_spatial,
    drop_last=True
)

val_loader_spatial = GeoDataLoader(
    val_dataset_spatial,
    batch_size=BATCH_SIZE,
    shuffle=False,
    drop_last=False
)

te_loader_spatial = GeoDataLoader(
    dataset_corr_te, # Use full spatial test dataset
    batch_size=BATCH_SIZE,
    shuffle=False,
    drop_last=False
)

# Create GeoDataLoaders for correlation dataset
train_loader_corr = GeoDataLoader(
    train_dataset_corr,
    batch_size=BATCH_SIZE,
    sampler=sampler_corr,
    drop_last=True
)

val_loader_corr = GeoDataLoader(
    val_dataset_corr,
    batch_size=BATCH_SIZE,
    shuffle=False,
    drop_last=False
)

te_loader_corr = GeoDataLoader(
    dataset_corr_tr,  # Use full correlation test dataset
    batch_size=BATCH_SIZE,
    shuffle=False,
    drop_last=False
)

print("\n=== DATA LOADERS CREATED ===")
print(f"Spatial - Train: {len(train_loader_spatial)} batches, Val: {len(val_loader_spatial)} batches, Test: {len(te_loader_spatial)} batches")
print(f"Correlation - Train: {len(train_loader_corr)} batches, Val: {len(val_loader_corr)} batches, Test: {len(te_loader_corr)} batches")

=== SPATIAL DATASET SPLITTING ===
Spatial dataset - Total: 12993, Train: 10394, Val: 2599
Spatial dataset labels distribution before split:
[17:23:13] Train labels: 0 -> 8374, 1 -> 2020
[17:23:13] Val labels:   0 -> 2102, 1 -> 497

Spatial dataset - Class weights: [0.00011942 0.00049505]
Spatial dataset - Class distribution in train: [8374 2020]

=== CORRELATION DATASET SPLITTING ===
Correlation dataset - Total: 12986, Train: 10388, Val: 2598
Correlation dataset labels distribution before split:
[17:23:13] Train labels: 0 -> 8377, 1 -> 2011
[17:23:13] Val labels:   0 -> 2098, 1 -> 500

Correlation dataset - Class weights: [0.00011937 0.00049727]
Correlation dataset - Class distribution in train: [8377 2011]

=== SUMMARY ===
Spatial: 10394 train, 2599 val
Correlation: 10388 train, 2598 val

=== DATA LOADERS CREATED ===
Spatial - Train: 162 batches, Val: 41 batches, Test: 57 batches
Correlation - Train: 162 batches, Val: 41 batches, Test: 203 batches


In [8]:
import numpy as np

# ==============================================================================
# DATASET SELECTION FOR TRAINING
# ==============================================================================
# Choose which dataset type to use for training:
# - 'spatial': Uses spatial distance-based graph connections
# - 'correlation': Uses correlation-based graph connections

DATASET_TYPE = 'spatial'  # Change this to 'correlation' to train with correlation-based graphs

if DATASET_TYPE == 'spatial':
    print("🌐 Selected SPATIAL dataset for training")
    train_loader = train_loader_spatial
    val_loader = val_loader_spatial
    te_loader = te_loader_spatial
    current_dataset = dataset_spatial_tr
elif DATASET_TYPE == 'correlation':
    print("🔗 Selected CORRELATION dataset for training")
    train_loader = train_loader_corr
    val_loader = val_loader_corr
    te_loader = te_loader_corr
    current_dataset = dataset_corr_tr
else:
    raise ValueError(f"Unknown dataset type: {DATASET_TYPE}. Choose 'spatial' or 'correlation'")

print(f"✅ Using {DATASET_TYPE} dataset:")
print(f"   Train batches: {len(train_loader)}")
print(f"   Val batches: {len(val_loader)}")
print(f"   Test batches: {len(te_loader)}")
print(f"   Total samples in dataset: {len(current_dataset)}")

# Optional: Print first batch info to verify data loading
try:
    first_batch = next(iter(train_loader))
    print(f"   First batch - Nodes: {first_batch.x.shape}, Edges: {first_batch.edge_index.shape}")
    print(f"   First batch - Labels: {first_batch.y.shape}, Batch size: {first_batch.num_graphs}")
except Exception as e:
    print(f"   Could not inspect first batch: {e}")

print(f"\n🚀 Ready to train with {DATASET_TYPE} dataset!")

🌐 Selected SPATIAL dataset for training
✅ Using spatial dataset:
   Train batches: 162
   Val batches: 41
   Test batches: 57
   Total samples in dataset: 12993
   First batch - Nodes: torch.Size([1216, 3009]), Edges: torch.Size([2, 21888])
   First batch - Labels: torch.Size([64]), Batch size: 64

🚀 Ready to train with spatial dataset!


In [9]:
%aimport
import torch.optim as optim
import torch.nn as nn
from src.layers.hybrid.cnn_bilstm_gcn import EEGCNNBiLSTMGCN
from src.utils.train import train_model
from src.utils.plot import plot_training_loss

config = {
    "learning_rate": 1e-4,
    "weight_decay": 1e-2,
    "patience": 10,
    "epochs": 100,
}

Modules to reload:


Modules to skip:



ModuleNotFoundError: No module named 'data.dataset_graph'

In [11]:
def wrap_train(model, save_path):
    model = model.to(device)
    # optimizer = optim.Adam(model.parameters(), lr=config["learning_rate"], weight_decay=config["weight_decay"])
    # optimizer = Lion(model.parameters(), lr=config["learning_rate"], weight_decay=config["weight_decay"])
    optimizer = torch.optim.AdamW(
        model.parameters(), lr=1e-4, weight_decay=0.01, betas=(0.9, 0.999))
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5)
    loss = nn.BCEWithLogitsLoss()  # Not weighted as we use a balanced sampler!

    # train model
    train_history, val_history = train_model(
        wandb_config=None,
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        criterion=loss,
        scheduler=scheduler,
        optimizer=optimizer,
        device=device,
        num_epochs=config["epochs"],
        patience=config["patience"],
        save_path=save_path,
        use_gnn=True,
        # hidden attribute
        try_load_checkpoint=True,
    )
    plot_training_loss(train_history["loss"], val_history["loss"])

### Test 3 - First breakthrough model

In [19]:
SAVE_PATH = CHECKPOINT_ROOT / "cnn_bilstm_gcn_test_3.pt"
model = EEGCNNBiLSTMGCN(
    # Parameters for the CNN_BiLSTM_Encoder (temporal encoder)
    cnn_dropout_prob = 0.25,
    lstm_hidden_dim = 128,
    lstm_out_dim = 128,
    lstm_dropout_prob = 0.25,
    encoder_use_batch_norm= True,
    encoder_use_layer_norm= False,
    # Parameters for the EEGGCN (graph neural network)
    gcn_hidden_channels = 128,
    gcn_out_channels = 128,
    gcn_pooling_type= "mean",
    gcn_use_batch_norm = True,
    gcn_num_layers = 3,
    gcn_dropout_prob = 0.5,
    num_channels = 19,
)
wrap_train(model, SAVE_PATH)

### Test 4 - Smaller CGN output channels

In [12]:
SAVE_PATH = CHECKPOINT_ROOT / "cnn_bilstm_gcn_test_4.pt"

model = EEGCNNBiLSTMGCN(
    # Parameters for the CNN_BiLSTM_Encoder (temporal encoder)
    cnn_dropout_prob = 0.25,
    lstm_hidden_dim = 128,
    lstm_out_dim = 128,
    lstm_dropout_prob = 0.25,
    encoder_use_batch_norm= True,
    encoder_use_layer_norm= False,
    # Parameters for the EEGGCN (graph neural network)
    gcn_hidden_channels = 128,
    gcn_out_channels = 64,
    gcn_pooling_type= "mean",
    gcn_use_batch_norm = True,
    gcn_num_layers = 3,
    gcn_dropout_prob = 0.5,
    num_channels = 19,
)
wrap_train(model, SAVE_PATH)

### Test 5 - Smaller GCN output channels + increased embedding length + Deeper GCN

In [13]:
SAVE_PATH = CHECKPOINT_ROOT / "cnn_bilstm_gcn_test_5.pt"
model = EEGCNNBiLSTMGCN(
    # Parameters for the CNN_BiLSTM_Encoder (temporal encoder)
    cnn_dropout_prob = 0.25,
    lstm_hidden_dim = 128,
    lstm_out_dim = 128,
    lstm_dropout_prob = 0.25,
    encoder_use_batch_norm= True,
    encoder_use_layer_norm= False,
    # Parameters for the EEGGCN (graph neural network)
    gcn_hidden_channels = 128,
    gcn_out_channels = 64,
    gcn_pooling_type= "mean",
    gcn_use_batch_norm = True,
    gcn_num_layers = 4,
    gcn_dropout_prob = 0.5,
    num_channels = 19,
)
wrap_train(model, SAVE_PATH)


### Test 6: slighly bigger GCN output channels
>[HIGHEST F1 SCORE EVER RECORDED]
```
✅ Checkpoint loaded. Resuming from epoch 33. Best 'val_f1' score: 0.7346
```

In [14]:
SAVE_PATH = CHECKPOINT_ROOT / "cnn_bilstm_gcn_test_6.pt"
model = EEGCNNBiLSTMGCN(
    # Parameters for the CNN_BiLSTM_Encoder (temporal encoder)
    cnn_dropout_prob = 0.25,
    lstm_hidden_dim = 128,
    lstm_out_dim = 128,
    lstm_dropout_prob = 0.25,
    encoder_use_batch_norm= True,
    encoder_use_layer_norm= False,
    # Parameters for the EEGGCN (graph neural network)
    gcn_hidden_channels = 128,
    gcn_out_channels = 96,
    gcn_pooling_type= "mean",
    gcn_use_batch_norm = True,
    gcn_num_layers = 4,
    gcn_dropout_prob = 0.5,
    num_channels = 19,
)
wrap_train(model, SAVE_PATH)

### Test 7B: Alternative architecture to improve generalization

In [15]:
SAVE_PATH = CHECKPOINT_ROOT / "cnn_bilstm_gcn_test_8.pt"
model = EEGCNNBiLSTMGCN(
    # Parameters for the CNN_BiLSTM_Encoder (temporal encoder)
    cnn_dropout_prob = 0.35, # slightly higher dropout to avoid overfitting
    lstm_hidden_dim = 128,
    lstm_out_dim = 128,
    lstm_dropout_prob = 0.35, # slightly higher dropout to avoid overfitting
    encoder_use_batch_norm= True,
    encoder_use_layer_norm= False,
    # Parameters for the EEGGCN (graph neural network)
    gcn_hidden_channels = 128,
    gcn_out_channels = 96,
    gcn_pooling_type= "mean",
    gcn_use_batch_norm = True,
    gcn_num_layers = 4,
    gcn_dropout_prob = 0.6, # slightly higher dropout to avoid overfitting
    num_channels = 19,
)
wrap_train(model, SAVE_PATH)

### Test 7C: slightly bigger GCN layers

BEST MODEL YET!

In [None]:
SAVE_PATH = CHECKPOINT_ROOT / "lstm_gnn_generalizable_bigger.pt"
model = EEGCNNBiLSTMGCN(
    # Parameters for the CNN_BiLSTM_Encoder (temporal encoder)
    cnn_dropout_prob = 0.25, # slightly higher dropout to avoid overfitting
    lstm_hidden_dim = 128,
    lstm_out_dim = 128,
    lstm_dropout_prob = 0.25, # slightly higher dropout to avoid overfitting
    encoder_use_batch_norm = True,
    encoder_use_layer_norm = False,
    # Parameters for the EEGGCN (graph neural network)
    gcn_hidden_channels = 192,
    gcn_out_channels = 128,
    gcn_pooling_type = "mean",
    gcn_use_batch_norm = True,
    gcn_num_layers = 4,
    gcn_dropout_prob = 0.6, # slightly higher dropout to avoid overfitting
    num_channels = 19,
)
wrap_train(model, SAVE_PATH)

### Test 7D: even bigger GCN layers

Comparable performance to best model. We might need to increase the number of GCN layers

In [None]:
SAVE_PATH = CHECKPOINT_ROOT / "lstm_gnn_generalizable_even_bigger.pt"
model = EEGCNNBiLSTMGCN(
    # Parameters for the CNN_BiLSTM_Encoder (temporal encoder)
    cnn_dropout_prob = 0.25, # slightly higher dropout to avoid overfitting
    lstm_hidden_dim = 128,
    lstm_out_dim = 128,
    lstm_dropout_prob = 0.25, # slightly higher dropout to avoid overfitting
    encoder_use_batch_norm = True,
    encoder_use_layer_norm = False,
    # Parameters for the EEGGCN (graph neural network)
    gcn_hidden_channels = 224,
    gcn_out_channels = 192,
    gcn_pooling_type = "mean",
    gcn_use_batch_norm = True,
    gcn_num_layers = 4,
    gcn_dropout_prob = 0.6, # slightly higher dropout to avoid overfitting
    num_channels = 19,
)
wrap_train(model, SAVE_PATH)

### Test 7E: increased number of GCN layers

Assumption: the previous model was unable to learn enough, maybe the GCN was unable to capture

```
Epochs:   9%| | 9/100 [17:54<3:23:31, 134.20s/it, train_loss=0.4532, val_loss=0.3489, best_val_f1=0.6695, lr=5.00e-05, b2025-06-07 17:01:05 - INFO - 
```

In [None]:
SAVE_PATH = CHECKPOINT_ROOT / "lstm_gnn_generalizable_even_more_bigger.pt"
model_generalizable_even_more_bigger = EEGCNNBiLSTMGCN(
    # Parameters for the CNN_BiLSTM_Encoder (temporal encoder)
    cnn_dropout_prob = 0.25,
    lstm_hidden_dim = 128,
    lstm_out_dim = 128,
    lstm_dropout_prob = 0.25,
    # Parameters for the EEGGCN (graph neural network)
    gcn_hidden_channels = 224,
    gcn_out_channels = 192,
    gcn_num_layers = 5,
    gcn_dropout_prob = 0.6, # slightly higher dropout to avoid overfitting
    num_classes = 1,
    num_channels = 19,
)

### Test 7F: Increased number of BiLSTM layers + Test 7E architecture

Assumpion: we saw a drammatical increase in accuracy by increasing the number of GCN layers. This hints that the model was now able to learn the most from the embeddings. To improve the performance even further without having to increase the number of GCN layers even more (overall reduce complexity, improve generalization), we will try to increase the number of BiLSTM layers. 

Using multiple BiLSTM layers will allow embeddings to be processed in a more complex way, potentially capturing more intricate relationships in the data. The GCN layers will take care of the graph structure, while the BiLSTM layers will enhance the temporal dependencies and relationships in the data.


In [None]:
SAVE_PATH = CHECKPOINT_ROOT / "lstm_gnn_generalizable_even_more_bigger.pt"
model_generalizable_even_more_bigger = EEGCNNBiLSTMGCN(
    # Parameters for the CNN_BiLSTM_Encoder (temporal encoder)
    cnn_dropout_prob = 0.25,
    lstm_hidden_dim = 128,
    lstm_out_dim = 128,
    lstm_dropout_prob = 0.25,
    lstm_num_layers = 2,
    # Parameters for the EEGGCN (graph neural network)
    gcn_hidden_channels = 224,
    gcn_out_channels = 192,
    gcn_num_layers = 5,
    gcn_dropout_prob = 0.6, # slightly higher dropout to avoid overfitting
    num_classes = 1,
    num_channels = 19,
)

```
Epochs:   1%|▊                                                                                  | 1/100 [00:00<?, ?it/s]2025-06-07 18:55:16 - INFO -
Epochs:   2%| | 2/100 [04:35<7:29:19, 275.10s/it, train_loss=0.6212, val_loss=0.4619, best_val_f1=0.4055, lr=1.00e-04, b2025-06-07 18:59:51 - INFO -
Epochs:   3%| | 3/100 [09:09<7:23:49, 274.53s/it, train_loss=0.5819, val_loss=0.4295, best_val_f1=0.4055, lr=1.00e-04, b2025-06-07 19:04:25 - INFO -
Epochs:   4%| | 4/100 [13:42<7:18:31, 274.08s/it, train_loss=0.5628, val_loss=0.4437, best_val_f1=0.4055, lr=1.00e-04, b2025-06-07 19:08:59 - INFO -
Epochs:   5%| | 5/100 [18:16<7:13:28, 273.78s/it, train_loss=0.5452, val_loss=0.3942, best_val_f1=0.4858, lr=1.00e-04, b2025-06-07 19:13:32 - INFO -
Epochs:   6%| | 6/100 [22:49<7:08:41, 273.63s/it, train_loss=0.5334, val_loss=0.4563, best_val_f1=0.4858, lr=1.00e-04, b2025-06-07 19:18:05 - INFO -
Epochs:   7%| | 7/100 [27:22<7:04:01, 273.57s/it, train_loss=0.5319, val_loss=0.3738, best_val_f1=0.5137, lr=1.00e-04, b2025-06-07 19:22:39 - INFO -
Epochs:   8%| | 8/100 [31:56<6:59:20, 273.48s/it, train_loss=0.5181, val_loss=0.4369, best_val_f1=0.5695, lr=1.00e-04, b2025-06-07 19:27:12 - INFO -
Epochs:   9%| | 9/100 [36:29<6:54:50, 273.52s/it, train_loss=0.5220, val_loss=0.4202, best_val_f1=0.5695, lr=1.00e-04, b2025-06-07 19:31:46 - INFO -
Epochs:  10%| | 10/100 [41:03<6:50:17, 273.52s/it, train_loss=0.5286, val_loss=0.4167, best_val_f1=0.5695, lr=1.00e-04, 2025-06-07 19:36:19 - INFO -
Epochs:  11%| | 11/100 [45:36<6:45:44, 273.53s/it, train_loss=0.5065, val_loss=0.3864, best_val_f1=0.5695, lr=1.00e-04, 2025-06-07 19:40:53 - INFO -
Epochs:  12%| | 12/100 [50:10<6:41:03, 273.45s/it, train_loss=0.5158, val_loss=0.5175, best_val_f1=0.5695, lr=5.00e-05, 2025-06-07 19:45:26 - INFO -
Epochs:  13%|▏| 13/100 [54:43<6:36:23, 273.37s/it, train_loss=0.5035, val_loss=0.3785, best_val_f1=0.5940, lr=5.00e-05, 2025-06-07 19:49:59 - INFO -
Epochs:  14%|▏| 14/100 [59:16<6:31:50, 273.38s/it, train_loss=0.4842, val_loss=0.3838, best_val_f1=0.5981, lr=5.00e-05, 2025-06-07 19:54:33 - INFO -
Epochs:  15%|▏| 15/100 [1:03:50<6:27:17, 273.38s/it, train_loss=0.4644, val_loss=0.3493, best_val_f1=0.6106, lr=5.00e-052025-06-07 19:59:06 - INFO -
Epochs:  16%|▏| 16/100 [1:08:23<6:22:46, 273.41s/it, train_loss=0.4887, val_loss=0.3737, best_val_f1=0.6106, lr=5.00e-052025-06-07 20:03:39 - INFO -
Epochs:  17%|▏| 17/100 [1:12:57<6:18:12, 273.41s/it, train_loss=0.4775, val_loss=0.3565, best_val_f1=0.6106, lr=5.00e-052025-06-07 20:08:13 - INFO -
Epochs:  18%|▏| 18/100 [1:17:30<6:13:42, 273.44s/it, train_loss=0.4635, val_loss=0.3704, best_val_f1=0.6106, lr=2.50e-052025-06-07 20:12:46 - INFO -
Epochs:  19%|▏| 19/100 [1:22:04<6:09:15, 273.53s/it, train_loss=0.4501, val_loss=0.3635, best_val_f1=0.6131, lr=2.50e-052025-06-07 20:17:20 - INFO -
Epochs:  20%|▏| 20/100 [1:26:37<6:04:39, 273.49s/it, train_loss=0.4379, val_loss=0.3638, best_val_f1=0.6179, lr=2.50e-052025-06-07 20:21:53 - INFO -
Epochs:  21%|▏| 21/100 [1:31:10<6:00:01, 273.43s/it, train_loss=0.4494, val_loss=0.3543, best_val_f1=0.6179, lr=2.50e-052025-06-07 20:26:27 - INFO -
Epochs:  22%|▏| 22/100 [1:35:44<5:55:26, 273.42s/it, train_loss=0.4616, val_loss=0.3616, best_val_f1=0.6659, lr=2.50e-052025-06-07 20:31:00 - INFO -
Epochs:  23%|▏| 23/100 [1:40:17<5:50:54, 273.44s/it, train_loss=0.4381, val_loss=0.3532, best_val_f1=0.6659, lr=2.50e-052025-06-07 20:35:34 - INFO -
Epochs:  24%|▏| 24/100 [1:44:51<5:46:22, 273.45s/it, train_loss=0.4423, val_loss=0.3635, best_val_f1=0.6659, lr=1.25e-052025-06-07 20:40:07 - INFO -
Epochs:  25%|▎| 25/100 [1:49:24<5:41:52, 273.49s/it, train_loss=0.4291, val_loss=0.3473, best_val_f1=0.6659, lr=1.25e-052025-06-07 20:44:41 - INFO -
Epochs:  26%|▎| 26/100 [1:53:58<5:37:12, 273.42s/it, train_loss=0.4403, val_loss=0.3380, best_val_f1=0.6659, lr=1.25e-052025-06-07 20:49:14 - INFO -
Epochs:  27%|▎| 27/100 [1:58:31<5:32:38, 273.40s/it, train_loss=0.4312, val_loss=0.3374, best_val_f1=0.6659, lr=1.25e-052025-06-07 20:53:47 - INFO -
Epochs:  28%|▎| 28/100 [2:03:05<5:28:07, 273.44s/it, train_loss=0.4393, val_loss=0.3441, best_val_f1=0.6659, lr=1.25e-052025-06-07 20:58:21 - INFO -
Epochs:  29%|▎| 29/100 [2:07:38<5:23:35, 273.46s/it, train_loss=0.4226, val_loss=0.3392, best_val_f1=0.6659, lr=1.25e-052025-06-07 21:02:54 - INFO -
Epochs:  30%|▎| 30/100 [2:12:11<5:19:02, 273.46s/it, train_loss=0.4240, val_loss=0.3525, best_val_f1=0.6659, lr=6.25e-062025-06-07 21:07:28 - INFO -
Epochs:  31%|▎| 31/100 [2:16:45<5:14:28, 273.46s/it, train_loss=0.4249, val_loss=0.3492, best_val_f1=0.6659, lr=6.25e-062025-06-07 21:12:01 - INFO -
```

In [None]:
SAVE_PATH = CHECKPOINT_ROOT / "lstm_gnn_generalizable_optimized.pt"
model_generalizable_optimized = EEGCNNBiLSTMGCN(
    # Parameters for the CNN_BiLSTM_Encoder (temporal encoder)
    cnn_dropout_prob = 0.25,
    lstm_hidden_dim = 160,
    lstm_out_dim = 128,
    lstm_dropout_prob = 0.25,
    lstm_num_layers = 2,
    # Parameters for the EEGGCN (graph neural network)
    gcn_hidden_channels = 192,
    gcn_out_channels = 128,
    gcn_num_layers = 4,
    gcn_dropout_prob = 0.5, # slightly higher dropout to avoid overfitting
    num_channels = 19,
)

NameError: name 'EEGCNNBiLSTMGCN' is not defined

### Test 8: Narrow but Deep GCN model

In [None]:
SAVE_PATH = CHECKPOINT_ROOT / "lstm_gnn_narrow_deep_model.pt"
narrow_deep_model = EEGCNNBiLSTMGCN(
    # --- Simplify the Temporal Encoder ---
    cnn_dropout_prob = 0.2,
    lstm_hidden_dim = 64,  # Reduced
    lstm_out_dim = 64,     # Reduced
    lstm_dropout_prob = 0.2,
    # --- Focus on the GCN ---
    gcn_hidden_channels = 128, # Keep GCN capacity high
    gcn_out_channels = 64,
    gcn_num_layers = 5,      # Try going even deeper
    gcn_dropout_prob = 0.5,
    num_classes = 1,
    num_channels = 19,
)

### Test 9: First best model, with wider + deeper GCN

In [None]:
SAVE_PATH = CHECKPOINT_ROOT / "lstm_gnn_new_old_best_model.pt"
new_old_best_model = EEGCNNBiLSTMGCN(
    # Parameters for the CNN_BiLSTM_Encoder (temporal encoder)
    cnn_dropout_prob = 0.25,
    lstm_hidden_dim = 128,
    lstm_out_dim = 128,
    lstm_dropout_prob = 0.25,
    # Parameters for the EEGGCN (graph neural network)
    gcn_hidden_channels = 128,
    gcn_out_channels = 128, # from 64 to 128
    gcn_num_layers = 4, # from 3 to 4
    gcn_dropout_prob = 0.5,
    num_classes = 1,
    num_channels = 19,
)

### Best model + attention BiLSTM

In [None]:
%aimport src.layers.hybrid.cnn_bilstm_attention_gcn
from src.layers.hybrid.cnn_bilstm_attention_gcn import EEGCNNBiLSTMAttentionGNN

SAVE_PATH = CHECKPOINT_ROOT / "lstm_gnn_attention.pt"
model_first_attention = EEGCNNBiLSTMAttentionGNN(
    # Parameters for the CNN_BiLSTM_Encoder (temporal encoder)
    cnn_dropout_prob = 0.25, # slightly higher dropout to avoid overfitting
    lstm_hidden_dim = 128,
    lstm_out_dim = 128,
    lstm_dropout_prob = 0.25, # slightly higher dropout to avoid overfitting
    encoder_use_batch_norm= True,
    encoder_use_layer_norm= False,
    # Parameters for the EEGGCN (graph neural network)
    gcn_hidden_channels = 192,
    gcn_out_channels = 128,
    gcn_num_layers = 4,
    gcn_dropout_prob = 0.6, # slightly higher dropout to avoid overfitting
    gcn_pooling_type= "mean",
    gcn_use_batch_norm = True,
    num_channels = 19,
)

In [None]:
%aimport
import torch.optim as optim
import torch.nn as nn
from src.utils.train import train_model

model = model_small_gcn_bigger_embedding
model = model.to(device)
# optimizer = optim.Adam(model.parameters(), lr=config["learning_rate"], weight_decay=config["weight_decay"])
# optimizer = Lion(model.parameters(), lr=config["learning_rate"], weight_decay=config["weight_decay"])
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01, betas=(0.9, 0.999))
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
loss = nn.BCEWithLogitsLoss() # Not weighted as we use a balanced sampler!

# empty cache in order to free up VRAM (if available)
if torch.cuda.is_available():
    torch.cuda.empty_cache()

# train model
train_history, val_history = train_model(
    wandb_config=None,
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=loss,
    scheduler=scheduler,
    optimizer=optimizer,
    device=device,
    num_epochs=config["epochs"],
    patience=config["patience"],
    save_path=SAVE_PATH,
    use_gnn=True,
    # hidden attribute
    try_load_checkpoint=True,
)

from src.utils.plot import plot_training_loss

plot_training_loss(train_history["loss"], val_history["loss"])

In [None]:
# torch cuda clear cache
torch.cuda.empty_cache()

In [None]:
from src.utils.plot import plot_training_loss

plot_training_loss(train_history["loss"], val_history["loss"])

In [None]:
print("=== CREATING DATA LOADERS ===")

# Create data loaders for SPATIAL dataset
print("Creating spatial data loaders...")
train_loader_spatial = GeoDataLoader(
    train_dataset_spatial,
    batch_size=BATCH_SIZE,
    sampler=sampler_spatial,
    shuffle=False,  # Don't shuffle when using sampler
    num_workers=2,
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=4
)

val_loader_spatial = GeoDataLoader(
    val_dataset_spatial,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=2,
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=4
)

te_loader_spatial = GeoDataLoader(
    dataset_corr_te,  # Using correlation test dataset for consistency
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=2,
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=4
)

print(f"Spatial - Train batches: {len(train_loader_spatial)}")
print(f"Spatial - Val batches: {len(val_loader_spatial)}")
print(f"Spatial - Test batches: {len(te_loader_spatial)}")

# Create data loaders for CORRELATION dataset
print("\nCreating correlation data loaders...")
train_loader_corr = GeoDataLoader(
    train_dataset_corr,
    batch_size=BATCH_SIZE,
    sampler=sampler_corr,
    shuffle=False,  # Don't shuffle when using sampler
    num_workers=2,
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=4
)

val_loader_corr = GeoDataLoader(
    val_dataset_corr,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=2,
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=4
)

te_loader_corr = GeoDataLoader(
    dataset_corr_te,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=2,
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=4
)

print(f"Correlation - Train batches: {len(train_loader_corr)}")
print(f"Correlation - Val batches: {len(val_loader_corr)}")
print(f"Correlation - Test batches: {len(te_loader_corr)}")

print("\n✅ All data loaders created successfully!")
print("\nYou can now use:")
print("  - train_loader_spatial, val_loader_spatial, te_loader_spatial for spatial graph training")
print("  - train_loader_corr, val_loader_corr, te_loader_corr for correlation graph training")

In [None]:
# ==============================================================================
# DATASET VERIFICATION AND COMPARISON
# ==============================================================================

print("=== DATASET COMPARISON ===")
print(f"Spatial dataset size: {len(dataset_spatial_tr)} samples")
print(f"Correlation dataset size: {len(dataset_corr_tr)} samples")
print(f"Test dataset size: {len(dataset_corr_te)} samples")

# Verify split consistency
print("\n=== SPLIT VERIFICATION ===")
print(f"Spatial splits - Train: {len(train_dataset_spatial)}, Val: {len(val_dataset_spatial)}")
print(f"Correlation splits - Train: {len(train_dataset_corr)}, Val: {len(val_dataset_corr)}")

# Check split ratios
spatial_train_ratio = len(train_dataset_spatial) / len(dataset_spatial_tr)
corr_train_ratio = len(train_dataset_corr) / len(dataset_corr_tr)
print(f"\nTrain ratios - Spatial: {spatial_train_ratio:.3f}, Correlation: {corr_train_ratio:.3f}")

# Verify labels are balanced
print("\n=== LABEL BALANCE VERIFICATION ===")
print("Spatial train labels:", np.bincount([clips_tr.iloc[i]['label'] for i in train_indices_spatial]))
print("Spatial val labels:", np.bincount([clips_tr.iloc[i]['label'] for i in val_indices_spatial]))
print("Correlation train labels:", np.bincount([clips_tr.iloc[i]['label'] for i in train_indices_corr]))
print("Correlation val labels:", np.bincount([clips_tr.iloc[i]['label'] for i in val_indices_corr]))

print("\n✅ All splits created successfully and verified!")
print("\n📝 Note: To train with different datasets, change DATASET_TYPE in the cell above.")

## Training Instructions

### Dataset Selection
You can now train with either dataset type by changing the `DATASET_TYPE` variable:

- **Spatial**: `DATASET_TYPE = 'spatial'` - Uses spatial distance-based graph connections
- **Correlation**: `DATASET_TYPE = 'correlation'` - Uses correlation-based graph connections

### Available Data Loaders

#### For Spatial Dataset:
- `train_loader_spatial` - Training data with weighted sampling for class balance
- `val_loader_spatial` - Validation data
- `te_loader_spatial` - Test data

#### For Correlation Dataset:
- `train_loader_corr` - Training data with weighted sampling for class balance
- `val_loader_corr` - Validation data
- `te_loader_corr` - Test data

### Split Details
- **Train/Validation ratio**: 80/20
- **Random seed**: 42 (for reproducibility)
- **Class balancing**: WeightedRandomSampler with oversampling power = 1.0
- **Batch size**: 64

### Training Tips
1. The `train_loader`, `val_loader`, and `te_loader` variables are automatically set based on your `DATASET_TYPE` selection
2. Both datasets use the same preprocessing pipeline but different graph construction strategies
3. The correlation dataset uses top-k=5 connections, while spatial uses distance-based connections
4. All data loaders include proper error handling and batch verification