In [None]:
!pip install  tqdm
!pip install -U braindecode==1.2.0 mne==1.10.1 mne-bids==0.17.0 --quiet
!pip install -U eegdash==0.3.8 s3fs==2025.9.0 fsspec==2025.9.0 pandas==2.3.3 --quiet

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m91.2/91.2 kB[0m [31m6.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m305.2/305.2 kB[0m [31m19.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.4/7.4 MB[0m [31m57.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m168.9/168.9 kB[0m [31m10.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m263.1/263.1 kB[0m [31m22.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m163.8/163.8 kB[0m [31m18.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.4/12.4 MB[0m [31m147.6 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the follo

In [None]:
from pathlib import Path
import os, math, random
from collections import defaultdict
import importlib.util

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.nn.utils import clip_grad_norm_

from sklearn.preprocessing import StandardScaler

from braindecode.preprocessing import create_fixed_length_windows
from braindecode.datasets.base import BaseConcatDataset, BaseDataset, EEGWindowsDataset
from braindecode.models import EEGNetv4

from eegdash import EEGChallengeDataset

#Configuration and Device Setup

SFREQ = 100                # challenge downsampled sampling rate
WIN_SEC = 4                # fixed-length window size (competition examples)
CROP_SEC = 2               # random 2 s crop inside each 4 s window
STRIDE_SEC = 2             # window stride
TASK = "contrastChangeDetection"  # primary supervised task
DATADIR = Path("./data_cache")
DATADIR.mkdir(parents=True, exist_ok=True)

TRAIN_RELEASES = ["R1", "R2", "R3", "R4"]
VAL_RELEASE = "R5"

# Description fields to keep in dataset (ids + demographics + label)
DESC_FIELDS = [
    "subject", "task", "session", "run",
    "age", "sex", "gender", "handedness",
    "externalizing"
]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

def seed_all(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_all(42)

#Filtering and Loading Functions
SUB_RM = ["NDARAH793FBF", "NDARAJ689BVN"]

def filter_dataset(bcd: BaseConcatDataset):
    kept = []
    for ds in bcd.datasets:
        # Check subject not in removal list
        if ds.description["subject"] in SUB_RM:
            continue

        # Check valid label
        p_factor = ds.description.get("externalizing", np.nan)
        try:
            p_factor = float(p_factor)
            if not math.isfinite(p_factor):
                continue
        except Exception:
            continue

        # Check minimum recording length (4 seconds)
        if ds.raw.n_times < 4 * SFREQ:
            continue

        # Check for 129 channels
        if len(ds.raw.ch_names) != 129:
            continue

        kept.append(ds)

    print(f"Filtered rows: {len(kept)} (from {len(bcd.datasets)})")
    return BaseConcatDataset(kept)


def load_split(releases, task=TASK, mini=True):
    """Load multiple releases and concatenate them"""
    datasets = []
    for r in releases:
        try:
            ds = EEGChallengeDataset(
                release=r, task=task, mini=mini,
                description_fields=DESC_FIELDS, cache_dir=str(DATADIR)
            )
            datasets.append(ds)
            print(f"Loaded release {r}")
        except Exception as e:
            print(f"Warning: Could not load release {r}: {e}")
            continue

    if not datasets:
        raise ValueError("No datasets loaded successfully")

    bcd = BaseConcatDataset(datasets)
    return filter_dataset(bcd)

# 3) Demographics Processing
def extract_unique_demo(bcd: BaseConcatDataset):
    """Extract unique subject demographics (one row per subject)"""
    seen = {}
    for ds in bcd.datasets:
        sid = ds.description["subject"]
        if sid in seen:
            continue

        age = ds.description.get("age", np.nan)
        try:
            age = float(age) if age is not None and math.isfinite(float(age)) else np.nan
        except Exception:
            age = np.nan

        sex_str = str(ds.description.get("sex", ds.description.get("gender", ""))).strip().lower()
        if sex_str in ["female", "f", "2"]:
            sex = 1.0
        elif sex_str in ["male", "m", "1"]:
            sex = 0.0
        else:
            sex = np.nan

        hand = ds.description.get("handedness", np.nan)
        try:
            hand = float(hand) if hand is not None and math.isfinite(float(hand)) else np.nan
        except Exception:
            hand = np.nan

        seen[sid] = [age, sex, hand]

    arr = np.array(list(seen.values()), dtype=np.float32) if seen else np.zeros((0, 3), dtype=np.float32)
    return arr


class SafeStandardScaler(StandardScaler):
    """A robust per-column standard scaler with NaN/Inf safeguards"""
    def fit(self, X, y=None):
        super().fit(X, y)
        if hasattr(self, "scale_"):
            bad = ~np.isfinite(self.scale_) | (self.scale_ == 0)
            self.scale_[bad] = 1.0
        if hasattr(self, "var_"):
            self.var_[~np.isfinite(self.var_)] = 0.0
        if hasattr(self, "mean_"):
            self.mean_[~np.isfinite(self.mean_)] = 0.0
        return self


def build_demo_transform(train_bcd: BaseConcatDataset):
    """Build demographic transformation pipeline"""
    # Assemble unique demos
    unique = extract_unique_demo(train_bcd)
    if unique.shape[0] == 0:
        print("No demographics found; disabling late fusion.")
        return 0, None, np.zeros((0,), dtype=np.float32), np.array([], dtype=int)

    # Detect all-NaN columns
    all_nan = np.isnan(unique).all(axis=0)
    keep_mask = ~all_nan
    keep_idx = np.where(keep_mask)[0]
    keep_names = [n for n, k in zip(["age", "sex", "hand"], keep_mask) if k]
    print("Keeping demo columns:", keep_names)

    kept = unique[:, keep_idx] if keep_idx.size > 0 else np.zeros((unique.shape[0], 0), dtype=np.float32)
    if kept.shape[1] == 0:
        print("No usable demographic columns; disabling late fusion.")
        return 0, None, np.zeros((0,), dtype=np.float32), np.array([], dtype=int)

    # Compute column medians for imputation
    with np.errstate(all="ignore"):
        col_medians = np.nanmedian(kept, axis=0).astype(np.float32)
        col_medians[~np.isfinite(col_medians)] = 0.0

    # Impute and fit scaler
    def impute_cols(arr, meds):
        out = arr.copy()
        for j in range(out.shape[1]):
            mask = ~np.isfinite(out[:, j])
            out[mask, j] = meds[j]
        return out

    kept_imp = impute_cols(kept, col_medians)
    scaler = SafeStandardScaler().fit(kept_imp)
    print(f"Demo scaler fitted on {kept_imp.shape[0]} subjects | dims: {keep_idx.size}")

    def transform_batch(demo_tensor: torch.Tensor):
        if demo_tensor.numel() == 0 or keep_idx.size == 0:
            return demo_tensor.to(device=device, dtype=torch.float32)

        demo_np = demo_tensor.detach().cpu().numpy().astype(np.float32)
        # Impute per kept column using col_medians
        for j in range(demo_np.shape[1]):
            mask = ~np.isfinite(demo_np[:, j])
            demo_np[mask, j] = col_medians[j]

        demo_np = scaler.transform(demo_np)
        out = torch.from_numpy(demo_np).to(device=device, dtype=torch.float32)
        out = torch.nan_to_num(out, nan=0.0, posinf=0.0, neginf=0.0)
        return out

    return keep_idx.size, transform_batch, col_medians, keep_idx


# Dataset Wrapper with Windowing
class WindowDatasetWrapper(torch.utils.data.Dataset):
    """Wrap a WindowsDataset to apply random 2 s crop and return (X, y, demo, crop_idx, info)"""
    def __init__(self, windows_dataset: EEGWindowsDataset, crop_size_samples: int,
                 keep_idx: np.ndarray, seed: int = 42):
        self.ds = windows_dataset
        self.crop = crop_size_samples
        self.keep_idx = keep_idx
        self.rng = random.Random(seed)

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, index):
        X, y, window_ind = self.ds[index]

        # Get description from the metadata
        metadata = self.ds.get_metadata()
        dataset_idx = metadata.iloc[index]['target']
        desc = self.ds.datasets[dataset_idx].description

        # Extract label (p_factor)
        p_factor = desc.get("p_factor", np.nan)
        try:
            y_val = float(p_factor) if math.isfinite(float(p_factor)) else np.nan
        except Exception:
            y_val = np.nan
        y = torch.tensor([y_val], dtype=torch.float32)

        # Extract demographics: [age, sex, hand] raw vector
        age = desc.get("age", np.nan)
        try:
            age = float(age) if age is not None and math.isfinite(float(age)) else np.nan
        except Exception:
            age = np.nan

        sex_str = str(desc.get("sex", desc.get("gender", ""))).strip().lower()
        if sex_str in ["female", "f", "2"]:
            sex = 1.0
        elif sex_str in ["male", "m", "1"]:
            sex = 0.0
        else:
            sex = np.nan

        hand = desc.get("handedness", np.nan)
        try:
            hand = float(hand) if hand is not None and math.isfinite(float(hand)) else np.nan
        except Exception:
            hand = np.nan

        # Build full demo array and select kept columns
        full_demo = np.array([age, sex, hand], dtype=np.float32)
        if len(self.keep_idx) > 0:
            demo = torch.from_numpy(full_demo[self.keep_idx])
        else:
            demo = torch.empty(0, dtype=torch.float32)

        # X comes as (C, T_window) from braindecode windows
        if not torch.is_tensor(X):
            X = torch.from_numpy(np.asarray(X))
        X = X.to(dtype=torch.float32)

        C, Tw = X.shape[-2], X.shape[-1]
        assert Tw >= self.crop, f"Window too short: {Tw} < crop {self.crop}"

        # Random crop
        start = self.rng.randint(0, Tw - self.crop)
        stop = start + self.crop
        X = X[:, start:stop]  # (C, crop)

        # Per-window z-score (time axis), sanitize
        mu = X.mean(dim=1, keepdim=True)
        sd = X.std(dim=1, keepdim=True)
        X = (X - mu) / (sd + 1e-6)
        X = torch.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)
        X = torch.clamp(X, min=-1e3, max=1e3)

        info = dict(
            subject=desc["subject"],
            task=desc.get("task", ""),
            session=desc.get("session", ""),
            run=desc.get("run", "")
        )
        crop_idx = (start, stop)

        return X, y, demo, crop_idx, info


def make_windows(bcd: BaseConcatDataset, keep_idx: np.ndarray,
                 win_sec=WIN_SEC, crop_sec=CROP_SEC):
    """Create fixed-length windows with random cropping"""
    windows = create_fixed_length_windows(
        bcd,
        window_size_samples=int(win_sec * SFREQ),
        window_stride_samples=int(STRIDE_SEC * SFREQ),
        drop_last_window=True,
    )
    wrapped = WindowDatasetWrapper(
        windows,
        crop_size_samples=int(crop_sec * SFREQ),
        keep_idx=keep_idx,
        seed=42
    )
    return wrapped


# Load and Process Data
print("\n" + "="*60)
print("Loading training releases...")
print("="*60)
train_bcd = load_split(TRAIN_RELEASES, task=TASK, mini=True)

print("\n" + "="*60)
print("Loading validation release...")
print("="*60)
val_bcd = load_split([VAL_RELEASE], task=TASK, mini=True)

# Fit demographics transform on TRAIN ONLY
demodim, transform_demo_batch, demo_medians, keep_idx = build_demo_transform(train_bcd)

# Create windows
print("\n" + "="*60)
print("Creating windows...")
print("="*60)
train_windows = make_windows(train_bcd, keep_idx, WIN_SEC, CROP_SEC)
val_windows = make_windows(val_bcd, keep_idx, WIN_SEC, CROP_SEC)

# DataLoaders
BATCH_SIZE = 32
train_loader = DataLoader(train_windows, batch_size=BATCH_SIZE, shuffle=True,
                         num_workers=0, pin_memory=True)
val_loader = DataLoader(val_windows, batch_size=BATCH_SIZE, shuffle=False,
                       num_workers=0, pin_memory=True)

print(f"Train batches: {len(train_loader)} | Val batches: {len(val_loader)}")

# Infer shapes for model init
sample_X, sample_y, sample_demo, _, _ = train_windows[0]
C, T = sample_X.shape[-2], sample_X.shape[-1]
print(f"Sample window shape: ({C}, {T}) | demodim: {demodim}")





Using device: cuda

Loading training releases...



[EEGChallengeDataset] EEG 2025 Competition Data Notice:
-------------------------------------------------------
This object loads the HBN dataset that has been preprocessed for the EEG Challenge:
  - Downsampled from 500Hz to 100Hz
  - Bandpass filtered (0.5–50 Hz)

For full preprocessing details, see:
  https://github.com/eeg2025/downsample-datasets

IMPORTANT: The data accessed via `EEGChallengeDataset` is NOT identical to what you get from `EEGDashDataset` directly.
If you are participating in the competition, always use `EEGChallengeDataset` to ensure consistency with the challenge data.


  warn(


Loaded release R1



[EEGChallengeDataset] EEG 2025 Competition Data Notice:
-------------------------------------------------------
This object loads the HBN dataset that has been preprocessed for the EEG Challenge:
  - Downsampled from 500Hz to 100Hz
  - Bandpass filtered (0.5–50 Hz)

For full preprocessing details, see:
  https://github.com/eeg2025/downsample-datasets

IMPORTANT: The data accessed via `EEGChallengeDataset` is NOT identical to what you get from `EEGDashDataset` directly.
If you are participating in the competition, always use `EEGChallengeDataset` to ensure consistency with the challenge data.


  warn(


Loaded release R2



[EEGChallengeDataset] EEG 2025 Competition Data Notice:
-------------------------------------------------------
This object loads the HBN dataset that has been preprocessed for the EEG Challenge:
  - Downsampled from 500Hz to 100Hz
  - Bandpass filtered (0.5–50 Hz)

For full preprocessing details, see:
  https://github.com/eeg2025/downsample-datasets

IMPORTANT: The data accessed via `EEGChallengeDataset` is NOT identical to what you get from `EEGDashDataset` directly.
If you are participating in the competition, always use `EEGChallengeDataset` to ensure consistency with the challenge data.


  warn(


Loaded release R3



[EEGChallengeDataset] EEG 2025 Competition Data Notice:
-------------------------------------------------------
This object loads the HBN dataset that has been preprocessed for the EEG Challenge:
  - Downsampled from 500Hz to 100Hz
  - Bandpass filtered (0.5–50 Hz)

For full preprocessing details, see:
  https://github.com/eeg2025/downsample-datasets

IMPORTANT: The data accessed via `EEGChallengeDataset` is NOT identical to what you get from `EEGDashDataset` directly.
If you are participating in the competition, always use `EEGChallengeDataset` to ensure consistency with the challenge data.


  warn(


Loaded release R4



Downloading sub-NDARAP359UM6_task-contrastChangeDetection_run-3_eeg.bdf:   0%|          | 0.00/1.00 [00:00<?, ?B/s][A
Downloading sub-NDARAP359UM6_task-contrastChangeDetection_run-3_eeg.bdf: 100%|██████████| 1.00/1.00 [00:00<00:00, 1.20B/s]

Downloading sub-NDARAP359UM6_task-contrastChangeDetection_run-3_eeg.bdf:   0%|          | 0.00/1.00 [00:00<?, ?B/s][A
Downloading sub-NDARAP359UM6_task-contrastChangeDetection_run-3_eeg.bdf: 100%|██████████| 1.00/1.00 [00:00<00:00, 2.60B/s]

Downloading sub-NDARBD879MBX_task-ThePresent_events.tsv: 100%|██████████| 1.00/1.00 [00:00<00:00, 7.00B/s]

Downloading sub-NDARBD879MBX_task-RestingState_events.tsv: 100%|██████████| 1.00/1.00 [00:00<00:00, 5.37B/s]

Downloading sub-NDARBD879MBX_task-DespicableMe_events.tsv: 100%|██████████| 1.00/1.00 [00:00<00:00, 6.26B/s]

Downloading sub-NDARBD879MBX_task-FunwithFractals_events.tsv: 100%|██████████| 1.00/1.00 [00:00<00:00, 6.05B/s]

Downloading sub-NDARBD879MBX_task-surroundSupp_run-1_events.tsv: 100%|██

In [None]:
# Load Enhanced Model
MODEL_PATH = "/content/enhanced_eegnet.py"
if not Path(MODEL_PATH).exists():
    print(f"\n{'='*60}")
    print("ERROR: Please upload enhanced_eegnet.py to /content/")
    print("="*60)
else:
    spec = importlib.util.spec_from_file_location("enhanced_eegnet", MODEL_PATH)
    ee_module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(ee_module)

    EnhancedEEGNetRegressor = ee_module.EnhancedEEGNetRegressor
    RMSELoss = getattr(ee_module, "RMSELoss", None)
    NRMSELoss = getattr(ee_module, "NRMSELoss", None)
    adapt_batch_norm = getattr(ee_module, "adapt_batch_norm", None)

    model = EnhancedEEGNetRegressor(
        n_channels=C,           # e.g., 129
        n_times=T,              # 2 s * 100 Hz = 200
        n_demographic_features=demodim,
        dropout=0.5,
        F1=16,
        D=2,
        num_heads=8
    ).to(device)

    print(f"\nModel initialized with {sum(p.numel() for p in model.parameters()):,} parameters")


#Evaluation Function
@torch.no_grad()
def evaluate_subject_nrmse(model, loader):
    """Evaluate subject-level nRMSE"""
    model.eval()
    subj_pred, subj_true = defaultdict(list), {}

    for X, y, demo, crop_idx, info in tqdm(loader, desc="Validation", leave=False):
        X = X.to(device=device, dtype=torch.float32)            # (B, C, T)
        X = X.unsqueeze(1)                                      # (B, 1, C, T)
        y = y.to(device=device, dtype=torch.float32)            # (B, 1)
        demo = demo.to(device=device, dtype=torch.float32)

        if demodim > 0:
            demo = transform_demo_batch(demo)
        else:
            demo = torch.empty((X.size(0), 0), device=device, dtype=torch.float32)

        preds = model(X, demographics=demo).squeeze(1).detach().cpu().numpy()
        ys = y.squeeze(1).detach().cpu().numpy()

        # Handle info dict
        for i in range(len(preds)):
            sid = info["subject"][i] if isinstance(info, dict) else info[i]["subject"]
            subj_pred[sid].append(float(preds[i]))
            subj_true[sid] = float(ys[i])

    if len(subj_pred) == 0:
        return np.nan, np.nan, np.nan, 0

    ytrue, yhat, sids = [], [], []
    for sid, plist in subj_pred.items():
        if sid not in subj_true or len(plist) == 0:
            continue
        ytrue.append(subj_true[sid])
        yhat.append(float(np.mean(plist)))
        sids.append(sid)

    ytrue = np.array(ytrue, dtype=np.float64)
    yhat = np.array(yhat, dtype=np.float64)
    rmse = float(np.sqrt(np.mean((ytrue - yhat) ** 2))) if ytrue.size > 0 else np.nan
    stdy = float(np.std(ytrue, ddof=0)) if ytrue.size > 0 else np.nan
    nrmse = rmse / stdy if stdy and stdy > 0 else np.nan

    # Save per-subject predictions
    out_dir = Path("./outputs")
    out_dir.mkdir(parents=True, exist_ok=True)
    pd.DataFrame({
        "subject": sids,
        "p_factor_pred": yhat,
        "p_factor_true": ytrue
    }).to_csv(out_dir / "val_subject_predictions.csv", index=False)

    return nrmse, rmse, stdy, len(sids)

# Training Function
def train_model(model, train_loader, val_loader, max_epochs=40, lr=1e-3,
                weight_decay=1e-2, patience=8, use_nrmse_loss=True):
    """Train the model with early stopping"""
    # Loss function
    if use_nrmse_loss and NRMSELoss is not None:
        criterion = NRMSELoss()
        print("Using NRMSELoss for training.")
    elif RMSELoss is not None:
        criterion = RMSELoss()
        print("Using RMSELoss for training.")
    else:
        criterion = nn.MSELoss()
        print("Using MSELoss for training.")

    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode="min", factor=0.5, patience=3
    )

    best_nrmse, wait = float("inf"), 0
    ckpt_dir = Path("./checkpoints")
    ckpt_dir.mkdir(parents=True, exist_ok=True)
    ckpt_path = ckpt_dir / "best_enhanced_eegnet.pt"

    for epoch in range(1, max_epochs + 1):
        model.train()
        train_loss_sum, n_seen = 0.0, 0

        pbar = tqdm(train_loader, desc=f"Epoch {epoch:02d} • Train", leave=False)
        for X, y, demo, crop_idx, info in pbar:
            X = X.to(device=device, dtype=torch.float32)   # (B, C, T)
            X = X.unsqueeze(1)                             # (B, 1, C, T)
            y = y.to(device=device, dtype=torch.float32)   # (B, 1)
            demo = demo.to(device=device, dtype=torch.float32)

            if demodim > 0:
                demo = transform_demo_batch(demo)
            else:
                demo = torch.empty((X.size(0), 0), device=device, dtype=torch.float32)

            optimizer.zero_grad(set_to_none=True)
            y_pred = model(X, demographics=demo)
            loss = criterion(y_pred, y)

            if not torch.isfinite(loss):
                print("Non-finite loss; skipping batch.")
                continue

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            train_loss_sum += loss.item() * X.size(0)
            n_seen += X.size(0)
            pbar.set_postfix(loss=loss.item())

        train_loss = train_loss_sum / max(1, n_seen)

        # Validate subject-level metrics
        nrmse, rmse, stdy, nsubj = evaluate_subject_nrmse(model, val_loader)
        scheduler.step(nrmse if nrmse == nrmse else 1.0)

        print(f"Epoch {epoch:02d} • train loss {train_loss:.4f} • val nRMSE {nrmse:.4f} • " +
              f"RMSE {rmse:.4f} • std(y) {stdy:.4f} • subjects {nsubj}")

        improved = (nrmse < best_nrmse) if nrmse == nrmse else False
        if improved:
            best_nrmse, wait = nrmse, 0
            torch.save(model.state_dict(), ckpt_path)
            print(f"✓ Saved {ckpt_path} • best nRMSE {best_nrmse:.4f}")
        else:
            wait += 1
            if wait >= patience:
                print(f"Early stopping at epoch {epoch} • best nRMSE {best_nrmse:.4f}")
                break

    # Load best checkpoint
    if ckpt_path.exists():
        model.load_state_dict(torch.load(ckpt_path, map_location=device))
        print(f"Loaded best checkpoint: {ckpt_path}")

    return model, best_nrmse


#Run Training
if Path(MODEL_PATH).exists():
    print("\n" + "="*60)
    print("Starting training...")
    print("="*60)
    model, best_nrmse = train_model(
        model, train_loader, val_loader,
        max_epochs=40, lr=1e-3, weight_decay=1e-2,
        patience=8, use_nrmse_loss=True
    )

    # AdaBN on validation domain
    if adapt_batch_norm is not None:
        print("\n" + "="*60)
        print("Adapting BatchNorm on validation domain (AdaBN)...")
        print("="*60)
        adapt_batch_norm(model, val_loader, device)

    # Final validation report
    print("\n" + "="*60)
    print("Final evaluation...")
    print("="*60)
    nrmse, rmse, stdy, nsubj = evaluate_subject_nrmse(model, val_loader)
    print(f"Final • val nRMSE {nrmse:.4f} • RMSE {rmse:.4f} • " +
          f"std(y) {stdy:.4f} • subjects {nsubj}")
    print("Results saved to: ./outputs/val_subject_predictions.csv")
else:
    print("\nUpload enhanced_eegnet.py to /content/ and re-run the cells.")