In [1]:
import os
import sys
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from pathlib import Path
import timm
from scipy.signal import butter, sosfiltfilt


from pytorch_tcn import TCN

sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "..", "..", "..")))


from src.utils.k_folds_creator import KFoldCreator
from src.utils.utils import get_models_save_path
from src.utils.constants import Constants
from src.utils.feature_extraction import FeatureExtraction

2025-11-02 22:06:15,928 :: root :: INFO :: Initialising Utils
2025-11-02 22:06:16,014 :: root :: INFO :: Initialising Datasets
2025-11-02 22:06:16,017 :: root :: INFO :: Initialising Models


In [2]:


class ClassifierHead(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, dropout):
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.bn1 = nn.BatchNorm1d(hidden_size)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)
        self.fc2 = nn.Linear(hidden_size, output_size)
        
    def forward(self, x):
        x = self.fc1(x); x = self.bn1(x); x = self.relu(x)
        x = self.dropout(x); x = self.fc2(x)
        return x

class FeatureDataset(Dataset):
    def __init__(self, features, labels):
        self.features = torch.tensor(features, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.float32)
    def __len__(self): return len(self.features)
    def __getitem__(self, idx): return self.features[idx], self.labels[idx]



In [3]:

class CFG:
    # Fold Config
    SEED = 42
    N_SPLITS = 5
    
    # Paths 
    DATA_PATH = '../../../data/'
    TRAIN_CSV_NAME = 'processed_data_sum_votes_window.csv'
    DATA_PREP_VOTE_METHOD = "sum_and_normalize" 
    FEATURE_STORE_PATH = DATA_PATH + 'extracted_feature/'
    
    # TCN Config 
    TCN_MODEL_DIR = "TCNModel"
    TCN_NUM_CHANNELS = 20
    TCN_CHANNEL_SIZES = [64, 128, 128, 256, 256, 512, 512, 512]
    TCN_KERNEL_SIZE = 21
    TCN_DROPOUT = 0.35
    TCN_DOWNSAMPLE_FACTOR = 3

    # CNN Config
    CNN_MODEL_DIR = "MultiSpectCNN"
    CNN_MODEL_NAME = 'tf_efficientnet_b0_ns'
    CNN_IN_CHANNELS = 8
    CNN_IMG_SIZE = (128, 256)
    CNN_EEG_SPEC_PATH = '../../../data/custom_eegs/cwt'
    
    # Feature Head Config 
    HEAD_MODEL_SAVE_DIR = get_models_save_path() / "MultiModalHead" / DATA_PREP_VOTE_METHOD
    HEAD_HIDDEN_SIZE = 512
    HEAD_DROPOUT = 0.4
    HEAD_BATCH_SIZE = 64
    HEAD_EPOCHS = 20
    HEAD_LR = 1e-4
    TARGET_SIZE = 6

    # General Inference
    INFERENCE_BATCH_SIZE = 64
    NUM_WORKERS = 4

def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

set_seed(CFG.SEED)
CFG.HEAD_MODEL_SAVE_DIR.mkdir(parents=True, exist_ok=True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')




# Training

In [4]:
def run_head_training(train_loader, valid_loader, input_size, fold_k):
    """Trains the classifier head for one fold."""
    model = ClassifierHead(
        input_size=input_size,
        hidden_size=CFG.HEAD_HIDDEN_SIZE,
        output_size=CFG.TARGET_SIZE,
        dropout=CFG.HEAD_DROPOUT
    ).to(device)
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=CFG.HEAD_LR)
    loss_fn = nn.KLDivLoss(reduction='batchmean')
    best_val_loss = float('inf')
    
    for epoch in range(CFG.HEAD_EPOCHS):
        model.train()
        for features, labels in train_loader:
            features, labels = features.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(features)
            log_probs = F.log_softmax(outputs, dim=1)
            loss = loss_fn(log_probs, labels)
            loss.backward()
            optimizer.step()
        
        model.eval()
        valid_loss = 0
        with torch.no_grad():
            for features, labels in valid_loader:
                features, labels = features.to(device), labels.to(device)
                outputs = model(features)
                log_probs = F.log_softmax(outputs, dim=1)
                loss = loss_fn(log_probs, labels)
                valid_loss += loss.item() * features.size(0)
        
        valid_loss /= len(valid_loader.dataset)
        
        if valid_loss < best_val_loss:
            best_val_loss = valid_loss
            save_path = CFG.HEAD_MODEL_SAVE_DIR / f'best_head_fold{fold_k}.pth'
            torch.save(model.state_dict(), save_path)
            print(f"Saved Best Model for Fold {fold_k} at Epoch {epoch+1}")
            
        print(f"  Epoch {epoch+1}/{CFG.HEAD_EPOCHS}, Val Loss: {valid_loss:.4f}")

    print(f"  Fold {fold_k} Best Val Loss: {best_val_loss:.4f}")
    return best_val_loss


if __name__ == '__main__':
    
    CFG.FEATURE_CACHE_PATH = Path(CFG.FEATURE_STORE_PATH) 
    Path(CFG.FEATURE_CACHE_PATH).mkdir(parents=True, exist_ok=True)
        
    df_path = Path(CFG.DATA_PATH) / CFG.TRAIN_CSV_NAME
    df = pd.read_csv(df_path)
    
    fold_creator = KFoldCreator(n_splits=CFG.N_SPLITS, seed=CFG.SEED)
    df = fold_creator.create_folds(df, stratify_col='expert_consensus', group_col='patient_id')

    extractor = FeatureExtraction(CFG, device)

    eeg_loader_all = extractor.get_eeg_inference_loader(
        df, CFG.DATA_PATH, CFG.TCN_DOWNSAMPLE_FACTOR
    )
    spec_loader_all = extractor.get_spec_inference_loader(
        df, Constants.TARGETS, CFG.DATA_PATH, CFG.CNN_IMG_SIZE, CFG.CNN_EEG_SPEC_PATH
    )

    all_fold_scores = []

    for fold_k in range(CFG.N_SPLITS):
        print("\n" + "="*50)
        print(f"Processing Fold {fold_k}")
        print("="*50)
        
        cache_dir = Path(CFG.FEATURE_CACHE_PATH)
        tcn_cache_file = cache_dir / f"tcn_features_fold{fold_k}.npy"
        cnn_cache_file = cache_dir / f"cnn_features_fold{fold_k}.npy"
        labels_cache_file = cache_dir / f"labels_fold{fold_k}.npy" 

        if (tcn_cache_file.exists() and 
            cnn_cache_file.exists() and 
            labels_cache_file.exists()):
            
            print(f"Loading features for Fold {fold_k}...")
            tcn_features_all = np.load(tcn_cache_file)
            cnn_features_all = np.load(cnn_cache_file)
            tcn_labels_all = np.load(labels_cache_file)
            cnn_labels_all = tcn_labels_all 

        else:
            print(f"Cached features not found. Generating for Fold {fold_k}...")
            
            print(f"Loading TCN Model for Fold {fold_k}...")
            tcn_model_path = (
                get_models_save_path() / CFG.TCN_MODEL_DIR / 
                CFG.DATA_PREP_VOTE_METHOD / f'best_model_fold{fold_k}.pth'
            )
            tcn_extractor = extractor.build_tcn_feature_extractor(tcn_model_path)
            tcn_features_all, tcn_labels_all = extractor.extract_features(tcn_extractor, eeg_loader_all)
            del tcn_extractor; torch.cuda.empty_cache() 

            print(f"Loading CNN Model for Fold {fold_k}...")
            cnn_model_path = (
                get_models_save_path() / CFG.CNN_MODEL_DIR / 
                CFG.DATA_PREP_VOTE_METHOD / f'best_model_fold{fold_k}.pth'
            )
            cnn_extractor = extractor.build_cnn_feature_extractor(cnn_model_path)
            cnn_features_all, cnn_labels_all = extractor.extract_features(cnn_extractor, spec_loader_all)
            del cnn_extractor; torch.cuda.empty_cache() 
            
            print(f"Saving features to cache: {cache_dir}")
            np.save(tcn_cache_file, tcn_features_all)
            np.save(cnn_cache_file, cnn_features_all)
            np.save(labels_cache_file, tcn_labels_all) 

        print(f"TCN Features Extracted. Shape: {tcn_features_all.shape}")
        print(f"CNN Features Extracted. Shape: {cnn_features_all.shape}")
        
        if not np.array_equal(tcn_labels_all, cnn_labels_all):
            raise ValueError(f"Label mismatch in Fold {fold_k}!")
        
        combined_features_all = np.concatenate([tcn_features_all, cnn_features_all], axis=1)
        labels_all = tcn_labels_all
        input_size = combined_features_all.shape[1]
        
        print("Splitting features into train/validation for this fold...")
        train_indices = df[df['fold'] != fold_k].index
        val_indices = df[df['fold'] == fold_k].index
        
        train_features = combined_features_all[train_indices]
        train_labels = labels_all[train_indices]
        val_features = combined_features_all[val_indices]
        val_labels = labels_all[val_indices]
        
        train_dataset = FeatureDataset(train_features, train_labels)
        valid_dataset = FeatureDataset(val_features, val_labels)
        
        train_loader = DataLoader(
            train_dataset, batch_size=CFG.HEAD_BATCH_SIZE, shuffle=True
        )
        valid_loader = DataLoader(
            valid_dataset, batch_size=CFG.HEAD_BATCH_SIZE, shuffle=False
        )
        
        print(f"Training classifier head for Fold {fold_k}...")
        best_fold_loss = run_head_training(
            train_loader, valid_loader, input_size, fold_k
        )
        all_fold_scores.append(best_fold_loss)

    print("\n" + "="*50)
    print("Full 5-Fold Cross-Validation Complete.")
    print(f"Scores per fold: {all_fold_scores}")
    
    mean_cv_score = np.mean(all_fold_scores)
    print(f"\nMean CV Score: {mean_cv_score:.4f}")
    print("="*50)


Processing Fold 0
Cached features not found. Generating for Fold 0...
Loading TCN Model for Fold 0...


Extracting Features:   0%|          | 0/268 [00:00<?, ?it/s]

Loading CNN Model for Fold 0...


  model = create_fn(


Extracting Features:   0%|          | 0/268 [00:00<?, ?it/s]

Saving features to cache: ../../../data/extracted_feature
TCN Features Extracted. Shape: (17089, 512)
CNN Features Extracted. Shape: (17089, 1280)
Splitting features into train/validation for this fold...
Training classifier head for Fold 0...
Saved Best Model for Fold 0 at Epoch 1
  Epoch 1/20, Val Loss: 0.5044
  Epoch 2/20, Val Loss: 0.5231
  Epoch 3/20, Val Loss: 0.5227
  Epoch 4/20, Val Loss: 0.5420
  Epoch 5/20, Val Loss: 0.5709
  Epoch 6/20, Val Loss: 0.5406
  Epoch 7/20, Val Loss: 0.5557
  Epoch 8/20, Val Loss: 0.5445
  Epoch 9/20, Val Loss: 0.6014
  Epoch 10/20, Val Loss: 0.5819
  Epoch 11/20, Val Loss: 0.5870
  Epoch 12/20, Val Loss: 0.5798
  Epoch 13/20, Val Loss: 0.5807
  Epoch 14/20, Val Loss: 0.5817
  Epoch 15/20, Val Loss: 0.5636
  Epoch 16/20, Val Loss: 0.5774
  Epoch 17/20, Val Loss: 0.5894
  Epoch 18/20, Val Loss: 0.5793
  Epoch 19/20, Val Loss: 0.6122
  Epoch 20/20, Val Loss: 0.6120
  Fold 0 Best Val Loss: 0.5044

Processing Fold 1
Cached features not found. Generatin

Extracting Features:   0%|          | 0/268 [00:00<?, ?it/s]

Loading CNN Model for Fold 1...


Extracting Features:   0%|          | 0/268 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f01aa8c8fe0>
Traceback (most recent call last):
Exception ignored in:   File "/home/maiko/miniconda3/envs/aicomp/lib/python3.13/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
<function _MultiProcessingDataLoaderIter.__del__ at 0x7f01aa8c8fe0>    
self._shutdown_workers()Traceback (most recent call last):

  File "/home/maiko/miniconda3/envs/aicomp/lib/python3.13/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
  File "/home/maiko/miniconda3/envs/aicomp/lib/python3.13/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
Exception ignored in:         self._shutdown_workers()<function _MultiProcessingDataLoaderIter.__del__ at 0x7f01aa8c8fe0>
if w.is_alive():
  File "/home/maiko/miniconda3/envs/aicomp/lib/python3.13/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
Traceback (most recent call last):

      File "/home/

Saving features to cache: ../../../data/extracted_feature
TCN Features Extracted. Shape: (17089, 512)
CNN Features Extracted. Shape: (17089, 1280)
Splitting features into train/validation for this fold...
Training classifier head for Fold 1...
Saved Best Model for Fold 1 at Epoch 1
  Epoch 1/20, Val Loss: 0.5126
Saved Best Model for Fold 1 at Epoch 2
  Epoch 2/20, Val Loss: 0.5054
  Epoch 3/20, Val Loss: 0.5236
  Epoch 4/20, Val Loss: 0.5154
  Epoch 5/20, Val Loss: 0.5078
  Epoch 6/20, Val Loss: 0.5066
  Epoch 7/20, Val Loss: 0.5197
  Epoch 8/20, Val Loss: 0.5212
  Epoch 9/20, Val Loss: 0.5315
  Epoch 10/20, Val Loss: 0.5193
  Epoch 11/20, Val Loss: 0.5181
  Epoch 12/20, Val Loss: 0.5139
  Epoch 13/20, Val Loss: 0.5297
  Epoch 14/20, Val Loss: 0.5173
  Epoch 15/20, Val Loss: 0.5312
  Epoch 16/20, Val Loss: 0.5330
  Epoch 17/20, Val Loss: 0.5209
  Epoch 18/20, Val Loss: 0.5207
  Epoch 19/20, Val Loss: 0.5272
  Epoch 20/20, Val Loss: 0.5370
  Fold 1 Best Val Loss: 0.5054

Processing Fold

Extracting Features:   0%|          | 0/268 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f01aa8c8fe0>
Traceback (most recent call last):
  File "/home/maiko/miniconda3/envs/aicomp/lib/python3.13/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/home/maiko/miniconda3/envs/aicomp/lib/python3.13/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
  File "/home/maiko/miniconda3/envs/aicomp/lib/python3.13/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f01aa8c8fe0>
Traceback (most recent call last):
  File "/home/maiko/miniconda3/envs/aicomp/lib/python3.13/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/home/maiko/miniconda3/envs/aicomp/lib/py

Loading CNN Model for Fold 2...


Extracting Features:   0%|          | 0/268 [00:00<?, ?it/s]

Saving features to cache: ../../../data/extracted_feature
TCN Features Extracted. Shape: (17089, 512)
CNN Features Extracted. Shape: (17089, 1280)
Splitting features into train/validation for this fold...
Training classifier head for Fold 2...
Saved Best Model for Fold 2 at Epoch 1
  Epoch 1/20, Val Loss: 0.5215
  Epoch 2/20, Val Loss: 0.6037
  Epoch 3/20, Val Loss: 0.5774
  Epoch 4/20, Val Loss: 0.6028
  Epoch 5/20, Val Loss: 0.5890
  Epoch 6/20, Val Loss: 0.5980
  Epoch 7/20, Val Loss: 0.5910
  Epoch 8/20, Val Loss: 0.6446
  Epoch 9/20, Val Loss: 0.6492
  Epoch 10/20, Val Loss: 0.6212
  Epoch 11/20, Val Loss: 0.6593
  Epoch 12/20, Val Loss: 0.6115
  Epoch 13/20, Val Loss: 0.6436
  Epoch 14/20, Val Loss: 0.6212
  Epoch 15/20, Val Loss: 0.6274
  Epoch 16/20, Val Loss: 0.6368
  Epoch 17/20, Val Loss: 0.6305
  Epoch 18/20, Val Loss: 0.6291
  Epoch 19/20, Val Loss: 0.6611
  Epoch 20/20, Val Loss: 0.6424
  Fold 2 Best Val Loss: 0.5215

Processing Fold 3
Cached features not found. Generatin

Extracting Features:   0%|          | 0/268 [00:00<?, ?it/s]

Loading CNN Model for Fold 3...


Extracting Features:   0%|          | 0/268 [00:00<?, ?it/s]

Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f01aa8c8fe0><function _MultiProcessingDataLoaderIter.__del__ at 0x7f01aa8c8fe0>
Exception ignored in: 
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/maiko/miniconda3/envs/aicomp/lib/python3.13/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
  File "/home/maiko/miniconda3/envs/aicomp/lib/python3.13/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    <function _MultiProcessingDataLoaderIter.__del__ at 0x7f01aa8c8fe0>    self._shutdown_workers()
self._shutdown_workers()

Traceback (most recent call last):
  File "/home/maiko/miniconda3/envs/aicomp/lib/python3.13/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
  File "/home/maiko/miniconda3/envs/aicomp/lib/python3.13/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
  File "/home/maiko/miniconda3/envs/aic

Saving features to cache: ../../../data/extracted_feature
TCN Features Extracted. Shape: (17089, 512)
CNN Features Extracted. Shape: (17089, 1280)
Splitting features into train/validation for this fold...
Training classifier head for Fold 3...
Saved Best Model for Fold 3 at Epoch 1
  Epoch 1/20, Val Loss: 0.5560
  Epoch 2/20, Val Loss: 0.5846
  Epoch 3/20, Val Loss: 0.5881
  Epoch 4/20, Val Loss: 0.6054
  Epoch 5/20, Val Loss: 0.6213
  Epoch 6/20, Val Loss: 0.6162
  Epoch 7/20, Val Loss: 0.6290
  Epoch 8/20, Val Loss: 0.6211
  Epoch 9/20, Val Loss: 0.6313
  Epoch 10/20, Val Loss: 0.6204
  Epoch 11/20, Val Loss: 0.6145
  Epoch 12/20, Val Loss: 0.6238
  Epoch 13/20, Val Loss: 0.6419
  Epoch 14/20, Val Loss: 0.6231
  Epoch 15/20, Val Loss: 0.6417
  Epoch 16/20, Val Loss: 0.6252
  Epoch 17/20, Val Loss: 0.6259
  Epoch 18/20, Val Loss: 0.6343
  Epoch 19/20, Val Loss: 0.6574
  Epoch 20/20, Val Loss: 0.6421
  Fold 3 Best Val Loss: 0.5560

Processing Fold 4
Cached features not found. Generatin

Extracting Features:   0%|          | 0/268 [00:00<?, ?it/s]

Loading CNN Model for Fold 4...


Extracting Features:   0%|          | 0/268 [00:00<?, ?it/s]

Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f01aa8c8fe0>Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f01aa8c8fe0>

Traceback (most recent call last):
<function _MultiProcessingDataLoaderIter.__del__ at 0x7f01aa8c8fe0>Traceback (most recent call last):
  File "/home/maiko/miniconda3/envs/aicomp/lib/python3.13/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__

  File "/home/maiko/miniconda3/envs/aicomp/lib/python3.13/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    Traceback (most recent call last):
    self._shutdown_workers()self._shutdown_workers()  File "/home/maiko/miniconda3/envs/aicomp/lib/python3.13/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__


      File "/home/maiko/miniconda3/envs/aicomp/lib/python3.13/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
  File "/home/maiko/miniconda3/envs/aicomp/li

Saving features to cache: ../../../data/extracted_feature
TCN Features Extracted. Shape: (17089, 512)
CNN Features Extracted. Shape: (17089, 1280)
Splitting features into train/validation for this fold...
Training classifier head for Fold 4...
Saved Best Model for Fold 4 at Epoch 1
  Epoch 1/20, Val Loss: 0.5885
  Epoch 2/20, Val Loss: 0.6378
  Epoch 3/20, Val Loss: 0.6518
  Epoch 4/20, Val Loss: 0.6640
  Epoch 5/20, Val Loss: 0.6455
  Epoch 6/20, Val Loss: 0.6694
  Epoch 7/20, Val Loss: 0.6690
  Epoch 8/20, Val Loss: 0.6672
  Epoch 9/20, Val Loss: 0.6504
  Epoch 10/20, Val Loss: 0.6658
  Epoch 11/20, Val Loss: 0.6904
  Epoch 12/20, Val Loss: 0.6726
  Epoch 13/20, Val Loss: 0.6733
  Epoch 14/20, Val Loss: 0.6986
  Epoch 15/20, Val Loss: 0.6820
  Epoch 16/20, Val Loss: 0.6778
  Epoch 17/20, Val Loss: 0.6879
  Epoch 18/20, Val Loss: 0.6741
  Epoch 19/20, Val Loss: 0.7022
  Epoch 20/20, Val Loss: 0.6801
  Fold 4 Best Val Loss: 0.5885

Full 5-Fold Cross-Validation Complete.
Scores per fold