# EEG Person Identification Pipeline

This notebook implements a data processing and modeling pipeline for the **EEG Motor Movement/Imagery Dataset**. 

The pipeline performs the following steps:
1.  **Data Loading**: Reads raw EDF files using MNE.
2.  **Preprocessing**: Applies bandpass filtering and channel selection.
3.  **Feature Extraction**: Segments signals into epochs and converts them into spectrograms (STFT).
4.  **Dataset Creation**: Splits data into Training (Session A) and Testing (Session B) sets and saves them as `.npy` arrays.
5.  **Modeling**: Defines and trains a PyTorch model for person identification.

In [None]:
from __future__ import annotations
import json
import os
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Dict, Iterable, List, Tuple

# Core Data & Plotting
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm

# Signal Processing & Domain Specific
import mne
from scipy.signal import stft
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split

# Deep Learning (PyTorch)
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

print(f"MNE version: {mne.__version__}")
print(f"NumPy version: {np.__version__}")
print(f"PyTorch version: {torch.__version__}")

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

## 1. Global Configuration

The `Config` class centralizes all pipeline parameters. This includes:
* **File Paths**: Locations for input data and processed output.
* **Signal Parameters**: Sample rate (160 Hz), bandpass frequencies (1-40 Hz).
* **STFT Settings**: Window size and overlap for spectrogram generation.
* **Execution Limits**: Options to limit the number of subjects (`subjects_limit`) or runs for faster testing.

In [None]:
#changes values to control how much data is loaded, exact STFT windows, and where processed files are written.

Project_Root = Path.cwd()
Data_Root = Path("/kaggle/input/eeg-motor-movementimagery-dataset/files") #data store location, change as needed
Processed_Root = Path("/kaggle/working/processed") #processed data output location, change as desired

@dataclass
class Config:
    data_root: Path = Data_Root
    processed_root: Path = Processed_Root
    subjects_limit: int | None = 15 # set to None to use all 109 subjects
    runs_per_subject: int | None = None  # each subject has 14 runs (7 per session)
    window_seconds: float = 2.0
    step_seconds: float = 0.5
    raw_sample_rate: int = 160  # Hz per PhysioNet documentation
    bandpass_low: float = 1.0
    bandpass_high: float = 40.0
    stft_nperseg: int = 128
    stft_noverlap: int = 96
    max_epochs_per_run: int | None = None
    max_abs_uV: float = 300.0  # reject epochs with larger amplitudes
    target_channels: List[str] | None = None  # None -> keep whatever is common across runs
    random_state: int = 67

    @property
    def output_arrays_dir(self) -> Path:
        return self.processed_root / "metadata"

    @property
    def metadata_dir(self) -> Path:
        return self.processed_root / "metadata"

config = Config()

env_override = {
    "subject_limit": os.environ.get("EEG_SUBJECTS_LIMIT"),
    "runs_per_subject": os.environ.get("EEG_RUNS_PER_SUBJECT"),
    "max_epoch_per_run": os.environ.get("EEG_MAX_EPOCHS_PER_RUN"),
}
for key, value in env_override.items():
    if value is not None:
        casted = None if value.lower() == "none" else int(value)
        setattr(config, key, casted)

config.processed_root.mkdir(parents=True, exist_ok=True)
config.output_arrays_dir.mkdir(parents=True, exist_ok=True)
config.metadata_dir.mkdir(parents=True, exist_ok=True)

print(json.dumps({k: str(v) if isinstance(v, Path) else v for k, v in asdict(config).items()}, indent= 2))

## 2. Data Discovery

These helper functions verify the input directory structure and list available subjects and runs to ensure the dataset is accessible before processing.

In [None]:
def list_subjects(data_root: Path) -> List[Path]:
    return sorted([p for p in data_root.glob("S*") if p.is_dir()])

def list_runs_for_subject(subject_dir: Path) -> List[Path]:
    return sorted(subject_dir.glob("*.edf"))

def preview_df(data_root: Path, max_subjects: int = 5) -> pd.DataFrame:
    rows = []
    for subject_dir in list_subjects(data_root)[:max_subjects]:
        runs = list_runs_for_subject(subject_dir)
        rows.append(
            {
                "subject": subject_dir.name,
                "num_runs": len(runs),
                "first_run": runs[0].name if runs else None,
                "last_run": runs[-1].name if runs else None,
            }
        )
    return pd.DataFrame(rows)

preview_df(config.data_root)

## 3. Signal Processing Utilities

We define core functions to handle the raw EEG data:
* `load_raw_run`: Loads `.edf` files.
* `bandpass_filter`: Applies FIR bandpass filtering (1-40 Hz).
* `make_epochs`: Segments continuous data into fixed-length windows.
* `epoch_to_spectrogram`: Converts time-domain epochs into frequency-domain spectrograms using Short-Time Fourier Transform (STFT).

In [None]:
#smol functions for filtering,epoching, and spectrogram gen

def load_raw_run(edf_path: Path) -> mne.io.BaseRaw:
    raw = mne.io.read_raw_edf(edf_path, preload = True, verbose = "ERROR")
    raw.set_montage("standard_1020", on_missing = "ignore")
    return raw

def prep_channels(raw: mne.io.BaseRaw, target_channels: List[str] | None) -> mne.io.BaseRaw:
    if target_channels is None:
        return raw
    present_channels = [ch for ch in target_channels if ch in raw.ch_names]
    if not present_channels:
        raise ValueError("None of the requested channels were found in this rec.")
    return raw.pick(present_channels, verbose = "ERROR")

def bandpass_filter(raw: mne.io.BaseRaw, low: float, high: float) -> mne.io.BaseRaw:
    return raw.filter(l_freq = low, h_freq = high, fir_design = "firwin", verbose = "ERROR")

def make_epochs(raw: mne.io.BaseRaw, config: Config) -> mne.Epochs:
    events = mne.make_fixed_length_events(
        raw,
        start = 0,
        stop = None,
        duration = config.step_seconds,
        overlap = 0.0
    )
    reject = dict(eeg=config.max_abs_uV * 1e-6)
    epochs = mne.Epochs(
        raw,
        events,
        event_id = {"segment": 1},
        tmin = 0.0,
        tmax = config.window_seconds,
        baseline = None,
        preload = True,
        reject = reject,
        verbose = "ERROR",
    )
    if config.max_epochs_per_run:
        epochs = epochs[: config.max_epochs_per_run]
    return epochs

def epoch_to_spectrogram(
    epoch_data: np.ndarray,
    sfreq: int,
    nperseg: int,
    noverlap: int,
) -> np.ndarray:
    """Convert a (channels, samples) array to (freq, time, channels) spectrogram."""

    spectrograms = []
    for channel_trace in epoch_data:
        freqs, times, Zxx = stft(
            channel_trace,
            fs = sfreq,
            nperseg = nperseg,
            noverlap = noverlap,
            padded = False,
            boundary = None,
        )
        spectrograms.append(np.abs(Zxx))
    spec = np.stack(spectrograms, axis = -1) # (Freq, time, ch)
    spec = np.log1p(spec)
    spec = (spec - spec.mean()) / (spec.std() + 1e-8)
    return spec.astype(np.float32)

def map_run_to_sesh(run_path: Path) -> str:
    run_number = int(run_path.stem[-2:]) #R01 thru R014
    return "Session_A" if run_number <= 7 else "Session_B"

## 4. Pipeline Logic (`build_df`)

The `build_df` function drives the processing pipeline:
1.  Iterates through subjects and runs.
2.  Applies preprocessing and feature extraction.
3.  Splits data based on session (Session A → Train, Session B → Test).
4.  Aggregates processed spectrograms into `.npy` arrays and saves metadata.

In [None]:
def build_df(config: Config) -> Dict[str, Path]:
    subjects = list_subjects(config.data_root)
    if config.subjects_limit:
        subjects = subjects[: config.subjects_limit]

    label_encoder = LabelEncoder()
    label_encoder.fit([subj.name for subj in subjects])

    train_specs, train_labels = [], []
    test_specs, test_labels = [], []
    metadata_rows = []

    for subject_dir in tqdm(subjects, desc="Subjects"):
        runs = list_runs_for_subject(subject_dir)
        if config.runs_per_subject:
            runs = runs[: config.runs_per_subject]

        for run_path in runs:
            session = map_run_to_sesh(run_path)
            raw = load_raw_run(run_path)
            raw = prep_channels(raw, config.target_channels)
            raw = bandpass_filter(raw, config.bandpass_low, config.bandpass_high)
            if int(raw.info["sfreq"]) != config.raw_sample_rate:
                raw.resample(config.raw_sample_rate)

            epochs = make_epochs(raw, config)
                        # Check for empty epochs and skip the run if necessary
            if not len(epochs):
                print(f"!!! WARNING: No epochs found for {run_path.name}. Skipping run.")
                continue  # Skip the rest of the inner loop and move to the next run.

            epoch_arr = epochs.get_data() #(n epochs, ch, samples)

            for epoch_idx, epoch_data in enumerate(epoch_arr):
                spec = epoch_to_spectrogram(
                    epoch_data,
                    sfreq = config.raw_sample_rate,
                    nperseg = config.stft_nperseg,
                    noverlap = config.stft_noverlap,
                )
                label = label_encoder.transform([subject_dir.name])[0]
                record = {
                    "subject": subject_dir.name,
                    "run": run_path.name,
                    "session": session,
                    "epoch_index": epoch_idx,
                    "label": int(label),
                }
                metadata_rows.append(record)

                if session == "Session_A":
                    train_specs.append(spec)
                    train_labels.append(label)
                else:
                    test_specs.append(spec)
                    test_labels.append(label)

    X_train = np.stack(train_specs)
    y_train = np.array(train_labels)
    X_test = np.stack(test_specs)
    y_test = np.array(test_labels)

    np.save(config.output_arrays_dir / "X_train.npy", X_train)
    np.save(config.output_arrays_dir / "y_train.npy", y_train)
    np.save(config.output_arrays_dir / "X_test.npy", X_test)
    np.save(config.output_arrays_dir / "y_test.npy", y_test)

    metadata_df = pd.DataFrame(metadata_rows)
    metadata_path = config.metadata_dir / "epochs_metadata.csv"
    metadata_df.to_csv(metadata_path, index = False)

    label_mapping = dict(enumerate(label_encoder.classes_))
    with open(config.metadata_dir / "label_mapping.json", "w") as fp:
        json.dump(
            {k: str(v) if isinstance(v, Path) else v for k, v in asdict(config).items()}, 
            fp,
            indent = 2,
        )

    return {
        "X_train": config.output_arrays_dir / "X_train.npy",
        "y_train": config.output_arrays_dir / "y_train.npy",
        "X_test": config.output_arrays_dir / "X_test.npy",
        "y_test": config.output_arrays_dir / "y_test.npy",
        "metadata": metadata_path,
        "label_mapping": config.metadata_dir / "label_mapping.json"
    }

## 5. Execution

Set `RUN_PIPELINE = True` to execute the heavy processing. This may take some time depending on the number of subjects configured.

In [None]:
RUN_PIPELINE = True

if RUN_PIPELINE:
    print("Starting data pipeline build..")
    output_paths = build_df(config)
    print("Done !!!")
    output_paths
else:
    print("Set RUN_PIPELINE = True !!")

## 6. Output Verification

After processing, we verify that the output files (`X_train.npy`, etc.) exist and inspect their shapes to ensure data consistency.

In [None]:
#sanity check :3
def descArrays(config: Config) -> pd.DataFrame:
    paths ={
        "X_train": config.output_arrays_dir / "X_train.npy",
        "y_train": config.output_arrays_dir / "y_train.npy",
        "X_test": config.output_arrays_dir / "X_test.npy",
        "y_test": config.output_arrays_dir / "y_test.npy",
    }
    rows = []
    for name, path in paths.items():
        if not path.exists():
            rows.append({"name": name, "status": "missing"})
            continue

        arr = np.load(path, mmap_mode="r")
        rows.append(
            {
                "name": name,
                "status": "ok",
                "dtype": str(arr.dtype),
                "shape": arr.shape
            }
        )
    return pd.DataFrame(rows)

descArrays(config)

In [None]:
Project_Root = Path.cwd()
Processed_Root = Path("/kaggle/working/processed")
Models_Root = Project_Root / "models"
Reports_Root = Project_Root / "reports"

@dataclass
class TrainConfig:
    processed_root: Path = Processed_Root
    models_root: Path = Models_Root
    reports_root: Path = Reports_Root
    batch_size: int = 32
    epochs: int = 30
    learning_rate: float = 1e-3
    validation_split: float = 0.1
    dropout_rate: float = 0.3
    lstm_units: int = 128
    conv_filters: tuple[int, ...] = (32, 64, 128)
    patience: int = 10

    @property
    def numpy_dir(self) -> Path:
        return self.processed_root / "metadata"

    @property
    def metadata_dir(self) -> Path:
        return self.processed_root / "metadata"

Tconfig = TrainConfig()

train_overrides = {
    "batch_size": os.environ.get("EEG_BATCH_SIZE"),
    "epochs": os.environ.get("EEG_EPOCHS"),
    "learning_rate": os.environ.get("EEG_LR"),
}

for key, value in train_overrides.items():
    if value is not None:
        casted = float(value) if key == "learning_rate" else int(value)
        setattr(Tconfig, key, casted)

Tconfig.models_root.mkdir(parents = True, exist_ok = True)
Tconfig.reports_root.mkdir(parents = True, exist_ok = True)

print(json.dumps({k: str(v) if isinstance(v, Path) else v for k, v in asdict(Tconfig).items()}, indent = 2))

## 7. Data Loading & Dataset Definition

We load the preprocessed `.npy` files and wrap them in a custom PyTorch `Dataset`.
* **`SpectrogramDataset`**: Handles the conversion of NumPy arrays to PyTorch tensors. It also permutes the dimensions to `(Batch, Channels, Height, Width)` as required by 2D CNN layers.

In [None]:
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import Dataset, DataLoader
import torch
import numpy as np

class EEGDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.long)
        # Permute (N, Freq, Time, Ch) -> (N, Ch, Freq, Time) to match PyTorch Conv2d expectations
        self.X = self.X.permute(0, 3, 1, 2)

    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

def load_data_safe(config: TrainConfig):
    # 1. Load Arrays using the config paths
    try:
        X_train = np.load(config.numpy_dir / "X_train.npy")
        y_train = np.load(config.numpy_dir / "y_train.npy")
        X_test = np.load(config.numpy_dir / "X_test.npy")
        y_test = np.load(config.numpy_dir / "y_test.npy")
    except FileNotFoundError:
        print("Data files not found. Please ensure RUN_PIPELINE=True and has completed successfully.")
        return None, None, 0, (0,0,0)
    

    le = LabelEncoder()
    # Fit on ALL labels (train + test) to ensure consistency
    all_labels = np.concatenate([y_train, y_test])
    le.fit(all_labels)
    
    y_train = le.transform(y_train)
    y_test = le.transform(y_test)
    
    num_classes = len(le.classes_)
    print(f"Labels Re-indexed. Mapped {len(le.classes_)} unique subjects to range 0-{num_classes-1}.")
    
    train_ds = EEGDataset(X_train, y_train)
    test_ds = EEGDataset(X_test, y_test)
    
    # Get input dimensions from the first sample
    sample_x, _ = train_ds[0] # Shape: (Ch, Freq, Time)
    input_dims = sample_x.shape
    
    return train_ds, test_ds, num_classes, input_dims

# Load Data
train_dataset, test_dataset, num_classes, input_dims = load_data_safe(Tconfig)

if train_dataset is not None:
    # Extract dimensions for the model
    num_ch, freq_bins, time_bins = input_dims
    print(f"Model Input: {num_ch} Channels, {freq_bins} Freq, {time_bins} Time")
    print(f"Output Classes: {num_classes}")

    # Create DataLoaders
    # Split training set into Train (90%) and Validation (10%)
    val_size = int(Tconfig.validation_split * len(train_dataset))
    train_size = len(train_dataset) - val_size
    train_subset, val_subset = torch.utils.data.random_split(train_dataset, [train_size, val_size])

    # Define Loaders
    dtrain = DataLoader(train_subset, batch_size=Tconfig.batch_size, shuffle=True)
    dval = DataLoader(val_subset, batch_size=Tconfig.batch_size, shuffle=False)
    dtest = DataLoader(test_dataset, batch_size=Tconfig.batch_size, shuffle=False)
    
    print(f"DataLoaders Created: Train={len(dtrain)} batches, Val={len(dval)} batches, Test={len(dtest)} batches")

## 8. Hybrid Model Architecture (CNN-BiLSTM)

We implement a hybrid architecture designed for Spatio-Temporal feature extraction:
1.  **CNN Encoder**: A stack of Convolutional blocks extracts spatial and frequency features from the spectrograms. We use *asymmetric pooling* to reduce frequency dimensions while preserving the time axis.
2.  **BiLSTM**: A Bidirectional LSTM processes the sequence of CNN features to capture temporal dynamics.
3.  **Classifier**: A dense layer maps the final LSTM state to the person identity.

In [None]:
class CNNDropout(nn.Module):
    def __init__(self, filters: int, dropout_rate: float):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels=filters[0], out_channels=filters[1], kernel_size=(3, 3), padding="same"),
            nn.BatchNorm2d(filters[1]),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)),
            nn.Dropout(dropout_rate)
        )
    
    def forward(self, x):
        return self.conv(x)

class CNNBiLSTM(nn.Module):
    def __init__(self, freq_bins: int, time_bins: int, channels: int, num_classes: int, Tconfig: TrainConfig):
        super().__init__()
        self.freq_bins = freq_bins
        self.time_bins = time_bins
        
        # 1. Convolutional Tower 
        conv_layers = []
        in_channels = channels
        
        for filters in Tconfig.conv_filters: 
            conv_layers.append(nn.Conv2d(in_channels, filters, kernel_size=(3, 3), padding=1))
            conv_layers.append(nn.BatchNorm2d(filters))
            conv_layers.append(nn.ReLU())
            
            # FIX: Asymmetric Pooling
            # Pool Frequency (Height) by 2, but keep Time (Width) as is (1)
            # Input time dim is only 7, so we can't afford to shrink it.
            conv_layers.append(nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1)))
            
            in_channels = filters
        
        self.cnn_stack = nn.Sequential(*conv_layers)
        self.cnn_dropout = nn.Dropout(Tconfig.dropout_rate)
        
        # Calculate sizes
        num_pools = len(Tconfig.conv_filters)
        
        # Frequency is pooled 3 times (div by 2^3)
        self.post_cnn_freq = freq_bins // (2**num_pools)
        
        # FIX: Time is NOT pooled anymore
        self.post_cnn_time = time_bins 
        
        rnn_input_size = in_channels * self.post_cnn_freq
        
        # 2. RNN: BiLSTM
        self.bilstm1 = nn.LSTM(
            input_size=rnn_input_size, 
            hidden_size=Tconfig.lstm_units,
            num_layers=1, 
            bidirectional=True, 
            batch_first=True, 
            dropout=Tconfig.dropout_rate if Tconfig.dropout_rate > 0 else 0
        )
        
        self.bilstm2 = nn.LSTM(
            input_size=Tconfig.lstm_units * 2,
            hidden_size=Tconfig.lstm_units // 2,
            num_layers=1, 
            bidirectional=True, 
            batch_first=True,
        )

        dense_input_size = (Tconfig.lstm_units // 2) * 2
        self.dense_stack = nn.Sequential(
            nn.Dropout(Tconfig.dropout_rate),
            nn.Linear(dense_input_size, 256),
            nn.ReLU(),
            nn.Dropout(Tconfig.dropout_rate),
            nn.Linear(256, num_classes)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # 1. CNN
        x = self.cnn_stack(x)
        x = self.cnn_dropout(x)
        
        # 2. Permute and Flatten 
        # Output is (Batch, Channels, Freq, Time)
        # We want (Batch, Time, Freq, Channels) -> (Batch, Time, Features)
        x = x.permute(0, 3, 2, 1) 
        B, T, F, C = x.shape
        
        x = x.reshape(B, T, F * C) 
        
        # 3. RNN
        x, _ = self.bilstm1(x) 
        x, (h_n, c_n) = self.bilstm2(x) 
        
        x = torch.cat((h_n[-2, :, :], h_n[-1, :, :]), dim=1)
        x = self.dense_stack(x)
        return x

model = CNNBiLSTM(freq_bins, time_bins, num_ch, num_classes, Tconfig).to(DEVICE)
print(f"Model created successfully. Output classes: {num_classes}")

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total trainable parameters: {count_parameters(model):,}")

## 9. Training & Validation Loop

We set up the training infrastructure:
* **Optimizer**: Adam with learning rate scheduling (`ReduceLROnPlateau`).
* **Loss Function**: CrossEntropyLoss for multi-class classification.
* **Checkpointing**: The model is saved (`best_model.pth`) only when validation accuracy improves.
* **Early Stopping**: Training halts if no improvement is seen for a set number of epochs.

In [None]:
#Setup Optimizer, Loss, and Scheduler
optimizer = optim.Adam(model.parameters(), lr=Tconfig.learning_rate, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss() 
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 
    mode='max',
    factor=0.5, 
    patience=2, 
    min_lr=1e-5, 
    verbose=True
)

#Checkpoint and History Setup
checkpoint_path = Tconfig.models_root / "best_model.pth"
final_model_path = Tconfig.models_root / "final_model.pth"
log_path = Tconfig.reports_root / "training_history.json"

history = {'loss': [], 'val_loss': [], 'accuracy': [], 'val_accuracy': [], 'top5_accuracy': [], 'lr': []}
best_val_accuracy = 0.0
patience_counter = 0

#Training and Evaluation Functions
def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for inputs, labels in dataloader:
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * inputs.size(0)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
    epoch_loss = running_loss / len(dataloader.dataset)
    epoch_acc = correct / total
    return epoch_loss, epoch_acc

def evaluate_epoch(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    top5_correct = 0
    
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item() * inputs.size(0)
            
            # Accuracy
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            # Top-5 Accuracy
            # Use min(5, num_classes) to prevent errors if you have < 5 classes
            k = min(5, outputs.size(1))
            _, topk_preds = outputs.topk(k, 1, True, True)
            top5_correct += torch.sum(topk_preds.eq(labels.view(-1, 1))).item()
            
    epoch_loss = running_loss / len(dataloader.dataset)
    epoch_acc = correct / total
    epoch_top5_acc = top5_correct / total
    
    return epoch_loss, epoch_acc, epoch_top5_acc


print("Starting training...")
for epoch in range(1, Tconfig.epochs + 1):
    
    # Training step
    train_loss, train_acc = train_epoch(model, dtrain, criterion, optimizer, DEVICE)
    
    # Validation step
    val_loss, val_acc, val_top5_acc = evaluate_epoch(model, dval, criterion, DEVICE)
    
    # Update History
    history['loss'].append(train_loss)
    history['accuracy'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_accuracy'].append(val_acc)
    history['top5_accuracy'].append(val_top5_acc)
    history['lr'].append(optimizer.param_groups[0]['lr'])
    
    print(f"Epoch {epoch}/{Tconfig.epochs} | LR: {history['lr'][-1]:.6f}")
    print(f"  Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f}")
    print(f"  Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}, Top5 Acc: {val_top5_acc:.4f}")

    # Learning Rate Scheduler (based on val_acc)
    scheduler.step(val_acc) 

    # Model Checkpointing
    if val_acc > best_val_accuracy:
        best_val_accuracy = val_acc
        patience_counter = 0
        print(f"  --> Improvement! Saving model to {checkpoint_path}")
        torch.save(model.state_dict(), checkpoint_path)
    else:
        patience_counter += 1
        
    # Early Stopping
    if patience_counter >= Tconfig.patience:
        print(f"\nEarly stopping triggered after {patience_counter} epochs without improvement.")
        break

# Save final model state
torch.save(model.state_dict(), final_model_path)

# Save history
with open(log_path, "w") as fp:
    dump_history = {k: [float(x) for x in v] for k, v in history.items()}
    json.dump(dump_history, fp, indent=2)
    
print("Training complete.")

In [None]:
def plot_history(history: dict) -> None:
    history_df = pd.DataFrame(history)
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Loss Plot
    if "loss" in history_df.columns:
        history_df[["loss", "val_loss"]].plot(ax=axes[0])
        axes[0].set_title("Loss Curve")
        axes[0].set_xlabel("Epoch")
        axes[0].set_ylabel("Cross Entropy Loss")
        axes[0].grid(True)

    # Accuracy Plot 
    if "accuracy" in history_df.columns:
        history_df[["accuracy", "val_accuracy"]].plot(ax=axes[1])
        axes[1].set_title("Accuracy Curve")
        axes[1].set_xlabel("Epoch")
        axes[1].set_ylabel("Accuracy")
        axes[1].grid(True)
    
    plt.tight_layout()
    plt.show()

plot_history(history)

## 10. Performance Metrics

We run inference on the held-out **Test Set** (Session B) and generate a classification report. This includes:
* **Precision/Recall/F1-Score**: To evaluate performance per person.
* **Accuracy**: Overall system performance.

In [None]:
import json
import numpy as np
import torch
import torch.nn as nn
from sklearn.metrics import classification_report

# Instantiate a new model to ensure we are testing the clean "best" version
best_model = CNNBiLSTM(freq_bins, time_bins, num_ch, num_classes, Tconfig).to(DEVICE)

if checkpoint_path.exists():
    print(f"Loading best model from {checkpoint_path}")
    best_model.load_state_dict(torch.load(checkpoint_path))
else:
    print("Checkpoint not found, using current model state.")

best_model.eval()

def predict_and_evaluate(model, dataloader, criterion, device):
    model.eval()
    all_labels = []
    all_preds = []
    total_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            total_loss += loss.item() * inputs.size(0)

            # Metrics
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(outputs.cpu().numpy())

    test_loss = total_loss / total
    test_accuracy = correct / total
    
    test_results = {
        "loss": test_loss,
        "accuracy": test_accuracy,
    }
    
    # Convert predictions to class indices (labels)
    y_pred_probs = np.array(all_preds)
    y_pred_labels = np.argmax(y_pred_probs, axis=1)
    
    return test_results, np.array(all_labels), y_pred_labels

# Run Evaluation on Test Set
criterion = nn.CrossEntropyLoss()
test_results, y_true_labels, y_pred_labels = predict_and_evaluate(best_model, dtest, criterion, DEVICE)

print(f"Test Results: {test_results}")

if 'label_mapping' not in locals():
    label_mapping = {i: f"Subject_{i}" for i in range(num_classes)}
    print(f"Created generic label mapping for {num_classes} classes.")

# Classification Report
# Ensure keys are sorted integers so they match the model's output indices 0..N
sorted_keys = sorted([int(k) for k in label_mapping.keys()])
target_names = [str(label_mapping[k]) for k in sorted_keys]

# Print Text Report
print("\nClassification Report:\n")
print(classification_report(
    y_true_labels,
    y_pred_labels,
    labels=sorted_keys,
    target_names=target_names,
    zero_division=0
))

# Save metrics to JSON
metrics_path = Tconfig.reports_root / "test_metrics.json"
with open(metrics_path, "w") as fp:
    json.dump({
        "test_results": test_results, 
        "classification_report": classification_report(
            y_true_labels, y_pred_labels, labels=sorted_keys, 
            target_names=target_names, output_dict=True, zero_division=0
        )
    }, fp, indent=2)

print(f"Metrics saved to {metrics_path}")

In [None]:
from sklearn.metrics import classification_report, f1_score, accuracy_score, confusion_matrix
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import torch

# 1. Get Predictions and Probabilities
def get_predictions(model, dataloader, device):
    model.eval()
    all_probs = []
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs = inputs.to(device)
            outputs = model(inputs)
            probs = torch.softmax(outputs, dim=1)  # Convert logits to probs
            preds = torch.argmax(probs, dim=1)     # Get predicted class
            
            all_probs.extend(probs.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.numpy())
            
    return np.array(all_labels), np.array(all_preds), np.array(all_probs)

y_true, y_pred, y_probs = get_predictions(best_model, dtest, DEVICE)

# 2. Calculate Top-5 Accuracy
# Note: If you have < 5 classes, Top-5 Acc is always 1.0
k = min(5, num_classes)
top_k_preds = np.argsort(y_probs, axis=1)[:, -k:]
top_5_acc = np.mean([1 if true_label in pred_row else 0 
                     for true_label, pred_row in zip(y_true, top_k_preds)])

# 3. Overall Metrics
test_acc = accuracy_score(y_true, y_pred)
weighted_f1 = f1_score(y_true, y_pred, average='weighted')

print(f"Test Accuracy:  {test_acc:.4f}")
print(f"Weighted F1:    {weighted_f1:.4f}")
print(f"Top-{k} Accuracy: {top_5_acc:.4f}")

# 4. Per-Subject Accuracy
# Ensure we use the correct label keys. If label_mapping isn't available, we create a dummy one.
try:
    # If label_mapping is a dict of {int: str}
    labels_list = sorted(label_mapping.keys())
    subject_names = [label_mapping[i] for i in labels_list]
except (NameError, AttributeError, KeyError):
    # Fallback if label_mapping missing
    labels_list = sorted(list(set(y_true) | set(y_pred)))
    subject_names = [f"Subject {i}" for i in labels_list]

# Compute confusion matrix
cm = confusion_matrix(y_true, y_pred, labels=labels_list)

# Normalize CM to get recall (accuracy) per class
with np.errstate(divide='ignore', invalid='ignore'):
    cm_norm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

# --- FIX: Replace NaNs (0/0 division) with 0.0 ---
cm_norm = np.nan_to_num(cm_norm)

# Create DataFrame
per_subject_df = pd.DataFrame({
    "subject": subject_names,
    "accuracy": np.diag(cm_norm),  # Diagonal elements are correct predictions
    "support": cm.sum(axis=1)      # Total samples per class
})

# Sort by accuracy
per_subject_df = per_subject_df.sort_values("accuracy", ascending=False)

print("\nTop 5 Performing Subjects:")
print(per_subject_df.head(5))

print("\nWorst 5 Performing Subjects:")
print(per_subject_df.tail(5))

## 11. Confusion Matrix

The confusion matrix helps identify specific misclassifications.
* **Diagonal**: Correct predictions.
* **Off-diagonal**: Errors (e.g., identifying Person A as Person B).

In [None]:
plt.figure(figsize=(14, 12))
sns.heatmap(cm_norm, annot=False, fmt='.2f', cmap='Blues',
            xticklabels=[label_mapping[i] for i in sorted(label_mapping.keys())],
            yticklabels=[label_mapping[i] for i in sorted(label_mapping.keys())])
plt.title("Normalized Confusion Matrix (Recall)")
plt.xlabel("Predicted Subject")
plt.ylabel("True Subject")
plt.show()

plot_df = per_subject_df[per_subject_df["support"] > 0]

n_plot = min(5, len(plot_df))
fig, axes = plt.subplots(1, 2, figsize=(14, 5), sharey=True)

# Top 5
plot_df.head(n_plot).plot.bar(x="subject", y="accuracy", ax=axes[0], color="#1f77b4", legend=False)
axes[0].set_title(f"Top {n_plot} Subjects")
axes[0].set_ylabel("Accuracy")
axes[0].set_ylim(0, 1.05)
axes[0].grid(axis='y', alpha=0.3)

# Bottom 5 (Reversed for visual consistency)
plot_df.tail(n_plot).sort_values("accuracy", ascending=True).plot.bar(x="subject", y="accuracy", ax=axes[1], color="#d62728", legend=False)
axes[1].set_title(f"Lowest {n_plot} Subjects")
axes[1].grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.show()

## 12. Latent Space Visualization (t-SNE)

To understand *what* the model has learned, we extract the feature embeddings (the output of the LSTM before the final classification layer). 
We use **t-SNE** to project these high-dimensional vectors into 2D space. well-separated clusters indicate that the model has learned distinct features for each person.

In [None]:
from sklearn.manifold import TSNE
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import torch

# 1. Setup Hook to extract embeddings
embeddings_list = []
def hook_fn(module, input, output):
    # Flatten the input to the final layer if it isn't already
    embeddings_list.append(input[0].detach().cpu().numpy())

# Attach hook to the final layer (usually named 'fc' or similar in your architecture)
final_layer = list(best_model.children())[-1]
handle = final_layer.register_forward_hook(hook_fn)

# 2. Run a larger batch to ensure we catch enough data for our target subjects
# We scan more data initially (e.g. 3000) to find enough samples for specific subjects
scan_limit = 3000
extracted_labels = []
count = 0

best_model.eval()
with torch.no_grad():
    for inputs, labels in dtest:
        inputs = inputs.to(DEVICE)
        _ = best_model(inputs) # Forward pass triggers the hook
        
        extracted_labels.extend(labels.numpy())
        count += inputs.size(0)
        if count >= scan_limit:
            break

# Remove hook
handle.remove()

# Concatenate all batches
X_full = np.concatenate(embeddings_list, axis=0)[:count]
y_full = np.array(extracted_labels)[:count]

# Flatten if necessary
if len(X_full.shape) > 2:
    X_full = X_full.reshape(X_full.shape[0], -1)

# --- FILTER FOR 5 SUBJECTS ---
# Pick 5 unique subjects present in the data
unique_subjects = np.unique(y_full)
if len(unique_subjects) > 5:
    target_subjects = unique_subjects[:5] # Pick the first 5 found
else:
    target_subjects = unique_subjects

print(f"Visualizing Subjects: {target_subjects}")

# Create a mask to select only these subjects
mask = np.isin(y_full, target_subjects)
X_embedded = X_full[mask]
y_embedded = y_full[mask]

print(f"Embeddings shape (filtered): {X_embedded.shape}")

# 3. Compute t-SNE
print("Running t-SNE... this might take a moment.")
tsne = TSNE(n_components=2, perplexity=30, init='pca', random_state=42)
tsne_coords = tsne.fit_transform(X_embedded)

# 4. Plot t-SNE
plt.figure(figsize=(10, 8))

cmap = matplotlib.colormaps['tab10'] 
colors = cmap.resampled(len(target_subjects))

for i, subject_id in enumerate(target_subjects):
    # Find points belonging to this subject
    idxs = (y_embedded == subject_id)
    
    # Get label name if available, else use ID
    try:
        label_name = label_mapping[subject_id]
    except:
        label_name = f"Subject {subject_id}"
        
    plt.scatter(tsne_coords[idxs, 0], tsne_coords[idxs, 1], 
                label=label_name, alpha=0.7, s=20)

plt.legend(title="Subject ID")
plt.title(f"t-SNE Visualization (Top {len(target_subjects)} Subjects)")
plt.xlabel("t-SNE 1")
plt.ylabel("t-SNE 2")
plt.grid(True, alpha=0.3)
plt.show()