In [1]:
!pip install mne==1.10.1 mne-bids==0.17.0 --quiet
!pip install scikit-learn --quiet
from pathlib import Path
import os, math, random
from collections import defaultdict
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.nn.utils import clip_grad_norm_

from sklearn.preprocessing import StandardScaler
from scipy.stats import pearsonr

import mne
from mne_bids import BIDSPath, read_raw_bids
mne.set_log_level('ERROR')

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.4/7.4 MB[0m [31m55.7 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m168.9/168.9 kB[0m [31m15.1 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
SFREQ = 100  # Sampling rate (Hz)
WIN_SEC = 4  # Window size (seconds)
CROP_SEC = 2  # Random crop size (seconds)
STRIDE_SEC = 2  # Window stride (seconds)
TASK = "contrastChangeDetection"  # Primary task

# Path to the parent folder containing R6_L100_bdf and R5_L100_bdf
DATA_ROOT = Path("/kaggle/input/eeg-dataset")

TRAIN_RELEASES = ["R6_L100_bdf","R7_L100_bdf","R8_L100_bdf"]
VAL_RELEASE = "R5_L100_bdf"

SUB_RM = ["NDARAC350XUM", "NDARAJ689BVN"] 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def seed_all(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_all(42)


def resolve_double_nested_path(data_root, release):
    outer = data_root / release
    inner = outer / release
    if (inner.exists() and (inner / "participants.tsv").exists()):
        return inner
    elif (outer.exists() and (outer / "participants.tsv").exists()):
        return outer
    else:
        return None

def load_participants_data(release_path):
    participants_file = release_path / "participants.tsv"
    if not participants_file.exists():
        return pd.DataFrame()
    df = pd.read_csv(participants_file, sep='\t')
    return df

# Helper to load RAW data for a subject/run (Used by the cache and the dataset)
def load_raw_eeg(subject_path, subject_id, task, run=None):
    """Load a single EEG file using MNE-BIDS. Returns mne.io.Raw object."""
    try:
        # BIDS_ROOT is subject_path.parent
        bids_path = BIDSPath(
            subject=subject_id,
            task=task,
            run=run,
            datatype='eeg',
            extension='.bdf',
            root=subject_path.parent
        )
        raw = read_raw_bids(bids_path, verbose=False)
        raw.load_data()
        
        # Validation checks
        if raw.n_times < 4 * SFREQ or len(raw.ch_names) != 129:
             return None
             
        return raw
    except Exception:
        return None

def load_release_data_lazy(release, task=TASK, data_root=DATA_ROOT):
    """
    Scans the data and returns a list of *window pointers* (metadata), 
    not the actual EEG data, to avoid filling RAM.
    """
    release_path = resolve_double_nested_path(data_root, release)

    if not release_path.exists():
        print(f"Warning: Release path {release_path} not found")
        return []

    participants_df = load_participants_data(release_path)

    if participants_df.empty:
        print(f"No participants data found for {release}")
        return []

    window_pointers = []
    win_samples = int(WIN_SEC * SFREQ)
    stride_samples = int(STRIDE_SEC * SFREQ)
    
    # Store the BIDS root for easier access later
    bids_root = str(release_path)

    for _, row in tqdm(
        participants_df.iterrows(),
        total=len(participants_df),
        desc=f"Scanning {release} for windows"
    ):
        subject_id = row['participant_id'].replace('sub-', '')

        if subject_id in SUB_RM:
            continue

        # Extract demographics and label
        age = row.get('age', np.nan)
        sex = row.get('sex', np.nan)
        handedness = row.get('handedness', np.nan)
        externalizing = row.get('externalizing', np.nan)
        
        try:
            externalizing = float(externalizing)
            if not math.isfinite(externalizing):
                continue
        except Exception:
            continue
            
        sex_str = str(sex).strip().lower()
        if sex_str in ['female', 'f', '2']:
            sex_encoded = 1.0
        elif sex_str in ['male', 'm', '1']:\
            sex_encoded = 0.0
        else:
            sex_encoded = np.nan
        
        # Path to subject directory inside the resolved release folder
        subject_path = release_path / f"sub-{subject_id}"

        runs = [1, 2, 3] if task == "contrastChangeDetection" else [None]

        for run in runs:
            try:
                # Pre-check if the file exists and its length
                # NOTE: This is less robust than calling load_raw_eeg, but much faster 
                # for generating pointers. We'll rely on load_raw_eeg failing later.
                bids_path_check = BIDSPath(
                    subject=subject_id, task=task, run=run, datatype='eeg', extension='.bdf', root=release_path
                )
                if not bids_path_check.fpath.exists():
                    continue

                # We MUST load data here to know the total length for windowing, 
                # but we'll minimize its lifetime.
                raw = load_raw_eeg(subject_path, subject_id, task, run)
                if raw is None:
                    continue

                n_times = raw.n_times
                del raw # IMMEDIATELY RELEASE MEMORY

                starts = range(0, n_times - win_samples + 1, stride_samples)
                
                # Store pointers
                for start in starts:
                    window_pointers.append({
                        "bids_root": bids_root,
                        "subject": subject_id,
                        "task": task,
                        "run": run,
                        "start_sample": start,
                        "end_sample": start + win_samples,
                        "age": age,
                        "sex": sex_encoded,
                        "handedness": handedness,
                        "externalizing": externalizing,
                    })

            except Exception:
                continue

    print(f"Scanned and generated {len(window_pointers)} window pointers from {release}")
    return window_pointers


class EEGWindowsDataset(torch.utils.data.Dataset):
    """
    Dataset that lazily loads EEG data from disk using MNE-BIDS path metadata,
    and uses a simple LRU cache to keep the last few raw files in memory.
    """
    def __init__(self, data_list, crop_samples, keep_idx, seed=42, cache_size=3):
        self.data_list = data_list
        self.crop_samples = crop_samples
        self.keep_idx = keep_idx
        self.rng = random.Random(seed)
        
        # Simple cache for raw objects to speed up repeated access to the same file
        self.raw_cache = {} 
        self.cache_keys = []
        self.cache_size = cache_size
        
        # For memory efficiency, pre-calculate path to subject folder
        self._subject_paths = {}
        for item in data_list:
            key = (item['subject'], item['run'])
            if key not in self._subject_paths:
                bids_root = Path(item['bids_root'])
                self._subject_paths[key] = bids_root / f"sub-{item['subject']}"
    
    def __len__(self):
        return len(self.data_list)

    def _get_raw_from_cache(self, subject_id, run):
        """Fetches raw object from cache or loads it from disk."""
        key = (subject_id, run)
        
        # 1. Hit: Move key to front (MRU)
        if key in self.raw_cache:
            self.cache_keys.remove(key)
            self.cache_keys.append(key)
            return self.raw_cache[key]
        
        # 2. Miss: Load from disk
        item = next(item for item in self.data_list if (item['subject'], item['run']) == key)
        subject_path = self._subject_paths[key]

        raw = load_raw_eeg(subject_path, subject_id, item['task'], run)
        
        if raw is None:
            raise FileNotFoundError(f"Could not load raw file for {subject_id}, run {run}")
            
        # 3. Add to cache (LRU eviction)
        if len(self.raw_cache) >= self.cache_size:
            lru_key = self.cache_keys.pop(0) # Pop oldest (LRU)
            del self.raw_cache[lru_key]
        
        self.raw_cache[key] = raw
        self.cache_keys.append(key)
        
        return raw

    def __getitem__(self, idx):
        item = self.data_list[idx]
        
        # LAZY LOAD: Get the raw MNE object (either from cache or disk)
        raw = self._get_raw_from_cache(item['subject'], item['run'])
        
        # Extract the window using MNE method for safe memory access
        eeg, _ = raw[:, item['start_sample']:item['end_sample']]
        # MNE returns (n_channels, n_times)
        
        # Take first 128 channels if 129 present
        if eeg.shape[0] == 129:
            eeg = eeg[:128, :]
        
        # Convert to tensor and apply crop/normalization (rest of your original logic)
        eeg = torch.from_numpy(eeg.copy()).float()
        C, T = eeg.shape
        
        # Random crop
        if T < self.crop_samples:
            pad_amount = self.crop_samples - T
            eeg = torch.nn.functional.pad(eeg, (0, pad_amount), mode='constant', value=0)
            start_crop = 0
            stop_crop = self.crop_samples
        else:
            start_crop = self.rng.randint(0, T - self.crop_samples)
            stop_crop = start_crop + self.crop_samples
            eeg = eeg[:, start_crop:stop_crop]
        
        # Per-window z-score normalization
        mu = eeg.mean(dim=1, keepdim=True)
        sd = eeg.std(dim=1, keepdim=True)
        eeg = (eeg - mu) / (sd + 1e-6)
        eeg = torch.nan_to_num(eeg, nan=0.0, posinf=0.0, neginf=0.0)
        eeg = torch.clamp(eeg, min=-1e3, max=1e3)
        
        # Get label and demographics (mostly copied from your original code)
        y = torch.tensor([item['externalizing']], dtype=torch.float32)
        
        age = item['age']
        try:
            age = float(age) if age is not None and math.isfinite(float(age)) else np.nan
        except:
            age = np.nan
        
        sex = item['sex']
        hand = item['handedness']
        try:
            hand = float(hand) if hand is not None and math.isfinite(float(hand)) else np.nan
        except:
            hand = np.nan
        
        full_demo = np.array([age, sex, hand], dtype=np.float32)
        if len(self.keep_idx) > 0:
            demo = torch.from_numpy(full_demo[self.keep_idx])
        else:
            demo = torch.empty(0, dtype=torch.float32)
        
        info = {
            'subject': item['subject'],
            'task': item['task'],
            'run': item['run'],
        }
        crop_idx = (item['start_sample'] + start_crop, item['start_sample'] + stop_crop)
        
        return eeg, y, demo, crop_idx, info


def extract_unique_demographics(data_list):
    seen = {}
    for item in data_list:
        sid = item['subject']
        if sid in seen:
            continue
        
        age = item['age']
        try:
            age = float(age) if age is not None and math.isfinite(float(age)) else np.nan
        except:
            age = np.nan
        
        sex = item['sex']
        hand = item['handedness']
        try:
            hand = float(hand) if hand is not None and math.isfinite(float(hand)) else np.nan
        except:
            hand = np.nan
        
        seen[sid] = [age, sex, hand]
    
    arr = np.array(list(seen.values()), dtype=np.float32) if seen else np.zeros((0, 3), dtype=np.float32)
    return arr

class SafeStandardScaler(StandardScaler):
    def fit(self, X, y=None):
        super().fit(X, y)
        if hasattr(self, "scale_"):
            bad = ~np.isfinite(self.scale_) | (self.scale_ == 0)
            self.scale_[bad] = 1.0
        if hasattr(self, "var_"):
            self.var_[~np.isfinite(self.var_)] = 0.0
        if hasattr(self, "mean_"):
            self.mean_[~np.isfinite(self.mean_)] = 0.0
        return self

def build_demo_transform(train_data):
    unique = extract_unique_demographics(train_data)
    
    if unique.shape[0] == 0:
        print("No demographics found; disabling late fusion.")
        return 0, None, np.zeros((0,), dtype=np.float32), np.array([], dtype=int)
    
    all_nan = np.isnan(unique).all(axis=0)
    keep_mask = ~all_nan
    keep_idx = np.where(keep_mask)[0]
    keep_names = [n for n, k in zip(['age', 'sex', 'hand'], keep_mask) if k]
    print("Keeping demo columns:", keep_names)
    
    kept = unique[:, keep_idx] if keep_idx.size > 0 else np.zeros((unique.shape[0], 0), dtype=np.float32)
    
    if kept.shape[1] == 0:
        print("No usable demographic columns; disabling late fusion.")
        return 0, None, np.zeros((0,), dtype=np.float32), np.array([], dtype=int)
    
    with np.errstate(all='ignore'):
        col_medians = np.nanmedian(kept, axis=0).astype(np.float32)
        col_medians[~np.isfinite(col_medians)] = 0.0
    
    def impute_cols(arr, meds):
        out = arr.copy()
        for j in range(out.shape[1]):
            mask = ~np.isfinite(out[:, j])
            out[mask, j] = meds[j]
        return out
    
    kept_imp = impute_cols(kept, col_medians)
    scaler = SafeStandardScaler().fit(kept_imp)
    print(f"Demo scaler fitted on {kept_imp.shape[0]} subjects | dims: {keep_idx.size}")
    
    def transform_batch(demo_tensor):
        if demo_tensor.numel() == 0 or keep_idx.size == 0:
            return demo_tensor.to(device=device, dtype=torch.float32)
        
        demo_np = demo_tensor.detach().cpu().numpy().astype(np.float32)
        for j in range(demo_np.shape[1]):
            mask = ~np.isfinite(demo_np[:, j])
            demo_np[mask, j] = col_medians[j]
        
        demo_np = scaler.transform(demo_np)
        out = torch.from_numpy(demo_np).to(device=device, dtype=torch.float32)
        out = torch.nan_to_num(out, nan=0.0, posinf=0.0, neginf=0.0)
        return out
    
    return keep_idx.size, transform_batch, col_medians, keep_idx



print("\n" + "="*60)
print("Loading training releases...")
print("="*60)

# *** NOW CALLING load_release_data_lazy ***
train_data_pointers = []
for release in TRAIN_RELEASES:
    train_data_pointers.extend(load_release_data_lazy(release, task=TASK, data_root=DATA_ROOT)) 

print(f"\n✓ Total training windows (pointers): {len(train_data_pointers)}")

print("\n" + "="*60)
print("Loading validation release...")
print("="*60)

val_data_pointers = load_release_data_lazy(VAL_RELEASE, TASK, DATA_ROOT)
print(f"✓ Total validation windows (pointers): {len(val_data_pointers)}")

# Build demographics
print("\n" + "="*60)
print("Building demographic transformations...")
print("="*60)
demodim, transform_demo_batch, demo_medians, keep_idx = build_demo_transform(train_data_pointers)

# Create datasets
print("\n" + "="*60)
print("Creating PyTorch datasets...")
print("="*60)

if len(train_data_pointers) == 0:
    raise ValueError("No training window pointers found – check your BIDS structure and data validity!")

train_dataset = EEGWindowsDataset(
    train_data_pointers,
    crop_samples=int(CROP_SEC * SFREQ),
    keep_idx=keep_idx,
    seed=42,
    cache_size=5 # Cache the 5 most recently accessed raw files
)

val_dataset = EEGWindowsDataset(
    val_data_pointers,
    crop_samples=int(CROP_SEC * SFREQ),
    keep_idx=keep_idx,
    seed=42,
    cache_size=5 
)

# DataLoaders
BATCH_SIZE = 32
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4, 
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

# T=200 for 2-second crop
C, T = 128, int(CROP_SEC * SFREQ) 
print(f"\n✓ Train batches: {len(train_loader)}")
print(f"✓ Val batches: {len(val_loader)}")
print(f"\n✓ Sample EEG shape: ({C}, {T})")
print(f"✓ Demographic features: {demodim}")
print("\n" + "="*60)
print("DATA LOADING COMPLETE!")
print("="*60)


Loading training releases...


Scanning R6_L100_bdf for windows: 100%|██████████| 135/135 [01:35<00:00,  1.41it/s]


Scanned and generated 32789 window pointers from R6_L100_bdf


Scanning R7_L100_bdf for windows: 100%|██████████| 381/381 [03:37<00:00,  1.75it/s]


Scanned and generated 81065 window pointers from R7_L100_bdf


Scanning R8_L100_bdf for windows: 100%|██████████| 257/257 [03:45<00:00,  1.14it/s]


Scanned and generated 93674 window pointers from R8_L100_bdf

✓ Total training windows (pointers): 207528

Loading validation release...


Scanning R5_L100_bdf for windows: 100%|██████████| 330/330 [03:55<00:00,  1.40it/s]

Scanned and generated 106562 window pointers from R5_L100_bdf
✓ Total validation windows (pointers): 106562

Building demographic transformations...
Keeping demo columns: ['age', 'sex']
Demo scaler fitted on 486 subjects | dims: 2

Creating PyTorch datasets...

✓ Train batches: 6486
✓ Val batches: 3331

✓ Sample EEG shape: (128, 200)
✓ Demographic features: 2

DATA LOADING COMPLETE!





In [8]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int = 8, dropout: float = 0.3):
        super().__init__()
        self.attention = nn.MultiheadAttention(
            embed_dim=embed_dim, num_heads=num_heads, dropout=dropout, batch_first=True
        )
    def forward(self, x):
        attn_output, _ = self.attention(x, x, x)
        return attn_output

class EnhancedEEGNetRegressor(nn.Module):
    def __init__(self, n_channels: int = 128, n_times: int = 200, n_demographic_features: int = 3, dropout: float = 0.5, F1: int = 16, D: int = 2, num_heads: int = 8):
        super().__init__()
        self.n_channels = n_channels
        self.dropout_rate = dropout
        self.temporal_conv = nn.Conv2d(1, F1, kernel_size=(1, 51), stride=(1, 1), padding=(0, 25), bias=False)
        self.bn1 = nn.BatchNorm2d(F1)
        self.spatial_conv = nn.Conv2d(F1, F1 * D, kernel_size=(n_channels, 1), stride=(1, 1), groups=F1, bias=False)
        self.bn2 = nn.BatchNorm2d(F1 * D)
        self.elu = nn.ELU()
        self.pool1 = nn.AvgPool2d(kernel_size=(1, 4), stride=(1, 4))
        self.dropout1 = nn.Dropout(p=dropout)
        self.attention = MultiHeadAttention(embed_dim=F1 * D, num_heads=num_heads, dropout=dropout)
        self.pool2 = nn.AvgPool2d(kernel_size=(1, 2), stride=(1, 2))
        self.dropout2 = nn.Dropout(p=dropout)
        self.eeg_feature_dim = F1 * D * (n_times // 8) # 32 * 25 = 800 for 200 samples
        
        self.demographic_encoder = nn.Sequential(
            nn.Linear(n_demographic_features, 16), nn.ReLU(), nn.Dropout(p=dropout), nn.Linear(16, 32)
        )
        fusion_input_dim = self.eeg_feature_dim + 32 if n_demographic_features > 0 else self.eeg_feature_dim
        self.fusion = nn.Sequential(
            nn.Linear(fusion_input_dim, 64), nn.ReLU(), nn.Dropout(p=dropout),
            nn.Linear(64, 32), nn.ReLU(), nn.Dropout(p=dropout)
        )
        self.regression_head = nn.Linear(32, 1)
    
    def forward(self, eeg: torch.Tensor, demographics: torch.Tensor = None):
        x = self.temporal_conv(eeg)
        x = self.bn1(x)
        x = self.elu(x)
        x = self.spatial_conv(x)
        x = self.bn2(x)
        x = self.elu(x)
        x = self.pool1(x)
        x = self.dropout1(x)
        batch_size = x.shape[0]
        x = x.squeeze(2).transpose(1, 2)
        x = self.attention(x)
        x = x.transpose(1, 2).unsqueeze(2)
        x = self.pool2(x)
        x = self.dropout2(x)
        eeg_features = x.view(batch_size, -1)
        
        if demographics is not None and demographics.numel() > 0:
            demo_features = self.demographic_encoder(demographics)
            combined_features = torch.cat([eeg_features, demo_features], dim=1)
        else:
            combined_features = eeg_features
        
        fused = self.fusion(combined_features)
        output = self.regression_head(fused)
        return output

# Loss and Metric functions
class RMSELoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.mse = nn.MSELoss()
    def forward(self, pred, target):
        return torch.sqrt(self.mse(pred, target) + 1e-8)

class NRMSELoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.mse = nn.MSELoss()
    def forward(self, pred, target):
        rmse = torch.sqrt(self.mse(pred, target) + 1e-8)
        target_range = target.max() - target.min() + 1e-8
        return rmse / target_range

def calculate_metrics(predictions, targets):
    mse = np.mean((predictions - targets) ** 2)
    rmse = np.sqrt(mse)
    target_range = targets.max() - targets.min()
    nrmse = rmse / (target_range + 1e-8)
    mae = np.mean(np.abs(predictions - targets))
    if len(predictions) > 1 and np.std(predictions) > 0 and np.std(targets) > 0:
        corr, _ = pearsonr(predictions, targets)
    else:
        corr = 0.0
    return {'mse': mse, 'rmse': rmse, 'nrmse': nrmse, 'mae': mae, 'pearson_r': corr}


class EarlyStoppingCallback:
    def __init__(self, patience: int = 3, verbose: bool = True):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_loss = None
        self.early_stop = False
    
    def __call__(self, val_loss: float, model: nn.Module, save_path: str = "best_model.pt"):
        if self.best_loss is None:
            self.best_loss = val_loss
            torch.save(model.state_dict(), save_path)
        elif val_loss < self.best_loss * 0.99:
            self.best_loss = val_loss
            self.counter = 0
            torch.save(model.state_dict(), save_path)
            if self.verbose:
                print(f"✓ Validation loss improved to {val_loss:.6f}. Model saved.")
        else:
            self.counter += 1
            if self.verbose:
                print(f"No improvement for {self.counter}/{self.patience} epochs")
            if self.counter >= self.patience:
                self.early_stop = True
                if self.verbose:
                    print("Early stopping triggered!")

def train_epoch(model, train_loader, criterion, optimizer, device, transform_demo_batch=None):
    model.train()
    total_loss = 0.0

    for eeg_batch, target_batch, demo_batch, _, _ in train_loader:
        # (B, C, T) -> (B, 1, C, T) and move to GPU
        eeg_batch = eeg_batch.unsqueeze(1).to(device, non_blocking=True)
        target_batch = target_batch.to(device, non_blocking=True)

        if demo_batch is not None and demo_batch.numel() > 0:
            demo_batch = demo_batch.to(device, non_blocking=True)
        else:
            demo_batch = None

        if transform_demo_batch is not None and demo_batch is not None and demo_batch.numel() > 0:
            demo_batch = transform_demo_batch(demo_batch)

        optimizer.zero_grad()
        preds = model(eeg_batch, demo_batch)     # runs on GPU
        loss = criterion(preds, target_batch)    # on GPU
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(train_loader)

def validate_epoch(model, val_loader, criterion, device, transform_demo_batch=None):
    model.eval()
    total_loss = 0.0
    all_preds, all_targets = [], []

    with torch.no_grad():
        for eeg_batch, target_batch, demo_batch, _, _ in val_loader:
            eeg_batch = eeg_batch.unsqueeze(1).to(device, non_blocking=True)
            target_batch = target_batch.to(device, non_blocking=True)

            if demo_batch is not None and demo_batch.numel() > 0:
                demo_batch = demo_batch.to(device, non_blocking=True)
            else:
                demo_batch = None

            if transform_demo_batch is not None and demo_batch is not None and demo_batch.numel() > 0:
                demo_batch = transform_demo_batch(demo_batch)

            preds = model(eeg_batch, demo_batch)
            loss = criterion(preds, target_batch)

            total_loss += loss.item()
            all_preds.append(preds.detach().cpu().numpy())
            all_targets.append(target_batch.detach().cpu().numpy())

    avg_loss = total_loss / len(val_loader)
    all_preds = np.concatenate(all_preds, axis=0).flatten()
    all_targets = np.concatenate(all_targets, axis=0).flatten()
    return avg_loss, all_preds, all_targets


In [11]:
print("\n" + "="*60)
print("INITIALIZING MODEL")
print("="*60)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

model = EnhancedEEGNetRegressor(
    n_channels=C,
    n_times=T,
    n_demographic_features=demodim,
    dropout=0.25,
    F1=16,
    D=2,
    num_heads=8,
).to(device)

criterion = NRMSELoss().to(device)

print("Model device:", next(model.parameters()).device)


total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"✓ Model created on {device}")
print(f"✓ Total parameters: {total_params:,}")
print(f"✓ Trainable parameters: {trainable_params:,}")
print("="*60)

print("\n" + "="*60)
print("STARTING TRAINING")
print("="*60)

N_EPOCHS = 100
LEARNING_RATE = 0.001
EARLY_STOPPING_PATIENCE = 15
USE_NRMSE = True

if USE_NRMSE:
    criterion = NRMSELoss().to(device)   # loss on GPU
    print("Using nRMSE loss")
else:
    criterion = RMSELoss().to(device)    # loss on GPU
    print("Using RMSE loss")

optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode="min", factor=0.5, patience=5, verbose=True
)

early_stopping = EarlyStoppingCallback(patience=EARLY_STOPPING_PATIENCE, verbose=True)

history = {
    "train_loss": [],
    "val_loss": [],
    "val_rmse": [],
    "val_mae": [],
    "val_pearson_r": [],
}

best_model_path = "/kaggle/working/best_enhanced_eegnet.pt"

for epoch in range(N_EPOCHS):
    print(f"\n{'='*60}")
    print(f"Epoch {epoch + 1}/{N_EPOCHS}")
    print(f"{'='*60}")

    train_loss = train_epoch(
        model,
        train_loader,
        criterion,
        optimizer,
        device,
        transform_demo_batch=transform_demo_batch,
    )
    history["train_loss"].append(train_loss)
    print(f"Train Loss: {train_loss:.6f}")

    # validate_epoch must also move its batches to device internally
    val_loss, val_preds, val_targets = validate_epoch(
        model,
        val_loader,
        criterion,
        device,
        transform_demo_batch=transform_demo_batch,
    )
    history["val_loss"].append(val_loss)

    metrics = calculate_metrics(val_preds, val_targets)
    history["val_rmse"].append(metrics["rmse"])
    history["val_mae"].append(metrics["mae"])
    history["val_pearson_r"].append(metrics["pearson_r"])

    print(f"Val Loss: {val_loss:.6f}")
    print(f"Val RMSE: {metrics['rmse']:.6f}")
    print(f"Val nRMSE: {metrics['nrmse']:.6f}")
    print(f"Val MAE: {metrics['mae']:.6f}")
    print(f"Val Pearson r: {metrics['pearson_r']:.4f}")

    scheduler.step(val_loss)

    early_stopping(val_loss, model, best_model_path)
    if early_stopping.early_stop:
        print(f"\n✓ Early stopping at epoch {epoch + 1}")
        break

try:
    state_dict = torch.load(best_model_path, map_location=device)
    model.load_state_dict(state_dict)
    model.to(device)
    print(f"\n✓ Loaded best model from {best_model_path} onto {device}")
except Exception as e:
    print(f"\nWarning: Could not load best model state dict ({e}). Using model from last epoch.")

print("\n" + "="*60)
print("TRAINING COMPLETE!")
print("="*60)



INITIALIZING MODEL
Device: cuda
Model device: cuda:0
✓ Model created on cuda
✓ Total parameters: 65,249
✓ Trainable parameters: 65,249

STARTING TRAINING
Using nRMSE loss

Epoch 1/100


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x78f0d3b06840>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
Exception ignored in:     <function _MultiProcessingDataLoaderIter.__del__ at 0x78f0d3b06840>
self._shutdown_workers()Traceback (most recent call last):

  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1618, in __del__
      File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
self._shutdown_workers()
      File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
if w.is_alive():Exception ignored in: 
    <function _MultiProcessingDataLoaderIter.__del__ at 0x78f0d3b06840> if w.is_alive(): 

 Traceback (most recent call last):
   File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py", line 161

KeyboardInterrupt: 

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

axes[0, 0].plot(history['train_loss'], label='Train Loss', linewidth=2)
axes[0, 0].plot(history['val_loss'], label='Val Loss', linewidth=2)
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].set_title('Training and Validation Loss')
axes[0, 0].legend()
axes[0, 0].grid(True)

axes[0, 1].plot(history['val_rmse'], label='Val RMSE', linewidth=2, color='orange')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('RMSE')
axes[0, 1].set_title('Validation RMSE')
axes[0, 1].legend()
axes[0, 1].grid(True)

axes[1, 0].plot(history['val_mae'], label='Val MAE', linewidth=2, color='green')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('MAE')
axes[1, 0].set_title('Validation MAE')
axes[1, 0].legend()
axes[1, 0].grid(True)

axes[1, 1].plot(history['val_pearson_r'], label='Val Pearson r', linewidth=2, color='red')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Pearson r')
axes[1, 1].set_title('Validation Pearson Correlation')
axes[1, 1].legend()
axes[1, 1].grid(True)

plt.tight_layout()
plt.savefig('/kaggle/working/training_history.png', dpi=150, bbox_inches='tight')
plt.show()

print("✓ Training curves saved to training_history.png")


print("\n" + "="*60)
print("FINAL EVALUATION")
print("="*60)

val_loss, val_preds, val_targets = validate_epoch(model, val_loader, criterion, device, transform_demo_batch=transform_demo_batch)
final_metrics = calculate_metrics(val_preds, val_targets)

print(f"Final Val RMSE: {final_metrics['rmse']:.6f}")
print(f"Final Val MAE: {final_metrics['mae']:.6f}")
print(f"Final Val Pearson r: {final_metrics['pearson_r']:.4f}")
print("="*60)