In [None]:
import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import StratifiedGroupKFold
from tqdm.auto import tqdm 
import wandb
from torch.optim.lr_scheduler import CosineAnnealingLR
import sys
from pathlib import Path

from pytorch_tcn import TCN


sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "..", "..", "..")))
from src.utils.k_folds_creator import KFoldCreator
from src.utils.utils import get_models_save_path
from src.utils.constants import Constants 
from src.datasets.eeg_dataset import EEGDataset
wandb.login()




2025-10-16 16:56:36,735 :: root :: INFO :: Initialising Utils
2025-10-16 16:56:36,736 :: root :: INFO :: Initialising Datasets
[34m[1mwandb[0m: Currently logged in as: [33mmaikotrede[0m ([33mhms-hslu-aicomp-hs25[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Preparing data and creating folds...
Train shape: (17089, 12)
Targets ['seizure_vote', 'lpd_vote', 'gpd_vote', 'lrda_vote', 'grda_vote', 'other_vote']
Folds created. Value counts per fold:
fold
0    3741
1    3703
2    3527
4    3081
3    3037
Name: count, dtype: int64
Using device: cuda



   --- Epoch 1/15 ---


Training:   0%|          | 0/417 [00:00<?, ?it/s]


--- Data Shape Verification ---
Shape of one training batch (batch, seq_len, channels): torch.Size([32, 3334, 20])
Calculated sequence length in CFG: 3300
Downsampling enabled: True, Factor: 3
--- End Verification ---



Validation:   0%|          | 0/117 [00:00<?, ?it/s]

   Epoch 1: Train Loss = 1.2291, Valid Loss = 1.2715, LR = 0.000100
   New best model saved with validation loss: 1.2715
   --- Epoch 2/15 ---


Training:   0%|          | 0/417 [00:00<?, ?it/s]

Validation:   0%|          | 0/117 [00:00<?, ?it/s]

   Epoch 2: Train Loss = 1.0890, Valid Loss = 1.1452, LR = 0.000099
   New best model saved with validation loss: 1.1452
   --- Epoch 3/15 ---


Training:   0%|          | 0/417 [00:00<?, ?it/s]

Validation:   0%|          | 0/117 [00:00<?, ?it/s]

   Epoch 3: Train Loss = 0.9202, Valid Loss = 0.9583, LR = 0.000096
   New best model saved with validation loss: 0.9583
   --- Epoch 4/15 ---


Training:   0%|          | 0/417 [00:00<?, ?it/s]

Validation:   0%|          | 0/117 [00:00<?, ?it/s]

   Epoch 4: Train Loss = 0.8215, Valid Loss = 0.8821, LR = 0.000090
   New best model saved with validation loss: 0.8821
   --- Epoch 5/15 ---


Training:   0%|          | 0/417 [00:00<?, ?it/s]

Validation:   0%|          | 0/117 [00:00<?, ?it/s]

   Epoch 5: Train Loss = 0.7537, Valid Loss = 0.8540, LR = 0.000083
   New best model saved with validation loss: 0.8540
   --- Epoch 6/15 ---


Training:   0%|          | 0/417 [00:00<?, ?it/s]

Validation:   0%|          | 0/117 [00:00<?, ?it/s]

   Epoch 6: Train Loss = 0.6955, Valid Loss = 0.8483, LR = 0.000075
   New best model saved with validation loss: 0.8483
   --- Epoch 7/15 ---


Training:   0%|          | 0/417 [00:00<?, ?it/s]

Validation:   0%|          | 0/117 [00:00<?, ?it/s]

   Epoch 7: Train Loss = 0.6355, Valid Loss = 0.8334, LR = 0.000065
   New best model saved with validation loss: 0.8334
   --- Epoch 8/15 ---


Training:   0%|          | 0/417 [00:00<?, ?it/s]

Validation:   0%|          | 0/117 [00:00<?, ?it/s]

   Epoch 8: Train Loss = 0.5816, Valid Loss = 0.8710, LR = 0.000055
   --- Epoch 9/15 ---


Training:   0%|          | 0/417 [00:00<?, ?it/s]

Validation:   0%|          | 0/117 [00:00<?, ?it/s]

   Epoch 9: Train Loss = 0.5086, Valid Loss = 0.8955, LR = 0.000045
   --- Epoch 10/15 ---


Training:   0%|          | 0/417 [00:00<?, ?it/s]

Validation:   0%|          | 0/117 [00:00<?, ?it/s]

   Epoch 10: Train Loss = 0.4527, Valid Loss = 0.8902, LR = 0.000035
   --- Epoch 11/15 ---


Training:   0%|          | 0/417 [00:00<?, ?it/s]

Validation:   0%|          | 0/117 [00:00<?, ?it/s]

   Epoch 11: Train Loss = 0.3971, Valid Loss = 0.8962, LR = 0.000025
   --- Epoch 12/15 ---


Training:   0%|          | 0/417 [00:00<?, ?it/s]

Validation:   0%|          | 0/117 [00:00<?, ?it/s]

   Epoch 12: Train Loss = 0.3489, Valid Loss = 0.9366, LR = 0.000017
   --- Epoch 13/15 ---


Training:   0%|          | 0/417 [00:00<?, ?it/s]

Validation:   0%|          | 0/117 [00:00<?, ?it/s]

   Epoch 13: Train Loss = 0.3174, Valid Loss = 0.9355, LR = 0.000010
   --- Epoch 14/15 ---


Training:   0%|          | 0/417 [00:00<?, ?it/s]

Validation:   0%|          | 0/117 [00:00<?, ?it/s]

   Epoch 14: Train Loss = 0.2971, Valid Loss = 0.9708, LR = 0.000004
   --- Epoch 15/15 ---


Training:   0%|          | 0/417 [00:00<?, ?it/s]

Validation:   0%|          | 0/117 [00:00<?, ?it/s]

   Epoch 15: Train Loss = 0.2875, Valid Loss = 0.9620, LR = 0.000001

Logged artifact for fold 0 with best validation loss: 0.8334


0,1
epoch,▁▁▂▃▃▃▄▅▅▅▆▇▇▇█
gradients/fc_layer_norm,▃▂▂▁▄▃▂▅▂▂▃█▃▅█▄▄▄▄▆▂▂▂▃▃▅▇▂▃▃▂▃▁▃▃▃▃▂▂▂
gradients/total_norm,▂▂▂▁▂▅▅▃▃▄▄▃▆▅▆▃▄▅▅▄▅▆▅▅▅▄▅██▄▇█▃▅▅▄▃▃▇▄
train/epoch_loss,█▇▆▅▄▄▄▃▃▂▂▁▁▁▁
train/epoch_lr,███▇▇▆▆▅▄▃▃▂▂▁▁
train/loss,██▇█▇▆▆▆▆▅▄▃▆▅▄▆▆▃▅▄▄▃▃▄▃▄▃▂▃▂▂▂▃▂▂▂▁▁▁▂
val/kl_div,█▆▃▂▁▁▁▂▂▂▂▃▃▃▃
val/loss,█▆▃▂▁▁▁▂▂▂▂▃▃▃▃

0,1
best_val_kl_div,0.83338
epoch,15.0
gradients/fc_layer_norm,1.20085
gradients/total_norm,4.8048
train/epoch_loss,0.28753
train/epoch_lr,0.0
train/loss,0.28411
val/kl_div,0.96197
val/loss,0.96197





   --- Epoch 1/15 ---


Training:   0%|          | 0/418 [00:00<?, ?it/s]


--- Data Shape Verification ---
Shape of one training batch (batch, seq_len, channels): torch.Size([32, 3334, 20])
Calculated sequence length in CFG: 3300
Downsampling enabled: True, Factor: 3
--- End Verification ---



Validation:   0%|          | 0/116 [00:00<?, ?it/s]

   Epoch 1: Train Loss = 1.2256, Valid Loss = 1.2559, LR = 0.000100
   New best model saved with validation loss: 1.2559
   --- Epoch 2/15 ---


Training:   0%|          | 0/418 [00:00<?, ?it/s]

Validation:   0%|          | 0/116 [00:00<?, ?it/s]

   Epoch 2: Train Loss = 1.0115, Valid Loss = 1.0490, LR = 0.000099
   New best model saved with validation loss: 1.0490
   --- Epoch 3/15 ---


Training:   0%|          | 0/418 [00:00<?, ?it/s]

Validation:   0%|          | 0/116 [00:00<?, ?it/s]

   Epoch 3: Train Loss = 0.8693, Valid Loss = 0.9116, LR = 0.000096
   New best model saved with validation loss: 0.9116
   --- Epoch 4/15 ---


Training:   0%|          | 0/418 [00:00<?, ?it/s]

Validation:   0%|          | 0/116 [00:00<?, ?it/s]

   Epoch 4: Train Loss = 0.7986, Valid Loss = 0.9638, LR = 0.000090
   --- Epoch 5/15 ---


Training:   0%|          | 0/418 [00:00<?, ?it/s]

Validation:   0%|          | 0/116 [00:00<?, ?it/s]

   Epoch 5: Train Loss = 0.7223, Valid Loss = 0.8651, LR = 0.000083
   New best model saved with validation loss: 0.8651
   --- Epoch 6/15 ---


Training:   0%|          | 0/418 [00:00<?, ?it/s]

Validation:   0%|          | 0/116 [00:00<?, ?it/s]

   Epoch 6: Train Loss = 0.6687, Valid Loss = 0.8967, LR = 0.000075
   --- Epoch 7/15 ---


Training:   0%|          | 0/418 [00:00<?, ?it/s]

Validation:   0%|          | 0/116 [00:00<?, ?it/s]

   Epoch 7: Train Loss = 0.6079, Valid Loss = 0.8706, LR = 0.000065
   --- Epoch 8/15 ---


Training:   0%|          | 0/418 [00:00<?, ?it/s]

Validation:   0%|          | 0/116 [00:00<?, ?it/s]

   Epoch 8: Train Loss = 0.5476, Valid Loss = 0.8631, LR = 0.000055
   New best model saved with validation loss: 0.8631
   --- Epoch 9/15 ---


Training:   0%|          | 0/418 [00:00<?, ?it/s]

Validation:   0%|          | 0/116 [00:00<?, ?it/s]

   Epoch 9: Train Loss = 0.4732, Valid Loss = 0.8731, LR = 0.000045
   --- Epoch 10/15 ---


Training:   0%|          | 0/418 [00:00<?, ?it/s]

Validation:   0%|          | 0/116 [00:00<?, ?it/s]

   Epoch 10: Train Loss = 0.4045, Valid Loss = 0.8692, LR = 0.000035
   --- Epoch 11/15 ---


Training:   0%|          | 0/418 [00:00<?, ?it/s]

Validation:   0%|          | 0/116 [00:00<?, ?it/s]

   Epoch 11: Train Loss = 0.3542, Valid Loss = 0.9398, LR = 0.000025
   --- Epoch 12/15 ---


Training:   0%|          | 0/418 [00:00<?, ?it/s]

Validation:   0%|          | 0/116 [00:00<?, ?it/s]

   Epoch 12: Train Loss = 0.3077, Valid Loss = 0.9299, LR = 0.000017
   --- Epoch 13/15 ---


Training:   0%|          | 0/418 [00:00<?, ?it/s]

Validation:   0%|          | 0/116 [00:00<?, ?it/s]

   Epoch 13: Train Loss = 0.2746, Valid Loss = 0.9884, LR = 0.000010
   --- Epoch 14/15 ---


Training:   0%|          | 0/418 [00:00<?, ?it/s]

Validation:   0%|          | 0/116 [00:00<?, ?it/s]

   Epoch 14: Train Loss = 0.2586, Valid Loss = 0.9513, LR = 0.000004
   --- Epoch 15/15 ---


Training:   0%|          | 0/418 [00:00<?, ?it/s]

In [None]:

class CFG:
    # Input Sequence Length :
    #   - Original Sampling Rate: 200 Hz
    #   - Downsample Factor: 2
    #   - Effective Sampling Rate: 200 Hz / 2 = 100 Hz
    #   - Clip Duration: 50 seconds
    #   - Required Input Length: 100 Hz * 50 s = 5000 timesteps

    # Receptive Field Calculation TCN:
    #   - The RF formula is: 1 + Sum[(kernel_size - 1) * dilation] for each layer.
    #   - With 8 layers, the dilations are: 1, 2, 4, 8, 16, 32, 64, 128.
    #   - Kernel Size: 21, so (kernel_size - 1) = 20.
    #   - RF = 1 + (20 * 1) + (20 * 2) + (20 * 4) + (20 * 8) + (20 * 16) + (20 * 32) + (20 * 64) + (20 * 128)
    #   - RF = 1 + 20 * (1 + 2 + 4 + 8 + 16 + 32 + 64 + 128)
    #   - RF = 1 + 20 * (255)
    #   - RF = 1 + 5100 = 5101 timesteps
    seed = 42
    n_splits = 5
    data_path = '../../../data/'

    model_name = 'TCN'
    num_tcn_channels = [64, 128, 128, 256, 256, 512, 512, 512]
    kernel_size = 21 
    dropout = 0.35
    target_size = 6


    original_sampling_rate = 200 # Hz
    sequence_duration = 50
    num_channels = 20 # Number of input EEG channels
    
    enable_downsampling = True
    downsample_factor = 3 # 2 would be optimal, but 3 is used to fit in memory. maybe if we can get a better GPU...


    batch_size = 32
    num_workers = 0
    epochs = 15
    lr = 1e-4

if CFG.enable_downsampling:
    CFG.sampling_rate = CFG.original_sampling_rate // CFG.downsample_factor
else:
    CFG.sampling_rate = CFG.original_sampling_rate

CFG.sequence_length = CFG.sequence_duration * CFG.sampling_rate

def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

set_seed(CFG.seed)
TARGETS = Constants.TARGETS



Unnamed: 0,eeg_id,patient_id,seizure_vote,lpd_vote,gpd_vote,lrda_vote,grda_vote,other_vote,expert_consensus,fold
0,0,52,0.077368,0.016019,0.010173,0.055574,0.404570,0.436296,other_vote,4
1,1,93,0.054295,0.412413,0.295168,0.219599,0.018326,0.000200,lpd_vote,0
2,2,15,0.094554,0.068652,0.091909,0.254528,0.241059,0.249297,lrda_vote,1
3,3,72,0.042833,0.269521,0.253119,0.242989,0.119773,0.071766,lpd_vote,4
4,4,61,0.031128,0.054718,0.163076,0.203813,0.219944,0.327320,other_vote,2
...,...,...,...,...,...,...,...,...,...,...
495,495,93,0.232811,0.187740,0.120637,0.240649,0.129826,0.088337,lrda_vote,0
496,496,67,0.104876,0.196337,0.201738,0.162458,0.182024,0.152567,gpd_vote,1
497,497,76,0.154515,0.139353,0.103543,0.277283,0.081564,0.243741,lrda_vote,0
498,498,26,0.680393,0.032345,0.092773,0.001559,0.182094,0.010837,seizure_vote,4


In [None]:

class TCNModel(nn.Module):
    def __init__(self, num_inputs, num_outputs, channel_sizes, kernel_size, dropout, causal=False, use_skip_connections=True):
        super(TCNModel, self).__init__()
        self.tcn = TCN(
            num_inputs=num_inputs,
            num_channels=channel_sizes,
            kernel_size=kernel_size,
            dropout=dropout,
            causal=causal,
            use_skip_connections=use_skip_connections
        )
        # The output of the TCN is the number of channels in the last layer
        self.fc = nn.Linear(channel_sizes[-1], num_outputs)

    def forward(self, x):
        # TCN expects input of shape (batch_size, num_channels, sequence_length)
        # Our dataloader provides (batch_size, sequence_length, num_channels)
        # So we need to permute the dimensions
        x = x.permute(0, 2, 1)
        
        tcn_output = self.tcn(x)
        
        # We take the output of the last time step for classification
        last_time_step_output = tcn_output[:, :, -1]
        
        output = self.fc(last_time_step_output)
        return output

def get_dataloaders(df, fold_id):
    train_df = df[df['fold'] != fold_id].reset_index(drop=True)
    valid_df = df[df['fold'] == fold_id].reset_index(drop=True)

    downsample_factor = CFG.downsample_factor if CFG.enable_downsampling else 1

    train_dataset = EEGDataset(
        df=train_df, data_path=CFG.data_path, mode='train', 
        downsample_factor=downsample_factor
    )
    valid_dataset = EEGDataset(
        df=valid_df, data_path=CFG.data_path, mode='valid', 
        downsample_factor=downsample_factor
    )

    train_loader = DataLoader(
        train_dataset, batch_size=CFG.batch_size, shuffle=True,
        num_workers=CFG.num_workers, pin_memory=True, drop_last=True
    )
    
    valid_loader = DataLoader(
        valid_dataset, batch_size=CFG.batch_size, shuffle=False,
        num_workers=CFG.num_workers, pin_memory=True, drop_last=False
    )
    
    return train_loader, valid_loader


In [None]:
def run_training(df, DATA_PREPARATION_VOTE_METHOD):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    fold_scores = []

    for fold in range(CFG.n_splits):
        print(f"\n========== FOLD {fold} ==========")

        config = {
            "architecture": CFG.model_name, "tcn_channels": CFG.num_tcn_channels,
            "kernel_size": CFG.kernel_size, "dropout": CFG.dropout,
            "fold": fold, "features": "raw_eeg", "sequence_duration": f"{CFG.sequence_duration}s",
            "optimizer": "AdamW", "learning_rate": CFG.lr, "batch_size": CFG.batch_size,
            "epochs": CFG.epochs, "seed": CFG.seed, "Scheduler": "CosineAnnealingLR" 
        }

        wandb.init(
            project="hms-aicomp-tcn",
            name=f"tcn-raw-eeg-fold{fold}", 
            tags=['tcn-baseline', f'fold{fold}'],
            config=config
        )

        model = TCNModel(
            num_inputs=CFG.num_channels,
            num_outputs=CFG.target_size,
            channel_sizes=CFG.num_tcn_channels,
            kernel_size=CFG.kernel_size,
            dropout=CFG.dropout,
            causal=False, # causality is a limitation we don't need for our prediction task
            use_skip_connections=True
        )
        model.to(device)

        optimizer = torch.optim.AdamW(model.parameters(), lr=CFG.lr)
        scheduler = CosineAnnealingLR(optimizer, T_max=CFG.epochs)
        loss_fn = nn.KLDivLoss(reduction='batchmean')
        train_loader, valid_loader = get_dataloaders(df, fold)

        best_val_loss = float('inf')
        best_model_path = None
        shape_printed = False  
        for epoch in range(CFG.epochs):
            print(f"   --- Epoch {epoch+1}/{CFG.epochs} ---")
            
            model.train()
            train_loss = 0
            for signals, labels in tqdm(train_loader, desc="Training"):
                signals, labels = signals.to(device), labels.to(device)
                if not shape_printed:
                                    print(f"\n--- Data Shape Verification ---")
                                    print(f"Shape of one training batch (batch, seq_len, channels): {signals.shape}")
                                    print(f"Calculated sequence length in CFG: {CFG.sequence_length}")
                                    print(f"Downsampling enabled: {CFG.enable_downsampling}, Factor: {CFG.downsample_factor if CFG.enable_downsampling else 'N/A'}")
                                    print(f"--- End Verification ---\n")
                                    shape_printed = True
                optimizer.zero_grad()
                outputs = model(signals)
                log_probs = F.log_softmax(outputs, dim=1)
                loss = loss_fn(log_probs, labels)
                loss.backward()

                total_norm = 0
                for name, p in model.named_parameters():
                    if p.grad is not None:
                        param_norm = p.grad.data.norm(2)
                        total_norm += param_norm.item() ** 2
                        if "tcn.network.0.conv1.weight" in name: 
                            wandb.log({"gradients/tcn_layer_0_norm": param_norm.item()})
                        elif "fc.weight" in name:
                            wandb.log({"gradients/fc_layer_norm": param_norm.item()})
                
                total_norm = total_norm ** 0.5
                wandb.log({"gradients/total_norm": total_norm})
                
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()

                train_loss += loss.item() * signals.size(0)
                wandb.log({"train/loss": loss.item()})
            train_loss /= len(train_loader.dataset)

            model.eval()
            valid_loss = 0
            with torch.no_grad():
                for i, (signals, labels) in enumerate(tqdm(valid_loader, desc="Validation")):
                    signals, labels = signals.to(device), labels.to(device)
                    outputs = model(signals)
                    log_probs = F.log_softmax(outputs, dim=1)
                    loss = loss_fn(log_probs, labels)
                    valid_loss += loss.item() * signals.size(0)

            valid_loss /= len(valid_loader.dataset)
            
            epoch_lr = optimizer.param_groups[0]['lr']
            print(f"   Epoch {epoch+1}: Train Loss = {train_loss:.4f}, Valid Loss = {valid_loss:.4f}, LR = {epoch_lr:.6f}")
            wandb.log({
                "epoch": epoch + 1, "train/epoch_loss": train_loss, "val/loss": valid_loss,
                "val/kl_div": valid_loss, "train/epoch_lr": epoch_lr
            })

            if valid_loss < best_val_loss:
                best_val_loss = valid_loss
                best_model_path = get_models_save_path() / "TCNModel" / DATA_PREPARATION_VOTE_METHOD / f'best_model_fold{fold}.pth'
                best_model_path.parent.mkdir(parents=True, exist_ok=True)
                torch.save(model.state_dict(), best_model_path)
                print(f"   New best model saved with validation loss: {best_val_loss:.4f}")
            scheduler.step()
        
        fold_scores.append(best_val_loss)
        wandb.summary['best_val_kl_div'] = best_val_loss
        if best_model_path:
            artifact = wandb.Artifact(f'model-fold{fold}', type='model')
            artifact.add_file(best_model_path)
            wandb.log_artifact(artifact)
            print(f"\nLogged artifact for fold {fold} with best validation loss: {best_val_loss:.4f}")
        else:
            print("\nNo best model was saved during training for this fold.")
        wandb.finish()
    return fold_scores



In [None]:
if __name__ == '__main__':
    DATA_PREPARATION_VOTE_METHOD = "sum_and_normalize"

    print("Preparing data and creating folds...")
    df = pd.read_csv(CFG.data_path + 'processed_data_sum_votes_window.csv') 

    label_map = {t: i for i, t in enumerate(TARGETS)}
    df['expert_consensus'] = df[TARGETS].idxmax(axis=1)

    print('Train shape:', df.shape)
    print('Targets', list(TARGETS))

    fold_creator = KFoldCreator(n_splits=CFG.n_splits, seed=CFG.seed)
    df = fold_creator.create_folds(df, stratify_col='expert_consensus', group_col='patient_id')

    print("Folds created. Value counts per fold:")
    print(df['fold'].value_counts())

    all_fold_scores = run_training(df, DATA_PREPARATION_VOTE_METHOD)
    cv_score = np.mean(all_fold_scores)
    print(f"\nCross-Validation Score (Mean KL Divergence across folds): {cv_score:.4f}")