# Model Training

In [68]:
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
from dataclasses import dataclass
from torch.utils.data import DataLoader, Dataset, random_split
from sklearn.metrics import roc_auc_score, accuracy_score, precision_recall_fscore_support
from typing import Dict

In [69]:
# Dataset
class CustomEEGDataset(Dataset):
    def __init__(self, annotations_file):
        df = pd.read_csv(annotations_file)

        self.X = df.iloc[:, 1:-1].to_numpy(dtype=np.float32)
        y = df.iloc[:, -1]
        y = y.astype(np.int64)
        self.y = y.to_numpy()

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.item()
        x = torch.from_numpy(self.X[idx]).unsqueeze(0)
        y = torch.as_tensor(self.y[idx])
        
        return x, y

In [70]:
@dataclass
class TrainConfig:
    csv_path: str = "../data/data_preprocessed.csv"
    batch_size: int = 64
    epochs: int = 20
    lr: float = 1e-3
    weight_decay: float = 1e-4
    grad_clip: float = 1.0
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    num_workers: int = 2
    model_out: str = "model.pt"
    threshold_out: str = "threshold.json"

class TinyEEGCNN(nn.Module):
    """
    Input: [B, 1, L]
    Output: raw logit per sample (use BCEWithLogitsLoss)
    """
    def __init__(self, in_ch=1):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv1d(in_ch, 16, kernel_size=7, padding=3), nn.BatchNorm1d(16), nn.ReLU(), nn.MaxPool1d(2),
            nn.Conv1d(16, 32, kernel_size=5, padding=2),    nn.BatchNorm1d(32), nn.ReLU(), nn.MaxPool1d(2),
            nn.Conv1d(32, 64, kernel_size=3, padding=1),    nn.BatchNorm1d(64), nn.ReLU(),
            nn.AdaptiveAvgPool1d(1)  # -> [B, 64, 1]
        )
        self.classifier = nn.Linear(64, 1)  # -> [B, 1]

    def forward(self, x):
        x = self.features(x).squeeze(-1)  # [B, 64]
        logit = self.classifier(x)        # [B, 1]
        return logit

In [71]:
def train_one_epoch(model, loader, optimizer, loss_fn, device, grad_clip=1.0):
    model.train()
    total_loss = 0.0
    for x, y in loader:
        x = x.to(device)                  # [B, 1, L]
        y = y.to(device).float().unsqueeze(1)  # [B, 1]
        optimizer.zero_grad()
        logit = model(x)                  # [B, 1]
        loss = loss_fn(logit, y)
        loss.backward()
        if grad_clip is not None:
            nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
        optimizer.step()
        total_loss += loss.item() * x.size(0)
    return total_loss / len(loader.dataset)

@torch.no_grad()
def infer_probs(model, loader, device):
    model.eval()
    probs, ys = [], []
    for x, y in loader:
        x = x.to(device)
        logit = model(x)
        p = torch.sigmoid(logit).squeeze(1).cpu().numpy()
        probs.append(p)
        ys.append(y.numpy())
    return np.concatenate(ys), np.concatenate(probs)

def evaluate_probs(y_true: np.ndarray, probs: np.ndarray) -> Dict[str, float]:
    y_true = y_true.astype(int)
    # AUROC can fail if only one class present; guard it
    try:
        auc = roc_auc_score(y_true, probs)
    except Exception:
        auc = float("nan")
    # default threshold 0.5 for reporting
    preds = (probs >= 0.5).astype(int)
    acc = accuracy_score(y_true, preds)
    prec, rec, f1, _ = precision_recall_fscore_support(y_true, preds, average="binary", zero_division=0)
    return {"auc": auc, "acc": acc, "prec": prec, "rec": rec, "f1": f1}

In [72]:
cfg = TrainConfig()

eeg_dataset = CustomEEGDataset(cfg.csv_path)
train_dataset, test_dataset = random_split(eeg_dataset, [0.8, 0.2])
train_loader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=cfg.batch_size, shuffle=False)

neg = (train_dataset[:][1] == 0).sum()
pos = (train_dataset[:][1] == 1).sum()
pos_weight = torch.tensor([neg / max(pos, 1)], dtype=torch.float32, device=cfg.device)

In [73]:
model = TinyEEGCNN().to(cfg.device)
optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", factor=0.5, patience=2)

In [75]:
for epoch in range(1, cfg.epochs + 1):
    train_loss = train_one_epoch(model, train_loader, optimizer, loss_fn, cfg.device, cfg.grad_clip)
    y_test, p_test = infer_probs(model, test_loader, cfg.device)
    metrics = evaluate_probs(y_test, p_test)

    scheduler.step(metrics["f1"])  # step on F1 (or metrics["auc"])

    print(f"[{epoch:02d}] loss={train_loss:.4f} test_f1={metrics['f1']:.3f} test_auc={metrics['auc']:.3f} "
            f"test_acc={metrics['acc']:.3f}")

[01] loss=0.1731 test_f1=0.951 test_auc=0.998 test_acc=0.979
[02] loss=0.1402 test_f1=0.959 test_auc=0.998 test_acc=0.983
[03] loss=0.1114 test_f1=0.963 test_auc=0.998 test_acc=0.985
[04] loss=0.1150 test_f1=0.959 test_auc=0.999 test_acc=0.983
[05] loss=0.1009 test_f1=0.954 test_auc=0.998 test_acc=0.980
[06] loss=0.1006 test_f1=0.962 test_auc=0.998 test_acc=0.984
[07] loss=0.0829 test_f1=0.969 test_auc=0.999 test_acc=0.987
[08] loss=0.0854 test_f1=0.968 test_auc=0.999 test_acc=0.987
[09] loss=0.0781 test_f1=0.974 test_auc=0.999 test_acc=0.989
[10] loss=0.0860 test_f1=0.977 test_auc=0.999 test_acc=0.990
[11] loss=0.0827 test_f1=0.972 test_auc=0.999 test_acc=0.988
[12] loss=0.0817 test_f1=0.975 test_auc=0.999 test_acc=0.990
[13] loss=0.0733 test_f1=0.963 test_auc=0.999 test_acc=0.984
[14] loss=0.0707 test_f1=0.981 test_auc=0.999 test_acc=0.992
[15] loss=0.0689 test_f1=0.977 test_auc=0.999 test_acc=0.990
[16] loss=0.0654 test_f1=0.978 test_auc=0.999 test_acc=0.991
[17] loss=0.0664 test_f1