In [1]:
%load_ext autoreload
%autoreload 1
import time
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.loader import DataLoader as GeoDataLoader
from torch.utils.data import Subset, WeightedRandomSampler
# from torch.utils.data import DataLoader
from src.utils.seeder import seed_everything

# set seaborn theme
sns.set_theme()

# create useful constants
RANDOM_SEED = 42
IS_SCITAS = True # set to True if running on SCITAS cluster
LOCAL_DATA_ROOT = Path("./data")
DATA_ROOT = Path("/home/ogut/data") if IS_SCITAS else LOCAL_DATA_ROOT
CHECKPOINT_ROOT = Path("./.checkpoints")
SUBMISSION_ROOT = Path("./.submissions")

# create directories if they do not exist
CHECKPOINT_ROOT.mkdir(parents=True, exist_ok=True)
SUBMISSION_ROOT.mkdir(parents=True, exist_ok=True)

# set dataset root
seed_everything(RANDOM_SEED)

# setup torch device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [2]:
import subprocess

# execute feature extraction script
try:
    process = subprocess.Popen(["python3", "scripts/feature_extractor.py"])
    process.wait()
except KeyboardInterrupt:
    print("Process interrupted, terminating...")
    process.terminate()
    process.wait()
except Exception as e:
    print(f"Error occurred: {e}")
    if 'process' in locals():
        process.terminate()
        process.wait()

Traceback (most recent call last):
  File "/home/ldibello/NeuroGraphNet/scripts/feature_extractor.py", line 25, in <module>
    clips_tr = pd.read_parquet(DATA_ROOT / "train/segments.parquet")
  File "/home/ldibello/venvs/neuro/lib/python3.10/site-packages/pandas/io/parquet.py", line 667, in read_parquet
    return impl.read(
  File "/home/ldibello/venvs/neuro/lib/python3.10/site-packages/pandas/io/parquet.py", line 267, in read
    path_or_handle, handles, filesystem = _get_path_or_handle(
  File "/home/ldibello/venvs/neuro/lib/python3.10/site-packages/pandas/io/parquet.py", line 140, in _get_path_or_handle
    handles = get_handle(
  File "/home/ldibello/venvs/neuro/lib/python3.10/site-packages/pandas/io/common.py", line 882, in get_handle
    handle = open(handle, ioargs.mode)
FileNotFoundError: [Errno 2] No such file or directory: 'data/train/segments.parquet'


In [3]:
train_dir = DATA_ROOT / "train"
train_dir_metadata = train_dir / "segments.parquet"
train_dataset_dir = LOCAL_DATA_ROOT / "graph_dataset_train"
spatial_distance_file = LOCAL_DATA_ROOT / "distances_3d.csv"
extracted_features_dir = LOCAL_DATA_ROOT / "extracted_features"
embeddings_dir =  LOCAL_DATA_ROOT / "embeddings"

In [4]:
# Initialize wandb

In [5]:
%aimport
from src.data.dataset_graph import GraphEEGDataset
from src.utils.index import ensure_eeg_multiindex 

# ----------------- Prepare training data -----------------#
clips_tr = pd.read_parquet(train_dir_metadata)

# Ensure multiindex is correct
clips_tr = ensure_eeg_multiindex(clips_tr)

clips_tr = clips_tr[~clips_tr.label.isna()].reset_index()  # Filter NaN values out of clips_tr

# dataset settings
batch_size = 64
selected_features = []
embeddings = []
edge_strategy = "spatial"
correlation_threshold = 0.5
top_k = None
low_bandpass_frequency = 0.5
high_bandpass_frequency = 50

# additional settings
oversampling_power = 1.0

# -------------- Dataset definition -------------- #
dataset = GraphEEGDataset(
    root=train_dataset_dir,
    clips=clips_tr,
    signal_folder=train_dir,
    extracted_features_dir=extracted_features_dir,
    selected_features_train=selected_features,
    embeddings_dir=embeddings_dir,
    embeddings_train=embeddings,
    edge_strategy=edge_strategy,
    spatial_distance_file=(
        spatial_distance_file if edge_strategy == "spatial" else None
    ),
    top_k=top_k,
    correlation_threshold=correlation_threshold,
    force_reprocess=True,
    bandpass_frequencies=(
        low_bandpass_frequency,
        high_bandpass_frequency,
    ),
    segment_length=3000,
    apply_filtering=True,
    apply_rereferencing=False,
    apply_normalization=False,
    sampling_rate=250,
)

# Check the length of the dataset
print(f"Length of train_dataset: {len(dataset)}")
print(f' Eliminated IDs:{dataset.ids_to_eliminate}')

# Eliminate ids that did not have electrodes above correlation threshols
clips_tr = clips_tr[~clips_tr.index.isin(dataset.ids_to_eliminate)].reset_index(drop=True)

Modules to reload:


Modules to skip:



2025-06-05 11:30:19 - INFO - Initializing GraphEEGDataset...
2025-06-05 11:30:19 - INFO - Dataset parameters:
2025-06-05 11:30:19 - INFO -   - Root directory: data/graph_dataset_train
2025-06-05 11:30:19 - INFO -   - Edge strategy: spatial
2025-06-05 11:30:19 - INFO -   - Top-k neighbors: None
2025-06-05 11:30:19 - INFO -   - Correlation threshold: 0.5
2025-06-05 11:30:19 - INFO -   - Force reprocess: True
2025-06-05 11:30:19 - INFO -   - Bandpass frequencies: (0.5, 50)
2025-06-05 11:30:19 - INFO -   - Segment length: 3000
2025-06-05 11:30:19 - INFO -   - Apply filtering: True
2025-06-05 11:30:19 - INFO -   - Apply rereferencing: False
2025-06-05 11:30:19 - INFO -   - Apply normalization: False
2025-06-05 11:30:19 - INFO -   - Sampling rate: 250
2025-06-05 11:30:19 - INFO - Number of EEG channels: 19
2025-06-05 11:30:19 - INFO - Setting up signal filters...
2025-06-05 11:30:19 - INFO - Loading spatial distances from data/distances_3d.csv
2025-06-05 11:30:19 - INFO - Loading spatial dis

Length of train_dataset: 12993
 Eliminated IDs:[]


In [6]:
for batch in dataset:
    print(batch)
    break

Data(x=[19, 3000], edge_index=[2, 342], y=[1])


In [7]:
from sklearn.model_selection import GroupKFold
from src.utils.general_funcs import labels_stats

cv = GroupKFold(n_splits=5, shuffle=True, random_state=RANDOM_SEED)
groups = clips_tr.patient.values
y = clips_tr["label"].values
X = np.zeros(len(y))  # Dummy X (not used); just placeholder for the Kfold
train_ids, val_ids = next(cv.split(X, y, groups=groups))  # Just select one split
print('Labels before Kfold', flush=True)
print(y,flush=True)

# Print stats for class 0 and 1
labels_stats(y, train_ids, val_ids)

# 2. From dataset generate train and val datasets
train_dataset = Subset(dataset, train_ids)
val_dataset = Subset(dataset, val_ids)

Labels before Kfold
[1 1 1 ... 1 1 0]
[11:31:29] Train labels: 0 -> 8389, 1 -> 2093
[11:31:29] Val labels:   0 -> 2087, 1 -> 424


In [19]:
# 3. Compute sample weights for oversampling
train_labels = [clips_tr.iloc[i]["label"] for i in train_ids]
class_counts = np.bincount(train_labels)
class_weights = (1. / class_counts) ** oversampling_power # Higher weights for not frequent classes
sample_weights = [class_weights[label] for label in train_labels] # Assign weight to each sample based on its class

# 4. Define sampler
sampler = WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True) # Still train on N samples per epoch, but instead of sampling uniformly takes more from minority class

# Define dataloaders
BATCH_SIZE = 64
train_loader = GeoDataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=sampler, shuffle=False)
val_loader = GeoDataLoader(val_dataset, batch_size=BATCH_SIZE)
print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")

Train batches: 164
Val batches: 40


In [20]:
for batch in train_loader:
    print(batch)
    break

DataBatch(x=[1216, 3000], edge_index=[2, 21888], y=[64], batch=[1216], ptr=[65])


In [23]:
%aimport
from src.layers.eeggcn import EEGGCN
from src.layers.cnn_lstm_gnn import LSTM_GNN_Model
from src.utils.train import train_model

SAVE_PATH = CHECKPOINT_ROOT / "lstm_gnn_best_model.pt"
SUBMISSION_PATH = SUBMISSION_ROOT / "lstm_gnn_submission.csv"

config = {
    "learning_rate": 3e-4,
    "weight_decay": 1e-5,
    "patience": 5,
    "epochs": 100,
}

# build model with current parameters
model = LSTM_GNN_Model(
    # Parameters for the CNN_BiLSTM_Encoder (temporal encoder)
    cnn_dropout = 0.25,
    lstm_hidden_dim = 64,
    lstm_out_dim = 64,  # This will be the time_encoder_output_dim for the GCN
    lstm_dropout = 0.25,
    # Parameters for the EEGGCN (graph neural network)
    gcn_hidden_channels = 64,
    gcn_out_channels = 32,
    num_gcn_layers = 3,
    gcn_dropout = 0.5,
    num_classes = 1,  # For binary classification (seizure/non-seizure)
    num_channels = 19,  # Number of EEG channels
)

# if torch.cuda.device_count() > 1:
#     print(f"Using {torch.cuda.device_count()} GPUs for this combination.")
#     model = nn.DataParallel(model)
model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=config["learning_rate"], weight_decay=config["weight_decay"])

adjusted_pos_weight = torch.tensor([1.5], dtype=torch.float32).to(device)
print(f'pos_weight:{adjusted_pos_weight}')
loss = nn.BCEWithLogitsLoss(pos_weight=adjusted_pos_weight)

# train model
train_model(
    wandb_config=None,
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=loss,
    optimizer=optimizer,
    device=device,
    num_epochs=config["epochs"],
    patience=config["patience"],
    save_path=SAVE_PATH,
    use_gnn=True,
    use_oversampling=True
)

2025-06-05 11:39:42 - INFO - Starting training setup...
2025-06-05 11:39:42 - INFO - Model type: GNN
2025-06-05 11:39:42 - INFO - Device: cuda
2025-06-05 11:39:42 - INFO - Batch size: 64
2025-06-05 11:39:42 - INFO - Number of epochs: 100
2025-06-05 11:39:42 - INFO - Patience: 5
2025-06-05 11:39:42 - INFO - Monitor metric: val_f1
2025-06-05 11:39:42 - INFO - Initializing wandb...


Modules to reload:


Modules to skip:

pos_weight:tensor([1.5000], device='cuda:0')


2025-06-05 11:39:44 - INFO - Total training batches per epoch: 164
2025-06-05 11:39:44 - INFO - Starting training from epoch 1 to 100


🔗 Wandb initialized: fluent-wood-45


Epochs:   1%|▊                                                                                  | 1/100 [00:00<?, ?it/s]2025-06-05 11:39:44 - INFO - 
Epoch 1/100 - Training phase
2025-06-05 11:39:46 - INFO - Processing batch 1/164
2025-06-05 11:39:46 - INFO - Batch shapes - x: torch.Size([1216, 3000]), edge_index: torch.Size([2, 21888]), y: torch.Size([64])
2025-06-05 11:39:46 - ERROR - Error in forward pass for batch 0: The expanded size of the tensor (1216) must match the existing size (64) at non-singleton dimension 0.  Target sizes: [1216, 32].  Tensor sizes: [64, 1]
2025-06-05 11:39:46 - ERROR - Edge index shape: torch.Size([2, 21888])
2025-06-05 11:39:46 - ERROR - Edge index content: tensor([[   0,    0,    0,  ..., 1215, 1215, 1215],
        [   1,    2,    3,  ..., 1212, 1213, 1214]], device='cuda:0')
Epochs:   1%|▊                                                                                  | 1/100 [00:01<?, ?it/s]

Batch labels: tensor([1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1,
        1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1,
        0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], device='cuda:0') Type: <class 'torch.Tensor'>





RuntimeError: The expanded size of the tensor (1216) must match the existing size (64) at non-singleton dimension 0.  Target sizes: [1216, 32].  Tensor sizes: [64, 1]

In [17]:
# cuda clear cache
torch.cuda.empty_cache()

In [15]:
%aimport
from sklearn.metrics import f1_score, recall_score, precision_score
import wandb
from src.utils.general_funcs import confusion_matrix_plot

best_val_loss = float("inf")
best_val_f1 = 0
best_val_f1_epoch = 0
patience = 10
counter = 0
num_epochs = 100
print("Training started")

for epoch in range(1, num_epochs + 1):
    # ------- Training ------- #
    model.train()
    total_loss = 0
    for batch in train_loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        y_targets = batch.y.reshape(-1, 1)
        out = model(batch.x, batch.edge_index, batch.batch)
        loss = loss_fn(
            out, y_targets
        )
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        
    avg_train_loss = total_loss / len(train_loader)  # Average loss per batch

    # ------- Validation ------- #
    model.eval()
    val_loss = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in val_loader:
            batch = batch.to(device)
            print("Batch batch:", batch.batch)
            out = model(
                batch.x, batch.edge_index, batch.batch
            )
            loss = loss_fn(out, batch.y.reshape(-1, 1))
            val_loss += loss.item()
            probs = torch.sigmoid(out).squeeze()  # [batch_size, 1] -> [batch_size]
            preds = (probs > 0.5).int()
            all_preds.extend(preds.cpu().numpy().ravel())
            all_labels.extend(
                batch.y.int().cpu().numpy().ravel()
            )
            

    avg_val_loss = val_loss / len(val_loader)  # Average loss per batch
    #scheduler.step(avg_val_loss)
    val_f1 = f1_score(all_labels, all_preds, average="macro")

    all_labels = np.array(all_labels).astype(int)
    all_preds = np.array(all_preds).astype(int)

    # for name, param in model.named_parameters():
    #     if param.grad is not None:
    #         print(f"{name} grad mean: {param.grad.abs().mean()}")
    
    # Monitor progress
    print(f"Epoch {epoch} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | Val F1: {val_f1:.4f}")
    
    # Additional metrics

    # Confusion matrix
    confusion_matrix_plot(all_preds, all_labels)
    # Compute metrics per class (0 and 1)
    precision = precision_score(all_labels, all_preds, average=None)
    recall = recall_score(all_labels, all_preds, average=None)
    f1 = f1_score(all_labels, all_preds, average=None)

    # Print only for class 1
    print(f"Class 1 — Precision: {precision[1]:.2f}, Recall: {recall[1]:.2f}, F1: {f1[1]:.2f}")
    
    # W&B
    # wandb.log(
    #     {
    #         "epoch": epoch,
    #         "train_loss": avg_train_loss,
    #         "val_loss": avg_val_loss,
    #         "val_f1": val_f1,
    #         "val_f1_class_1":f1[1],
    #             "val_f1_class_0":f1[0]
    #     }
    # )
    print(f"Epoch {epoch} — Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, Val F1: {val_f1:.4f}", flush=True)
    # ------- Record best F1 score ------- #
    if val_f1 > best_val_f1:
        best_val_f1 = val_f1
        best_val_f1_epoch = epoch
        best_preds = all_preds.copy()
        best_labels = all_labels.copy()
        # Load best stats in wandb
        wandb.summary["best_f1_score"] = val_f1
        wandb.summary["f1_score_epoch"] = epoch
    # ------- Early Stopping ------- #
    if avg_val_loss < best_val_loss:
        # Save best statistics and model
        best_val_loss = avg_val_loss
        counter = 0
        best_state_dict = model.state_dict().copy()  # Save the best model state
    else:
        counter += 1
        if counter >= patience:
            print("Early stopping triggered.")
            break

print(f"Best validation F1: {best_val_f1:.4f} at epoch {best_val_f1_epoch}")

Modules to reload:


Modules to skip:

Training started
Batch batch: tensor([ 0,  0,  0,  ..., 63, 63, 63], device='cuda:0')
Batch labels: tensor([ 0,  0,  0,  ..., 63, 63, 63], device='cuda:0') Type: <class 'torch.Tensor'>


NameError: name 'loss_fn' is not defined