In [7]:
%load_ext autoreload
%autoreload 1
import time
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.loader import DataLoader as GeoDataLoader
from torch.utils.data import Subset, WeightedRandomSampler
# from torch.utils.data import DataLoader
from src.utils.seeder import seed_everything

# set seaborn theme
sns.set_theme()

# create useful constants
RANDOM_SEED = 42
IS_SCITAS = True # set to True if running on SCITAS cluster
LOCAL_DATA_ROOT = Path("./data")
DATA_ROOT = Path("/home/ogut/data") if IS_SCITAS else LOCAL_DATA_ROOT
CHECKPOINT_ROOT = Path("./.checkpoints")
SUBMISSION_ROOT = Path("./.submissions")

# create directories if they do not exist
CHECKPOINT_ROOT.mkdir(parents=True, exist_ok=True)
SUBMISSION_ROOT.mkdir(parents=True, exist_ok=True)

# set dataset root
seed_everything(RANDOM_SEED)

# setup torch device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Using device: cuda


In [3]:
import subprocess

# execute feature extraction script
try:
    process = subprocess.Popen(["python3", "scripts/feature_extractor.py"])
    process.wait()
except KeyboardInterrupt:
    print("Process interrupted, terminating...")
    process.terminate()
    process.wait()
except Exception as e:
    print(f"Error occurred: {e}")
    if 'process' in locals():
        process.terminate()
        process.wait()

Processing sessions: 100%|██████████| 177/177 [00:09<00:00, 19.18it/s]
Processing sessions: 100%|██████████| 50/50 [00:00<00:00, 306.35it/s]


Final dataset shapes:
X_train: (12993, 95)
y_train: (12993,)
X_test: (3614, 95)


In [8]:
train_dir = DATA_ROOT / "train"
train_dir_metadata = train_dir / "segments.parquet"
train_dataset_dir = LOCAL_DATA_ROOT / "graph_dataset_train"
spatial_distance_file = LOCAL_DATA_ROOT / "distances_3d.csv"
extracted_features_dir = LOCAL_DATA_ROOT / "extracted_features"
embeddings_dir =  LOCAL_DATA_ROOT / "embeddings"

In [3]:
# Initialize wandb

In [9]:
from src.data.dataset_graph import GraphEEGDataset
from src.utils.index import ensure_eeg_multiindex 

# ----------------- Prepare training data -----------------#
clips_tr = pd.read_parquet(train_dir_metadata)

# Ensure multiindex is correct
clips_tr = ensure_eeg_multiindex(clips_tr)

clips_tr = clips_tr[~clips_tr.label.isna()].reset_index()  # Filter NaN values out of clips_tr

# dataset settings
batch_size = 512
selected_features = []
embeddings = []
edge_strategy = "spatial"
correlation_threshold = 0.5
top_k = None
low_bandpass_frequency = 0.5
high_bandpass_frequency = 50

# additional settings
oversampling_power = 1.0

# -------------- Dataset definition -------------- #
dataset = GraphEEGDataset(
    root=train_dataset_dir,
    clips=clips_tr,
    signal_folder=train_dir,
    extracted_features_dir=extracted_features_dir,
    selected_features_train=selected_features,
    embeddings_dir=embeddings_dir,
    embeddings_train=embeddings,
    edge_strategy=edge_strategy,
    spatial_distance_file=(
        spatial_distance_file if edge_strategy == "spatial" else None
    ),
    top_k=top_k,
    correlation_threshold=correlation_threshold,
    force_reprocess=True,
    bandpass_frequencies=(
        low_bandpass_frequency,
        high_bandpass_frequency,
    ),
    segment_length=3000,
    apply_filtering=True,
    apply_rereferencing=False,
    apply_normalization=False,
    sampling_rate=250,
)

# Check the length of the dataset
print(f"Length of train_dataset: {len(dataset)}")
print(f' Eliminated IDs:{dataset.ids_to_eliminate}')

# Eliminate ids that did not have electrodes above correlation threshols
clips_tr = clips_tr[~clips_tr.index.isin(dataset.ids_to_eliminate)].reset_index(drop=True)

[09:39:06] Processing sessions
Length of train_dataset: 12993
 Eliminated IDs:[]


In [10]:
from sklearn.model_selection import GroupKFold
from src.utils.general_funcs import labels_stats

cv = GroupKFold(n_splits=5, shuffle=True, random_state=RANDOM_SEED)
groups = clips_tr.patient.values
y = clips_tr["label"].values
X = np.zeros(len(y))  # Dummy X (not used); just placeholder for the Kfold
train_ids, val_ids = next(cv.split(X, y, groups=groups))  # Just select one split
print('Labels before Kfold', flush=True)
print(y,flush=True)

# Print stats for class 0 and 1
labels_stats(y, train_ids, val_ids)

# 2. From dataset generate train and val datasets
train_dataset = Subset(dataset, train_ids)
val_dataset = Subset(dataset, val_ids)

Labels before Kfold
[1 1 1 ... 1 1 0]
[09:45:28] Train labels: 0 -> 8389, 1 -> 2093
[09:45:28] Val labels:   0 -> 2087, 1 -> 424


In [11]:
# 3. Compute sample weights for oversampling
train_labels = [clips_tr.iloc[i]["label"] for i in train_ids]
class_counts = np.bincount(train_labels)
class_weights = (1. / class_counts) ** oversampling_power # Higher weights for not frequent classes
sample_weights = [class_weights[label] for label in train_labels] # Assign weight to each sample based on its class

# 4. Define sampler
sampler = WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True) # Still train on N samples per epoch, but instead of sampling uniformly takes more from minority class

# Define dataloaders
train_loader = GeoDataLoader(train_dataset, batch_size=batch_size, sampler=sampler, shuffle=False)
val_loader = GeoDataLoader(val_dataset, batch_size=batch_size)
print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")

Train batches: 21
Val batches: 5


In [13]:
from src.layers.gat import EEGGAT

SAVE_PATH = CHECKPOINT_ROOT / "lstm_gnn_best_model.pt"
SUBMISSION_PATH = SUBMISSION_ROOT / "lstm_gnn_submission.csv"

config = {
    "learning_rate": 3e-4,
    "weight_decay": 1e-5,
    "patience": 5,
    "epochs": 100,
}

# build model with current parameters
model = EEGGAT(
    in_channels = 3000,
)

if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs for this combination.")
    model = nn.DataParallel(model)
model = model.to(device)

criterion = nn.BCEWithLogitsLoss()  # Assuming this remains constant
optimizer = optim.Adam(model.parameters(), lr=config["learning_rate"], weight_decay=config["weight_decay"])
# scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=config["patience"], factor=0.5)

adjusted_pos_weight = torch.tensor([1.5], dtype=torch.float32).to(device)
print(f'pos_weight:{adjusted_pos_weight}')
loss_fn = nn.BCEWithLogitsLoss(pos_weight=adjusted_pos_weight)

Using 2 GPUs for this combination.
pos_weight:tensor([1.5000], device='cuda:0')


In [17]:
%aimport
from src.utils.train import train_model

train_model(
    clips=clips_tr,
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=loss_fn,
    optimizer=optimizer,
    device=device,
    num_epochs=config["epochs"],
    patience=config["patience"],
    save_path=SAVE_PATH,
    use_gnn=True,
    use_oversampling=True
)

Modules to reload:


Modules to skip:



🔗 Wandb initialized: major-water-3
💪 Starting training from epoch 1 to 100...


Epochs:   1%|▊                                                                                  | 1/100 [00:11<?, ?it/s]


IndexError: Caught IndexError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/home/ldibello/venvs/neuro/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py", line 96, in _worker
    output = module(*input, **kwargs)
  File "/home/ldibello/venvs/neuro/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ldibello/venvs/neuro/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ldibello/NeuroGraphNet/src/layers/gat.py", line 20, in forward
    x = self.gat1(x, edge_index)
  File "/home/ldibello/venvs/neuro/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ldibello/venvs/neuro/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ldibello/venvs/neuro/lib/python3.10/site-packages/torch_geometric/nn/conv/gat_conv.py", line 347, in forward
    edge_index, edge_attr = remove_self_loops(
  File "/home/ldibello/venvs/neuro/lib/python3.10/site-packages/torch_geometric/utils/loop.py", line 112, in remove_self_loops
    mask = edge_index[0] != edge_index[1]
IndexError: index 1 is out of bounds for dimension 0 with size 1


In [14]:
from sklearn.metrics import f1_score, recall_score, precision_score
import wandb

from src.utils.general_funcs import confusion_matrix_plot

best_val_loss = float("inf")
best_val_f1 = 0
best_val_f1_epoch = 0
patience = 10
counter = 0
num_epochs = 100
print("Training started")

for epoch in range(1, num_epochs + 1):
    # ------- Training ------- #
    model.train()
    total_loss = 0
    # for batch in tqdm(train_loader,desc=f"Epoch {epoch} — Training" ):
    for batch in train_loader:
        batch = batch.to(device)  # Move batch to GPU
        optimizer.zero_grad()
        out = model(batch.x, batch.edge_index, batch.batch)
        loss = loss_fn(
            out, batch.y.reshape(-1, 1)
        )  # y: [batch_size] ->[batch_size, 1]
        loss.backward()
        # for name, param in model.named_parameters():
        #     if param.grad is not None:
        #         print(f"{name}: grad norm = {param.grad.norm().item():.2e}")
        optimizer.step()
        total_loss += loss.item()
        
    avg_train_loss = total_loss / len(train_loader)  # Average loss per batch

    # ------- Validation ------- #
    model.eval()
    val_loss = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in val_loader:
            batch = batch.to(device)  # Move batch to GPU
            out = model(
                batch.x, batch.edge_index, batch.batch
            )  # batch.batch: [num_nodes_batch] = 19*batch_size -> tells the model which graph each node belongs to
            loss = loss_fn(out, batch.y.reshape(-1, 1))
            val_loss += loss.item()
            

            probs = torch.sigmoid(out).squeeze()  # [batch_size, 1] -> [batch_size]
            preds = (probs > 0.5).int()
            all_preds.extend(preds.cpu().numpy().ravel())
            all_labels.extend(
                batch.y.int().cpu().numpy().ravel()
            )  # Labels: stored as float in dataset
            # print(f"Val logits stats — min: {out.min().item():.4f}, max: {out.max().item():.4f}, mean: {out.mean().item():.4f}, std: {out.std().item():.4f}")
            # print(f"Predictions:{preds.cpu().numpy()}")
            # print(f"Sigmoid outputs: { torch.sigmoid(out).detach().cpu().numpy()}")
            # print(f"Labels:{batch.y}")
            

    avg_val_loss = val_loss / len(val_loader)  # Average loss per batch
    #scheduler.step(avg_val_loss)
    val_f1 = f1_score(all_labels, all_preds, average="macro")

    all_labels = np.array(all_labels).astype(int)
    all_preds = np.array(all_preds).astype(int)

    # for name, param in model.named_parameters():
    #     if param.grad is not None:
    #         print(f"{name} grad mean: {param.grad.abs().mean()}")
    
    # Monitor progress
    print(f"Epoch {epoch} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | Val F1: {val_f1:.4f}")
    
    # Additional metrics

    # Confusion matrix
    confusion_matrix_plot(all_preds, all_labels)
    # Compute metrics per class (0 and 1)
    precision = precision_score(all_labels, all_preds, average=None)
    recall = recall_score(all_labels, all_preds, average=None)
    f1 = f1_score(all_labels, all_preds, average=None)

    # Print only for class 1
    print(f"Class 1 — Precision: {precision[1]:.2f}, Recall: {recall[1]:.2f}, F1: {f1[1]:.2f}")
    

    # W&B
    # wandb.log(
    #     {
    #         "epoch": epoch,
    #         "train_loss": avg_train_loss,
    #         "val_loss": avg_val_loss,
    #         "val_f1": val_f1,
    #         "val_f1_class_1":f1[1],
    #             "val_f1_class_0":f1[0]
    #     }
    # )
    print(f"Epoch {epoch} — Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, Val F1: {val_f1:.4f}", flush=True)
    # ------- Record best F1 score ------- #
    if val_f1 > best_val_f1:
        best_val_f1 = val_f1
        best_val_f1_epoch = epoch
        best_preds = all_preds.copy()
        best_labels = all_labels.copy()
        # Load best stats in wandb
        wandb.summary["best_f1_score"] = val_f1
        wandb.summary["f1_score_epoch"] = epoch
    # ------- Early Stopping ------- #
    if avg_val_loss < best_val_loss:
        # Save best statistics and model
        best_val_loss = avg_val_loss
        counter = 0
        best_state_dict = model.state_dict().copy()  # Save the best model state
    else:
        counter += 1
        if counter >= patience:
            print("Early stopping triggered.")
            break

print(f"Best validation F1: {best_val_f1:.4f} at epoch {best_val_f1_epoch}")

Training started


  return F.linear(x, self.weight, self.bias)


IndexError: Caught IndexError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/home/ldibello/venvs/neuro/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py", line 96, in _worker
    output = module(*input, **kwargs)
  File "/home/ldibello/venvs/neuro/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ldibello/venvs/neuro/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ldibello/NeuroGraphNet/src/layers/gat.py", line 20, in forward
    x = self.gat1(x, edge_index)
  File "/home/ldibello/venvs/neuro/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ldibello/venvs/neuro/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ldibello/venvs/neuro/lib/python3.10/site-packages/torch_geometric/nn/conv/gat_conv.py", line 347, in forward
    edge_index, edge_attr = remove_self_loops(
  File "/home/ldibello/venvs/neuro/lib/python3.10/site-packages/torch_geometric/utils/loop.py", line 112, in remove_self_loops
    mask = edge_index[0] != edge_index[1]
IndexError: index 1 is out of bounds for dimension 0 with size 1


In [None]:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay