# Model Training

In [68]:
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
from dataclasses import dataclass
from torch.utils.data import DataLoader, Dataset, random_split
from sklearn.metrics import roc_auc_score, accuracy_score, precision_recall_fscore_support
from typing import Dict

## Dataset

In [69]:
# Dataset
class CustomEEGDataset(Dataset):
    def __init__(self, annotations_file):
        df = pd.read_csv(annotations_file)

        self.X = df.iloc[:, 1:-1].to_numpy(dtype=np.float32)
        y = df.iloc[:, -1]
        y = y.astype(np.int64)
        self.y = y.to_numpy()

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.item()
        x = torch.from_numpy(self.X[idx]).unsqueeze(0)
        y = torch.as_tensor(self.y[idx])
        
        return x, y

## Model and Training Config

In [70]:
@dataclass
class TrainConfig:
    csv_path: str = "../data/data_preprocessed.csv"
    batch_size: int = 64
    epochs: int = 20
    lr: float = 1e-3
    weight_decay: float = 1e-4
    grad_clip: float = 1.0
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    num_workers: int = 2
    model_out: str = "model.pt"
    threshold_out: str = "threshold.json"

class TinyEEGCNN(nn.Module):
    """
    Input: [B, 1, L]
    Output: raw logit per sample (use BCEWithLogitsLoss)
    """
    def __init__(self, in_ch=1):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv1d(in_ch, 16, kernel_size=7, padding=3), nn.BatchNorm1d(16), nn.ReLU(), nn.MaxPool1d(2),
            nn.Conv1d(16, 32, kernel_size=5, padding=2),    nn.BatchNorm1d(32), nn.ReLU(), nn.MaxPool1d(2),
            nn.Conv1d(32, 64, kernel_size=3, padding=1),    nn.BatchNorm1d(64), nn.ReLU(),
            nn.AdaptiveAvgPool1d(1)  # -> [B, 64, 1]
        )
        self.classifier = nn.Linear(64, 1)  # -> [B, 1]

    def forward(self, x):
        x = self.features(x).squeeze(-1)  # [B, 64]
        logit = self.classifier(x)        # [B, 1]
        return logit

## Training/Evaluation Function

In [71]:
def train_one_epoch(model, loader, optimizer, loss_fn, device, grad_clip=1.0):
    model.train()
    total_loss = 0.0
    for x, y in loader:
        x = x.to(device)                  # [B, 1, L]
        y = y.to(device).float().unsqueeze(1)  # [B, 1]
        optimizer.zero_grad()
        logit = model(x)                  # [B, 1]
        loss = loss_fn(logit, y)
        loss.backward()
        if grad_clip is not None:
            nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
        optimizer.step()
        total_loss += loss.item() * x.size(0)
    return total_loss / len(loader.dataset)

@torch.no_grad()
def infer_probs(model, loader, device):
    model.eval()
    probs, ys = [], []
    for x, y in loader:
        x = x.to(device)
        logit = model(x)
        p = torch.sigmoid(logit).squeeze(1).cpu().numpy()
        probs.append(p)
        ys.append(y.numpy())
    return np.concatenate(ys), np.concatenate(probs)

def evaluate_probs(y_true: np.ndarray, probs: np.ndarray) -> Dict[str, float]:
    y_true = y_true.astype(int)
    # AUROC can fail if only one class present; guard it
    try:
        auc = roc_auc_score(y_true, probs)
    except Exception:
        auc = float("nan")
    # default threshold 0.5 for reporting
    preds = (probs >= 0.5).astype(int)
    acc = accuracy_score(y_true, preds)
    prec, rec, f1, _ = precision_recall_fscore_support(y_true, preds, average="binary", zero_division=0)
    return {"auc": auc, "acc": acc, "prec": prec, "rec": rec, "f1": f1}

## Training

In [72]:
cfg = TrainConfig()

eeg_dataset = CustomEEGDataset(cfg.csv_path)
train_dataset, test_dataset = random_split(eeg_dataset, [0.8, 0.2])
train_loader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=cfg.batch_size, shuffle=False)

neg = (train_dataset[:][1] == 0).sum()
pos = (train_dataset[:][1] == 1).sum()
pos_weight = torch.tensor([neg / max(pos, 1)], dtype=torch.float32, device=cfg.device)

In [73]:
model = TinyEEGCNN().to(cfg.device)
optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", factor=0.5, patience=2)

In [87]:
best_f1 = 0.0

for epoch in range(1, cfg.epochs + 1):
    train_loss = train_one_epoch(model, train_loader, optimizer, loss_fn, cfg.device, cfg.grad_clip)
    y_test, p_test = infer_probs(model, test_loader, cfg.device)
    metrics = evaluate_probs(y_test, p_test)

    # Save the best model based on F1 score
    if metrics["f1"] > best_f1:
        best_f1 = metrics["f1"]
        # Save the model state dict
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': train_loss,
            'metrics': metrics
        }, 'best_model.pth')
        print(f"New best model saved with F1: {best_f1:.4f}")

    scheduler.step(metrics["f1"])

    print(f"[{epoch:02d}] loss={train_loss:.4f} test_f1={metrics['f1']:.3f} test_auc={metrics['auc']:.3f} "
          f"test_acc={metrics['acc']:.3f}")

New best model saved with F1: 0.9809
[01] loss=0.0591 test_f1=0.981 test_auc=0.999 test_acc=0.992
[02] loss=0.0619 test_f1=0.980 test_auc=0.999 test_acc=0.992
New best model saved with F1: 0.9820
[03] loss=0.0610 test_f1=0.982 test_auc=1.000 test_acc=0.993
[04] loss=0.0618 test_f1=0.981 test_auc=0.999 test_acc=0.992
[05] loss=0.0614 test_f1=0.980 test_auc=0.999 test_acc=0.992
[06] loss=0.0603 test_f1=0.981 test_auc=1.000 test_acc=0.992
New best model saved with F1: 0.9841
[07] loss=0.0603 test_f1=0.984 test_auc=1.000 test_acc=0.993
[08] loss=0.0574 test_f1=0.983 test_auc=1.000 test_acc=0.993
[09] loss=0.0546 test_f1=0.982 test_auc=1.000 test_acc=0.993
[10] loss=0.0568 test_f1=0.981 test_auc=0.999 test_acc=0.992
[11] loss=0.0580 test_f1=0.984 test_auc=1.000 test_acc=0.993
[12] loss=0.0533 test_f1=0.981 test_auc=1.000 test_acc=0.992
[13] loss=0.0527 test_f1=0.983 test_auc=1.000 test_acc=0.993
[14] loss=0.0560 test_f1=0.980 test_auc=0.999 test_acc=0.992
[15] loss=0.0518 test_f1=0.981 test

# Model Training v2: Grouping By Patient

In [86]:
cfg = TrainConfig()
df = pd.read_csv(cfg.csv_path)
df['Patient'] = df['Unnamed'].str.extract(r"(.*)\.V[0-9]+\.*")
df['Patient'].value_counts()

Patient
X21    500
X15    500
X8     500
X16    500
X20    500
X14    500
X3     500
X11    500
X19    500
X7     500
X1     500
X22    500
X9     500
X23    500
X18    500
X2     500
X12    500
X5     500
X10    500
X13    500
X4     500
X17    500
X6     500
Name: count, dtype: int64

In [82]:
df['Unnamed'].head()

0    X21.V1.791
1    X15.V1.924
2       X8.V1.1
3     X16.V1.60
4     X20.V1.54
Name: Unnamed, dtype: object