In [1]:
DATA_PREPARATION_VOTE_METHOD = "max_vote_window" # "max_vote_window" or "sum_and_normalize". Decides how to aggregate the predictions of the overlapping windows
PRETRAINED_MODEL_NAME_OR_PATH = "tf_efficientnet_b0_ns"

In [2]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm 
import wandb
from torch.optim.lr_scheduler import CosineAnnealingLR
import sys

if bool(os.environ.get("KAGGLE_URL_BASE", "")):
  import sys
  # running on kaggle
  sys.path.insert(0, "/kaggle/input/hsm-source-files")
else:
  # running locally
  sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "..", "..", "..")))

from src.datasets.spectrogram_dataset import SpectrogramDataset
from src.utils.k_folds_creator import KFoldCreator
from src.utils.utils import get_models_save_path, set_seeds, get_raw_data_dir, get_processed_data_dir
from src.models.base_cnn import BaseCNN
from src.utils.constants import Constants 
from src.datasets.eeg_processor import EEGDataProcessor

2025-11-01 15:51:24,907 :: root :: INFO :: Initialising Utils
2025-11-01 15:51:26,033 :: root :: INFO :: Initialising Datasets
2025-11-01 15:51:26,060 :: root :: INFO :: Initialising Models


Skipping module tcn due to missing dependency: No module named 'pytorch_tcn'


In [3]:
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mdavidhodel[0m ([33mhms-hslu-aicomp-hs25[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [4]:
class CFG:
    seed = 42
    n_splits = 5
    data_path = get_raw_data_dir()
    
    model_name = PRETRAINED_MODEL_NAME_OR_PATH
    in_channels = 4  
    target_size = 6 
    
    batch_size = 32
    num_workers = 8
    epochs = 5
    lr = 1e-3
    
    # This is the base size of each channel, not the final reshaped size
    img_size = (128, 256)

set_seeds(42)

In [5]:
TARGETS = Constants.TARGETS

processor = EEGDataProcessor(raw_data_path=CFG.data_path, processed_data_path=get_processed_data_dir())

Processor initialized.
Raw data path: '/home/david/git/aicomp/data'
Processed data path: '/home/david/git/aicomp/data/processed'


In [6]:
def get_dataloaders(df, fold_id):
    train_df = df[df['fold'] != fold_id].reset_index(drop=True)
    valid_df = df[df['fold'] == fold_id].reset_index(drop=True)

    train_dataset = SpectrogramDataset(train_df, TARGETS, CFG.data_path, CFG.img_size, mode='train')
    valid_dataset = SpectrogramDataset(valid_df, TARGETS, CFG.data_path, CFG.img_size, mode='train')

    train_loader = DataLoader(
        train_dataset,
        batch_size=CFG.batch_size,
        shuffle=True,
        num_workers=CFG.num_workers,
        pin_memory=True,
        drop_last=True
    )
    valid_loader = DataLoader(
        valid_dataset,
        batch_size=CFG.batch_size,
        shuffle=False,
        num_workers=CFG.num_workers,
        pin_memory=True,
        drop_last=False
    )
    
    return train_loader, valid_loader

In [7]:
def run_training(df, data_preparation_vote_method):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    all_oof_preds = []
    all_oof_labels = []
    for fold in range(CFG.n_splits):
        print(f"\n========== FOLD {fold} ==========")

        config = {
            # Model
            "architecture": CFG.model_name,
            "pretrained": True,
            # Data
            "fold": fold,
            "features": "spectrograms",
            "window_selection": "sum_and_normalize",
            # Training
            "optimizer": "AdamW",
            "learning_rate": CFG.lr,
            "batch_size": CFG.batch_size,
            "epochs": CFG.epochs,
            "seed": CFG.seed,
            "scheduler": "CosineAnnealingLR" 
        }

        wandb.init(
            project="hms-aicomp-cnn",
            name=f"{CFG.model_name}-spec-fold{fold}", 
            tags=[f'fold{fold}'],
            config=config
        )

        model = BaseCNN(CFG.model_name, pretrained=True, num_classes=CFG.target_size)
        model.to(device)

        optimizer = torch.optim.AdamW(model.parameters(), lr=CFG.lr)
        scheduler = CosineAnnealingLR(optimizer, T_max=CFG.epochs)
        loss_fn = nn.KLDivLoss(reduction='batchmean')
        train_loader, valid_loader = get_dataloaders(df, fold)

        best_val_loss = float('inf')
        best_model_path = None

        for epoch in range(CFG.epochs):
            print(f"  --- Epoch {epoch+1}/{CFG.epochs} ---")

            model.train()
            train_loss = 0
            for images, labels in tqdm(train_loader, desc="Training"):
                images, labels = images.to(device), labels.to(device)

                optimizer.zero_grad()
                outputs = model(images)
                log_probs = F.log_softmax(outputs, dim=1)
                loss = loss_fn(log_probs, labels)
                loss.backward()
                optimizer.step()

                train_loss += loss.item() * images.size(0)

                wandb.log({"train/loss": loss.item()})

            train_loss /= len(train_loader.dataset)

            model.eval()
            valid_loss = 0
            with torch.no_grad():
                for images, labels in tqdm(valid_loader, desc="Validation"):
                    images, labels = images.to(device), labels.to(device)
                    outputs = model(images)
                    log_probs = F.log_softmax(outputs, dim=1)
                    loss = loss_fn(log_probs, labels)
                    valid_loss += loss.item() * images.size(0)

            valid_loss /= len(valid_loader.dataset)
            
            epoch_lr = optimizer.param_groups[0]['lr']
            print(f"   Epoch {epoch+1}: Train Loss = {train_loss:.4f}, Valid Loss = {valid_loss:.4f}, LR = {epoch_lr:.6f}")

            wandb.log({
                "epoch": epoch + 1,
                "train/epoch_loss": train_loss,
                "val/loss": valid_loss,
                "val/kl_div": valid_loss,
                "train/epoch_lr": epoch_lr
            })

            if valid_loss < best_val_loss:
                best_val_loss = valid_loss
                best_model_path = get_models_save_path() / "base_cnn" / CFG.model_name / data_preparation_vote_method / f'best_model_fold{fold}.pth'
                best_model_path.parent.mkdir(parents=True, exist_ok=True)
                torch.save(model.state_dict(), best_model_path)
                print(f"  New best model saved with validation loss: {best_val_loss:.4f}")
            scheduler.step()

        print(f"   --- Generating OOF predictions for fold {fold} ---")
        if best_model_path:
            model.load_state_dict(torch.load(best_model_path))
            model.eval()

            fold_oof_preds = []
            fold_oof_labels = []

            with torch.no_grad():
                for images, labels in tqdm(valid_loader, desc=f"OOF Prediction Fold {fold}"):
                    images = images.to(device)
                    outputs = model(images)
                    probs = F.softmax(outputs, dim=1).cpu()

                    fold_oof_preds.append(probs)
                    fold_oof_labels.append(labels.cpu())

            all_oof_preds.append(torch.cat(fold_oof_preds).numpy())
            all_oof_labels.append(torch.cat(fold_oof_labels).numpy())
            print(f"   Finished OOF predictions for fold {fold}")
        else:
            raise RuntimeError("Best model path is None, cannot generate OOF predictions.")


        wandb.summary['best_val_kl_div'] = best_val_loss

        if best_model_path:
            artifact = wandb.Artifact(f'{CFG.model_name}-fold{fold}', type='model')
            artifact.add_file(best_model_path)
            wandb.log_artifact(artifact)
            print(f"\nLogged artifact for fold {fold} with best validation loss: {best_val_loss:.4f}")
        else:
            print("\nNo best model was saved during training for this fold.")

        wandb.finish()

    if all_oof_preds and all_oof_labels:
        print("\nCalculating final OOF score...")
        final_oof_preds = np.concatenate(all_oof_preds)
        final_oof_labels = np.concatenate(all_oof_labels)

        oof_preds_tensor = torch.tensor(final_oof_preds, dtype=torch.float32)
        oof_labels_tensor = torch.tensor(final_oof_labels, dtype=torch.float32)

        log_oof_preds_tensor = torch.log(oof_preds_tensor)

        kl_loss_fn = nn.KLDivLoss(reduction='batchmean')
        overall_oof_score = kl_loss_fn(log_oof_preds_tensor, oof_labels_tensor).item()

        print(f"\nOverall OOF KL Score: {overall_oof_score:.4f}")
    else:
        print("\nCould not calculate OOF score because no predictions were generated.")
        
    return overall_oof_score

## Run Training

In [8]:
print("Preparing data and creating folds...")
train_df = processor.process_data(vote_method=DATA_PREPARATION_VOTE_METHOD, skip_parquet=True)
print('Train shape:', train_df.shape)
print('Targets', list(TARGETS))

fold_creator = KFoldCreator(n_splits=CFG.n_splits, seed=CFG.seed)
train_df = fold_creator.create_folds(train_df, stratify_col='expert_consensus', group_col='patient_id')

print("Folds created. Value counts per fold:")
print(train_df['fold'].value_counts())

Preparing data and creating folds...
Starting EEG Data Processing Pipeline


Skipping Parquet file creation as requested.
Using 'max_vote_window' vote aggregation strategy.

Processed train data saved to '/home/david/git/aicomp/data/processed/train_processed.csv'.
Shape of the final dataframe: (17089, 12)

Pipeline finished successfully!
Train shape: (17089, 12)
Targets ['seizure_vote', 'lpd_vote', 'gpd_vote', 'lrda_vote', 'grda_vote', 'other_vote']
Folds created. Value counts per fold:
fold
0    4067
1    3658
2    3381
4    3358
3    2625
Name: count, dtype: int64


In [9]:
overall_oof_score = run_training(train_df, DATA_PREPARATION_VOTE_METHOD)

In [None]:
print(f"Overall OOF KL Score from training: {overall_oof_score:.4f}")