In [1]:
# ╔══════════════════════════════════════════════════════════════════════╗
# ║ 0.C SET-UP: Imports                                                   ║
# ╚══════════════════════════════════════════════════════════════════════╝
import random
import math
import time
import json
import shutil
import warnings
import itertools
from pathlib import Path # Using pathlib can be nice for paths

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch_geometric.data import Data, Dataset, Batch
from torch_geometric.loader import DataLoader # Corrected import for DataLoader
from torch_geometric.nn import GCNConv, SAGEConv, GATConv, global_mean_pool

from datasets import load_dataset
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
import wandb

PROJECT_BASE_DIR = "../EEG_nml" # ✏️ ADJUST THIS PATH!

warnings.filterwarnings("ignore")
print("Libraries imported.")

Libraries imported.


In [178]:
# ╔══════════════════════════════════════════════════════════════════════╗
# ║ 1. CONFIGURATION                                                      ║
# ╚══════════════════════════════════════════════════════════════════════╝
config = dict(
    project         = "eeg-gnn-group-project",  # ✏️ ADJUST WandB project name if needed
    entity          = "danielebelfiore7-epfl", # ✏️ ADJUST your WandB entity (username or team)
    gnn_type        = "sage",                    # "gcn" | "sage" | "gat"
    edge_variant    = "knn",                   # "grid" | "knn" (ensure corresponding .pt file exists)
    hidden_dim      = 64,
    num_layers      = 6,
    dropout         = 0.25,
    lr              = 5e-5,
    weight_decay    = 1e-4,
    batch_size      = 32,
    epochs          = 25,                       # Can reduce for quick testing (e.g., 5)
    seed            = 1,
    base_dir        = PROJECT_BASE_DIR          # Using the verified path from Cell 1
)

# Initialize Weights & Biases (do this once at the beginning)
try:
    wandb.init(project=config["project"], entity=config["entity"],
               config=config, mode="online") # Use "disabled" for no logging during tests
    print(f"WandB initialized. Project: {config['project']}, Entity: {config['entity']}")
    print(f"Run link: {wandb.run.get_url() if wandb.run else 'WandB run not active (possibly disabled mode)'}")
except Exception as e:
    print(f"WandB initialization failed: {e}")
    print("Continuing without WandB logging. Check your entity and project name.")
    # Fallback to disabled mode if init fails
    if wandb.run: wandb.finish(quiet=True) # Ensure any partial run is closed
    wandb.init(project=config["project"], mode="disabled", config=config)
    print("WandB running in disabled mode.")

WandB initialized. Project: eeg-gnn-group-project, Entity: danielebelfiore7-epfl
Run link: https://wandb.ai/danielebelfiore7-epfl/eeg-gnn-group-project/runs/x04vg56z


In [144]:
# ╔══════════════════════════════════════════════════════════════════════╗
# ║ 2. REPRODUCIBILITY                                                    ║
# ╚══════════════════════════════════════════════════════════════════════╝
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed) # for multi-GPU.
    # Potentially add:
    # torch.backends.cudnn.deterministic = True
    # torch.backends.cudnn.benchmark = False

set_seed(config["seed"])

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
print(f"Seed set to: {config['seed']}")

Using device: cuda
Seed set to: 1


In [145]:
# ╔══════════════════════════════════════════════════════════════════════╗
# ║ 3a. INSPECT TRAIN DATA                                               ║
# ╚══════════════════════════════════════════════════════════════════════╝

import pandas as pd
import os
import numpy as np # For checking .npy files later

# Ensure config is defined if you're running this in your main notebook
# If not, define PROJECT_BASE_DIR directly:
# PROJECT_BASE_DIR = "/content/drive/MyDrive/EEG_nml"
PROJECT_BASE_DIR = config["base_dir"] # Assuming 'config' is available from your notebook

train_meta_path = os.path.join(PROJECT_BASE_DIR, "train_data", "metadata.parquet")

print(f"--- Inspecting TRAIN Metadata ({train_meta_path}) ---")
if os.path.exists(train_meta_path):
    try:
        df_meta = pd.read_parquet(train_meta_path)
        print(f"Successfully loaded {train_meta_path}")
        print(f"\nTotal rows in metadata: {len(df_meta)}")
        print(f"\nColumns in metadata.parquet:")
        print(list(df_meta.columns))

        print(f"\nMetadata Index Type: {type(df_meta.index)}")
        print(f"Metadata Index Names: {df_meta.index.names}")
        if not df_meta.empty:
            print(f"First 3 index entries: {list(df_meta.index[:3])}")
            print(f"Are indices unique? {df_meta.index.is_unique}")

        print(f"\nFirst 3 rows of metadata.parquet:")
        print(df_meta.head(3))

        # Check for NaN/missing values in critical columns
        critical_cols = ["signals_path", "segment", "label"] # Based on your script
        for col in critical_cols:
            if col in df_meta.columns:
                print(f"\nMissing values in '{col}': {df_meta[col].isnull().sum()}")
            else:
                print(f"\nColumn '{col}' NOT FOUND in metadata!")

        # Try to identify what data might be around the 999th entry (0-indexed)
        if len(df_meta) > 999:
            print("\nData around the 999th entry (0-indexed, so rows 998, 999, 1000):")
            print(df_meta.iloc[998:1001])
        elif not df_meta.empty:
            print("\nMetadata has fewer than 1000 entries. Showing last few entries:")
            print(df_meta.tail(3))


    except Exception as e:
        print(f"Error reading or inspecting {train_meta_path}: {e}")
else:
    print(f"File NOT FOUND: {train_meta_path}")

--- Inspecting TRAIN Metadata (../EEG_nml/train_data/metadata.parquet) ---
Successfully loaded ../EEG_nml/train_data/metadata.parquet

Total rows in metadata: 12993

Columns in metadata.parquet:
['patient', 'session', 'segment', 'label', 'start_time', 'end_time', 'date', 'sampling_rate', 'signals_path']

Metadata Index Type: <class 'pandas.core.indexes.range.RangeIndex'>
Metadata Index Names: [None]
First 3 index entries: [0, 1, 2]
Are indices unique? True

First 3 rows of metadata.parquet:
    patient    session  segment  label  start_time  end_time       date  \
0  pqejgcff  s001_t000        0      1         0.0      12.0 2003-01-01   
1  pqejgcff  s001_t000        1      1        12.0      24.0 2003-01-01   
2  pqejgcff  s001_t000        2      1        24.0      36.0 2003-01-01   

   sampling_rate                        signals_path  
0            250  signals/pqejgcff_s001_t000.parquet  
1            250  signals/pqejgcff_s001_t000.parquet  
2            250  signals/pqejgcff_s00

In [147]:
# ╔══════════════════════════════════════════════════════════════════════╗
# ║ 3b. LOAD HUGGING-FACE DATASET (TRAIN SPLIT)                          ║
# ╚══════════════════════════════════════════════════════════════════════╝
import gc

# List all potential variables that might hold references
vars_to_delete = ['hf_train', 'hf_test', 'train_ds', 'val_ds', 'test_ds',
                  'train_loader', 'val_loader', 'test_loader', 'model']

for var_name in vars_to_delete:
    if var_name in globals():
        try:
            del globals()[var_name]
            print(f"Deleted {var_name}")
        except Exception as e:
            print(f"Could not delete {var_name}: {e}")

# Force garbage collection
gc.collect()
print("Garbage collection run.")


SCRIPT_PATH = os.path.join(config["base_dir"], "EEG_nml.py")
print(f"Attempting to load dataset using script: {SCRIPT_PATH}")
print(f"Using data_dir for script: {config['base_dir']}")

hf_train = None # Initialize to allow later check

hf_train = load_dataset(path=SCRIPT_PATH, data_dir=config["base_dir"], split="train", name="default", download_mode="force_redownload")

try:
    # Assumes EEG_nml.py uses config["base_dir"] to find "train_data" for split="train"
    print(f"HF train dataset loaded successfully. Size: {len(hf_train)}")

    # Simple 80/20 random split for validation from the training patient data
    # This creates a validation set from the 50 training patients
    if len(hf_train) > 0:
        split_idx = int(0.8 * len(hf_train))
        perm      = np.random.permutation(len(hf_train))
        train_idx, val_idx = perm[:split_idx], perm[split_idx:]
        print(f"Train/validation split created: {len(train_idx)} train, {len(val_idx)} validation samples.")
    else:
        print("Warning: Loaded training dataset is empty. Cannot create train/val split.")
        train_idx, val_idx = [], []


except Exception as e:
    print(f"!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
    print(f"ERROR loading training dataset: {e}")
    print(f"Please check SCRIPT_PATH ('{SCRIPT_PATH}') and that EEG_nml.py is correctly configured.")
    print(f"Ensure 'EEG_nml.py' can find the 'train_data' folder within '{config['base_dir']}' when split='train' is used.")
    print(f"!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
    # Optional: Stop execution
    # raise

Old caching folder /home/.cache/huggingface/datasets/eeg_nml/default-data_dir=..%2FEEG_nml/1.0.0/16d9681eeba9a735676e78110d14b899d1de320df916d70e9f8d11ee07540654 for dataset eeg_nml exists but no data were found. Removing it. 


Deleted hf_train
Garbage collection run.
Attempting to load dataset using script: ../EEG_nml/EEG_nml.py
Using data_dir for script: ../EEG_nml
../EEG_nml


Generating train split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

HF train dataset loaded successfully. Size: 12993
Train/validation split created: 10394 train, 2599 validation samples.


In [183]:
# In a new cell after loading hf_train or from df_meta in Cell 4
# Assuming 'label' is the column name
# If using df_meta from Cell 4:
# label_counts = df_meta['label'].value_counts()
# print(f"Label distribution in metadata:\n{label_counts}")

# If using hf_train from Cell 5:
# Assuming hf_train is a Hugging Face Dataset object
labels = [sample['label'] for sample in hf_train]
import collections
label_counts = collections.Counter(labels)
print(f"Label distribution in hf_train:\n{label_counts}")
print(f"Class 0 percentage: {label_counts[0] / len(labels) * 100:.2f}%")
print(f"Class 1 percentage: {label_counts[1] / len(labels) * 100:.2f}%")
pos_weight_tensor_for_train = pos_w = torch.tensor([(label_counts[0] / len(labels))/(label_counts[1] / len(labels))], device=device)
print(pos_weight_tensor_for_train)

Label distribution in hf_train:
Counter({0: 10476, 1: 2517})
Class 0 percentage: 80.63%
Class 1 percentage: 19.37%
tensor([4.1621], device='cuda:0')


In [149]:
# ╔══════════════════════════════════════════════════════════════════════╗
# ║ 4.A DEFINE PYTORCH GEOMETRIC DATASET CLASS & LOAD EDGE INDEX         ║
# ╚══════════════════════════════════════════════════════════════════════╝
EDGE_INDEX = None # Initialize
edge_path = os.path.join(config["base_dir"], "edge_index", f"edge_index_{config['edge_variant']}.pt")
print(f"Attempting to load edge index from: {edge_path}")

try:
    EDGE_INDEX = torch.load(edge_path).long()  # shape (2, E)
    print(f"Edge index loaded successfully. Shape: {EDGE_INDEX.shape}")
except FileNotFoundError:
    print(f"!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
    print(f"ERROR: Edge index file NOT FOUND at {edge_path}")
    print(f"Please ensure the 'edge_index' folder and the file 'edge_index_{config['edge_variant']}.pt' exist in '{os.path.join(config['base_dir'], 'edge_index')}'.")
    print(f"!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
    # Optional: Stop execution
    # raise
except Exception as e:
    print(f"!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
    print(f"An unexpected error occurred while loading edge index: {e}")
    print(f"!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
    # Optional: Stop execution
    # raise

class GraphEEGDataset(Dataset):
    def __init__(self, hf_dataset, indices):
        super().__init__()
        if hf_dataset is None or len(indices) == 0 and len(hf_dataset)>0 : # check if hf_dataset is None or if indices are empty for non-empty dataset
             print("Warning: GraphEEGDataset initialized with None or empty hf_dataset/indices.")
             self.hf = [] # Make it an empty list to avoid errors in len/get
        elif len(hf_dataset) == 0 and len(indices) > 0 :
             print("Warning: GraphEEGDataset initialized with empty hf_dataset but non-empty indices.")
             self.hf = []
        else:
             self.hf = hf_dataset.select(indices)


    def len(self):
        return len(self.hf)

    def get(self, idx):
        if not self.hf: # Handle empty dataset case
            raise IndexError("Attempting to get item from an empty or uninitialized hf_dataset in GraphEEGDataset")
        row = self.hf[idx]
        x = torch.tensor(row["features"], dtype=torch.float)  # (19,9) - Ensure dtype
        y = torch.tensor([row["label"]], dtype=torch.float) # shape (1,) - Ensure dtype

        # --- Construct the new composite signal_id ---
        base_id = row["signal_id"]      # e.g., "pqejgcpt_s001_t000"
        segment_num = row["segment"]    # e.g., 0, 1, 2

        # Create the composite ID by appending segment number to the base ID
        # This matches your example: "pqejgcpt_s001_t000" + "0" -> "pqejgcpt_s001_t0000"
        composite_id = f"{base_id}_{segment_num}"
        
        # Make sure EDGE_INDEX is loaded, otherwise this will fail
        if EDGE_INDEX is None:
            raise ValueError("EDGE_INDEX is not loaded. Cannot create Data object.")
        
        # Optional: Add a debug print for the first few IDs to verify
        if idx < 5 and self.hf.split == 'test': # Print only for test set and first 5 items
             print(f"DEBUG GraphEEGDataset.get() (Test Split): base_id='{base_id}', segment_num='{segment_num}', CREATED composite_id='{composite_id}'")
        
        data = Data(x=x, edge_index=EDGE_INDEX, y=y,
                    signal_id=composite_id) # Make sure 'signal_id' exists
        return data

print("GraphEEGDataset class defined.")

Attempting to load edge index from: ../EEG_nml/edge_index/edge_index_knn.pt
Edge index loaded successfully. Shape: torch.Size([2, 152])
GraphEEGDataset class defined.


In [150]:
# ╔══════════════════════════════════════════════════════════════════════╗
# ║ 4.B CREATE PYG DATASETS AND DATALOADERS (TRAIN & VALIDATION)         ║
# ╚══════════════════════════════════════════════════════════════════════╝
if hf_train and EDGE_INDEX is not None and train_idx is not None and val_idx is not None:
    train_ds = GraphEEGDataset(hf_train, train_idx)
    val_ds   = GraphEEGDataset(hf_train, val_idx)

    train_loader = DataLoader(train_ds, batch_size=config["batch_size"], shuffle=True, drop_last=True if len(train_ds) > config["batch_size"] else False)
    val_loader   = DataLoader(val_ds, batch_size=config["batch_size"], shuffle=False, drop_last=False)

    print(f"PyG Datasets and DataLoaders created.")
    print(f"Loaded: {len(train_ds)} train | {len(val_ds)} validation examples.")
    if len(train_ds) > 0:
      print(f"Sample data from train_ds[0]: {train_ds[0]}")
      print(f"Node features shape: {train_ds[0].x.shape}, Label: {train_ds[0].y}")
    if len(val_ds) > 0:
      print(f"Sample data from val_ds[0]: {val_ds[0]}")
else:
    print("Skipping DataLoader creation due to previous errors (hf_train, EDGE_INDEX, or splits not available).")
    train_loader, val_loader = None, None # Ensure they are defined for later checks

PyG Datasets and DataLoaders created.
Loaded: 10394 train | 2599 validation examples.
Sample data from train_ds[0]: Data(x=[19, 9], edge_index=[2, 152], y=[1], signal_id='pqejgetp_s011_t001_40')
Node features shape: torch.Size([19, 9]), Label: tensor([0.])
Sample data from val_ds[0]: Data(x=[19, 9], edge_index=[2, 152], y=[1], signal_id='pqejgnkx_s002_t006_26')


In [167]:
# ╔══════════════════════════════════════════════════════════════════════╗
# ║ 5. (OLD) MODULAR GNN CLASS                                           ║
# ╚══════════════════════════════════════════════════════════════════════╝
CONV_MAP = dict(
    gcn  = GCNConv,
    sage = SAGEConv,
    gat  = lambda in_c, out_c: GATConv(in_c, out_c, heads=4, concat=True) # Single head, no dim expansion
)

class EEGGNN(nn.Module):
    def __init__(self, gnn_type, in_dim=9, hidden_dim=64, num_layers=2, dropout=0.3, num_classes=1):
        super().__init__()
        if gnn_type not in CONV_MAP:
            raise ValueError(f"Unsupported GNN type: {gnn_type}. Supported types are {list(CONV_MAP.keys())}")

        ConvLayer = CONV_MAP[gnn_type]
        self.convs = nn.ModuleList()

        # First layer
        self.convs.append(ConvLayer(in_dim, hidden_dim))

        # Hidden layers
        for _ in range(num_layers - 1):
            self.convs.append(ConvLayer(hidden_dim, hidden_dim))

        self.dropout_rate = dropout
        self.cls = nn.Linear(hidden_dim, num_classes) # num_classes usually 1 for binary_cross_entropy_with_logits

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        if x is None or edge_index is None:
             raise ValueError("Input data (x or edge_index) is None in EEGGNN forward pass.")

        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout_rate, training=self.training)

        # Global pooling
        if batch is None : # Handle cases where batch might be missing if a single graph is passed without a loader
             if x.size(0) > 0 : # Check if x has any nodes
                  g = global_mean_pool(x, torch.zeros(x.size(0), dtype=torch.long, device=x.device))
             else: # Handle empty graph case after convolutions
                  # Return a zero tensor of appropriate shape or handle error
                  # For now, let's assume this case should not happen with proper data
                  # Or, if it does, the classifier input needs to be handled.
                  # This might indicate an issue upstream if graphs become empty.
                  print("Warning: Graph has no nodes after convolutions before global_mean_pool.")
                  # Create a zero tensor for the classifier input, matching hidden_dim
                  # This part might need adjustment based on how num_classes is handled
                  # If num_classes is 1, then shape is (1, hidden_dim) for pool, then (1,1) for cls output
                  g = torch.zeros((data.num_graphs if hasattr(data, 'num_graphs') and data.num_graphs > 0 else 1, self.cls.in_features), device=x.device if x.device else device)
        else:
             g = global_mean_pool(x, batch)  # (batch_size, hidden_dim)

        out = self.cls(g) # (batch_size, num_classes)

        # If num_classes is 1 (for BCEWithLogitsLoss), squeeze the last dimension
        if self.cls.out_features == 1:
            out = out.squeeze(1) # (batch_size,)

        return out

print("EEGGNN class defined.")

EEGGNN class defined.


In [173]:
# ╔══════════════════════════════════════════════════════════════════════╗
# ║ 5. MULTI-HEAD-READY GNN CLASS                                        ║
# ╚══════════════════════════════════════════════════════════════════════╝



import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, SAGEConv, GATConv, global_mean_pool, BatchNorm

# Assume CONV_MAP and GAT_HEADS are defined as above

class EEGGNN(nn.Module):
    def __init__(self, gnn_type, in_dim=9, hidden_dim=64, num_layers=2, dropout=0.3, num_classes=1):
        super().__init__()
        if gnn_type not in CONV_MAP:
            raise ValueError(f"Unsupported GNN type: {gnn_type}. Supported types are {list(CONV_MAP.keys())}")

        ConvLayer = CONV_MAP[gnn_type]
        self.convs = nn.ModuleList()
        self.bns = nn.ModuleList()
        self.dropout_rate = dropout
        self.num_layers = num_layers
        self.gnn_type = gnn_type # Store gnn_type

        current_dim = in_dim

        for i in range(num_layers):
            # For GAT with concat=True, the output dim of the conv layer is hidden_dim * heads
            # The input to the *next* GAT layer's internal linear projection will handle this.
            # The GATConv layer itself will output hidden_dim * heads if concat=True.
            
            output_channels_for_conv = hidden_dim # This is 'out_channels' per head for GAT

            self.convs.append(ConvLayer(current_dim, output_channels_for_conv))
            
            # Determine the dimension after the convolution layer
            if gnn_type == "gat" and GAT_HEADS > 1 and getattr(self.convs[-1], 'concat', False): # Check if concat is True
                dim_after_conv = output_channels_for_conv * GAT_HEADS
            else:
                dim_after_conv = output_channels_for_conv
            
            self.bns.append(BatchNorm(dim_after_conv))
            current_dim = dim_after_conv # Input dim for the next layer

        # The final classifier's input dimension must match the output of the last GNN block
        self.cls = nn.Linear(current_dim, num_classes)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        if x is None or edge_index is None:
             raise ValueError("Input data (x or edge_index) is None in EEGGNN forward pass.")

        for i in range(self.num_layers):
            x_input_for_skip = x 

            x = self.convs[i](x, edge_index)
            x = self.bns[i](x)
            x = F.relu(x)
            
            # Skip connection: Ensure dimensions match or use a projection
            # This basic skip assumes the output of ReLU (x) has the same dim as x_input_for_skip
            # This might need adjustment if dimensions change significantly due to GAT concat.
            # A common pattern for GAT with concat is that the *next* layer takes hidden_dim * heads as input.
            # If you add skip connections, project x_input_for_skip to match x's dimension if they differ.
            # For simplicity, let's assume skip connections are primarily for non-GAT or GAT with average
            if self.gnn_type != "gat" or not getattr(self.convs[i], 'concat', False) or GAT_HEADS == 1:
                 if i > 0 and x_input_for_skip.shape == x.shape:
                     x = x + x_input_for_skip
            # If GAT with concat, managing skip connections requires careful dimension handling
            # or a projection layer for the skip path.

            x = F.dropout(x, p=self.dropout_rate, training=self.training)

        if batch is None:
             g = global_mean_pool(x, torch.zeros(x.size(0), dtype=torch.long, device=x.device)) \
                 if x.size(0) > 0 else torch.zeros((1, self.cls.in_features), device=x.device if x.device else 'cpu')
        else:
             g = global_mean_pool(x, batch)

        out = self.cls(g)

        if self.cls.out_features == 1:
            out = out.squeeze(1)
        return out
print("EEGGNN multi-head class defined.")

EEGGNN multi-head class defined.


In [168]:
# ╔══════════════════════════════════════════════════════════════════════╗
# ║ 5.B INSTANTIATE MODEL & OPTIMIZER                                     ║
# ╚══════════════════════════════════════════════════════════════════════╝
# Determine input dimension from data if possible, otherwise default to 9
# This assumes train_ds has been successfully created.
# Sample one data point to get in_dim (node feature dimension)
in_dim_data = 9 # Default
if train_loader and len(train_ds) > 0:
    try:
        sample_data_for_dim = train_ds.get(0)
        if hasattr(sample_data_for_dim, 'x') and sample_data_for_dim.x is not None:
            in_dim_data = sample_data_for_dim.x.shape[1]
        print(f"Input dimension automatically detected from data: {in_dim_data}")
    except Exception as e:
        print(f"Could not determine in_dim from data: {e}. Using default {in_dim_data}.")
else:
    print(f"train_loader or train_ds not available or empty. Using default in_dim: {in_dim_data}")


try:
    model = EEGGNN(
        gnn_type=config["gnn_type"],
        in_dim=in_dim_data, # Use detected or default in_dim
        hidden_dim=config["hidden_dim"],
        num_layers=config["num_layers"],
        dropout=config["dropout"]
        # num_classes is 1 by default for binary classification with BCEWithLogitsLoss
    ).to(device)

    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=config["lr"],
        weight_decay=config["weight_decay"]
    )
    print("Model and Optimizer instantiated.")
    print(model)
except Exception as e:
    print(f"!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
    print(f"ERROR instantiating model or optimizer: {e}")
    print(f"This could be due to an issue with GNN type: '{config['gnn_type']}' or other parameters.")
    print(f"!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
    model, optimizer = None, None # Ensure they are None if instantiation fails
    # raise

Input dimension automatically detected from data: 9
Model and Optimizer instantiated.
EEGGNN(
  (convs): ModuleList(
    (0): SAGEConv(9, 64, aggr=mean)
    (1-5): 5 x SAGEConv(64, 64, aggr=mean)
  )
  (cls): Linear(in_features=64, out_features=1, bias=True)
)


In [180]:
# ╔══════════════════════════════════════════════════════════════════════╗
# ║ 6a. Loss for imbalanced dataset                                      ║
# ╚══════════════════════════════════════════════════════════════════════╝



import torch
import torch.nn as nn
import torch.nn.functional as F

class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0, reduction='mean', pos_weight=None):
        """
        Focal Loss for binary classification.
        Args:
            alpha (float): Weighting factor for the rare class (e.g., if class 1 is rare, alpha for class 1).
                           Can be thought of as similar to pos_weight's role but applied differently.
                           Set to 0.25 for positive class if positive class is minority, 
                           or 1-alpha for negative class.
                           If None, no alpha weighting is applied.
            gamma (float): Focusing parameter. Higher values give more weight to hard, misclassified examples.
            reduction (str): 'mean', 'sum', or 'none'.
            pos_weight (torch.Tensor, optional): A weight of positive examples. If provided,
                                                 this will be used by the underlying BCEWithLogitsLoss.
                                                 If you use pos_weight here, you might not need alpha, or
                                                 alpha can be used to further tune. For simplicity,
                                                 you might start by using EITHER a good pos_weight OR alpha.
                                                 If using alpha, pos_weight should ideally be None or 1.
        """
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
        self.pos_weight = pos_weight

    def forward(self, logits, targets):
        # logits: model outputs BEFORE sigmoid (N,)
        # targets: binary ground truth labels (N,)

        # BCEWithLogitsLoss combines sigmoid and BCE loss for numerical stability
        # If pos_weight is provided, it handles the weighting for class 1.
        BCE_loss = F.binary_cross_entropy_with_logits(logits, targets, reduction='none', pos_weight=self.pos_weight)
        
        # Calculate pt (probability of the true class)
        # targets * logits for class 1, (1-targets) * (-logits) for class 0
        # p_t = exp(-BCE_loss) would be one way if BCE_loss was -log(pt)
        # A more direct way with logits:
        p = torch.sigmoid(logits)
        # For true positives (targets=1), pt = p. For true negatives (targets=0), pt = 1-p
        pt = p * targets + (1 - p) * (1 - targets)
        
        # Calculate Focal Loss
        focal_term = (1.0 - pt).pow(self.gamma)
        loss = focal_term * BCE_loss

        if self.alpha is not None:
            # If targets=1, alpha_t = alpha. If targets=0, alpha_t = 1-alpha
            # This assumes alpha is the weight for the positive class (targets=1)
            alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets)
            loss = alpha_t * loss

        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss

print("Focal_loss function defined.")

Focal_loss function defined.


In [153]:
# ╔══════════════════════════════════════════════════════════════════════╗
# ║ 6. OLD EPOCH FUNCTION (other loss)                                   ║
# ╚══════════════════════════════════════════════════════════════════════╝
#def run_epoch(loader, train_mode=True):
#    if model is None or optimizer is None:
#        print("Model or optimizer not initialized. Skipping run_epoch.")
#        return {"loss": float('nan'), "acc": 0.0, "f1": 0.0, "auroc": 0.0}
#
#    model.train(train_mode) # Set model to train or eval mode
#
#    all_labels, all_predictions_prob, all_losses = [], [], []
#
#    for batch_data in loader:
#        batch_data = batch_data.to(device)
#
#        # Forward pass
#        logits = model(batch_data) # Shape: (batch_size,)
#
#        # Ensure labels are also (batch_size,) and float for BCEWithLogitsLoss
#        pos_w = torch.tensor([4.16], device=device) # ~ num_class_0 / num_class_1
#        labels = batch_data.y.squeeze().float()
#        
#        if logits.shape != labels.shape:
#            # This can happen if batch size is 1 and squeeze leads to 0-dim tensor
#            # Or if there's a mismatch in expected output.
#            # For BCEWithLogits, both logits and labels should be [N] or [N, C]
#            # Given model outputs (batch_size,), labels should be (batch_size,)
#            print(f"Warning: Shape mismatch. Logits: {logits.shape}, Labels: {labels.shape}. Attempting to reshape labels.")
#            if labels.ndim == 0 and logits.ndim == 1 and logits.shape[0] == 1: # labels is scalar, logits is [1]
#                 labels = labels.unsqueeze(0)
#            elif labels.ndim == 2 and labels.shape[1] == 1 and logits.ndim == 1 : # labels is [N,1]
#                 labels = labels.squeeze(1)
#            # Add more sophisticated checks if needed
#            if logits.shape != labels.shape:
#                  raise ValueError(f"Corrected label shape {labels.shape} still does not match logits shape {logits.shape}")
#
#
#        loss = F.binary_cross_entropy_with_logits(logits, labels, pos_weight=pos_w)
#
#        if train_mode:
#            optimizer.zero_grad()
#            loss.backward()
#            optimizer.step()
#
#        all_labels.append(labels.cpu().detach())
#        all_predictions_prob.append(torch.sigmoid(logits).cpu().detach()) # Probabilities for AUROC
#        all_losses.append(loss.item())
#
#    if not all_labels or not all_predictions_prob: # If loader was empty
#        print("Warning: No data processed in run_epoch. Returning NaN/0 metrics.")
#        return {"loss": float('nan'), "acc": 0.0, "f1": 0.0, "auroc": 0.0}
#
#    # Concatenate all batch results
#    final_labels = torch.cat(all_labels).numpy()
#    final_predictions_prob = torch.cat(all_predictions_prob).numpy()
#
#    # Convert probabilities to binary predictions (0 or 1)
#    final_predictions_binary = (final_predictions_prob > 0.5).astype(int)
#
#    # Inside run_epoch function in Cell 10,
#    # right before: metrics = dict(...) or f1_score(...)
#    
#    print(f"Unique values in final_labels: {np.unique(final_labels)}")
#    print(f"dtype of final_labels: {final_labels.dtype}")
#    print(f"Shape of final_labels: {final_labels.shape}")
#    print(f"Unique values in final_predictions_binary: {np.unique(final_predictions_binary)}")
#    print(f"dtype of final_predictions_binary: {final_predictions_binary.dtype}")
#    
#    # Then your existing f1_score calculation:
#    # f1_val = f1_score(final_labels, final_predictions_binary, zero_division=0) # Renamed to f1_val to avoid conflict if 'f1' is a key
#    # metrics["f1"] = f1_val
#
#    metrics = dict(
#        loss = np.mean(all_losses),
#        acc  = accuracy_score(final_labels, final_predictions_binary),
#        f1   = f1_score(final_labels, final_predictions_binary, zero_division=0),
#    )
#    try:
#        # AUROC requires at least one sample from each class in y_true for non-trivial calculation
#        if len(np.unique(final_labels)) > 1:
#            metrics["auroc"] = roc_auc_score(final_labels, final_predictions_prob)
#        else:
#            metrics["auroc"] = 0.0 # Or nan, or skip logging if only one class present
#            # print(f"Warning: Only one class present in labels for {'train' if train_mode else 'validation/test'} set. AUROC set to 0.0.")
#    except ValueError as e_auroc:
#        # print(f"Could not calculate AUROC: {e_auroc}. Setting to 0.0.")
#        metrics["auroc"] = 0.0 # Or handle as NaN
#
#    return metrics
#
#print("run_epoch function defined.")

run_epoch function defined.


In [176]:
# ╔══════════════════════════════════════════════════════════════════════╗
# ║ 6b. TRAIN / VALIDATE EPOCH FUNCTION                                  ║
# ╚══════════════════════════════════════════════════════════════════════╝

import pandas as pd # Ensure pandas is imported if you use it for submission later


def run_epoch(loader, model, train_mode=True, is_predict_mode=False, pos_weight_tensor=None): # Added is_predict_mode and pos_weight_tensor
    if model is None: # Check if model is defined (should be)
        print("Model not initialized. Skipping run_epoch.")
        if is_predict_mode:
            return {"predictions_prob": np.array([]), "predictions_binary": np.array([]), "signal_ids": []}
        return {"loss": float('nan'), "acc": 0.0, "f1": 0.0, "auroc": 0.0}

    model.train(train_mode if not is_predict_mode else False) # Eval mode for prediction

    all_predictions_prob_list = [] # Use lists to append tensors of varying batch sizes
    all_signal_ids_list = []
    
    # Variables for metrics if not in predict mode
    all_labels_list = []
    all_losses_list = []

    for batch_data in loader:
        batch_data = batch_data.to(device)
        
        with torch.no_grad() if not train_mode or is_predict_mode else torch.enable_grad(): # No gradients needed for val or predict
            logits = model(batch_data) # Shape: (batch_size,)

        # Collect predictions (probabilities)
        current_probs = torch.sigmoid(logits).cpu().detach()
        all_predictions_prob_list.append(current_probs)

        if hasattr(batch_data, 'signal_id'):
            # Assuming signal_id is a list of strings already (from your GraphEEGDataset)
            all_signal_ids_list.extend(batch_data.signal_id)

        if not is_predict_mode:
            labels = batch_data.y.squeeze().float()
            if labels.ndim == 0: labels = labels.unsqueeze(0) # Ensure labels are at least 1D

            # Handle potential -1 labels if they accidentally slip into validation
            # This is a safeguard; ideally, val labels are 0 or 1.
            valid_label_mask = (labels == 0) | (labels == 1)
            if not torch.all(valid_label_mask):
                print(f"Warning: Invalid labels found in {'train' if train_mode else 'validation'} batch. Only using 0/1 for loss/metrics.")
            
            valid_logits = logits[valid_label_mask]
            valid_labels = labels[valid_label_mask]

            if valid_logits.numel() > 0: # Ensure there are valid samples
                current_pos_weight = None
                if train_mode and pos_weight_tensor is not None:
                     current_pos_weight = pos_weight_tensor
                
                focal_loss_fn = FocalLoss(alpha=None, gamma=2.0, reduction='mean', pos_weight=pos_weight_tensor_for_train)
                #loss = F.binary_cross_entropy_with_logits(valid_logits, valid_labels, pos_weight=current_pos_weight)
                loss = focal_loss_fn(valid_logits, valid_labels)
                
                all_losses_list.append(loss.item())
                all_labels_list.append(valid_labels.cpu()) # Store only valid labels for metrics

                if train_mode:
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
            elif train_mode: # If no valid labels in a training batch, still need to handle optimizer
                optimizer.zero_grad() # Clear gradients even if no backward pass occurred on this batch


    if is_predict_mode:
        if not all_predictions_prob_list: # Handle empty loader case
            return {"predictions_prob": np.array([]), "predictions_binary": np.array([]), "signal_ids": []}
            
        final_predictions_prob = torch.cat(all_predictions_prob_list).numpy()
        final_predictions_binary = (final_predictions_prob > 0.5).astype(int)
        return {
            "predictions_prob": final_predictions_prob,
            "predictions_binary": final_predictions_binary,
            "signal_ids": all_signal_ids_list # This is already a flat list of strings
        }

    # --- Metric Calculation for Training/Validation ---
    if not all_labels_list: # If loader was empty or no valid labels processed
        print("Warning: No data with valid labels (0 or 1) processed in run_epoch. Returning NaN/0 metrics.")
        return {"loss": float('nan'), "acc": 0.0, "f1": 0.0, "auroc": 0.0}

    final_labels = torch.cat(all_labels_list).numpy()
    # For metrics, use predictions corresponding to valid labels
    # This requires aligning predictions with valid_labels if filtering occurred batch-wise.
    # For simplicity here, assuming all_predictions_prob was for all samples,
    # and we're evaluating on the subset that had valid labels.
    # A more robust way would be to filter predictions alongside labels within the loop.
    # However, since we expect val labels to be 0/1, this might be okay.
    # Let's re-evaluate: we need predictions that correspond to final_labels
    
    # Re-think: if labels were filtered, predictions need to be filtered too.
    # For now, let's assume validation set has proper 0/1 labels and this filtering is minor.
    # The f1_score will be based on the collected valid labels.
    # We need the predictions that corresponded to these valid labels.
    # This part is tricky if `valid_label_mask` was used extensively.
    # Assuming for validation, all labels are 0 or 1.
    
    final_predictions_prob_for_metrics = torch.cat(all_predictions_prob_list).numpy() # These are all predictions
    if len(final_predictions_prob_for_metrics) != len(final_labels):
        # This would happen if some batches had no valid labels but still contributed predictions.
        # This indicates a need to more carefully collect predictions corresponding to valid labels.
        # For now, this will likely cause an error in metric calculation if lengths differ.
        # The robust way is to append to a list of predictions_for_metrics only when valid_labels exist.
        # Let's assume for val, all labels are valid 0/1.
        print(f"Warning: Length mismatch between predictions ({len(final_predictions_prob_for_metrics)}) and labels ({len(final_labels)}). Metrics might be incorrect.")
        # Fallback or error needed here. For simplicity, let's proceed but acknowledge this.


    final_predictions_binary_for_metrics = (final_predictions_prob_for_metrics > 0.5).astype(int)
    
    # Ensure final_labels only contains 0 and 1 for sklearn metrics if this check is still needed
    if not np.all(np.isin(final_labels, [0, 1])):
        print("Error: final_labels for metric calculation still contain non-binary values.")
        return {"loss": np.mean(all_losses_list) if all_losses_list else float('nan'), "acc": 0.0, "f1": 0.0, "auroc": 0.0}

    metrics = dict(
        loss = np.mean(all_losses_list) if all_losses_list else float('nan'),
        acc  = accuracy_score(final_labels, final_predictions_binary_for_metrics),
        f1   = f1_score(final_labels, final_predictions_binary_for_metrics, zero_division=0), # Default average='binary' if labels are [0,1]
    )
    try:
        if len(np.unique(final_labels)) > 1: # AUROC requires at least one sample from each class
            metrics["auroc"] = roc_auc_score(final_labels, final_predictions_prob_for_metrics)
        else:
            metrics["auroc"] = 0.0 
    except ValueError as e_auroc:
        metrics["auroc"] = 0.0
    return metrics

print("run_epoch_final with focal_loss function defined.")

run_epoch_final with focal_loss function defined.


In [177]:
# ╔══════════════════════════════════════════════════════════════════════╗
# ║ 7. MAIN TRAINING LOOP (just one)                                     ║
# ╚══════════════════════════════════════════════════════════════════════╝
if train_loader is not None and val_loader is not None and model is not None:
    best_val_f1 = 0
    best_model_state = None

    print(f"\n--- Starting Training for {config['epochs']} epochs ---")
    wandb.watch(model, log="all", log_freq=100) # Log model gradients and parameters

    for epoch in range(1, config["epochs"] + 1):
        epoch_start_time = time.time()

        train_metrics = run_epoch(train_loader, train_mode=True)
        val_metrics   = run_epoch(val_loader,   train_mode=False)

        epoch_duration = time.time() - epoch_start_time

        # Log metrics to WandB
        wandb_log_data = {}
        for k, v in train_metrics.items(): wandb_log_data[f"train_{k}"] = v
        for k, v in val_metrics.items(): wandb_log_data[f"val_{k}"] = v
        wandb_log_data["epoch"] = epoch
        wandb_log_data["epoch_duration_sec"] = epoch_duration
        wandb.log(wandb_log_data)

        print(f"[Epoch {epoch:03d}/{config['epochs']}] "
              f"Train Loss: {train_metrics['loss']:.4f}, Train F1: {train_metrics['f1']:.4f}, Train AUROC: {train_metrics['auroc']:.4f} | "
              f"Val Loss: {val_metrics['loss']:.4f}, Val F1: {val_metrics['f1']:.4f}, Val AUROC: {val_metrics['auroc']:.4f} | "
              f"Time: {epoch_duration:.2f}s")

        if val_metrics["f1"] > best_val_f1:
            best_val_f1 = val_metrics["f1"]
            best_model_state = model.state_dict().copy() # Get a copy of the state dict
            print(f"  New best validation F1: {best_val_f1:.4f}. Saving model state.")

            # Save the best model checkpoint to Google Drive
            best_model_filename = f"best_model_{config['gnn_type']}_{config['edge_variant']}_epoch{epoch}.pt"
            best_model_save_path = os.path.join(config["base_dir"], best_model_filename)
            try:
                torch.save(best_model_state, best_model_save_path)
                print(f"  Best model saved to: {best_model_save_path}")
                # Optionally, tell WandB about the saved model artifact
                # best_model_artifact = wandb.Artifact(f"{config['gnn_type']}-{config['edge_variant']}-best-model", type="model")
                # best_model_artifact.add_file(best_model_save_path)
                # wandb.log_artifact(best_model_artifact)

            except Exception as e_save:
                print(f"  Error saving model: {e_save}")

    print(f"\n--- Training Finished ---")
    print(f"Best Validation F1 achieved: {best_val_f1:.4f}")

    # Save the final best model state explicitly if not done above or if you want a generic name
    if best_model_state is not None:
        final_best_model_filename = f"best_overall_model_{config['gnn_type']}_{config['edge_variant']}.pt"
        final_best_model_save_path = os.path.join(config["base_dir"], final_best_model_filename)
        try:
            torch.save(best_model_state, final_best_model_save_path)
            print(f"Final best model state saved to: {final_best_model_save_path}")
        except Exception as e_save_final:
            print(f"  Error saving final best model: {e_save_final}")
    else:
        print("No best model state was saved during training (e.g., validation F1 never improved or training was skipped).")

else:
    print("Skipping training loop due to earlier errors (loaders or model not available).")


--- Starting Training for 25 epochs ---


Error: You must call wandb.init() before wandb.watch()

RuntimeError: Error(s) in loading state_dict for EEGGNN:
	Missing key(s) in state_dict: "convs.0.lin_l.weight", "convs.0.lin_l.bias", "convs.0.lin_r.weight", "convs.1.lin_l.weight", "convs.1.lin_l.bias", "convs.1.lin_r.weight", "convs.2.lin_l.weight", "convs.2.lin_l.bias", "convs.2.lin_r.weight", "convs.3.lin_l.weight", "convs.3.lin_l.bias", "convs.3.lin_r.weight", "convs.4.lin_l.weight", "convs.4.lin_l.bias", "convs.4.lin_r.weight", "convs.5.lin_l.weight", "convs.5.lin_l.bias", "convs.5.lin_r.weight". 
	Unexpected key(s) in state_dict: "convs.0.bias", "convs.0.lin.weight", "convs.1.bias", "convs.1.lin.weight". 

In [181]:
# ╔══════════════════════════════════════════════════════════════════════╗
# ║ 8. TRAIN ALL COMBINATIONS AND SAVE CSVs                              ║
# ╚══════════════════════════════════════════════════════════════════════╝

import os
import time
import torch
import wandb
import numpy as np
import pandas as pd
from torch_geometric.loader import DataLoader # Ensure DataLoader is imported

# --- ASSUMPTIONS: Ensure these are defined from your notebook before this block ---
# PROJECT_BASE_DIR, SCRIPT_PATH, device, set_seed (function)
# GraphEEGDataset (class), EEGGNN (class), run_epoch (function)
# hf_train, train_idx, val_idx, hf_test (Hugging Face dataset objects and indices)
# pos_weight_tensor_for_train (calculated tensor for weighted loss)
# global EDGE_INDEX (this will be updated in the loop)
# ---------------------------------------------------------------------------------

# --- Configurations to iterate through ---
gnn_types_to_run = ["gat", "sage", "gcn"]
edge_variants_to_run = ["grid", "knn"] # Ensure 'edge_index_knn.pt' exists in your 'edge_index' folder
GAT_HEADS = 4

# Example instantiation (tune alpha, gamma, and decide on pos_weight vs alpha)
# Option 1: Using alpha (pos_weight inside FocalLoss should be None or 1)
# focal_loss_fn = FocalLoss(alpha=0.75, gamma=2.0, reduction='mean') # Assuming class 1 is minority and we weight it by 0.75

# Option 2: Using pos_weight directly within FocalLoss's BCE_loss calculation (alpha=None)
# Ensure pos_weight_tensor_for_train is calculated and on the correct device
focal_loss_fn = FocalLoss(alpha=None, gamma=2.0, reduction='mean', pos_weight=pos_weight_tensor_for_train)

# --- Base configuration (parts will be overridden) ---
# (Using your provided config, you can adjust epochs for quicker tests initially)
base_run_config = {
    "project": "eeg-gnn-group-project",
    "entity": "danielebelfiore7-epfl", # Your WandB entity
    "hidden_dim": 64,
    "num_layers": 4,  # Using num_layers from your latest config
    "dropout": 0.05,
    "lr": 1e-5,       # Using lr from your latest config
    "weight_decay": 1e-6,
    "batch_size": 64,
    "epochs": 15,     # Using epochs from your latest config
    "seed": 1,
    "base_dir": PROJECT_BASE_DIR
}

# --- To store results for the final table ---
experiment_results = []

# --- Main Experimental Loop ---
for gnn_type_val in gnn_types_to_run:
    for edge_variant_val in edge_variants_to_run:
        
        current_experiment_config = base_run_config.copy()
        current_experiment_config["gnn_type"] = gnn_type_val
        current_experiment_config["edge_variant"] = edge_variant_val
        
        run_id_str = f"{gnn_type_val}_{edge_variant_val}_layers{current_experiment_config['num_layers']}_lr{current_experiment_config['lr']}"
        print(f"\n\n{'='*35}\n🚀 STARTING EXPERIMENT: {run_id_str}\n{'='*35}")

        # 1. Initialize WandB
        if wandb.run is not None:
            wandb.finish() # Ensure any previous run is closed
        try:
            run = wandb.init(
                project=current_experiment_config["project"],
                entity=current_experiment_config["entity"],
                config=current_experiment_config,
                name=run_id_str,
                reinit=True,
                mode="online" # Set to "disabled" if you want to test without logging
            )
            print(f"WandB initialized for: {run_id_str}")
        except Exception as e:
            print(f"WandB initialization failed for {run_id_str}: {e}. Running in 'disabled' mode.")
            run = wandb.init(mode="disabled", config=current_experiment_config, name=run_id_str, reinit=True)

        # 2. Set Seed
        set_seed(current_experiment_config["seed"])

        # 3. Load Edge Index (crucially, this needs to be the global EDGE_INDEX if GraphEEGDataset uses it globally)
        edge_file_path = os.path.join(current_experiment_config["base_dir"], "edge_index", f"edge_index_{current_experiment_config['edge_variant']}.pt")
        print(f"Loading edge index: {edge_file_path}")
        try:
            EDGE_INDEX = torch.load(edge_file_path, map_location=device).long()
            print(f"Edge index '{current_experiment_config['edge_variant']}' loaded. Shape: {EDGE_INDEX.shape}")
        except FileNotFoundError:
            print(f"🛑 ERROR: Edge index file NOT FOUND at {edge_file_path}. SKIPPING this configuration.")
            if run.mode != "disabled": run.finish()
            continue
        except Exception as e:
            print(f"🛑 ERROR loading edge index {edge_file_path}: {e}. SKIPPING this configuration.")
            if run.mode != "disabled": run.finish()
            continue

        # 4. Create DataLoaders (these will now use the correct, globally updated EDGE_INDEX when GraphEEGDataset.get is called)
        # Assumes hf_train, train_idx, val_idx are already loaded/defined.
        iter_train_ds = GraphEEGDataset(hf_train, train_idx)
        iter_val_ds = GraphEEGDataset(hf_train, val_idx)
        iter_train_loader = DataLoader(iter_train_ds, batch_size=current_experiment_config["batch_size"], shuffle=True, drop_last=len(iter_train_ds) > current_experiment_config["batch_size"])
        iter_val_loader = DataLoader(iter_val_ds, batch_size=current_experiment_config["batch_size"], shuffle=False)
        print(f"DataLoaders created. Train: {len(iter_train_ds)}, Val: {len(iter_val_ds)}")

        # 5. Instantiate Model & Optimizer
        # Determine input dimension from data if possible, otherwise default.
        model_in_dim = iter_train_ds.get(0).x.shape[1] if len(iter_train_ds) > 0 else 9 
        
        try:
            model = EEGGNN(
                gnn_type=current_experiment_config["gnn_type"],
                in_dim=model_in_dim,
                hidden_dim=current_experiment_config["hidden_dim"],
                num_layers=current_experiment_config["num_layers"],
                dropout=current_experiment_config["dropout"]
            ).to(device)

            optimizer = torch.optim.Adam(
                model.parameters(),
                lr=current_experiment_config["lr"],
                weight_decay=current_experiment_config["weight_decay"]
            )
            print(f"Model ({current_experiment_config['gnn_type']}) & Optimizer instantiated.")
        except Exception as e:
            print(f"🛑 ERROR instantiating model/optimizer for {run_id_str}: {e}. SKIPPING this configuration.")
            if run.mode != "disabled": run.finish()
            continue
            
        # 6. Training Loop (adapting your existing loop)
        current_best_val_f1 = 0.0
        current_best_model_state = None
        
        print(f"--- 🏋️ Starting Training for: {run_id_str} ---")
        if run.mode != "disabled":
            wandb.watch(model, log="all", log_freq=100)

        for epoch_num in range(1, current_experiment_config["epochs"] + 1):
            epoch_time_start = time.time()
            
            # IMPORTANT: Ensure your run_epoch function takes model, optimizer, device as args
            # e.g., def run_epoch(loader, model, optimizer, device, train_mode=True, is_predict_mode=False, pos_weight_tensor=None):
            # def run_epoch(loader, model, train_mode=True, is_predict_mode=False, pos_weight_tensor=None): # Added is_predict_mode and pos_weight_tensor

            
            metrics_train = run_epoch(iter_train_loader, model, train_mode=True, is_predict_mode=False, pos_weight_tensor=pos_weight_tensor_for_train)
            metrics_val = run_epoch(iter_val_loader, model, train_mode=False, is_predict_mode=False, pos_weight_tensor=pos_weight_tensor_for_train)
            epoch_time_duration = time.time() - epoch_time_start

            if run.mode != "disabled":
                wandb_metrics_log = {f"train_{k_met}": v_met for k_met, v_met in metrics_train.items()}
                wandb_metrics_log.update({f"val_{k_met}": v_met for k_met, v_met in metrics_val.items()})
                wandb_metrics_log["epoch"] = epoch_num
                wandb.log(wandb_metrics_log)

            print(f"[Epoch {epoch_num:03d}/{current_experiment_config['epochs']}] "
                  f"Tr Loss: {metrics_train['loss']:.4f}, Tr F1: {metrics_train['f1']:.4f}, Tr AUROC: {metrics_train['auroc']:.4f} | "
                  f"Val Loss: {metrics_val['loss']:.4f}, Val F1: {metrics_val['f1']:.4f}, Val AUROC: {metrics_val['auroc']:.4f} | "
                  f"Time: {epoch_time_duration:.2f}s")

            if metrics_val["f1"] > current_best_val_f1:
                current_best_val_f1 = metrics_val["f1"]
                current_best_model_state = model.state_dict().copy()
                print(f"  ⭐ New best validation F1: {current_best_val_f1:.4f} for {run_id_str}.")
        
        print(f"--- ✅ Training Finished for {run_id_str}. Best Val F1: {current_best_val_f1:.4f} ---")

        # Save the best model for this specific run
        saved_model_path_for_run = None
        if current_best_model_state:
            best_model_file = f"best_model_{run_id_str}.pt"
            saved_model_path_for_run = os.path.join(current_experiment_config["base_dir"], best_model_file)
            torch.save(current_best_model_state, saved_model_path_for_run)
            print(f"Best model for {run_id_str} saved to: {saved_model_path_for_run}")
            if run.mode != "disabled":
                artifact_model = wandb.Artifact(f"{run_id_str}-model", type="model")
                artifact_model.add_file(saved_model_path_for_run)
                wandb.log_artifact(artifact_model)
        else:
            print(f"No best model state was saved for {run_id_str} (Val F1 did not improve).")

        # Store results for this run (customize what you want to store)
        run_summary = {
            "config_run_id": run_id_str,
            "gnn_type": current_experiment_config["gnn_type"],
            "edge_variant": current_experiment_config["edge_variant"],
            "num_layers": current_experiment_config["num_layers"],
            "lr": current_experiment_config["lr"],
            "best_val_f1": current_best_val_f1,
            "final_train_loss": metrics_train['loss'] if 'metrics_train' in locals() else float('nan'), # Last epoch's train loss
            "final_train_f1": metrics_train['f1'] if 'metrics_train' in locals() else float('nan'),
            "final_val_loss": metrics_val['loss'] if 'metrics_val' in locals() else float('nan'), # Last epoch's val loss
            "model_path": saved_model_path_for_run if saved_model_path_for_run else "N/A"
        }
        experiment_results.append(run_summary)

        # 7. Test Set Prediction (Optional - Generate submission CSV for each run)
        if saved_model_path_for_run and os.path.exists(saved_model_path_for_run):
            print(f"\n--- 🧪 Generating Predictions for {run_id_str} ---")
            model.load_state_dict(torch.load(saved_model_path_for_run, map_location=device))
            model.eval()

            if 'hf_test' not in globals() or hf_test is None: # Load hf_test if not already global
                 print("Loading hf_test for predictions...")
                 hf_test = load_dataset(path=SCRIPT_PATH, data_dir=current_experiment_config["base_dir"], split="test", name="default")

            if len(hf_test) > 0:
                iter_test_ds = GraphEEGDataset(hf_test, list(range(len(hf_test)))) # Uses current global EDGE_INDEX
                iter_test_loader = DataLoader(iter_test_ds, batch_size=current_experiment_config["batch_size"], shuffle=False)
                
                # Pass optimizer=None as it's not used in predict_mode
                output_test = run_epoch(iter_test_loader, model, train_mode=False, is_predict_mode=True)
                
                df_submission = pd.DataFrame({
                    'id': output_test['signal_ids'],
                    'label': output_test['predictions_binary']
                })
                submission_file = f"submission_{run_id_str}.csv"
                submission_file_path = os.path.join(current_experiment_config["base_dir"], submission_file)
                df_submission.to_csv(submission_file_path, index=False)
                print(f"Submission CSV for {run_id_str} saved to: {submission_file_path}")
                if run.mode != "disabled":
                    artifact_submission = wandb.Artifact(f"{run_id_str}-submission", type="predictions")
                    artifact_submission.add_file(submission_file_path)
                    wandb.log_artifact(artifact_submission)
            else:
                print(f"Skipping prediction for {run_id_str}: hf_test is empty.")
        else:
            print(f"Skipping prediction for {run_id_str}: No best model was saved or found.")

        # 8. Finish WandB run
        if run.mode != "disabled":
            run.finish()
        
        print(f"\n{'='*35}\n🏁 COMPLETED EXPERIMENT: {run_id_str}\n{'='*35}")

print("\n\n🎉 All experiment configurations processed! 🎉")

# --- Display results in a table ---
if experiment_results:
    results_df = pd.DataFrame(experiment_results)
    print("\n\n--- 📊 Experiment Summary Table ---")
    print(results_df.to_string()) # Print full DataFrame
else:
    print("No results were collected (all runs may have been skipped).")



🚀 STARTING EXPERIMENT: gat_grid_layers4_lr1e-05


WandB initialized for: gat_grid_layers4_lr1e-05
Loading edge index: ../EEG_nml/edge_index/edge_index_grid.pt
Edge index 'grid' loaded. Shape: torch.Size([2, 48])
DataLoaders created. Train: 10394, Val: 2599
Model (gat) & Optimizer instantiated.
--- 🏋️ Starting Training for: gat_grid_layers4_lr1e-05 ---
[Epoch 001/15] Tr Loss: 0.2768, Tr F1: 0.3567, Tr AUROC: 0.6348 | Val Loss: 0.2501, Val F1: 0.4191, Val AUROC: 0.7466 | Time: 8.10s
  ⭐ New best validation F1: 0.4191 for gat_grid_layers4_lr1e-05.
[Epoch 002/15] Tr Loss: 0.2547, Tr F1: 0.4147, Tr AUROC: 0.7143 | Val Loss: 0.2411, Val F1: 0.4372, Val AUROC: 0.7616 | Time: 8.42s
  ⭐ New best validation F1: 0.4372 for gat_grid_layers4_lr1e-05.
[Epoch 003/15] Tr Loss: 0.2481, Tr F1: 0.4274, Tr AUROC: 0.7323 | Val Loss: 0.2374, Val F1: 0.4564, Val AUROC: 0.7686 | Time: 8.11s
  ⭐ New best validation F1: 0.4564 for gat_grid_layers4_lr1e-05.
[Epoch 004/15] Tr Loss: 0.2473, Tr F1: 0.4281, Tr AUROC: 0.7292 | Val Loss: 0.2391, Val F1: 0.4446, Val A

0,1
epoch,▁▁▂▃▃▃▄▅▅▅▆▇▇▇█
train_acc,▁▅▆▇▇█▇██▇█████
train_auroc,▁▆▇▇▇▇▇▇▇▇█████
train_f1,▁▅▆▆▇▇▆▇█▇█▇██▇
train_loss,█▄▃▃▂▂▂▂▂▂▁▁▁▁▁
val_acc,▁▄▅▆▅▇▇▅▆▆▇█▇▇█
val_auroc,▁▅▇▃▃▇▃▅▄▅█▄▃▆▅
val_f1,▁▄▆▅▄█▆▅▂▆▇▅▂█▇
val_loss,█▄▃▄▃▂▃▂▂▂▁▂▂▁▂

0,1
epoch,15.0
train_acc,0.65461
train_auroc,0.74691
train_f1,0.43473
train_loss,0.24121
val_acc,0.72913
val_auroc,0.76126
val_f1,0.46423
val_loss,0.23394



🏁 COMPLETED EXPERIMENT: gat_grid_layers4_lr1e-05


🚀 STARTING EXPERIMENT: gat_knn_layers4_lr1e-05


WandB initialized for: gat_knn_layers4_lr1e-05
Loading edge index: ../EEG_nml/edge_index/edge_index_knn.pt
Edge index 'knn' loaded. Shape: torch.Size([2, 152])
DataLoaders created. Train: 10394, Val: 2599
Model (gat) & Optimizer instantiated.
--- 🏋️ Starting Training for: gat_knn_layers4_lr1e-05 ---
[Epoch 001/15] Tr Loss: 0.2767, Tr F1: 0.3559, Tr AUROC: 0.6369 | Val Loss: 0.2494, Val F1: 0.4126, Val AUROC: 0.7512 | Time: 8.06s
  ⭐ New best validation F1: 0.4126 for gat_knn_layers4_lr1e-05.
[Epoch 002/15] Tr Loss: 0.2538, Tr F1: 0.4138, Tr AUROC: 0.7154 | Val Loss: 0.2405, Val F1: 0.4422, Val AUROC: 0.7647 | Time: 8.09s
  ⭐ New best validation F1: 0.4422 for gat_knn_layers4_lr1e-05.
[Epoch 003/15] Tr Loss: 0.2475, Tr F1: 0.4287, Tr AUROC: 0.7330 | Val Loss: 0.2372, Val F1: 0.4514, Val AUROC: 0.7663 | Time: 7.91s
  ⭐ New best validation F1: 0.4514 for gat_knn_layers4_lr1e-05.
[Epoch 004/15] Tr Loss: 0.2461, Tr F1: 0.4340, Tr AUROC: 0.7322 | Val Loss: 0.2383, Val F1: 0.4481, Val AUROC: 

0,1
epoch,▁▁▂▃▃▃▄▅▅▅▆▇▇▇█
train_acc,▁▅▆▇▇▇▇██▇█████
train_auroc,▁▆▇▇▇▇▇▇▇▇█████
train_f1,▁▅▆▇▇▇▆▇▇▇█▇██▇
train_loss,█▄▃▃▂▂▂▂▂▂▁▁▁▁▁
val_acc,▁▄▅▅▆▇▇▅▇▇█▇▆▇█
val_auroc,▁▅▆▃▅▇▃▅▄▅█▄▄▆▅
val_f1,▁▄▅▄▆▇▅▅▄▇▇▃▅▅█
val_loss,█▄▃▃▃▂▃▂▂▂▁▂▂▁▁

0,1
epoch,15.0
train_acc,0.66107
train_auroc,0.75027
train_f1,0.43812
train_loss,0.23977
val_acc,0.72797
val_auroc,0.76523
val_f1,0.49027
val_loss,0.23276



🏁 COMPLETED EXPERIMENT: gat_knn_layers4_lr1e-05


🚀 STARTING EXPERIMENT: sage_grid_layers4_lr1e-05


WandB initialized for: sage_grid_layers4_lr1e-05
Loading edge index: ../EEG_nml/edge_index/edge_index_grid.pt
Edge index 'grid' loaded. Shape: torch.Size([2, 48])
DataLoaders created. Train: 10394, Val: 2599
Model (sage) & Optimizer instantiated.
--- 🏋️ Starting Training for: sage_grid_layers4_lr1e-05 ---
[Epoch 001/15] Tr Loss: 0.5666, Tr F1: 0.2584, Tr AUROC: 0.4199 | Val Loss: 0.3724, Val F1: 0.2554, Val AUROC: 0.4125 | Time: 7.24s
  ⭐ New best validation F1: 0.2554 for sage_grid_layers4_lr1e-05.
[Epoch 002/15] Tr Loss: 0.3327, Tr F1: 0.3020, Tr AUROC: 0.4751 | Val Loss: 0.2966, Val F1: 0.3010, Val AUROC: 0.5445 | Time: 7.22s
  ⭐ New best validation F1: 0.3010 for sage_grid_layers4_lr1e-05.
[Epoch 003/15] Tr Loss: 0.2925, Tr F1: 0.3262, Tr AUROC: 0.5690 | Val Loss: 0.2777, Val F1: 0.3379, Val AUROC: 0.6215 | Time: 7.23s
  ⭐ New best validation F1: 0.3379 for sage_grid_layers4_lr1e-05.
[Epoch 004/15] Tr Loss: 0.2781, Tr F1: 0.3526, Tr AUROC: 0.6162 | Val Loss: 0.2685, Val F1: 0.3765,

0,1
epoch,▁▁▂▃▃▃▄▅▅▅▆▇▇▇█
train_acc,▂▁▃▅▆▇▇████████
train_auroc,▁▂▅▆▆▇▇▇███████
train_f1,▁▃▄▆▆▇▇▇▇█▇▇███
train_loss,█▃▂▁▁▁▁▁▁▁▁▁▁▁▁
val_acc,▂▁▄▆▆▇▇▇███▇▇█▇
val_auroc,▁▄▆▇▇▇▇████████
val_f1,▁▃▄▆▆▆▇▆██▇█▇██
val_loss,█▄▂▂▂▂▂▁▁▁▁▁▁▁▁

0,1
epoch,15.0
train_acc,0.61912
train_auroc,0.681
train_f1,0.39921
train_loss,0.26134
val_acc,0.62178
val_auroc,0.71529
val_f1,0.43925
val_loss,0.25252



🏁 COMPLETED EXPERIMENT: sage_grid_layers4_lr1e-05


🚀 STARTING EXPERIMENT: sage_knn_layers4_lr1e-05


WandB initialized for: sage_knn_layers4_lr1e-05
Loading edge index: ../EEG_nml/edge_index/edge_index_knn.pt
Edge index 'knn' loaded. Shape: torch.Size([2, 152])
DataLoaders created. Train: 10394, Val: 2599
Model (sage) & Optimizer instantiated.
--- 🏋️ Starting Training for: sage_knn_layers4_lr1e-05 ---
[Epoch 001/15] Tr Loss: 0.5724, Tr F1: 0.2574, Tr AUROC: 0.4201 | Val Loss: 0.3725, Val F1: 0.2543, Val AUROC: 0.4146 | Time: 7.34s
  ⭐ New best validation F1: 0.2543 for sage_knn_layers4_lr1e-05.
[Epoch 002/15] Tr Loss: 0.3329, Tr F1: 0.3028, Tr AUROC: 0.4765 | Val Loss: 0.2969, Val F1: 0.3073, Val AUROC: 0.5502 | Time: 7.26s
  ⭐ New best validation F1: 0.3073 for sage_knn_layers4_lr1e-05.
[Epoch 003/15] Tr Loss: 0.2917, Tr F1: 0.3246, Tr AUROC: 0.5744 | Val Loss: 0.2772, Val F1: 0.3368, Val AUROC: 0.6210 | Time: 7.25s
  ⭐ New best validation F1: 0.3368 for sage_knn_layers4_lr1e-05.
[Epoch 004/15] Tr Loss: 0.2774, Tr F1: 0.3564, Tr AUROC: 0.6215 | Val Loss: 0.2679, Val F1: 0.3724, Val A

0,1
epoch,▁▁▂▃▃▃▄▅▅▅▆▇▇▇█
train_acc,▂▁▃▅▆▇▇████████
train_auroc,▁▂▅▆▆▇▇▇███████
train_f1,▁▃▄▆▆▇▇▇▇█▇▇███
train_loss,█▃▂▁▁▁▁▁▁▁▁▁▁▁▁
val_acc,▂▁▅▆▇██████████
val_auroc,▁▄▆▇▇▇▇▇███████
val_f1,▁▃▄▅▇▆▆▇███████
val_loss,█▄▂▂▂▂▂▁▁▁▁▁▁▁▁

0,1
epoch,15.0
train_acc,0.61941
train_auroc,0.68416
train_f1,0.40501
train_loss,0.26079
val_acc,0.61524
val_auroc,0.71948
val_f1,0.4363
val_loss,0.25162



🏁 COMPLETED EXPERIMENT: sage_knn_layers4_lr1e-05


🚀 STARTING EXPERIMENT: gcn_grid_layers4_lr1e-05


WandB initialized for: gcn_grid_layers4_lr1e-05
Loading edge index: ../EEG_nml/edge_index/edge_index_grid.pt
Edge index 'grid' loaded. Shape: torch.Size([2, 48])
DataLoaders created. Train: 10394, Val: 2599
Model (gcn) & Optimizer instantiated.
--- 🏋️ Starting Training for: gcn_grid_layers4_lr1e-05 ---
[Epoch 001/15] Tr Loss: 0.4342, Tr F1: 0.0641, Tr AUROC: 0.5659 | Val Loss: 0.3711, Val F1: 0.1094, Val AUROC: 0.5652 | Time: 7.46s
  ⭐ New best validation F1: 0.1094 for gcn_grid_layers4_lr1e-05.
[Epoch 002/15] Tr Loss: 0.3533, Tr F1: 0.1507, Tr AUROC: 0.5883 | Val Loss: 0.3154, Val F1: 0.1712, Val AUROC: 0.5952 | Time: 7.48s
  ⭐ New best validation F1: 0.1712 for gcn_grid_layers4_lr1e-05.
[Epoch 003/15] Tr Loss: 0.3147, Tr F1: 0.2144, Tr AUROC: 0.6141 | Val Loss: 0.2900, Val F1: 0.3038, Val AUROC: 0.6391 | Time: 7.49s
  ⭐ New best validation F1: 0.3038 for gcn_grid_layers4_lr1e-05.
[Epoch 004/15] Tr Loss: 0.2932, Tr F1: 0.2825, Tr AUROC: 0.6325 | Val Loss: 0.2801, Val F1: 0.3621, Val A

0,1
epoch,▁▁▂▃▃▃▄▅▅▅▆▇▇▇█
train_acc,█▇▇▆▅▅▄▄▃▃▂▂▁▁▁
train_auroc,▁▃▄▆▆▆▇▇▇█▇████
train_f1,▁▃▄▆▇▇▇▇███████
train_loss,█▅▃▂▂▁▁▁▁▁▁▁▁▁▁
val_acc,██▇▆▆▅▅▄▄▄▃▄▂▁▃
val_auroc,▁▃▅▅▆▆▇▇▇▇▇█▇██
val_f1,▁▂▅▇▇▇▇▇▇▇█████
val_loss,█▅▃▂▂▂▁▁▁▁▁▁▁▁▁

0,1
epoch,15.0
train_acc,0.63484
train_auroc,0.66475
train_f1,0.39151
train_loss,0.26552
val_acc,0.63601
val_auroc,0.70102
val_f1,0.42597
val_loss,0.2567



🏁 COMPLETED EXPERIMENT: gcn_grid_layers4_lr1e-05


🚀 STARTING EXPERIMENT: gcn_knn_layers4_lr1e-05


WandB initialized for: gcn_knn_layers4_lr1e-05
Loading edge index: ../EEG_nml/edge_index/edge_index_knn.pt
Edge index 'knn' loaded. Shape: torch.Size([2, 152])
DataLoaders created. Train: 10394, Val: 2599
Model (gcn) & Optimizer instantiated.
--- 🏋️ Starting Training for: gcn_knn_layers4_lr1e-05 ---
[Epoch 001/15] Tr Loss: 0.4347, Tr F1: 0.0726, Tr AUROC: 0.5660 | Val Loss: 0.3726, Val F1: 0.0998, Val AUROC: 0.5680 | Time: 7.48s
  ⭐ New best validation F1: 0.0998 for gcn_knn_layers4_lr1e-05.
[Epoch 002/15] Tr Loss: 0.3526, Tr F1: 0.1610, Tr AUROC: 0.5919 | Val Loss: 0.3145, Val F1: 0.1712, Val AUROC: 0.6017 | Time: 7.49s
  ⭐ New best validation F1: 0.1712 for gcn_knn_layers4_lr1e-05.
[Epoch 003/15] Tr Loss: 0.3151, Tr F1: 0.2241, Tr AUROC: 0.6164 | Val Loss: 0.2893, Val F1: 0.3038, Val AUROC: 0.6436 | Time: 7.47s
  ⭐ New best validation F1: 0.3038 for gcn_knn_layers4_lr1e-05.
[Epoch 004/15] Tr Loss: 0.2933, Tr F1: 0.2922, Tr AUROC: 0.6355 | Val Loss: 0.2791, Val F1: 0.3621, Val AUROC: 

0,1
epoch,▁▁▂▃▃▃▄▅▅▅▆▇▇▇█
train_acc,█▇▇▆▅▅▄▄▃▃▂▂▁▁▁
train_auroc,▁▃▄▆▆▆▇▇▇█▇████
train_f1,▁▃▄▆▇▇▇▇███████
train_loss,█▅▃▂▂▁▁▁▁▁▁▁▁▁▁
val_acc,██▇▆▆▅▅▄▄▄▂▃▁▁▂
val_auroc,▁▃▅▅▆▆▇▇▇▇▇█▇▇█
val_f1,▁▂▅▆▇▇▇▇██▇█▇██
val_loss,█▅▃▂▂▂▁▁▁▁▁▁▁▁▁

0,1
epoch,15.0
train_acc,0.63281
train_auroc,0.66548
train_f1,0.39253
train_loss,0.26529
val_acc,0.64025
val_auroc,0.70721
val_f1,0.43368
val_loss,0.25549



🏁 COMPLETED EXPERIMENT: gcn_knn_layers4_lr1e-05


🎉 All experiment configurations processed! 🎉


--- 📊 Experiment Summary Table ---
               config_run_id gnn_type edge_variant  num_layers       lr  best_val_f1  final_train_loss  final_train_f1  final_val_loss                                          model_path
0   gat_grid_layers4_lr1e-05      gat         grid           4  0.00001     0.469727          0.241210        0.434728        0.233941   ../EEG_nml/best_model_gat_grid_layers4_lr1e-05.pt
1    gat_knn_layers4_lr1e-05      gat          knn           4  0.00001     0.490267          0.239774        0.438120        0.232758    ../EEG_nml/best_model_gat_knn_layers4_lr1e-05.pt
2  sage_grid_layers4_lr1e-05     sage         grid           4  0.00001     0.441430          0.261336        0.399209        0.252516  ../EEG_nml/best_model_sage_grid_layers4_lr1e-05.pt
3   sage_knn_layers4_lr1e-05     sage          knn           4  0.00001     0.439279          0.260789        0.405006 

In [None]:
# ╔══════════════════════════════════════════════════════════════════════╗
# ║ 9. FINALIZE WANDB RUN                                                 ║
# ╚══════════════════════════════════════════════════════════════════════╝
if wandb.run: # Check if a wandb run is active
    wandb.finish()
    print("WandB run finished.")
else:
    print("No active WandB run to finish.")

print("\n--- End of Notebook ---")