In [1]:
%load_ext autoreload
%autoreload 1
import time
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.loader import DataLoader as GeoDataLoader
from torch.utils.data import Subset, WeightedRandomSampler
# from torch.utils.data import DataLoader
from src.utils.seeder import seed_everything

# set seaborn theme
sns.set_theme()

# create useful constants
RANDOM_SEED = 42
IS_SCITAS = True # set to True if running on SCITAS cluster
LOCAL_DATA_ROOT = Path("./data")
DATA_ROOT = Path("/home/ogut/data") if IS_SCITAS else LOCAL_DATA_ROOT
CHECKPOINT_ROOT = Path("./.checkpoints")
SUBMISSION_ROOT = Path("./.submissions")

# create directories if they do not exist
CHECKPOINT_ROOT.mkdir(parents=True, exist_ok=True)
SUBMISSION_ROOT.mkdir(parents=True, exist_ok=True)

# set dataset root
seed_everything(RANDOM_SEED)

# setup torch device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [2]:
# import subprocess

# # execute feature extraction script
# try:
#     process = subprocess.Popen(["python3", "scripts/feature_extractor.py"])
#     process.wait()
# except KeyboardInterrupt:
#     print("Process interrupted, terminating...")
#     process.terminate()
#     process.wait()
# except Exception as e:
#     print(f"Error occurred: {e}")
#     if 'process' in locals():
#         process.terminate()
#         process.wait()

In [3]:
# spacial distance matrix between sensors
spatial_distance_file = LOCAL_DATA_ROOT / "distances_3d.csv"

# training data
train_dir = DATA_ROOT / "train"
train_dir_metadata = train_dir / "segments.parquet"
train_dataset_dir = LOCAL_DATA_ROOT / "graph_dataset_train"

# test data
test_dir = DATA_ROOT / "test"
test_dir_metadata = test_dir / "segments.parquet"
test_dataset_dir = LOCAL_DATA_ROOT / "graph_dataset_test"

# additional features
extracted_features_dir = LOCAL_DATA_ROOT / "extracted_features"
embeddings_dir =  LOCAL_DATA_ROOT / "embeddings"

In [4]:
from src.utils.index import ensure_eeg_multiindex 

# Load clips from datasets
clips_tr = pd.read_parquet(train_dir_metadata)
clips_tr = ensure_eeg_multiindex(clips_tr)
clips_tr = clips_tr[~clips_tr.label.isna()].reset_index()  # Filter NaN values out of clips_tr

# Load clips from datasets
clips_te = pd.read_parquet(test_dir_metadata)
clips_te = ensure_eeg_multiindex(clips_te)

# Create unique IDs by converting all index components to strings and store in new column
clips_te['id'] = clips_te.index.map(lambda x: '_'.join(str(i) for i in x))
assert clips_te.id.nunique() == len(clips_te), "There are duplicate IDs"
print(clips_te["id"].head())

# sort in order to maintain the same submission order
clips_te = clips_te.sort_values(by="id")

patient   session    segment
pqejgcvm  s001_t000  0          pqejgcvm_s001_t000_0
                     1          pqejgcvm_s001_t000_1
                     2          pqejgcvm_s001_t000_2
                     3          pqejgcvm_s001_t000_3
                     4          pqejgcvm_s001_t000_4
Name: id, dtype: object


In [12]:
%aimport
from src.data.dataset_graph import GraphEEGDataset

# dataset settings
batch_size = 64
selected_features = []
embeddings = []
edge_strategy = "spatial"
correlation_threshold = 0.5
top_k = None
low_bandpass_frequency = 0.5
high_bandpass_frequency = 50

# additional settings
oversampling_power = 1.0

# load training dataset
dataset_tr = GraphEEGDataset(
    root=train_dataset_dir,
    clips=clips_tr,
    signal_folder=train_dir,
    extracted_features_dir=extracted_features_dir,
    selected_features_train=selected_features,
    embeddings_dir=embeddings_dir,
    embeddings_train=embeddings,
    edge_strategy=edge_strategy,
    spatial_distance_file=(
        spatial_distance_file if edge_strategy == "spatial" else None
    ),
    top_k=top_k,
    correlation_threshold=correlation_threshold,
    force_reprocess=False,
    bandpass_frequencies=(
        low_bandpass_frequency,
        high_bandpass_frequency,
    ),
    segment_length=3000,
    apply_filtering=True,
    apply_rereferencing=False,
    apply_normalization=False,
    sampling_rate=250,
)

# Check the length of the dataset
print(f"Length of train_dataset: {len(dataset_tr)}")
print(f' Eliminated IDs: {dataset_tr.ids_to_eliminate}')

# Eliminate ids that did not have electrodes above correlation threshols
clips_tr = clips_tr[~clips_tr.index.isin(dataset_tr.ids_to_eliminate)].reset_index(drop=True)

2025-06-05 23:54:36 - INFO - Initializing GraphEEGDataset...
2025-06-05 23:54:36 - INFO - Dataset parameters:
2025-06-05 23:54:36 - INFO -   - Root directory: data/graph_dataset_train
2025-06-05 23:54:36 - INFO -   - Edge strategy: spatial
2025-06-05 23:54:36 - INFO -   - Top-k neighbors: None
2025-06-05 23:54:36 - INFO -   - Correlation threshold: 0.5
2025-06-05 23:54:36 - INFO -   - Force reprocess: False
2025-06-05 23:54:36 - INFO -   - Bandpass frequencies: (0.5, 50)
2025-06-05 23:54:36 - INFO -   - Segment length: 3000
2025-06-05 23:54:36 - INFO -   - Apply filtering: True
2025-06-05 23:54:36 - INFO -   - Apply rereferencing: False
2025-06-05 23:54:36 - INFO -   - Apply normalization: False
2025-06-05 23:54:36 - INFO -   - Sampling rate: 250
2025-06-05 23:54:36 - INFO -   - Test mode: False
2025-06-05 23:54:36 - INFO - Number of EEG channels: 19
2025-06-05 23:54:36 - INFO - Setting up signal filters...
2025-06-05 23:54:36 - INFO - Loading spatial distances from data/distances_3d.c

Modules to reload:


Modules to skip:

Length of train_dataset: 12993
 Eliminated IDs: []


In [11]:
%aimport
from src.data.dataset_graph import GraphEEGDataset

# load test dataset
te_dataset = GraphEEGDataset(
    root=test_dataset_dir,
    clips=clips_te,
    signal_folder=test_dir,
    extracted_features_dir=extracted_features_dir,
    selected_features_train=False,
    embeddings_dir=embeddings_dir,
    embeddings_train=False,
    edge_strategy="spatial",
    spatial_distance_file=spatial_distance_file,
    top_k=None,
    correlation_threshold=0.5,
    force_reprocess=False,
    bandpass_frequencies=(
        low_bandpass_frequency,
        high_bandpass_frequency,
    ),
    segment_length=3000,
    apply_filtering=True,
    apply_rereferencing=False,
    apply_normalization=False,
    sampling_rate=250,
    is_test = True,
)

# Check the length of the dataset
print(f"Length of test_dataset: {len(te_dataset)}")
print(f' Eliminated IDs:{te_dataset.ids_to_eliminate}')

# Eliminate ids that did not have electrodes above correlation threshols
clips_te = clips_te[~clips_te.index.isin(te_dataset.ids_to_eliminate)].reset_index(drop=True)

2025-06-05 23:54:32 - INFO - Initializing GraphEEGDataset...
2025-06-05 23:54:32 - INFO - Dataset parameters:
2025-06-05 23:54:32 - INFO -   - Root directory: data/graph_dataset_test
2025-06-05 23:54:32 - INFO -   - Edge strategy: spatial
2025-06-05 23:54:32 - INFO -   - Top-k neighbors: None
2025-06-05 23:54:32 - INFO -   - Correlation threshold: 0.5
2025-06-05 23:54:32 - INFO -   - Force reprocess: False
2025-06-05 23:54:32 - INFO -   - Bandpass frequencies: (0.5, 50)
2025-06-05 23:54:32 - INFO -   - Segment length: 3000
2025-06-05 23:54:32 - INFO -   - Apply filtering: True
2025-06-05 23:54:32 - INFO -   - Apply rereferencing: False
2025-06-05 23:54:32 - INFO -   - Apply normalization: False
2025-06-05 23:54:32 - INFO -   - Sampling rate: 250
2025-06-05 23:54:32 - INFO -   - Test mode: True
2025-06-05 23:54:32 - INFO - Number of EEG channels: 19
2025-06-05 23:54:32 - INFO - Setting up signal filters...
2025-06-05 23:54:32 - INFO - Loading spatial distances from data/distances_3d.csv

Modules to reload:


Modules to skip:

Length of test_dataset: 3614
 Eliminated IDs:[]


In [9]:
for batch in te_dataset:
    print(batch)
    break

Data(x=[19, 3000], edge_index=[2, 342], id='pqejgcvm_s001_t000_0')


In [13]:
from torch.utils.data import random_split
from src.utils.general_funcs import labels_stats

# Get total samples and split sizes
total_samples = len(dataset_tr)
train_size = int(0.8 * total_samples)
val_size = total_samples - train_size

# Get labels for initial split
y = clips_tr["label"].values

# Create initial train/val split
train_indices, val_indices = random_split(
    range(total_samples), 
    [train_size, val_size],
    generator=torch.Generator().manual_seed(RANDOM_SEED)
)

# Convert to numpy arrays for easier indexing
train_indices = np.array(train_indices)
val_indices = np.array(val_indices)

print('Labels before split', flush=True)
print(y, flush=True)

# Print stats for class 0 and 1
labels_stats(y, train_indices, val_indices)

# Create train and val datasets
train_dataset = Subset(dataset_tr, train_indices)
val_dataset = Subset(dataset_tr, val_indices)

# 3. Compute sample weights for oversampling
train_labels = [clips_tr.iloc[i]["label"] for i in train_indices]
class_counts = np.bincount(train_labels)
class_weights = (1. / class_counts) ** oversampling_power  # Higher weights for not frequent classes
sample_weights = [class_weights[label] for label in train_labels]  # Assign weight to each sample based on its class

# 4. Define sampler
sampler = WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True)

# Define dataloaders
BATCH_SIZE = 64
train_loader = GeoDataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=sampler, shuffle=False)
val_loader = GeoDataLoader(val_dataset, batch_size=BATCH_SIZE)
te_loader = GeoDataLoader(te_dataset, batch_size=BATCH_SIZE)
print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Test batches: {len(te_loader)}")

Labels before split
[1 1 1 ... 1 1 0]
[23:54:42] Train labels: 0 -> 8375, 1 -> 2019
[23:54:42] Val labels:   0 -> 2101, 1 -> 498
Train batches: 163
Val batches: 41
Test batches: 57


In [14]:
for batch in train_loader:
    print(batch)
    break

DataBatch(x=[1216, 3000], edge_index=[2, 21888], y=[64], batch=[1216], ptr=[65])


In [17]:
%aimport
from src.layers.cnn_lstm_gnn import LSTM_GNN_Model
from src.utils.train import train_model

SAVE_PATH = CHECKPOINT_ROOT / "lstm_gnn_best_model.pt"
SUBMISSION_PATH = SUBMISSION_ROOT / "lstm_gnn_submission.csv"

config = {
    "learning_rate": 1e-3,
    "weight_decay": 1e-5,
    "patience": 15,
    "epochs": 100,
}

# NOTE: model with default parameters
model_older = LSTM_GNN_Model(
    # Parameters for the CNN_BiLSTM_Encoder (temporal encoder)
    cnn_dropout = 0.25,
    lstm_hidden_dim = 64,
    lstm_out_dim = 64,  # This will be the time_encoder_output_dim for the GCN
    lstm_dropout = 0.25,
    # Parameters for the EEGGCN (graph neural network)
    gcn_hidden_channels = 64,
    gcn_out_channels = 32,
    num_gcn_layers = 3,
    gcn_dropout = 0.5,
    num_classes = 1,  # For binary classification (seizure/non-seizure)
    num_channels = 19,  # Number of EEG channels
)

# build model with current parameters
# Epochs:   6%| | 6/100 [13:36<4:17:46, 164.54s/it, train_loss=0.6508, val_loss=0.6166, best_val_f1=0.2840, lr=3.00e-04, b2025-06-05 13:20:42 - INFO - 
# Epochs:   7%| | 7/100 [16:24<4:17:02, 165.83s/it, train_loss=0.6446, val_loss=0.6258, best_val_f1=0.2840, lr=3.00e-04, b2025-06-05 13:23:30 - INFO - 
model_improved = LSTM_GNN_Model(
    # Parameters for the CNN_BiLSTM_Encoder (temporal encoder)
    cnn_dropout = 0.25,
    lstm_hidden_dim = 96, # 96 original
    lstm_out_dim = 96,  # This will be the time_encoder_output_dim for the GCN
    lstm_dropout = 0.25,
    # Parameters for the EEGGCN (graph neural network)
    gcn_hidden_channels = 96,
    gcn_out_channels = 64,
    num_gcn_layers = 3,
    gcn_dropout = 0.5,
    num_classes = 1,  # For binary classification (seizure/non-seizure)
    num_channels = 19,  # Number of EEG channels
)


# build model with current parameters
# Epochs:   2%| | 2/100 [03:21<5:28:40, 201.23s/it, train_loss=0.7174, val_loss=0.5724, best_val_f1=0.0848, lr=3.00e-04, b2025-06-05 13:32:42 - INFO - 
# Epochs:   3%| | 3/100 [06:28<5:12:31, 193.31s/it, train_loss=0.6813, val_loss=0.5767, best_val_f1=0.1734, lr=3.00e-04, b2025-06-05 13:35:50 - INFO - 
# Epochs:   4%| | 4/100 [09:26<4:57:36, 186.00s/it, train_loss=0.6638, val_loss=0.6072, best_val_f1=0.1734, lr=3.00e-04, b2025-06-05 13:38:47 - INFO - 
# Epochs:   5%| | 5/100 [12:26<4:50:40, 183.58s/it, train_loss=0.6704, val_loss=0.5754, best_val_f1=0.2178, lr=3.00e-04, b2025-06-05 13:41:47 - INFO - 
# ...
# Epochs:   7%| | 7/100 [19:21<5:10:58, 200.62s/it, train_loss=0.6333, val_loss=0.5949, best_val_f1=0.3921, lr=3.00e-04, b2025-06-05 13:48:43 - INFO - 
# Epochs:   8%| | 8/100 [23:11<5:22:14, 210.15s/it, train_loss=0.6261, val_loss=0.5993, best_val_f1=0.3921, lr=3.00e-04, b2025-06-05 13:52:33 - INFO - 
# Epochs:   9%| | 9/100 [26:35<5:15:33, 208.06s/it, train_loss=0.6043, val_loss=0.5743, best_val_f1=0.3921, lr=3.00e-04, b2025-06-05 13:55:56 - INFO - 
# ...
# Epochs:  12%| | 12/100 [36:17<4:49:57, 197.70s/it, train_loss=0.5935, val_loss=0.5691, best_val_f1=0.5043, lr=3.00e-04, 2025-06-05 14:05:38 - INFO - 
# Epochs:  13%|▏| 13/100 [39:27<4:43:05, 195.24s/it, train_loss=0.5701, val_loss=0.5855, best_val_f1=0.5380, lr=3.00e-04, 2025-06-05 14:08:48 - INFO - 
# Epochs:  14%|▏| 14/100 [42:32<4:35:35, 192.27s/it, train_loss=0.5329, val_loss=0.6952, best_val_f1=0.5380, lr=3.00e-04, 2025-06-05 14:11:54 - INFO -
# Epochs:  18%|▏| 18/100 [55:14<4:22:12, 191.86s/it, train_loss=0.5042, val_loss=0.5616, best_val_f1=0.5623, lr=3.00e-04, 2025-06-05 14:24:36 - INFO -
# Epochs:  19%|▏| 19/100 [58:26<4:19:03, 191.89s/it, train_loss=0.5092, val_loss=0.4702, best_val_f1=0.6405, lr=3.00e-04, 2025-06-05 14:27:48 - INFO - 
# Epochs:  20%|▏| 20/100 [04:25<5:53:37, 265.22s/it, train_loss=0.5077, val_loss=0.4850, best_val_f1=0.6405, lr=3.00e-04, 2025-06-05 15:35:20 - INFO - 
# Epochs:  21%|▏| 21/100 [07:55<5:06:16, 232.62s/it, train_loss=0.4657, val_loss=0.4666, best_val_f1=0.6405, lr=3.00e-04, 2025-06-05 15:38:49 - INFO - 
# ...
# Epochs:  23%|▏| 23/100 [16:40<5:02:39, 235.83s/it, train_loss=0.4786, val_loss=0.4441, best_val_f1=0.6405, lr=3.00e-04, 2025-06-05 15:24:57 - INFO -
# Epochs:  24%|▏| 24/100 [18:00<4:20:53, 205.96s/it, train_loss=0.4688, val_loss=0.5586, best_val_f1=0.6405, lr=3.00e-04, 2025-06-05 15:48:55 - INFO - 
# Epochs:  25%|▎| 25/100 [21:08<4:09:36, 199.69s/it, train_loss=0.4521, val_loss=0.4014, best_val_f1=0.6484, lr=3.00e-04, 2025-06-05 15:52:02 - INFO - 
# Epochs:  26%|▎| 26/100 [24:09<3:58:50, 193.65s/it, train_loss=0.4378, val_loss=0.3937, best_val_f1=0.6800, lr=3.00e-04, 2025-06-05 15:55:04 - INFO - 
# ---- FROM HERE IT DOES NOT LEARN ANYTHING!!!
# ....
#
# Epochs:  31%|▎| 31/100 [39:30<3:35:33, 187.44s/it, train_loss=0.4061, val_loss=0.4341, best_val_f1=0.6800, lr=3.00e-04, 2025-06-05 16:10:25 - INFO - 
#...
# (other run)
# Epochs:  32%|▎| 32/100 [18:51<3:20:29, 176.90s/it, train_loss=0.3984, val_loss=0.4484, best_val_f1=0.6800, lr=3.00e-04, 2025-06-05 18:22:51 - INFO - 
# ...
# Epochs:  35%|▎| 35/100 [52:42<3:31:23, 195.13s/it, train_loss=0.3835, val_loss=0.4302, best_val_f1=0.6800, lr=3.00e-04, 2025-06-05 16:23:36 - INFO - 
# Epochs:  37%|▎| 37/100 [32:44<2:56:23, 168.00s/it, train_loss=0.3619, val_loss=0.4276, best_val_f1=0.6800, lr=3.00e-04, 2025-06-05 18:36:45 - INFO - 
# NOTE: BEST MODEL SO FAR!!!
# SAVE_PATH = CHECKPOINT_ROOT / "lstm_gnn_best_model_epochs_.pt"
best_model = LSTM_GNN_Model(
    # Parameters for the CNN_BiLSTM_Encoder (temporal encoder)
    cnn_dropout = 0.25,
    lstm_hidden_dim = 128, # 96 original
    lstm_out_dim = 128,  # This will be the time_encoder_output_dim for the GCN
    lstm_dropout = 0.25,
    # Parameters for the EEGGCN (graph neural network)
    gcn_hidden_channels = 128,
    gcn_out_channels = 128,
    num_gcn_layers = 3,
    gcn_dropout = 0.5,
    num_classes = 1,  # For binary classification (seizure/non-seizure)
    num_channels = 19,  # Number of EEG channels
)

# Epochs:  29%|▎| 29/100 [2:32:36<6:26:13, 326.39s/it, train_loss=0.3669, val_loss=0.4361, best_val_f1=0.6758, lr=3.00e-042025-06-05 21:22:07 - INFO - 
# Epochs:  30%|▎| 30/100 [2:38:01<6:20:19, 326.00s/it, train_loss=0.3709, val_loss=0.5588, best_val_f1=0.6758, lr=3.00e-042025-06-05 21:27:32 - INFO - 
# NOTE: Not performing well....
# SAVE_PATH = CHECKPOINT_ROOT / "lstm_gnn_best_model_even_bigger.pt"
new_model = LSTM_GNN_Model(
    # Parameters for the CNN_BiLSTM_Encoder (temporal encoder)
    cnn_dropout = 0.25,
    lstm_hidden_dim = 128, # 96 original
    lstm_out_dim = 128,  # This will be the time_encoder_output_dim for the GCN
    lstm_dropout = 0.25,
    # Parameters for the EEGGCN (graph neural network)
    gcn_hidden_channels = 128,
    gcn_out_channels = 64,
    num_gcn_layers = 3,
    gcn_dropout = 0.5,
    num_classes = 1,  # For binary classification (seizure/non-seizure)
    num_channels = 19,  # Number of EEG channels
)

# Same setup as best model, with bigger GCN output channels to check if it can learn something
SAVE_PATH = CHECKPOINT_ROOT / "lstm_gnn_best_model_bigger_gcn_output_channels.pt"
# Epochs:  13%|▏| 13/100 [33:58<4:06:35, 170.07s/it, train_loss=0.5545, val_loss=0.5253, best_val_f1=0.5505, lr=3.00e-04, 2025-06-05 22:20:18 - INFO - 
# Epochs:  28%|▎| 28/100 [1:16:29<3:23:10, 169.32s/it, train_loss=0.4188, val_loss=0.4747, best_val_f1=0.5817, lr=3.00e-042025-06-05 23:02:49 - INFO - 
# NOTE: THIS MODEL IS NOT PERFORMING WELL + IT IS SLOW TO TRAIN
model_improved_bigger = LSTM_GNN_Model(
    # Parameters for the CNN_BiLSTM_Encoder (temporal encoder)
    cnn_dropout = 0.25,
    lstm_hidden_dim = 128, # original 128
    lstm_out_dim = 128,  # original 128
    lstm_dropout = 0.25,
    # Parameters for the EEGGCN (graph neural network)
    gcn_hidden_channels = 192, # original 128
    gcn_out_channels = 192, # original 64
    num_gcn_layers = 3,
    gcn_dropout = 0.5,
    num_classes = 1,
    num_channels = 19,
)

SAVE_PATH = CHECKPOINT_ROOT / "lstm_gnn_best_model_35_epochs.pt"
# Epochs:   2%| | 2/100 [04:54<8:01:46, 294.97s/it, train_loss=0.5635, val_loss=0.6869, best_val_f1=0.5291, lr=1.00e-03, b2025-06-06 00:08:26 - INFO - 
# Epochs:  26%|▎| 26/100 [1:16:00<3:35:35, 174.81s/it, train_loss=0.2740, val_loss=0.4044, best_val_f1=0.7278, lr=6.25e-052025-06-06 01:19:31 - INFO - 
new_best_model_test = LSTM_GNN_Model(
    # Parameters for the CNN_BiLSTM_Encoder (temporal encoder)
    cnn_dropout = 0.25,
    lstm_hidden_dim = 128, # 96 original
    lstm_out_dim = 128,  # This will be the time_encoder_output_dim for the GCN
    lstm_dropout = 0.25,
    # Parameters for the EEGGCN (graph neural network)
    gcn_hidden_channels = 128,
    gcn_out_channels = 64,
    num_gcn_layers = 4,
    gcn_dropout = 0.5,
    num_classes = 1,  # For binary classification (seizure/non-seizure)
    num_channels = 19,  # Number of EEG channels
)

# select model to use
model = new_best_model_test

model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=config["learning_rate"], weight_decay=config["weight_decay"])
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)

adjusted_pos_weight = torch.tensor([1.5], dtype=torch.float32).to(device)
print(f'pos_weight:{adjusted_pos_weight}')
loss = nn.BCEWithLogitsLoss(pos_weight=adjusted_pos_weight)

# /home/ldibello/venvs/neuro/lib/python3.10/site-packages/torch/nn/modules/rnn.py:123: UserWarning: dropout option adds dropout after all but last recurrent layer, so non-zero dropout expects num_layers greater than 1, but got dropout=0.25 and num_layers=1

# train model
train_history, val_history = train_model(
    wandb_config=None,
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=loss,
    scheduler=scheduler,
    optimizer=optimizer,
    device=device,
    num_epochs=config["epochs"],
    patience=config["patience"],
    save_path=SAVE_PATH,
    use_gnn=True,
    # hidden attribute
    try_load_checkpoint=True,
    log_wandb=False
)

2025-06-06 00:03:31 - INFO - Starting training setup...
2025-06-06 00:03:31 - INFO - Model type: GNN
2025-06-06 00:03:31 - INFO - Device: cuda
2025-06-06 00:03:31 - INFO - Batch size: 64
2025-06-06 00:03:31 - INFO - Number of epochs: 100
2025-06-06 00:03:31 - INFO - Patience: 15
2025-06-06 00:03:31 - INFO - Monitor metric: val_f1
2025-06-06 00:03:31 - INFO - Total training batches per epoch: 163
2025-06-06 00:03:31 - INFO - Starting training from epoch 1 to 100


Modules to reload:


Modules to skip:

pos_weight:tensor([1.5000], device='cuda:0')
🚀 Attempting to load checkpoint from .checkpoints/lstm_gnn_best_model_35_epochs.pt...
   - Loading checkpoint from: .checkpoints/lstm_gnn_best_model_35_epochs.pt
   - Detected full checkpoint dictionary.
 ⚠️ Could not load checkpoint: Error(s) in loading state_dict for LSTM_GNN_Model:
	Missing key(s) in state_dict: "gcn.conv_layers.3.bias", "gcn.conv_layers.3.lin.weight", "gcn.bn_layers.3.weight", "gcn.bn_layers.3.bias", "gcn.bn_layers.3.running_mean", "gcn.bn_layers.3.running_var". 
	size mismatch for gcn.conv_layers.2.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for gcn.conv_layers.2.lin.weight: copying a param with shape torch.Size([64, 128]) from checkpoint, the shape in current model is torch.Size([128, 128]).
	size mismatch for gcn.bn_layers.2.weight: copying a param with shape torch.Size([64]) from checkpoint, t

Epochs:   1%|▊                                                                                  | 1/100 [00:00<?, ?it/s]2025-06-06 00:03:31 - INFO - 
Epoch 1/100 - Training phase
2025-06-06 00:03:32 - INFO - Processing batch 1/163
2025-06-06 00:03:32 - INFO - Batch shapes - x: torch.Size([1216, 3000]), edge_index: torch.Size([2, 21888]), y: torch.Size([64, 1])
2025-06-06 00:03:32 - INFO - Batch 1/163 - Loss: 0.9557 - Avg batch time: 0.25s
2025-06-06 00:03:48 - INFO - Processing batch 11/163
2025-06-06 00:03:48 - INFO - Batch 11/163 - Loss: 0.5846 - Avg batch time: 0.25s
2025-06-06 00:04:04 - INFO - Processing batch 21/163
2025-06-06 00:04:04 - INFO - Batch 21/163 - Loss: 0.6532 - Avg batch time: 0.25s
2025-06-06 00:04:19 - INFO - Processing batch 31/163
2025-06-06 00:04:19 - INFO - Batch 31/163 - Loss: 0.5424 - Avg batch time: 0.25s
2025-06-06 00:04:33 - INFO - Processing batch 41/163
2025-06-06 00:04:33 - INFO - Batch 41/163 - Loss: 0.5044 - Avg batch time: 0.25s
2025-06-06 00:04:46 -

KeyboardInterrupt: 

In [24]:
from src.utils.plot import plot_training_loss

plot_training_loss(train_history["loss"], val_history["loss"])

NameError: name 'train_history' is not defined

In [21]:
%aimport src.utils.train
from src.utils.train import evaluate_model

evaluate_model(
    model=model,
    test_loader=te_loader,
    device=device,
    checkpoint_path=SAVE_PATH,
    submission_path=SUBMISSION_ROOT / "lstm_gnn_submission_new_best_model.csv",
    use_gnn=True,
)

⚙️ Evaluating model. Loading model from: .checkpoints/lstm_gnn_best_model_35_epochs.pt
   - Loading checkpoint from: .checkpoints/lstm_gnn_best_model_35_epochs.pt
   - Detected full checkpoint dictionary.
   - Model state successfully loaded.
🧪 Performing inference on the test set...
BATCH: DataBatch(x=[1216, 3000], edge_index=[2, 21888], id=[64], batch=[1216], ptr=[65]), feature: tensor([[ 13.7818,  12.0088,  10.9369,  ...,  55.7613,  54.8321,  53.2772],
        [ 17.4475,  12.0189,   8.5730,  ...,  25.5270,  23.5848,  20.9486],
        [ -1.1044,  -2.4138,  -1.8681,  ...,  12.8679,  12.2525,  11.3361],
        ...,
        [  4.4002,  -1.2584,  -4.3083,  ..., -25.9106, -30.0391, -37.5544],
        [ 26.8536,  24.6873,  23.8955,  ..., -17.1835, -21.8779, -27.2960],
        [ 37.8393,  39.9656,  40.4172,  ..., -13.8870, -16.6136, -16.9901]])
BATCH: DataBatch(x=[1216, 3000], edge_index=[2, 21888], id=[64], batch=[1216], ptr=[65]), feature: tensor([[-15.3751, -15.9740, -13.1877,  ...,  1

Unnamed: 0,id,label
0,pqejgcvm_s001_t000_0,0
1,pqejgcvm_s001_t000_1,0
2,pqejgcvm_s001_t000_10,1
3,pqejgcvm_s001_t000_11,1
4,pqejgcvm_s001_t000_12,1
...,...,...
3609,pqejgvej_s001_t000_95,0
3610,pqejgvej_s001_t000_96,0
3611,pqejgvej_s001_t000_97,0
3612,pqejgvej_s001_t000_98,0


In [15]:
%aimport
from sklearn.metrics import f1_score, recall_score, precision_score
import wandb
from src.utils.general_funcs import confusion_matrix_plot

best_val_loss = float("inf")
best_val_f1 = 0
best_val_f1_epoch = 0
patience = 10
counter = 0
num_epochs = 100
print("Training started")

for epoch in range(1, num_epochs + 1):
    # ------- Training ------- #
    model.train()
    total_loss = 0
    for batch in train_loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        y_targets = batch.y.reshape(-1, 1)
        out = model(batch.x, batch.edge_index, batch.batch)
        loss = loss_fn(
            out, y_targets
        )
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        
    avg_train_loss = total_loss / len(train_loader)  # Average loss per batch

    # ------- Validation ------- #
    model.eval()
    val_loss = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in val_loader:
            batch = batch.to(device)
            print("Batch batch:", batch.batch)
            out = model(
                batch.x, batch.edge_index, batch.batch
            )
            loss = loss_fn(out, batch.y.reshape(-1, 1))
            val_loss += loss.item()
            probs = torch.sigmoid(out).squeeze()  # [batch_size, 1] -> [batch_size]
            preds = (probs > 0.5).int()
            all_preds.extend(preds.cpu().numpy().ravel())
            all_labels.extend(
                batch.y.int().cpu().numpy().ravel()
            )
            

    avg_val_loss = val_loss / len(val_loader)  # Average loss per batch
    #scheduler.step(avg_val_loss)
    val_f1 = f1_score(all_labels, all_preds, average="macro")

    all_labels = np.array(all_labels).astype(int)
    all_preds = np.array(all_preds).astype(int)

    # for name, param in model.named_parameters():
    #     if param.grad is not None:
    #         print(f"{name} grad mean: {param.grad.abs().mean()}")
    
    # Monitor progress
    print(f"Epoch {epoch} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | Val F1: {val_f1:.4f}")
    
    # Additional metrics

    # Confusion matrix
    confusion_matrix_plot(all_preds, all_labels)
    # Compute metrics per class (0 and 1)
    precision = precision_score(all_labels, all_preds, average=None)
    recall = recall_score(all_labels, all_preds, average=None)
    f1 = f1_score(all_labels, all_preds, average=None)

    # Print only for class 1
    print(f"Class 1 — Precision: {precision[1]:.2f}, Recall: {recall[1]:.2f}, F1: {f1[1]:.2f}")
    
    # W&B
    # wandb.log(
    #     {
    #         "epoch": epoch,
    #         "train_loss": avg_train_loss,
    #         "val_loss": avg_val_loss,
    #         "val_f1": val_f1,
    #         "val_f1_class_1":f1[1],
    #             "val_f1_class_0":f1[0]
    #     }
    # )
    print(f"Epoch {epoch} — Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, Val F1: {val_f1:.4f}", flush=True)
    # ------- Record best F1 score ------- #
    if val_f1 > best_val_f1:
        best_val_f1 = val_f1
        best_val_f1_epoch = epoch
        best_preds = all_preds.copy()
        best_labels = all_labels.copy()
        # Load best stats in wandb
        wandb.summary["best_f1_score"] = val_f1
        wandb.summary["f1_score_epoch"] = epoch
    # ------- Early Stopping ------- #
    if avg_val_loss < best_val_loss:
        # Save best statistics and model
        best_val_loss = avg_val_loss
        counter = 0
        best_state_dict = model.state_dict().copy()  # Save the best model state
    else:
        counter += 1
        if counter >= patience:
            print("Early stopping triggered.")
            break

print(f"Best validation F1: {best_val_f1:.4f} at epoch {best_val_f1_epoch}")