In [2]:
#Dependencies: 

from pathlib import Path
import os
import math
import random
from collections import defaultdict
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.nn.utils import clip_grad_norm_

from sklearn.preprocessing import StandardScaler
from scipy.stats import pearsonr

import mne
mne.set_log_level('ERROR')

In [3]:
# Configuration
SFREQ = 100           # target sampling frequency (Hz)
WIN_SEC = 4           # window length in seconds for pointers
CROP_SEC = 2          # crop length (seconds) used for training window
STRIDE_SEC = 2        # stride in seconds for sliding windows
TASK = "contrastChangeDetection"

#defining paths of data
DATA_ROOT = Path("/kaggle/input/eeg-mini")
TRAIN_RELEASES = ["R10_mini_L100_bdf"]#, "R9_mini_L100_bdf", "R8_mini_L100_bdf"]
VAL_RELEASE = "R11_mini_L100_bdf"


In [4]:
# Checkpoints code
class CheckpointSaver:
    def __init__(self, save_path="/kaggle/working/best_model.pt"):
        self.best_val_loss = float("inf")
        self.save_path = save_path

    def save_if_best(self, model, optimizer, epoch, val_loss):
        """Save full checkpoint when validation improves."""
        if val_loss < self.best_val_loss:
            self.best_val_loss = val_loss
            checkpoint = {
                "epoch": epoch,
                "model_state": model.state_dict(),
                "optimizer_state": optimizer.state_dict(),
                "best_val_loss": val_loss,
            }
            torch.save(checkpoint, self.save_path)
            print(f"✓ Saved BEST checkpoint at epoch {epoch+1} (val_loss={val_loss:.6f})")

    def save_periodic(self, model, optimizer, epoch):
        """Optional periodic checkpoint every N epochs."""
        path = f"/kaggle/working/checkpoint_epoch_{epoch+1}.pt"
        checkpoint = {
            "epoch": epoch,
            "model_state": model.state_dict(),
            "optimizer_state": optimizer.state_dict(),
        }
        torch.save(checkpoint, path)
        print(f"✓ Periodic checkpoint saved: {path}")



In [5]:
# Subjects to remove if corrupted
SUB_RM = [] 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def seed_all(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_all(42)


In [6]:
# Data loading functions

#custom function to address issue with accessing nested releases, ie: /.../R6_L100_bdf/R6_L100_bdf/...
def resolve_double_nested_path(data_root: Path, release: str): 
    outer = data_root / release
    inner = outer / release
    if inner.exists() and (inner / "participants.tsv").exists():
        return inner
    elif outer.exists() and (outer / "participants.tsv").exists():
        return outer
    else:
        return outer if outer.exists() else None

def load_participants_data(release_path: Path):
    participants_file = release_path / "participants.tsv"
    if not participants_file.exists():
        return pd.DataFrame()
    df = pd.read_csv(participants_file, sep='\t')
    return df

#data loading function provided by competition
def load_raw_eeg(subject_path: Path, subject_id: str, task: str, run=None):
    try:
        if run is None:
            fname = f"sub-{subject_id}_task-{task}_eeg.bdf"
        else:
            fname = f"sub-{subject_id}_task-{task}_run-{run}_eeg.bdf"

        file_path = subject_path / "eeg" / fname

        if not file_path.exists():
            return None

        raw = mne.io.read_raw_bdf(file_path, preload=False, verbose=False)
        raw.load_data()

        # Reject incomplete data
        if raw.n_times < 4 * SFREQ or len(raw.ch_names) < 128:
            return None

        # Resample if needed
        if int(raw.info['sfreq']) != SFREQ:
            raw.resample(SFREQ)

        return raw

    except Exception as e:
        print("EEG load failed:", e)
        return None
        

In [7]:
# Lazy load functions.
"""""
Purpose is to load large datasets into kaggle while addressing RAM crashing by 
initializing and storing pointers to data instead of the entire dataset
"""""

def load_release_data_lazy(release, task=TASK, data_root=DATA_ROOT):
    release_path = resolve_double_nested_path(data_root, release)
    if release_path is None or not Path(release_path).exists():
        print(f"Warning: Release path not found for {release} under {data_root}")
        return []

    release_path = Path(release_path)
    participants_df = load_participants_data(release_path)
    if participants_df.empty:
        print(f"No participants.tsv for {release}")
        return []

    win_samples = int(WIN_SEC * SFREQ)
    stride_samples = int(STRIDE_SEC * SFREQ)
    bids_root = str(release_path)

    window_pointers = []
    for _, row in tqdm(participants_df.iterrows(), total=len(participants_df), desc=f"Scanning {release}"):
        subj_col = row.get('participant_id', None)
        if subj_col is None:
            continue
        subject_id = str(subj_col).replace('sub-', '')

        if subject_id in SUB_RM:
            continue

        age = row.get('age', np.nan)
        sex = row.get('sex', np.nan)
        handedness = row.get('handedness', np.nan)
        externalizing = row.get('externalizing', np.nan)
        try:
            externalizing = float(externalizing)
            if not math.isfinite(externalizing):
                continue
        except Exception:
            continue

        sex_str = str(sex).strip().lower()
        if sex_str in ['female', 'f', '2']:
            sex_encoded = 1.0
        elif sex_str in ['male', 'm', '1']:
            sex_encoded = 0.0
        else:
            sex_encoded = np.nan

        subject_path = release_path / f"sub-{subject_id}"
        runs = [1, 2, 3] if task == "contrastChangeDetection" else [None]

        for run in runs:
            try:
                if run is None:
                    eeg_path = release_path / f"sub-{subject_id}/eeg/sub-{subject_id}_task-{task}_eeg.bdf"
                else:
                    eeg_path = release_path / f"sub-{subject_id}/eeg/sub-{subject_id}_task-{task}_run-{run}_eeg.bdf"
                
                if not eeg_path.exists():
                    continue

                raw = load_raw_eeg(subject_path, subject_id, task, run)
                if raw is None:
                    continue
                n_times = raw.n_times
                del raw

                starts = range(0, n_times - win_samples + 1, stride_samples)
                for start in starts:
                    window_pointers.append({
                        "bids_root": bids_root,
                        "subject": subject_id,
                        "task": task,
                        "run": run,
                        "start_sample": int(start),
                        "end_sample": int(start + win_samples),
                        "age": age,
                        "sex": sex_encoded,
                        "handedness": handedness,
                        "externalizing": externalizing
                    })
            except Exception:
                continue

    print(f"Scanned {len(window_pointers)} windows from {release}")
    return window_pointers

class EEGWindowsDataset(torch.utils.data.Dataset):
    def __init__(self, data_list, crop_samples, keep_idx, seed=42, cache_size=3):
        """
        data_list: list of pointers generated by load_release_data_lazy
        crop_samples: number of time samples to crop (CROP_SEC * SFREQ)
        keep_idx: demographic indices to keep. Add placeholder data since demographic data irreleveant 
        """
        self.data_list = data_list
        self.crop_samples = int(crop_samples)
        self.keep_idx = np.array(keep_idx, dtype=int) if (hasattr(keep_idx, '__len__') and len(keep_idx) > 0) else np.array([], dtype=int)
        self.rng = random.Random(seed)
        self.raw_cache = {}
        self.cache_keys = []
        self.cache_size = int(cache_size)

        self._subject_paths = {}
        for item in data_list:
            key = (item['subject'], item['run'])
            if key not in self._subject_paths:
                bids_root = Path(item['bids_root'])
                self._subject_paths[key] = bids_root / f"sub-{item['subject']}"

    def __len__(self):
        return len(self.data_list)

    def _get_raw_from_cache(self, subject_id, run):
        key = (subject_id, run)
      
        if key in self.raw_cache:
            self.cache_keys.remove(key)
            self.cache_keys.append(key)
            return self.raw_cache[key]
        item = next((it for it in self.data_list if (it['subject'], it['run']) == key), None)
        if item is None:
            raise FileNotFoundError(f"No pointer found for {key}")
        subject_path = self._subject_paths[key]
        raw = load_raw_eeg(subject_path, subject_id, item['task'], run)
        if raw is None:
            raise FileNotFoundError(f"Could not load raw for {subject_id} run {run}")
        # insert into cache 
        if len(self.raw_cache) >= self.cache_size:
            lru = self.cache_keys.pop(0)
            del self.raw_cache[lru]
        self.raw_cache[key] = raw
        self.cache_keys.append(key)
        return raw

    def __getitem__(self, idx):
        item = self.data_list[idx]
        raw = self._get_raw_from_cache(item['subject'], item['run'])
        # extract requested window
        eeg_np, _ = raw[:, item['start_sample']:item['end_sample']]
        # take first 128 channels if 129 present
        if eeg_np.shape[0] >= 129:
            eeg_np = eeg_np[:128, :]
        eeg = torch.from_numpy(eeg_np.copy()).float()  # (channels, time)
        C, T = eeg.shape
        # pad or random crop to required crop_samples
        if T < self.crop_samples:
            pad_amount = int(self.crop_samples - T)
            eeg = torch.nn.functional.pad(eeg, (0, pad_amount), mode='constant', value=0.0)
            start_crop = 0
            stop_crop = self.crop_samples
        else:
            start_crop = int(self.rng.randint(0, T - self.crop_samples))
            stop_crop = int(start_crop + self.crop_samples)
            eeg = eeg[:, start_crop:stop_crop]
        # per-window z-score normalization across time (per channel)
        mu = eeg.mean(dim=1, keepdim=True)
        sd = eeg.std(dim=1, keepdim=True)
        eeg = (eeg - mu) / (sd + 1e-6)
        eeg = torch.nan_to_num(eeg, nan=0.0, posinf=0.0, neginf=0.0)
        eeg = torch.clamp(eeg, min=-1e3, max=1e3)

        # label
        y = torch.tensor([item['externalizing']], dtype=torch.float32)

        # demographics kept as placeholder 
        age = item['age']
        try:
            age = float(age) if (age is not None and math.isfinite(float(age))) else np.nan
        except:
            age = np.nan
        sex = item['sex']
        hand = item['handedness']
        try:
            hand = float(hand) if (hand is not None and math.isfinite(float(hand))) else np.nan
        except:
            hand = np.nan
        full_demo = np.array([age, sex, hand], dtype=np.float32)
        if self.keep_idx.size > 0:
            demo = torch.from_numpy(full_demo[self.keep_idx]).float()
        else:
            demo = torch.empty(0, dtype=torch.float32)

        info = {'subject': item['subject'], 'task': item['task'], 'run': item['run']}
        crop_idx = (int(item['start_sample'] + start_crop), int(item['start_sample'] + stop_crop))

        return eeg, y, demo, crop_idx, info




In [9]:
# Scaler and demographics helper (kept for compatibility but fusion disabled)

class SafeStandardScaler(StandardScaler):
    def fit(self, X, y=None):
        super().fit(X, y)
        if hasattr(self, "scale_"):
            bad = ~np.isfinite(self.scale_) | (self.scale_ == 0)
            self.scale_[bad] = 1.0
        if hasattr(self, "var_"):
            self.var_[~np.isfinite(self.var_)] = 0.0
        if hasattr(self, "mean_"):
            self.mean_[~np.isfinite(self.mean_)] = 0.0
        return self

def extract_unique_demographics(data_list):
    seen = {}
    for item in data_list:
        sid = item['subject']
        if sid in seen:
            continue
        age = item['age']
        try:
            age = float(age) if (age is not None and math.isfinite(float(age))) else np.nan
        except:
            age = np.nan
        sex = item['sex']
        hand = item['handedness']
        try:
            hand = float(hand) if (hand is not None and math.isfinite(float(hand))) else np.nan
        except:
            hand = np.nan
        seen[sid] = [age, sex, hand]
    arr = np.array(list(seen.values()), dtype=np.float32) if seen else np.zeros((0, 3), dtype=np.float32)
    return arr


def build_demo_transform(train_data):
    """
    Keep the API: returns (demodim, transform_fn, medians, keep_idx)
    But since we are disabling fusion will yield demodim=0.
    """
    unique = extract_unique_demographics(train_data)
    if unique.shape[0] == 0:
        return 0, None, np.zeros((0,), dtype=np.float32), np.array([], dtype=int)
    all_nan = np.isnan(unique).all(axis=0)
    keep_mask = ~all_nan
    keep_idx = np.where(keep_mask)[0]
    kept = unique[:, keep_idx] if keep_idx.size > 0 else np.zeros((unique.shape[0], 0), dtype=np.float32)
    if kept.shape[1] == 0:
        return 0, None, np.zeros((0,), dtype=np.float32), np.array([], dtype=int)
    with np.errstate(all='ignore'):
        col_medians = np.nanmedian(kept, axis=0).astype(np.float32)
        col_medians[~np.isfinite(col_medians)] = 0.0
    def impute_cols(arr, meds):
        out = arr.copy()
        for j in range(out.shape[1]):
            mask = ~np.isfinite(out[:, j])
            out[mask, j] = meds[j]
        return out
    kept_imp = impute_cols(kept, col_medians)
    scaler = SafeStandardScaler().fit(kept_imp)
    def transform_batch(demo_tensor):
        if demo_tensor.numel() == 0 or keep_idx.size == 0:
            return demo_tensor.to(device=device, dtype=torch.float32)
        demo_np = demo_tensor.detach().cpu().numpy().astype(np.float32)
        for j in range(demo_np.shape[1]):
            mask = ~np.isfinite(demo_np[:, j])
            demo_np[mask, j] = col_medians[j]
        demo_np = scaler.transform(demo_np)
        out = torch.from_numpy(demo_np).to(device=device, dtype=torch.float32)
        out = torch.nan_to_num(out, nan=0.0, posinf=0.0, neginf=0.0)
        return out
    return keep_idx.size, transform_batch, col_medians, keep_idx


In [10]:
# Build data pointers and DataLoaders for loading into RAM
print("\n" + "="*60)
print("Scanning mini-release and building pointers (lazy)...")
print("="*60)

train_data_pointers = []
for release in TRAIN_RELEASES:
    train_data_pointers.extend(load_release_data_lazy(release, task=TASK, data_root=DATA_ROOT))

print(f"\nTotal training window pointers: {len(train_data_pointers)}")

print("\n" + "="*60)
print("Loading validation pointers...")
print("="*60)
val_data_pointers = load_release_data_lazy(VAL_RELEASE, task=TASK, data_root=DATA_ROOT)
print(f"Total validation window pointers: {len(val_data_pointers)}")

# Build (placeholder) demographic transform
demodim, transform_demo_batch, demo_medians, keep_idx = build_demo_transform(train_data_pointers)

# Create datasets
if len(train_data_pointers) == 0:
    raise ValueError("No training windows found — check your mini release BIDS structure and TASK.")

train_dataset = EEGWindowsDataset(train_data_pointers, crop_samples=int(CROP_SEC * SFREQ), keep_idx=keep_idx, seed=42, cache_size=5)
val_dataset   = EEGWindowsDataset(val_data_pointers, crop_samples=int(CROP_SEC * SFREQ), keep_idx=keep_idx, seed=42, cache_size=5)

BATCH_SIZE = 32
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
val_loader   = DataLoader(val_dataset,   batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)

# Inferred shapes
C, T = 128, int(CROP_SEC * SFREQ)
print(f"\nSample EEG shape: ({C}, {T})")
print(f"Demographic dims (kept): {demodim}")
print("\nDATA LOADING COMPLETE\n")



Scanning mini-release and building pointers (lazy)...


Scanning R10_mini_L100_bdf: 100%|██████████| 20/20 [00:11<00:00,  1.77it/s]


Scanned 9988 windows from R10_mini_L100_bdf

Total training window pointers: 9988

Loading validation pointers...


Scanning R11_mini_L100_bdf: 100%|██████████| 20/20 [00:09<00:00,  2.07it/s]

Scanned 8294 windows from R11_mini_L100_bdf
Total validation window pointers: 8294

Sample EEG shape: (128, 200)
Demographic dims (kept): 2

DATA LOADING COMPLETE






In [14]:
# Model 
class EnhancedEEGNetRegressor(nn.Module):
    def __init__(self, n_channels=128, n_times=200, dropout=0.25, F1=16, D=2):
        super().__init__()
        self.n_channels = n_channels
        self.dropout_rate = dropout

        # temporal conv
        self.temporal_conv = nn.Conv2d(1, F1, kernel_size=(1, 51), stride=(1,1), padding=(0,25), bias=False)
        self.bn1 = nn.BatchNorm2d(F1)

        # spatial depthwise conv
        self.spatial_conv = nn.Conv2d(F1, F1 * D, kernel_size=(n_channels, 1), stride=(1,1), groups=F1, bias=False)
        self.bn2 = nn.BatchNorm2d(F1 * D)

        self.elu = nn.ELU()
        self.pool1 = nn.AvgPool2d(kernel_size=(1,4), stride=(1,4))
        self.dropout1 = nn.Dropout(p=dropout)

   
        self.pool2 = nn.AvgPool2d(kernel_size=(1,2), stride=(1,2))
        self.dropout2 = nn.Dropout(p=dropout)

        self.eeg_feature_dim = F1 * D * (n_times // 8)  

        # fusion / regression head 
        self.fusion = nn.Sequential(
            nn.Linear(self.eeg_feature_dim, 64),
            nn.ReLU(),
            nn.Dropout(p=dropout),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Dropout(p=dropout)
        )
        self.regression_head = nn.Linear(32, 1)

    def forward(self, eeg, demographics=None):
        # eeg expected shape: (B, 1, C, T)
        x = self.temporal_conv(eeg)
        x = self.bn1(x)
        x = self.elu(x)

        x = self.spatial_conv(x)
        x = self.bn2(x)
        x = self.elu(x)

        x = self.pool1(x)
        x = self.dropout1(x)

 
        batch_size = x.size(0)
        x = x.squeeze(2).unsqueeze(2)  # (B, F*, 1, time)
        x = self.pool2(x)
        x = self.dropout2(x)

        eeg_features = x.reshape(batch_size, -1)
        fused = self.fusion(eeg_features)
        out = self.regression_head(fused)
        return out



In [15]:
# Loss, metrics, early stopping

class RMSELoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.mse = nn.MSELoss()
    def forward(self, pred, target):
        return torch.sqrt(self.mse(pred, target) + 1e-8)

class NRMSELoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.mse = nn.MSELoss()
    def forward(self, pred, target):
        rmse = torch.sqrt(self.mse(pred, target) + 1e-8)
        target_range = target.max() - target.min() + 1e-8
        return rmse / target_range

def calculate_metrics(predictions, targets):
    mse = np.mean((predictions - targets) ** 2)
    rmse = np.sqrt(mse)
    target_range = targets.max() - targets.min()
    nrmse = rmse / (target_range + 1e-8)
    mae = np.mean(np.abs(predictions - targets))
    if len(predictions) > 1 and np.std(predictions) > 0 and np.std(targets) > 0:
        corr, _ = pearsonr(predictions, targets)
    else:
        corr = 0.0
    return {'mse': mse, 'rmse': rmse, 'nrmse': nrmse, 'mae': mae, 'pearson_r': corr}

class EarlyStoppingCallback:
    def __init__(self, patience:int = 10, verbose: bool=True):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_loss = None
        self.early_stop = False
    def __call__(self, val_loss: float, model: nn.Module, save_path: str = "best_model.pt"):
        if self.best_loss is None:
            self.best_loss = val_loss
            torch.save(model.state_dict(), save_path)
        elif val_loss < self.best_loss * 0.99:
            self.best_loss = val_loss
            self.counter = 0
            torch.save(model.state_dict(), save_path)
            if self.verbose:
                print(f"✓ Validation loss improved to {val_loss:.6f}. Model saved.")
        else:
            self.counter += 1
            if self.verbose:
                print(f"No improvement for {self.counter}/{self.patience} epochs")
            if self.counter >= self.patience:
                self.early_stop = True
                if self.verbose:
                    print("Early stopping triggered!")


In [16]:
# Train / Validate functions

def train_epoch(model, train_loader, criterion, optimizer, device, grad_clip=1.0, transform_demo_batch=None):
    model.train()
    total_loss = 0.0
    for eeg_batch, target_batch, demo_batch, _, _ in tqdm(train_loader, desc="Train", leave=False):
        eeg_batch = eeg_batch.unsqueeze(1).to(device)   # (B, 1, C, T)
        target_batch = target_batch.to(device)
        optimizer.zero_grad()
        preds = model(eeg_batch)
        loss = criterion(preds, target_batch)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(train_loader)

def validate_epoch(model, val_loader, criterion, device, transform_demo_batch=None):
    model.eval()
    total_loss = 0.0
    all_preds = []
    all_targets = []
    with torch.no_grad():
        for eeg_batch, target_batch, demo_batch, _, _ in tqdm(val_loader, desc="Val", leave=False):
            eeg_batch = eeg_batch.unsqueeze(1).to(device)
            target_batch = target_batch.to(device)
            preds = model(eeg_batch)
            loss = criterion(preds, target_batch)
            total_loss += loss.item()
            all_preds.append(preds.cpu().numpy())
            all_targets.append(target_batch.cpu().numpy())
    avg_loss = total_loss / len(val_loader)
    preds = np.concatenate(all_preds, axis=0).flatten()
    targets = np.concatenate(all_targets, axis=0).flatten()
    return avg_loss, preds, targets



In [17]:
# Model init, optimizer, scheduler

print("\n" + "="*60)
print("INITIALIZING MODEL")
print("="*60)

model = EnhancedEEGNetRegressor(n_channels=C, n_times=T, dropout=0.25, F1=16, D=2).to(device)
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"✓ Model created | total params: {total_params:,} | trainable: {trainable_params:,}")

N_EPOCHS = 10
LEARNING_RATE = 1e-3
EARLY_STOPPING_PATIENCE = 5
USE_NRMSE = False   

criterion = NRMSELoss() if USE_NRMSE else RMSELoss()
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)
early_stopping = EarlyStoppingCallback(patience=EARLY_STOPPING_PATIENCE, verbose=True)

history = {'train_loss': [], 'val_loss': [], 'val_rmse': [], 'val_mae': [], 'val_pearson_r': []}
best_model_path = "/kaggle/working/best_eegnet_mini.pt"




INITIALIZING MODEL
✓ Model created | total params: 58,385 | trainable: 58,385


In [None]:
# Training Loop with checkpoints

print("\n" + "="*60)
print("STARTING TRAINING (with checkpoints)")
print("="*60)

# Initialize the checkpoint saver
checkpoint_saver = CheckpointSaver("/kaggle/working/best_model.pt")

for epoch in range(N_EPOCHS):

    print(f"\n{'='*60}\nEpoch {epoch+1}/{N_EPOCHS}\n{'='*60}")

    # TRAIN

    train_loss = train_epoch(
        model=model,
        train_loader=train_loader,
        criterion=criterion,
        optimizer=optimizer,
        device=device,
        grad_clip=1.0
        num_workers=0,
    )

    # VALIDATE

    val_loss, val_preds, val_targets = validate_epoch(
        model=model,
        val_loader=val_loader,
        criterion=criterion,
        device=device
        num_workers=0,
    )

    # METRICS & LOGGING

    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)

    metrics = calculate_metrics(val_preds, val_targets)

    history['val_rmse'].append(metrics['rmse'])
    history['val_mae'].append(metrics['mae'])
    history['val_pearson_r'].append(metrics['pearson_r'])

    print(f"Train Loss: {train_loss:.6f}")
    print(f"Val Loss:   {val_loss:.6f} | RMSE: {metrics['rmse']:.6f} | "
          f"MAE: {metrics['mae']:.6f} | r: {metrics['pearson_r']:.4f}")

    # LR SCHEDULING

    scheduler.step(val_loss)

    # CHECKPOINT SAVING

    # Save best checkpoint
    checkpoint_saver.save_if_best(
        model=model,
        optimizer=optimizer,
        epoch=epoch,
        val_loss=val_loss
    )

    # Save periodic checkpoints
    if epoch % 5 == 0:
        checkpoint_saver.save_periodic(model, optimizer, epoch)


    # EARLY STOPPING

    early_stopping(val_loss, model, best_model_path)

    if early_stopping.early_stop:
        print(f"Early stopping triggered at epoch {epoch+1}")
        break


# LOAD BEST CHECKPOINT AFTER TRAINING
try:
    checkpoint = torch.load("/kaggle/working/best_model.pt", map_location=device)
    model.load_state_dict(checkpoint["model_state"])
    print("✓ Loaded BEST model checkpoint.")
except:
    print("⚠ Could not load best checkpoint!")


In [None]:
# Final evaluation + plotting

val_loss, val_preds, val_targets = validate_epoch(
    model,
    val_loader,
    criterion,
    device,
    transform_demo_batch=transform_demo_batch
)

final_metrics = calculate_metrics(val_preds, val_targets)

print("\n" + "="*60)
print("TRAINING COMPLETE")
print("="*60)
print(f"Final Val RMSE:   {final_metrics['rmse']:.6f}")
print(f"Final Val nRMSE:  {final_metrics['nrmse']:.6f}")
print(f"Final Val MAE:    {final_metrics['mae']:.6f}")
print(f"Final Val Pearson r: {final_metrics['pearson_r']:.4f}")

# Plot training curves 

fig, axes = plt.subplots(2, 2, figsize=(12, 9))

# Loss curves
axes[0,0].plot(history['train_loss'], label='Train Loss')
axes[0,0].plot(history['val_loss'], label='Val Loss')
axes[0,0].legend()
axes[0,0].set_title('Loss')

# RMSE curve
axes[0,1].plot(history['val_rmse'], label='Val RMSE')
axes[0,1].set_title('Val RMSE')
axes[0,1].legend()

# MAE curve
axes[1,0].plot(history['val_mae'], label='Val MAE')
axes[1,0].set_title('Val MAE')
axes[1,0].legend()

# Pearson correlation
axes[1,1].plot(history['val_pearson_r'], label='Val Pearson r')
axes[1,1].set_title('Val Pearson r')
axes[1,1].legend()

plt.tight_layout()
plt.savefig('/kaggle/working/training_history.png', dpi=150, bbox_inches='tight')
plt.show()

print("Saved training curves to /kaggle/working/training_history.png")
