In [1]:
import os
import sys
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm
from pathlib import Path
import timm
from scipy.signal import butter, sosfiltfilt


from pytorch_tcn import TCN

sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "..", "..", "..")))


from src.utils.k_folds_creator import KFoldCreator
from src.utils.utils import get_models_save_path
from src.utils.constants import Constants
from src.utils.feature_extraction import FeatureExtraction

2025-11-02 21:21:28,096 :: root :: INFO :: Initialising Utils
2025-11-02 21:21:28,139 :: root :: INFO :: Initialising Models
2025-11-02 21:21:28,141 :: root :: INFO :: Initialising Datasets


In [2]:


class ClassifierHead(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, dropout):
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.bn1 = nn.BatchNorm1d(hidden_size)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)
        self.fc2 = nn.Linear(hidden_size, output_size)
        
    def forward(self, x):
        x = self.fc1(x); x = self.bn1(x); x = self.relu(x)
        x = self.dropout(x); x = self.fc2(x)
        return x

class FeatureDataset(Dataset):
    def __init__(self, features, labels):
        self.features = torch.tensor(features, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.float32)
    def __len__(self): return len(self.features)
    def __getitem__(self, idx): return self.features[idx], self.labels[idx]



In [3]:

class CFG:
    # Fold Config
    SEED = 42
    N_SPLITS = 5
    
    # Paths 
    DATA_PATH = '../../../data/'
    TRAIN_CSV_NAME = 'processed_data_sum_votes_window.csv'
    DATA_PREP_VOTE_METHOD = "sum_and_normalize" 
    FEATURE_STORE_PATH = DATA_PATH + 'extracted_feature/'
    
    # TCN Config 
    TCN_MODEL_DIR = "TCNModel"
    TCN_NUM_CHANNELS = 20
    TCN_CHANNEL_SIZES = [64, 128, 128, 256, 256, 512, 512, 512]
    TCN_KERNEL_SIZE = 21
    TCN_DROPOUT = 0.35
    TCN_DOWNSAMPLE_FACTOR = 3

    # CNN Config
    CNN_MODEL_DIR = "MultiSpectCNN"
    CNN_MODEL_NAME = 'tf_efficientnet_b0_ns'
    CNN_IN_CHANNELS = 8
    CNN_IMG_SIZE = (128, 256)
    CNN_EEG_SPEC_PATH = '../../../data/custom_eegs/cwt'
    
    # Feature Head Config 
    HEAD_MODEL_SAVE_DIR = get_models_save_path() / "MultiModalHead" / DATA_PREP_VOTE_METHOD
    HEAD_HIDDEN_SIZE = 512
    HEAD_DROPOUT = 0.4
    HEAD_BATCH_SIZE = 64
    HEAD_EPOCHS = 20
    HEAD_LR = 1e-4
    TARGET_SIZE = 6

    # General Inference
    INFERENCE_BATCH_SIZE = 64
    NUM_WORKERS = 0 

def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

set_seed(CFG.SEED)
CFG.HEAD_MODEL_SAVE_DIR.mkdir(parents=True, exist_ok=True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')




# Training

In [None]:
def run_head_training(train_loader, valid_loader, input_size, fold_k):
    """Trains the classifier head for one fold."""
    model = ClassifierHead(
        input_size=input_size,
        hidden_size=CFG.HEAD_HIDDEN_SIZE,
        output_size=CFG.TARGET_SIZE,
        dropout=CFG.HEAD_DROPOUT
    ).to(device)
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=CFG.HEAD_LR)
    loss_fn = nn.KLDivLoss(reduction='batchmean')
    best_val_loss = float('inf')
    
    for epoch in range(CFG.HEAD_EPOCHS):
        model.train()
        for features, labels in train_loader:
            features, labels = features.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(features)
            log_probs = F.log_softmax(outputs, dim=1)
            loss = loss_fn(log_probs, labels)
            loss.backward()
            optimizer.step()
        
        model.eval()
        valid_loss = 0
        with torch.no_grad():
            for features, labels in valid_loader:
                features, labels = features.to(device), labels.to(device)
                outputs = model(features)
                log_probs = F.log_softmax(outputs, dim=1)
                loss = loss_fn(log_probs, labels)
                valid_loss += loss.item() * features.size(0)
        
        valid_loss /= len(valid_loader.dataset)
        
        if valid_loss < best_val_loss:
            best_val_loss = valid_loss
            save_path = CFG.HEAD_MODEL_SAVE_DIR / f'best_head_fold{fold_k}.pth'
            torch.save(model.state_dict(), save_path)
            
        print(f"  Epoch {epoch+1}/{CFG.HEAD_EPOCHS}, Val Loss: {valid_loss:.4f}")

    print(f"  Fold {fold_k} Best Val Loss: {best_val_loss:.4f}")
    return best_val_loss


if __name__ == '__main__':
    
    CFG.FEATURE_CACHE_PATH = Path(CFG.FEATURE_STORE_PATH) 
    Path(CFG.FEATURE_CACHE_PATH).mkdir(parents=True, exist_ok=True)
        
    df_path = Path(CFG.DATA_PATH) / CFG.TRAIN_CSV_NAME
    df = pd.read_csv(df_path)
    
    fold_creator = KFoldCreator(n_splits=CFG.N_SPLITS, seed=CFG.SEED)
    df = fold_creator.create_folds(df, stratify_col='expert_consensus', group_col='patient_id')

    extractor = FeatureExtraction(CFG, device)

    eeg_loader_all = extractor.get_eeg_inference_loader(
        df, CFG.DATA_PATH, CFG.TCN_DOWNSAMPLE_FACTOR
    )
    spec_loader_all = extractor.get_spec_inference_loader(
        df, Constants.TARGETS, CFG.DATA_PATH, CFG.CNN_IMG_SIZE, CFG.CNN_EEG_SPEC_PATH
    )

    all_fold_scores = []

    for fold_k in range(CFG.N_SPLITS):
        print("\n" + "="*50)
        print(f"Processing Fold {fold_k}")
        print("="*50)
        
        cache_dir = Path(CFG.FEATURE_CACHE_PATH)
        tcn_cache_file = cache_dir / f"tcn_features_fold{fold_k}.npy"
        cnn_cache_file = cache_dir / f"cnn_features_fold{fold_k}.npy"
        labels_cache_file = cache_dir / f"labels_fold{fold_k}.npy" 

        if (tcn_cache_file.exists() and 
            cnn_cache_file.exists() and 
            labels_cache_file.exists()):
            
            print(f"Loading features for Fold {fold_k}...")
            tcn_features_all = np.load(tcn_cache_file)
            cnn_features_all = np.load(cnn_cache_file)
            tcn_labels_all = np.load(labels_cache_file)
            cnn_labels_all = tcn_labels_all 

        else:
            print(f"Cached features not found. Generating for Fold {fold_k}...")
            
            print(f"Loading TCN Model for Fold {fold_k}...")
            tcn_model_path = (
                get_models_save_path() / CFG.TCN_MODEL_DIR / 
                CFG.DATA_PREP_VOTE_METHOD / f'best_model_fold{fold_k}.pth'
            )
            tcn_extractor = extractor.build_tcn_feature_extractor(tcn_model_path)
            tcn_features_all, tcn_labels_all = extractor.extract_features(tcn_extractor, eeg_loader_all)
            del tcn_extractor; torch.cuda.empty_cache() 

            print(f"Loading CNN Model for Fold {fold_k}...")
            cnn_model_path = (
                get_models_save_path() / CFG.CNN_MODEL_DIR / 
                CFG.DATA_PREP_VOTE_METHOD / f'best_model_fold{fold_k}.pth'
            )
            cnn_extractor = extractor.build_cnn_feature_extractor(cnn_model_path)
            cnn_features_all, cnn_labels_all = extractor.extract_features(cnn_extractor, spec_loader_all)
            del cnn_extractor; torch.cuda.empty_cache() 
            
            print(f"Saving features to cache: {cache_dir}")
            np.save(tcn_cache_file, tcn_features_all)
            np.save(cnn_cache_file, cnn_features_all)
            np.save(labels_cache_file, tcn_labels_all) 

        print(f"TCN Features Extracted. Shape: {tcn_features_all.shape}")
        print(f"CNN Features Extracted. Shape: {cnn_features_all.shape}")
        
        if not np.array_equal(tcn_labels_all, cnn_labels_all):
            raise ValueError(f"Label mismatch in Fold {fold_k}!")
        
        combined_features_all = np.concatenate([tcn_features_all, cnn_features_all], axis=1)
        labels_all = tcn_labels_all
        input_size = combined_features_all.shape[1]
        
        print("Splitting features into train/validation for this fold...")
        train_indices = df[df['fold'] != fold_k].index
        val_indices = df[df['fold'] == fold_k].index
        
        train_features = combined_features_all[train_indices]
        train_labels = labels_all[train_indices]
        val_features = combined_features_all[val_indices]
        val_labels = labels_all[val_indices]
        
        train_dataset = FeatureDataset(train_features, train_labels)
        valid_dataset = FeatureDataset(val_features, val_labels)
        
        train_loader = DataLoader(
            train_dataset, batch_size=CFG.HEAD_BATCH_SIZE, shuffle=True
        )
        valid_loader = DataLoader(
            valid_dataset, batch_size=CFG.HEAD_BATCH_SIZE, shuffle=False
        )
        
        print(f"Training classifier head for Fold {fold_k}...")
        best_fold_loss = run_head_training(
            train_loader, valid_loader, input_size, fold_k
        )
        all_fold_scores.append(best_fold_loss)

    print("\n" + "="*50)
    print("Full 5-Fold Cross-Validation Complete.")
    print(f"Scores per fold: {all_fold_scores}")
    
    mean_cv_score = np.mean(all_fold_scores)
    print(f"\nMean CV Score: {mean_cv_score:.4f}")
    print("="*50)


Processing Fold 0
Cached features not found. Generating for Fold 0...
Loading TCN Model for Fold 0...


Extracting Features:   0%|          | 0/268 [00:00<?, ?it/s]