In [1]:
!pip install mne==1.10.1 mne-bids==0.17.0 --quiet
!pip install scikit-learn --quiet

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.4/7.4 MB[0m [31m58.8 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m168.9/168.9 kB[0m [31m9.7 MB/s[0m eta [36m0:00:00[0m
[?25h

In [28]:
from pathlib import Path

def get_actual_release_path(data_root, release_name):
    """
    Returns the path to the release by checking for double nesting.
    """
    outer = data_root / release_name
    inner = outer / release_name
    if inner.exists():
        return inner
    return outer


In [37]:
from pathlib import Path
import os, math, random
from collections import defaultdict
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.nn.utils import clip_grad_norm_

from sklearn.preprocessing import StandardScaler
from scipy.stats import pearsonr

import mne
from mne_bids import BIDSPath, read_raw_bids
mne.set_log_level('ERROR')

In [46]:
# Configuration
SFREQ = 100  # Sampling rate (Hz)
WIN_SEC = 4  # Window size (seconds)
CROP_SEC = 2  # Random crop size (seconds)
STRIDE_SEC = 2  # Window stride (seconds)
TASK = "contrastChangeDetection"  # Primary task

# Kaggle paths - UPDATE THIS with your dataset name
DATA_ROOT = Path("/kaggle/input/eeg-dataset/R6_L100_bdf/R6_L100_bdf")

TRAIN_RELEASES = ["R6_L100_bdf"]
VAL_RELEASE = "R5_L100_bdf"

SUB_RM = ["NDARAC350XUM", "NDARAJ689BVN"] 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

def seed_all(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_all(42)

Using device: cuda
CUDA available: True
GPU: Tesla T4


In [47]:
def resolve_double_nested_path(data_root, release):
    outer = data_root / release
    inner = outer / release
    # If inner folder exists and contains participants.tsv, use it
    if (inner.exists() and (inner / "participants.tsv").exists()):
        return inner
    elif (outer.exists() and (outer / "participants.tsv").exists()):
        return outer
    else:
        # If neither has participants.tsv, return None
        return None

def load_participants_data(release_path):
    """Load participants.tsv file with demographics and labels"""
    participants_file = release_path / "participants.tsv"
    if not participants_file.exists():
        print(f"Warning: {participants_file} not found")
        return pd.DataFrame()
    
    df = pd.read_csv(participants_file, sep='\t')
    return df


def load_eeg_file(subject_path, subject_id, task, run=None):
    """Load a single EEG file using MNE-BIDS"""
    try:
        # root is the BIDS root; for sub-XXX/eeg/*.bdf this is subject_path.parent.parent
        bids_path = BIDSPath(
            subject=subject_id,
            task=task,
            run=run,
            datatype='eeg',
            extension='.bdf',
            root=subject_path.parent.parent
        )
        raw = read_raw_bids(bids_path, verbose=False)
        raw.load_data()
        return raw
    except Exception:
        # Silently skip files that can't be loaded
        return None


def create_windows_from_raw(raw, win_samples, stride_samples):
    """Create fixed-length windows from raw EEG data"""
    data = raw.get_data()  # (n_channels, n_times)
    n_channels, n_times = data.shape

    windows = []
    starts = range(0, n_times - win_samples + 1, stride_samples)

    for start in starts:
        end = start + win_samples
        window = data[:, start:end]  # (n_channels, win_samples)
        windows.append(window)

    return np.array(windows) if windows else np.array([]).reshape(0, n_channels, win_samples)


def _resolve_release_path(data_root, release):
    """
    Handle double nesting like:
      /kaggle/input/dataset/R5_L100_bdf/R5_L100_bdf/...
    Returns the *inner* folder if present, else the outer one.
    """
    outer = data_root / release
    inner = outer / release
    if inner.exists():
        return inner
    return outer


def load_release_data(release, task=TASK, data_root=DATA_ROOT):
    # Resolve actual folder that contains participants.tsv and sub-XXX
    release_path = _resolve_release_path(data_root, release)

    if not release_path.exists():
        print(f"Warning: Release path {release_path} not found")
        return []

    # Load demographics/labels
    participants_df = load_participants_data(release_path)

    if participants_df.empty:
        print(f"No participants data found for {release}")
        return []

    dataset = []

    # Iterate through subjects
    for _, row in tqdm(
        participants_df.iterrows(),
        total=len(participants_df),
        desc=f"Loading {release}"
    ):
        subject_id = row['participant_id'].replace('sub-', '')

        # Skip excluded subjects
        if subject_id in SUB_RM:
            continue

        # Extract demographics and label
        age = row.get('age', np.nan)
        sex = row.get('sex', np.nan)
        handedness = row.get('handedness', np.nan)
        externalizing = row.get('externalizing', np.nan)

        # Validate externalizing score
        try:
            externalizing = float(externalizing)
            if not math.isfinite(externalizing):
                continue
        except Exception:
            continue

        # Process sex encoding
        sex_str = str(sex).strip().lower()
        if sex_str in ['female', 'f', '2']:
            sex_encoded = 1.0
        elif sex_str in ['male', 'm', '1']:
            sex_encoded = 0.0
        else:
            sex_encoded = np.nan

        # Path to subject directory inside the resolved release folder
        subject_path = release_path / f"sub-{subject_id}"

        # For contrastChangeDetection, try all runs
        runs = [1, 2, 3] if task == "contrastChangeDetection" else [None]

        for run in runs:
            raw = load_eeg_file(subject_path, subject_id, task, run)

            if raw is None:
                continue

            # Check valid length
            if raw.n_times < 4 * SFREQ:
                continue

            # Check channel count (128 EEG + 1 reference = 129)
            if len(raw.ch_names) != 129:
                continue

            # Create windows
            win_samples = int(WIN_SEC * SFREQ)
            stride_samples = int(STRIDE_SEC * SFREQ)
            windows = create_windows_from_raw(raw, win_samples, stride_samples)

            if windows.shape[0] == 0:
                continue

            # Store each window with metadata
            for window in windows:
                dataset.append(
                    {
                        "eeg": window,  # (129, 400)
                        "subject": subject_id,
                        "task": task,
                        "run": run,
                        "age": age,
                        "sex": sex_encoded,
                        "handedness": handedness,
                        "externalizing": externalizing,
                    }
                )

    print(f"Loaded {len(dataset)} windows from {release}")
    return dataset


In [48]:
class EEGWindowsDataset(torch.utils.data.Dataset):
    """Dataset for windowed EEG data with demographics"""
    
    def __init__(self, data_list, crop_samples, keep_idx, seed=42):
        """
        Args:
            data_list: List of dicts with keys ['eeg', 'subject', 'externalizing', etc.]
            crop_samples: Number of samples for random crop
            keep_idx: Indices of demographic features to keep
            seed: Random seed
        """
        self.data_list = data_list
        self.crop_samples = crop_samples
        self.keep_idx = keep_idx
        self.rng = random.Random(seed)
    
    def __len__(self):
        return len(self.data_list)
    
    def __getitem__(self, idx):
        item = self.data_list[idx]
        
        # Get EEG data
        eeg = item['eeg']  # (n_channels, n_times)
        
        # Take first 128 channels if 129 present
        if eeg.shape[0] == 129:
            eeg = eeg[:128, :]
        
        # Convert to tensor
        eeg = torch.from_numpy(eeg.copy()).float()
        C, T = eeg.shape
        
        # Random crop
        if T < self.crop_samples:
            # Pad if too short
            pad_amount = self.crop_samples - T
            eeg = torch.nn.functional.pad(eeg, (0, pad_amount), mode='constant', value=0)
            start = 0
            stop = self.crop_samples
        else:
            start = self.rng.randint(0, T - self.crop_samples)
            stop = start + self.crop_samples
            eeg = eeg[:, start:stop]
        
        # Per-window z-score normalization
        mu = eeg.mean(dim=1, keepdim=True)
        sd = eeg.std(dim=1, keepdim=True)
        eeg = (eeg - mu) / (sd + 1e-6)
        eeg = torch.nan_to_num(eeg, nan=0.0, posinf=0.0, neginf=0.0)
        eeg = torch.clamp(eeg, min=-1e3, max=1e3)
        
        # Get label
        y = torch.tensor([item['externalizing']], dtype=torch.float32)
        
        # Get demographics
        age = item['age']
        try:
            age = float(age) if age is not None and math.isfinite(float(age)) else np.nan
        except:
            age = np.nan
        
        sex = item['sex']
        hand = item['handedness']
        try:
            hand = float(hand) if hand is not None and math.isfinite(float(hand)) else np.nan
        except:
            hand = np.nan
        
        # Build demo array
        full_demo = np.array([age, sex, hand], dtype=np.float32)
        if len(self.keep_idx) > 0:
            demo = torch.from_numpy(full_demo[self.keep_idx])
        else:
            demo = torch.empty(0, dtype=torch.float32)
        
        info = {
            'subject': item['subject'],
            'task': item['task'],
            'run': item['run'],
        }
        crop_idx = (start, stop)
        
        return eeg, y, demo, crop_idx, info

In [49]:


def extract_unique_demographics(data_list):
    """Extract unique subject demographics"""
    seen = {}
    for item in data_list:
        sid = item['subject']
        if sid in seen:
            continue
        
        age = item['age']
        try:
            age = float(age) if age is not None and math.isfinite(float(age)) else np.nan
        except:
            age = np.nan
        
        sex = item['sex']
        hand = item['handedness']
        try:
            hand = float(hand) if hand is not None and math.isfinite(float(hand)) else np.nan
        except:
            hand = np.nan
        
        seen[sid] = [age, sex, hand]
    
    arr = np.array(list(seen.values()), dtype=np.float32) if seen else np.zeros((0, 3), dtype=np.float32)
    return arr

class SafeStandardScaler(StandardScaler):
    """Robust scaler with NaN/Inf safeguards"""
    def fit(self, X, y=None):
        super().fit(X, y)
        if hasattr(self, "scale_"):
            bad = ~np.isfinite(self.scale_) | (self.scale_ == 0)
            self.scale_[bad] = 1.0
        if hasattr(self, "var_"):
            self.var_[~np.isfinite(self.var_)] = 0.0
        if hasattr(self, "mean_"):
            self.mean_[~np.isfinite(self.mean_)] = 0.0
        return self

def build_demo_transform(train_data):
    """Build demographic transformation pipeline"""
    unique = extract_unique_demographics(train_data)
    
    if unique.shape[0] == 0:
        print("No demographics found; disabling late fusion.")
        return 0, None, np.zeros((0,), dtype=np.float32), np.array([], dtype=int)
    
    # Detect all-NaN columns
    all_nan = np.isnan(unique).all(axis=0)
    keep_mask = ~all_nan
    keep_idx = np.where(keep_mask)[0]
    keep_names = [n for n, k in zip(['age', 'sex', 'hand'], keep_mask) if k]
    print("Keeping demo columns:", keep_names)
    
    kept = unique[:, keep_idx] if keep_idx.size > 0 else np.zeros((unique.shape[0], 0), dtype=np.float32)
    
    if kept.shape[1] == 0:
        print("No usable demographic columns; disabling late fusion.")
        return 0, None, np.zeros((0,), dtype=np.float32), np.array([], dtype=int)
    
    # Compute column medians for imputation
    with np.errstate(all='ignore'):
        col_medians = np.nanmedian(kept, axis=0).astype(np.float32)
        col_medians[~np.isfinite(col_medians)] = 0.0
    
    # Impute and fit scaler
    def impute_cols(arr, meds):
        out = arr.copy()
        for j in range(out.shape[1]):
            mask = ~np.isfinite(out[:, j])
            out[mask, j] = meds[j]
        return out
    
    kept_imp = impute_cols(kept, col_medians)
    scaler = SafeStandardScaler().fit(kept_imp)
    print(f"Demo scaler fitted on {kept_imp.shape[0]} subjects | dims: {keep_idx.size}")
    
    def transform_batch(demo_tensor):
        if demo_tensor.numel() == 0 or keep_idx.size == 0:
            return demo_tensor.to(device=device, dtype=torch.float32)
        
        demo_np = demo_tensor.detach().cpu().numpy().astype(np.float32)
        for j in range(demo_np.shape[1]):
            mask = ~np.isfinite(demo_np[:, j])
            demo_np[mask, j] = col_medians[j]
        
        demo_np = scaler.transform(demo_np)
        out = torch.from_numpy(demo_np).to(device=device, dtype=torch.float32)
        out = torch.nan_to_num(out, nan=0.0, posinf=0.0, neginf=0.0)
        return out
    
    return keep_idx.size, transform_batch, col_medians, keep_idx

In [50]:
print("\n" + "="*60)
print("Loading training releases...")
print("="*60)

DATA_ROOT = Path("/kaggle/input/eeg-dataset")
TRAIN_RELEASES = ["R6_L100_bdf", "R5_L100_bdf"]  # List the releases you need

train_data = []
for release in TRAIN_RELEASES:
    resolved = resolve_double_nested_path(DATA_ROOT, release)
    if resolved is None:
        print(f"Could not find a valid release folder for {release}")
        continue
    print(f"Using release path for {release}: {resolved}")
    train_data.extend(load_release_data(release, task=TASK, data_root=resolved.parent))



print(f"\n✓ Total training windows: {len(train_data)}")

print("\n" + "="*60)
print("Loading validation release...")
print("="*60)

val_data = load_release_data(VAL_RELEASE, TASK, DATA_ROOT)
print(f"✓ Total validation windows: {len(val_data)}")

# Build demographics transform
print("\n" + "="*60)
print("Building demographic transformations...")
print("="*60)
demodim, transform_demo_batch, demo_medians, keep_idx = build_demo_transform(train_data)

# Create datasets
print("\n" + "="*60)
print("Creating PyTorch datasets...")
print("="*60)

train_dataset = EEGWindowsDataset(
    train_data,
    crop_samples=int(CROP_SEC * SFREQ),
    keep_idx=keep_idx,
    seed=42
)

val_dataset = EEGWindowsDataset(
    val_data,
    crop_samples=int(CROP_SEC * SFREQ),
    keep_idx=keep_idx,
    seed=42
)
if len(train_data) == 0:
    raise ValueError("No training windows found – check your folder nesting, names, and participants.tsv presence!")

# DataLoaders
BATCH_SIZE = 32
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

print(f"\n✓ Train batches: {len(train_loader)}")
print(f"✓ Val batches: {len(val_loader)}")

# Infer shapes
sample_X, sample_y, sample_demo, _, _ = train_dataset[0]
C, T = sample_X.shape
print(f"\n✓ Sample EEG shape: ({C}, {T})")
print(f"✓ Demographic features: {demodim}")
print("\n" + "="*60)
print("DATA LOADING COMPLETE!")
print("="*60)



Loading training releases...
Using release path for R6_L100_bdf: /kaggle/input/eeg-dataset/R6_L100_bdf/R6_L100_bdf


Loading R6_L100_bdf: 100%|██████████| 135/135 [00:00<00:00, 1245.35it/s]


Loaded 0 windows from R6_L100_bdf
Using release path for R5_L100_bdf: /kaggle/input/eeg-dataset/R5_L100_bdf/R5_L100_bdf


Loading R5_L100_bdf: 100%|██████████| 330/330 [00:00<00:00, 1198.50it/s]


Loaded 0 windows from R5_L100_bdf

✓ Total training windows: 0

Loading validation release...


Loading R5_L100_bdf: 100%|██████████| 330/330 [00:00<00:00, 1246.91it/s]

Loaded 0 windows from R5_L100_bdf
✓ Total validation windows: 0

Building demographic transformations...
No demographics found; disabling late fusion.

Creating PyTorch datasets...





ValueError: No training windows found – check your folder nesting, names, and participants.tsv presence!

In [None]:
class MultiHeadAttention(nn.Module):
    """Multi-Head Self-Attention for EEG channel relationships"""
    def __init__(self, embed_dim: int, num_heads: int = 8, dropout: float = 0.3):
        super().__init__()
        self.attention = nn.MultiheadAttention(
            embed_dim=embed_dim,
            num_heads=num_heads,
            dropout=dropout,
            batch_first=True
        )
    
    def forward(self, x):
        attn_output, _ = self.attention(x, x, x)
        return attn_output

class EnhancedEEGNetRegressor(nn.Module):
    """
    Enhanced EEGNet for continuous externalizing score regression.
    
    Key modifications from standard EEGNet:
    - Regression output (1 continuous value)
    - Batch normalization after conv blocks
    - Multi-head self-attention
    - Increased dropout (0.5)
    - Late fusion with demographic features
    """
    def __init__(
        self,
        n_channels: int = 128,
        n_times: int = 200,
        n_demographic_features: int = 3,
        dropout: float = 0.5,
        F1: int = 16,
        D: int = 2,
        num_heads: int = 8,
    ):
        super().__init__()
        self.n_channels = n_channels
        self.dropout_rate = dropout
        
        # EEG processing branch
        self.temporal_conv = nn.Conv2d(
            1, F1, kernel_size=(1, 51), stride=(1, 1),
            padding=(0, 25), bias=False
        )
        self.bn1 = nn.BatchNorm2d(F1)
        
        self.spatial_conv = nn.Conv2d(
            F1, F1 * D, kernel_size=(n_channels, 1), stride=(1, 1),
            groups=F1, bias=False
        )
        self.bn2 = nn.BatchNorm2d(F1 * D)
        
        self.elu = nn.ELU()
        self.pool1 = nn.AvgPool2d(kernel_size=(1, 4), stride=(1, 4))
        self.dropout1 = nn.Dropout(p=dropout)
        
        # Multi-head attention
        self.attention = MultiHeadAttention(
            embed_dim=F1 * D,
            num_heads=num_heads,
            dropout=dropout
        )
        
        self.pool2 = nn.AvgPool2d(kernel_size=(1, 2), stride=(1, 2))
        self.dropout2 = nn.Dropout(p=dropout)
        
        # Calculate EEG feature dimension
        self.eeg_feature_dim = F1 * D * 25  # 32 * 25 = 800
        
        # Demographic fusion branch
        self.demographic_encoder = nn.Sequential(
            nn.Linear(n_demographic_features, 16),
            nn.ReLU(),
            nn.Dropout(p=dropout),
            nn.Linear(16, 32)
        )
        
        # Fusion and regression head
        fusion_input_dim = self.eeg_feature_dim + 32
        self.fusion = nn.Sequential(
            nn.Linear(fusion_input_dim, 64),
            nn.ReLU(),
            nn.Dropout(p=dropout),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Dropout(p=dropout)
        )
        
        self.regression_head = nn.Linear(32, 1)
    
    def forward(self, eeg: torch.Tensor, demographics: torch.Tensor = None):
        # EEG branch
        x = self.temporal_conv(eeg)
        x = self.bn1(x)
        x = self.elu(x)
        
        x = self.spatial_conv(x)
        x = self.bn2(x)
        x = self.elu(x)
        
        x = self.pool1(x)
        x = self.dropout1(x)
        
        # Reshape for attention
        batch_size = x.shape[0]
        x = x.squeeze(2)
        x = x.transpose(1, 2)
        
        # Apply attention
        x = self.attention(x)
        x = x.transpose(1, 2)
        x = x.unsqueeze(2)
        
        x = self.pool2(x)
        x = self.dropout2(x)
        
        # Flatten
        eeg_features = x.view(batch_size, -1)
        
        # Demographic branch
        if demographics is not None and demographics.numel() > 0:
            demo_features = self.demographic_encoder(demographics)
            combined_features = torch.cat([eeg_features, demo_features], dim=1)
        else:
            combined_features = eeg_features
        
        # Fusion and regression
        fused = self.fusion(combined_features)
        output = self.regression_head(fused)
        
        return output

class RMSELoss(nn.Module):
    """Root Mean Square Error Loss"""
    def __init__(self):
        super().__init__()
        self.mse = nn.MSELoss()
    
    def forward(self, pred, target):
        return torch.sqrt(self.mse(pred, target) + 1e-8)

class NRMSELoss(nn.Module):
    """Normalized RMSE Loss"""
    def __init__(self):
        super().__init__()
        self.mse = nn.MSELoss()
    
    def forward(self, pred, target):
        rmse = torch.sqrt(self.mse(pred, target) + 1e-8)
        target_range = target.max() - target.min() + 1e-8
        return rmse / target_range

def calculate_metrics(predictions, targets):
    """Calculate regression metrics"""
    mse = np.mean((predictions - targets) ** 2)
    rmse = np.sqrt(mse)
    nrmse = rmse / (targets.max() - targets.min() + 1e-8)
    mae = np.mean(np.abs(predictions - targets))
    
    # Pearson correlation
    if len(predictions) > 1:
        corr, _ = pearsonr(predictions, targets)
    else:
        corr = 0.0
    
    return {
        'mse': mse,
        'rmse': rmse,
        'nrmse': nrmse,
        'mae': mae,
        'pearson_r': corr
    }


In [None]:
class EarlyStoppingCallback:
    """Early stopping with model checkpointing"""
    def __init__(self, patience: int = 10, verbose: bool = True):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_loss = None
        self.early_stop = False
    
    def __call__(self, val_loss: float, model: nn.Module, save_path: str = "best_model.pt"):
        if self.best_loss is None:
            self.best_loss = val_loss
            torch.save(model.state_dict(), save_path)
        elif val_loss < self.best_loss * 0.99:
            self.best_loss = val_loss
            self.counter = 0
            torch.save(model.state_dict(), save_path)
            if self.verbose:
                print(f"✓ Validation loss improved to {val_loss:.6f}. Model saved.")
        else:
            self.counter += 1
            if self.verbose:
                print(f"No improvement for {self.counter}/{self.patience} epochs")
            if self.counter >= self.patience:
                self.early_stop = True
                if self.verbose:
                    print("Early stopping triggered!")

def train_epoch(model, train_loader, criterion, optimizer, device, grad_clip=1.0):
    """Single training epoch"""
    model.train()
    total_loss = 0.0
    
    for eeg_batch, target_batch, demo_batch, _, _ in tqdm(train_loader, desc="Train", leave=False):
        # Add channel dimension: (B, C, T) -> (B, 1, C, T)
        eeg_batch = eeg_batch.unsqueeze(1).to(device)
        target_batch = target_batch.to(device)
        demo_batch = demo_batch.to(device) if demo_batch.numel() > 0 else None
        
        optimizer.zero_grad()
        
        # Forward pass
        predictions = model(eeg_batch, demo_batch)
        loss = criterion(predictions, target_batch)
        
        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
        optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / len(train_loader)

def validate_epoch(model, val_loader, criterion, device):
    """Validation epoch"""
    model.eval()
    total_loss = 0.0
    all_predictions = []
    all_targets = []
    
    with torch.no_grad():
        for eeg_batch, target_batch, demo_batch, _, _ in tqdm(val_loader, desc="Val", leave=False):
            eeg_batch = eeg_batch.unsqueeze(1).to(device)
            target_batch = target_batch.to(device)
            demo_batch = demo_batch.to(device) if demo_batch.numel() > 0 else None
            
            predictions = model(eeg_batch, demo_batch)
            loss = criterion(predictions, target_batch)
            
            total_loss += loss.item()
            all_predictions.append(predictions.cpu().numpy())
            all_targets.append(target_batch.cpu().numpy())
    
    avg_loss = total_loss / len(val_loader)
    predictions = np.concatenate(all_predictions, axis=0).flatten()
    targets = np.concatenate(all_targets, axis=0).flatten()
    
    return avg_loss, predictions, targets


In [None]:
print("\n" + "="*60)
print("INITIALIZING MODEL")
print("="*60)

model = EnhancedEEGNetRegressor(
    n_channels=128,
    n_times=200,
    n_demographic_features=demodim,
    dropout=0.5,
    F1=16,
    D=2,
    num_heads=8
).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"✓ Model created")
print(f"✓ Total parameters: {total_params:,}")
print(f"✓ Trainable parameters: {trainable_params:,}")
print("="*60)


In [None]:

print("\n" + "="*60)
print("STARTING TRAINING")
print("="*60)

# Hyperparameters
N_EPOCHS = 100
LEARNING_RATE = 0.001
EARLY_STOPPING_PATIENCE = 15
USE_NRMSE = True

# Loss function
if USE_NRMSE:
    criterion = NRMSELoss()
    print("Using nRMSE loss")
else:
    criterion = RMSELoss()
    print("Using RMSE loss")

# Optimizer and scheduler
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=5, verbose=True
)

# Early stopping
early_stopping = EarlyStoppingCallback(patience=EARLY_STOPPING_PATIENCE, verbose=True)

# Training history
history = {
    'train_loss': [],
    'val_loss': [],
    'val_rmse': [],
    'val_mae': [],
    'val_pearson_r': []
}

best_model_path = "/kaggle/working/best_enhanced_eegnet.pt"

# Training loop
for epoch in range(N_EPOCHS):
    print(f"\n{'='*60}")
    print(f"Epoch {epoch+1}/{N_EPOCHS}")
    print(f"{'='*60}")
    
    # Train
    train_loss = train_epoch(model, train_loader, criterion, optimizer, device)
    history['train_loss'].append(train_loss)
    print(f"Train Loss: {train_loss:.6f}")
    
    # Validate
    val_loss, val_preds, val_targets = validate_epoch(model, val_loader, criterion, device)
    history['val_loss'].append(val_loss)
    
    # Calculate metrics
    metrics = calculate_metrics(val_preds, val_targets)
    history['val_rmse'].append(metrics['rmse'])
    history['val_mae'].append(metrics['mae'])
    history['val_pearson_r'].append(metrics['pearson_r'])
    
    print(f"Val Loss: {val_loss:.6f}")
    print(f"Val RMSE: {metrics['rmse']:.6f}")
    print(f"Val nRMSE: {metrics['nrmse']:.6f}")
    print(f"Val MAE: {metrics['mae']:.6f}")
    print(f"Val Pearson r: {metrics['pearson_r']:.4f}")
    
    # Learning rate scheduling
    scheduler.step(val_loss)
    
    # Early stopping
    early_stopping(val_loss, model, best_model_path)
    if early_stopping.early_stop:
        print(f"\n✓ Early stopping at epoch {epoch+1}")
        break

# Load best model
model.load_state_dict(torch.load(best_model_path))
print(f"\n✓ Loaded best model from {best_model_path}")

print("\n" + "="*60)
print("TRAINING COMPLETE!")
print("="*60)


In [None]:

fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Loss curves
axes[0, 0].plot(history['train_loss'], label='Train Loss', linewidth=2)
axes[0, 0].plot(history['val_loss'], label='Val Loss', linewidth=2)
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].set_title('Training and Validation Loss')
axes[0, 0].legend()
axes[0, 0].grid(True)

# RMSE
axes[0, 1].plot(history['val_rmse'], label='Val RMSE', linewidth=2, color='orange')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('RMSE')
axes[0, 1].set_title('Validation RMSE')
axes[0, 1].legend()
axes[0, 1].grid(True)

# MAE
axes[1, 0].plot(history['val_mae'], label='Val MAE', linewidth=2, color='green')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('MAE')
axes[1, 0].set_title('Validation MAE')
axes[1, 0].legend()
axes[1, 0].grid(True)

# Pearson correlation
axes[1, 1].plot(history['val_pearson_r'], label='Val Pearson r', linewidth=2, color='red')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Pearson r')
axes[1, 1].set_title('Validation Pearson Correlation')
axes[1, 1].legend()
axes[1, 1].grid(True)

plt.tight_layout()
plt.savefig('/kaggle/working/training_history.png', dpi=150, bbox_inches='tight')
plt.show()

print("✓ Training curves saved to training_history.png")


In [None]:


print("\n" + "="*60)
print("FINAL EVALUATION")
print("="*60)

# Final validation metrics
val_loss, val_preds, val_targets = validate_epoch(model, val_loader, criterion, device)

In [45]:
# exploration_eeg2025.py

import os
from pathlib import Path
import json
import pprint

import numpy as np
import pandas as pd

# Root directory for your uploaded dataset (change as needed)
DATA_ROOT = Path("/kaggle/input/eeg-dataset/R5_L100_bdf/R5_L100_bdf")  
# e.g. "/kaggle/input/eegchallenge2025/R11_L100_bdf"

def scan_dataset(root_dir: Path):
    """
    Walk the dataset directory, and collect statistics:
      - number of subjects
      - tasks per subject
      - EEG file counts
      - behavioral / metadata files
    """
    subjects = sorted([p for p in root_dir.iterdir() if p.is_dir() and p.name.startswith("sub-")])
    print(f"Found {len(subjects)} subjects")

    summary = {}
    for sub in subjects:
        summary[sub.name] = {}
        # tasks: likely subfolders under eeg/ or beh/ — inspect
        eeg_dir = sub / "eeg"
        beh_dir = sub / "beh"
        if eeg_dir.exists():
            eeg_files = list(eeg_dir.glob(".bdf")) + list(eeg_dir.glob(".set")) + list(eeg_dir.glob("*.npy"))
            summary[sub.name]["n_eeg_files"] = len(eeg_files)
            summary[sub.name]["eeg_files"] = [f.name for f in eeg_files]
        else:
            summary[sub.name]["n_eeg_files"] = 0

        if beh_dir.exists():
            beh_files = [f.name for f in beh_dir.iterdir() if f.is_file()]
            summary[sub.name]["beh_files"] = beh_files
        else:
            summary[sub.name]["beh_files"] = []

    return summary

def load_one_subject_sample(sub_dir: Path):
    """
    Try to load one EEG + metadata from a subject folder to inspect content.
    """
    eeg_dir = sub_dir / "eeg"
    beh_dir = sub_dir / "beh"
    # pick first EEG file
    eeg_files = list(eeg_dir.glob(".npy")) + list(eeg_dir.glob(".bdf")) + list(eeg_dir.glob("*.set"))
    if not eeg_files:
        print(f"No EEG files for {sub_dir.name}")
        return

    eeg_file = eeg_files[0]
    print("Loading EEG:", eeg_file)
    # If it's .npy, load via numpy
    if eeg_file.suffix == ".npy":
        eeg = np.load(eeg_file)
    else:
        print("Non-npy EEG format:", eeg_file.suffix,
              "- you need MNE or other library to load BDF/SET")
        return
    print(" EEG shape:", eeg.shape, " dtype:", eeg.dtype)

    # Load behavior/metadata if exists
    beh_files = list(beh_dir.glob("."))
    if beh_files:
        print("Behavior / metadata files:", [f.name for f in beh_files])
        for f in beh_files:
            if f.suffix.lower() in [".json", ".csv", ".tsv"]:
                try:
                    if f.suffix.lower() == ".json":
                        with open(f, "r") as fp:
                            beh = json.load(fp)
                        print(" Example keys:", list(beh.keys())[:10])
                    else:
                        df = pd.read_csv(f)
                        print(" Dataframe columns:", df.columns.tolist())
                except Exception as e:
                    print(" Error loading ", f, e)
    else:
        print("No behavioral/metadata files.")

def main():
    summary = scan_dataset(DATA_ROOT)
    pprint.pprint(summary)

    # load sample from first subject
    subjects = sorted([p for p in DATA_ROOT.iterdir() if p.is_dir() and p.name.startswith("sub-")])
    if subjects:
        load_one_subject_sample(subjects[0])

if __name__ == "__main__":
    main()

Found 330 subjects
{'sub-NDARAC350XUM': {'beh_files': [], 'eeg_files': [], 'n_eeg_files': 0},
 'sub-NDARAC857HDB': {'beh_files': [], 'eeg_files': [], 'n_eeg_files': 0},
 'sub-NDARAH304ED7': {'beh_files': [], 'eeg_files': [], 'n_eeg_files': 0},
 'sub-NDARAH793FBF': {'beh_files': [], 'eeg_files': [], 'n_eeg_files': 0},
 'sub-NDARAJ689BVN': {'beh_files': [], 'eeg_files': [], 'n_eeg_files': 0},
 'sub-NDARAK738BGC': {'beh_files': [], 'eeg_files': [], 'n_eeg_files': 0},
 'sub-NDARAM848GTE': {'beh_files': [], 'eeg_files': [], 'n_eeg_files': 0},
 'sub-NDARAP522AFK': {'beh_files': [], 'eeg_files': [], 'n_eeg_files': 0},
 'sub-NDARAP785CTE': {'beh_files': [], 'eeg_files': [], 'n_eeg_files': 0},
 'sub-NDARAT358XM9': {'beh_files': [], 'eeg_files': [], 'n_eeg_files': 0},
 'sub-NDARAU708TL8': {'beh_files': [], 'eeg_files': [], 'n_eeg_files': 0},
 'sub-NDARAV187GJ5': {'beh_files': [], 'eeg_files': [], 'n_eeg_files': 0},
 'sub-NDARAX358NB5': {'beh_files': [], 'eeg_files': [], 'n_eeg_files': 0},
 'sub-