In [3]:
# -*- coding: utf-8 -*-
"""
Two-Stage MI-EEG (PhysioNet) — Standalone Notebook Block
--------------------------------------------------------
- Stage-1: EEGNet-parallel + MHA (temporal) + S-TCN + Time Slicing (feature-map)
           → outputs (logits, emb128). Trained with CE (class-weighted), SGDR, logs train/val.
- Stage-2: Tabular head over [emb] ⊕ [PSD] (Welch log-bandpower). Uses TabNet if available,
           else an MLP fallback. Logs val accuracy for overfitting check.
- Optional Blending: final_logits = w * logits_stage1 + (1-w) * logits_stage2.
- Protocol: same paths and dataset construction patterns you use.

Run: execute this cell as-is in a notebook. It shares the same directory layout:
  PROJ/ data/raw  data/cache  models/folds/Kfold5.json

Notes:
- Uses only 8 MI channels you listed.
- Keeps GroupKFold(5) and GroupSplit for val inside train, like your baseline.
- WINDOW_MODE default changed to '3s' (0–3 s post-cue) for MI focus. You can switch back to '6s'.
- Adds detailed logs each epoch: train_acc and val_acc for Stage-1; and val_acc for Stage-2.
"""

import os, re, math, random, json, itertools, copy, warnings
from pathlib import Path
from datetime import datetime

import numpy as np
import pandas as pd
import mne

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset, WeightedRandomSampler

from sklearn.model_selection import GroupKFold, GroupShuffleSplit, StratifiedKFold, StratifiedShuffleSplit
from sklearn.utils import check_random_state
from sklearn.metrics import confusion_matrix, classification_report

import matplotlib.pyplot as plt
from tqdm import tqdm

# =============== CONFIG ===============
PROJ = Path('..').resolve().parent
DATA_RAW = PROJ / 'data' / 'raw'
CACHE_DIR = PROJ / 'data' / 'cache'
FOLDS_DIR = PROJ / 'models' / 'folds' / 'Kfold5.json'
CACHE_DIR.mkdir(parents=True, exist_ok=True)

RANDOM_STATE = 42
torch.manual_seed(RANDOM_STATE); np.random.seed(RANDOM_STATE); random.seed(RANDOM_STATE)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🚀 Usando dispositivo: {DEVICE}")

# Escenario y ventana
CLASS_SCENARIO = '4c'
WINDOW_MODE = '3s'   # recomendado para MI (0–3 s post-cue). Cambia a '6s' si quieres replicar.
FS = 160.0
N_FOLDS = 5

# Entrenamiento Stage-1
BATCH_SIZE = 64
EPOCHS_STAGE1 = 70
LR_STAGE1 = 1e-3
WD_STAGE1 = 1e-5
SGDR_T0 = 6
SGDR_Tmult = 2
PATIENCE_STAGE1 = 12
LOG_EVERY = 5

# Stage-2 (Tabular)
USE_TABNET = True  # si no está instalado, cae a MLP automáticamente
EPOCHS_STAGE2 = 400
PATIENCE_STAGE2 = 40
BLEND_W = 0.5      # peso de Stage-1 en el blend final

# Fine-tuning por sujeto (opcional, simple)
DO_SUBJECT_FT = False    # pon True si quieres hacer FT: adapta Stage-2 (+último bloque encoder)
FT_EPOCHS = 10
FT_LR_ENCODER_LAST = 1e-4

# Normalización por época
NORM_EPOCH_ZSCORE = True

# Sujetos excluidos
EXCLUDE_SUBJECTS = {38, 88, 89, 92, 100, 104}

# Runs
MI_RUNS_LR = [4, 8, 12]
MI_RUNS_OF = [6, 10, 14]
BASELINE_RUNS_EO = [1]

# Canales (8 con FCz)
EXPECTED_8 = ['C3','C4','Cz','CP3','CP4','FC3','FC4','FCz']
CLASS_NAMES_4C = ['Left', 'Right', 'Both Fists', 'Both Feet']

# ============== UTIL CANALES/EDF (idéntico a tu base) ==============
_re_file = re.compile(r'[Ss](\d{3}).*?[Rr](\d{2})')

def normalize_label(s: str) -> str:
    if s is None: return s
    s = s.strip()
    s = re.sub(r'[^A-Za-z0-9]', '', s)
    s = re.sub(r'([A-Za-z])0([0-9])', r'\1\2', s)
    s = re.sub(r'([A-Za-z])Z$', r'\1z', s)
    s = s.replace('fp', 'Fp').replace('FP', 'Fp')
    s = ''.join(ch.upper() if ch != 'z' else 'z' for ch in s)
    return s

def rename_channels_1010(raw: mne.io.BaseRaw):
    mapping = {}
    for ch in raw.ch_names:
        lab = normalize_label(ch)
        lab = lab[:-1] + 'z' if lab.endswith('Z') else lab
        lab = re.sub(r'([A-Z])Z$', r'\1z', lab)
        mapping[ch] = lab
    mne.rename_channels(raw.info, mapping)

def ensure_channels_order(raw: mne.io.BaseRaw, desired_channels=EXPECTED_8):
    missing = [ch for ch in desired_channels if ch not in raw.ch_names]
    if missing:
        print(f"Warning: faltan canales {missing} en archivo {getattr(raw,'filenames', [''])[0]}")
        return None
    raw.reorder_channels([ch for ch in raw.ch_names if ch in desired_channels] +
                         [ch for ch in raw.ch_names if ch not in desired_channels])
    raw.pick_channels(desired_channels, ordered=True)
    return raw

def parse_subject_run(path: Path):
    m = _re_file.search(str(path))
    if not m: return None, None
    return int(m.group(1)), int(m.group(2))

def run_kind(run_id:int):
    if run_id in MI_RUNS_LR: return 'LR'
    if run_id in MI_RUNS_OF: return 'OF'
    if run_id in BASELINE_RUNS_EO: return 'EO'
    return None

_SNR_TABLE = None

def _load_snr_table():
    global _SNR_TABLE
    if _SNR_TABLE is not None:
        return _SNR_TABLE
    csv_path = PROJ / 'reports' / 'psd_mains' / 'psd_mains_summary.csv'
    if csv_path.exists():
        try:
            df = pd.read_csv(csv_path)
            _SNR_TABLE = df
        except Exception as e:
            print(f"[SNR] No se pudo leer {csv_path}: {e}")
            _SNR_TABLE = None
    return _SNR_TABLE

def _decide_notch(subject, run, th_db=10.0):
    df = _load_snr_table()
    if df is None:  # default
        return 60.0
    row = df[(df['subject']==subject) & (df['run']==run)]
    if row.empty:
        return 60.0
    snr50 = float(row['snr50_db'].iloc[0]); snr60 = float(row['snr60_db'].iloc[0])
    if snr60 >= th_db and snr60 >= snr50: return 60.0
    if snr50 >= th_db and snr50 >  snr60: return 50.0
    return None  # no notch si no sobresale

def read_raw_edf(path: Path):
    raw = mne.io.read_raw_edf(path, preload=True, verbose=False)
    raw.pick(mne.pick_types(raw.info, eeg=True))
    rename_channels_1010(raw)
    try:
        mont = mne.channels.make_standard_montage('standard_1020')
        raw.set_montage(mont, on_missing='ignore')
    except Exception:
        pass
    if abs(raw.info['sfreq'] - FS) > 1e-6:
        raw.resample(FS, npad="auto")
    raw = ensure_channels_order(raw, EXPECTED_8)
    if raw is None:
        return None
    sid, rid = parse_subject_run(path)
    notch = _decide_notch(sid, rid)
    if notch is not None:
        raw.notch_filter(freqs=[float(notch)], picks='eeg', method='spectrum_fit', phase='zero')
    return raw

def collect_events_T1T2(raw: mne.io.BaseRaw):
    if raw.annotations is None or len(raw.annotations) == 0:
        return []
    def _norm(s): return str(s).strip().upper().replace(' ', '')
    res = []
    for onset, desc in zip(raw.annotations.onset, raw.annotations.description):
        tag = _norm(desc)
        if tag in ('T1','T2'):
            res.append((float(onset), tag))
    res.sort()
    dedup = []
    last_t1 = last_t2 = -1e9
    for t, tag in res:
        if tag == 'T1':
            if (t - last_t1) >= 0.5: dedup.append((t, tag)); last_t1 = t
        else:
            if (t - last_t2) >= 0.5: dedup.append((t, tag)); last_t2 = t
    return dedup

# ============== DATASET BUILD (como tu script) ==============

def subjects_available():
    subs = []
    for sdir in sorted(DATA_RAW.glob('S*')):
        if not sdir.is_dir(): continue
        try: sid = int(sdir.name[1:])
        except: continue
        if sid in EXCLUDE_SUBJECTS: continue
        any_mi = any((sdir / f"S{sid:03d}R{r:02d}.edf").exists() for r in (MI_RUNS_LR + MI_RUNS_OF))
        if any_mi: subs.append(sid)
    return subs


def extract_trials_from_run(edf_path: Path, scenario: str, window_mode: str):
    subj, run = parse_subject_run(edf_path)
    kind = run_kind(run)
    if kind not in ('LR','OF','EO'):
        return ([], [])

    raw = read_raw_edf(edf_path)
    if raw is None:
        return ([], [])

    data = raw.get_data()
    fs = raw.info['sfreq']
    assert abs(fs - FS) < 1e-6

    out = []

    if kind in ('LR','OF'):
        events = collect_events_T1T2(raw)
        if window_mode == '3s':
            rel_start, rel_end = 0.0, 3.0
        else:
            rel_start, rel_end = -1.0, 5.0

        for onset_sec, tag in events:
            if kind == 'LR':
                if tag == 'T1': label = 'L'
                elif tag == 'T2': label = 'R'
                else: continue
            else:
                if tag == 'T1': label = 'BFISTS'
                elif tag == 'T2': label = 'BFEET'
                else: continue

            if scenario == '2c' and label not in ('L','R'): continue
            if scenario == '3c' and label not in ('L','R','BFISTS'): continue
            if scenario == '4c' and label not in ('L','R','BFISTS','BFEET'): continue

            s = int(round((raw.first_time + onset_sec + rel_start) * fs))
            e = int(round((raw.first_time + onset_sec + rel_end) * fs))
            if s < 0 or e > data.shape[1]:
                continue

            seg = data[:, s:e].T.astype(np.float32)  # (T,C)
            if NORM_EPOCH_ZSCORE:
                seg = (seg - seg.mean(axis=0, keepdims=True)) / (seg.std(axis=0, keepdims=True) + 1e-6)

            if label == 'L':       y = 0
            elif label == 'R':     y = 1
            elif label == 'BFISTS':y = 2
            elif label == 'BFEET': y = 3
            else: continue

            out.append((seg, y, subj))

    elif kind == 'EO':
        return ([], raw.ch_names)

    return out, raw.ch_names


def build_dataset_all(subjects, scenario='4c', window_mode='3s'):
    X, y, groups = [], [], []
    ch_template = None

    for s in tqdm(subjects, desc="Construyendo dataset (RAW)"):
        sdir = DATA_RAW / f"S{s:03d}"
        if not sdir.exists(): continue

        trials_L, trials_R, trials_FISTS, trials_FEET = [], [], [], []

        for r in MI_RUNS_LR:
            p = sdir / f"S{s:03d}R{r:02d}.edf"
            if not p.exists(): continue
            outs, chs = extract_trials_from_run(p, scenario, window_mode)
            if ch_template is None and chs: ch_template = chs
            for seg, lab, _ in outs:
                if lab == 0: trials_L.append(seg)
                elif lab == 1: trials_R.append(seg)

        for r in MI_RUNS_OF:
            p = sdir / f"S{s:03d}R{r:02d}.edf"
            if not p.exists(): continue
            outs, chs = extract_trials_from_run(p, scenario, window_mode)
            if ch_template is None and chs: ch_template = chs
            for seg, lab, _ in outs:
                if lab == 2: trials_FISTS.append(seg)
                elif lab == 3: trials_FEET.append(seg)

        need_per_class = 21
        def pick(trials, n, rng):
            if len(trials) < n:
                idx = rng.choice(len(trials), size=n, replace=True)
                return [trials[i] for i in idx]
            rng.shuffle(trials)
            return trials[:n]

        rng = check_random_state(RANDOM_STATE + s)
        if len(trials_L)==0 or len(trials_R)==0 or len(trials_FISTS)==0 or len(trials_FEET)==0:
            continue

        Lp  = pick(trials_L,     need_per_class, rng)
        Rp  = pick(trials_R,     need_per_class, rng)
        FIp = pick(trials_FISTS, need_per_class, rng)
        FEp = pick(trials_FEET,  need_per_class, rng)

        pack = [(Lp, 0), (Rp, 1), (FIp, 2), (FEp, 3)]
        for segs, lab in pack:
            for seg in segs:
                X.append(seg); y.append(lab); groups.append(s)

    X = np.stack(X, axis=0)
    y = np.asarray(y, dtype=np.int64)
    groups = np.asarray(groups, dtype=np.int64)

    n, T, C = X.shape
    n_classes = len(np.unique(y))
    print(f"Dataset construido: N={n} | T={T} | C={C} | clases={n_classes} | sujetos únicos={len(np.unique(groups))}")
    return X, y, groups, ch_template

# ============== Dataset y utils torch ==============
class EEGTrials(Dataset):
    def __init__(self, X, y, groups):
        self.X = X.astype(np.float32)
        self.y = y.astype(np.int64)
        self.g = groups.astype(np.int64)
    def __len__(self): return self.X.shape[0]
    def __getitem__(self, idx):
        x = self.X[idx]
        x = np.expand_dims(x, 0)                 # (1, T, C)
        return torch.from_numpy(x), torch.tensor(self.y[idx]), torch.tensor(self.g[idx])

@torch.no_grad()
def apply_max_norm(layers, max_value=2.0, p=2.0):
    for layer in layers:
        if hasattr(layer, 'weight') and layer.weight is not None:
            w = layer.weight.data
            norms = w.view(w.size(0), -1).norm(p=p, dim=1, keepdim=True)
            desired = torch.clamp(norms, max=max_value)
            w.view(w.size(0), -1).mul_(desired / (1e-8 + norms))

# ============== Augments (ligeros) ==============

def do_time_jitter(x, max_ms=50, fs=160.0):
    max_shift = int(round(max_ms/1000.0 * fs))
    if max_shift <= 0: return x
    B,_,T,C = x.shape
    shifts = torch.randint(low=-max_shift, high=max_shift+1, size=(B,), device=x.device)
    out = torch.empty_like(x)
    for i,s in enumerate(shifts):
        if s==0: out[i] = x[i]; continue
        if s>0:
            out[i,:,s:,:] = x[i,:,:T-s,:]
            out[i,:,:s,:] = 0
        else:
            s = -s
            out[i,:,:T-s,:] = x[i,:,s:,:]
            out[i,:,T-s:,:] = 0
    return out

def do_gaussian_noise(x, sigma=0.01):
    if sigma<=0: return x
    return x + sigma*torch.randn_like(x)

def mixup_batch(x, y, n_classes, alpha=0.2):
    if alpha<=0:
        y_onehot = torch.nn.functional.one_hot(y, num_classes=n_classes).float()
        return x, y_onehot, 1.0
    lam = np.random.beta(alpha, alpha)
    perm = torch.randperm(x.size(0), device=x.device)
    x_mix = lam*x + (1-lam)*x[perm]
    y_a = torch.nn.functional.one_hot(y, num_classes=n_classes).float()
    y_b = y_a[perm]
    y_mix = lam*y_a + (1-lam)*y_b
    return x_mix, y_mix, lam

class WeightedSoftCrossEntropy(nn.Module):
    def __init__(self, class_weights=None, label_smoothing=0.05):
        super().__init__()
        self.register_buffer('w', None if class_weights is None else class_weights.clone().float())
        self.ls = float(label_smoothing)
    def forward(self, logits, target_probs):
        if self.ls > 0:
            K = logits.size(1)
            target_probs = (1-self.ls)*target_probs + self.ls*(1.0/K)
        logp = torch.log_softmax(logits, dim=1)
        loss_per_class = -(target_probs * logp)
        if self.w is not None:
            loss_per_class = loss_per_class * self.w.unsqueeze(0)
        loss = loss_per_class.sum(dim=1).mean()
        return loss

# ============== Stage-1: Encoder (EEGNet-parallel + MHA + S-TCN + Slicing) ==============
class MHABlock(nn.Module):
    def __init__(self, dim: int, num_heads: int = 2, dropout: float = 0.1):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, batch_first=True)
        self.drop = nn.Dropout(dropout)
        self.ffn = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, dim * 4), nn.GELU(), nn.Dropout(dropout),
            nn.Linear(dim * 4, dim), nn.Dropout(dropout),
        )
    def forward(self, x):
        h = self.norm(x)
        y, _ = self.attn(h, h, h, need_weights=False)
        x = x + self.drop(y)
        x = x + self.ffn(x)
        return x

class DSConv1d(nn.Module):
    def __init__(self, in_ch: int, out_ch: int, k: int = 4, dilation: int = 1, dropout: float = 0.1):
        super().__init__()
        pad = (k - 1) // 2 * dilation
        self.dw = nn.Conv1d(in_ch, in_ch, kernel_size=k, groups=in_ch, padding=pad, dilation=dilation, bias=False)
        self.pw = nn.Conv1d(in_ch, out_ch, kernel_size=1, bias=False)
        self.bn = nn.BatchNorm1d(out_ch)
        self.act = nn.ELU(); self.drop = nn.Dropout(dropout)
    def forward(self, x):
        x = self.dw(x); x = self.pw(x); x = self.bn(x); x = self.act(x); x = self.drop(x); return x

class STCN(nn.Module):
    def __init__(self, dim: int, k: int = 4, dilations=(1,2,4), dropout: float = 0.1):
        super().__init__()
        self.blocks = nn.ModuleList([
            nn.Sequential(
                DSConv1d(dim, dim, k=k, dilation=d, dropout=dropout),
                nn.Conv1d(dim, dim, kernel_size=1, bias=False),
                nn.BatchNorm1d(dim), nn.ELU(), nn.Dropout(dropout)
            ) for d in dilations
        ])
    def forward(self, x):  # (B,D,T)
        for b in self.blocks:
            res = x; x = b(x); x = x + res
        return x

class EEGNetParallel(nn.Module):
    def __init__(self, n_ch: int, F1: int = 16, D: int = 2, k_t_list=(32,64,96), pool1_t=4, k_sep=16, pool2_t=4, drop=0.2):
        super().__init__()
        self.n_ch = n_ch
        self.branches = nn.ModuleList()
        for k_t in k_t_list:
            self.branches.append(nn.Sequential(
                nn.Conv2d(1, F1, kernel_size=(k_t, 1), padding=(k_t // 2, 0), bias=False),
                nn.BatchNorm2d(F1), nn.ELU(),
                nn.Conv2d(F1, F1*D, kernel_size=(1, n_ch), groups=F1, bias=False),
                nn.BatchNorm2d(F1*D), nn.ELU(),
                nn.AvgPool2d(kernel_size=(pool1_t,1), stride=(pool1_t,1)),
                nn.Dropout(drop),
                nn.Conv2d(F1*D, F1*D, kernel_size=(k_sep,1), groups=F1*D, padding=(k_sep//2,0), bias=False),
                nn.Conv2d(F1*D, F1*D, kernel_size=(1,1), bias=False),
                nn.BatchNorm2d(F1*D), nn.ELU(),
                nn.AvgPool2d(kernel_size=(pool2_t,1), stride=(pool2_t,1)),
                nn.Dropout(drop),
            ))
        self.out_dim = len(k_t_list) * (F1*D)
    def forward(self, x):
        feats = [b(x) for b in self.branches]  # each: (B,F2,T',1)
        return torch.cat(feats, dim=1)         # (B,Dsum,T',1)

class TwoStageEncoder(nn.Module):
    def __init__(self, n_ch: int, n_classes: int,
                 F1=16, D=2, k_t_list=(32,64,96), attn_heads=2,
                 stcn_k=4, stcn_dils=(1,2,4), emb_dim=128, drop=0.2,
                 slice_len_steps=None, slice_stride_steps=None, avg_on='emb'):
        super().__init__()
        self.backbone = EEGNetParallel(n_ch=n_ch, F1=F1, D=D, k_t_list=k_t_list, drop=drop)
        self.proj = nn.Conv2d(self.backbone.out_dim, emb_dim, kernel_size=(1,1), bias=False)
        self.attn = MHABlock(dim=emb_dim, num_heads=attn_heads, dropout=drop)
        self.stcn = STCN(dim=emb_dim, k=stcn_k, dilations=stcn_dils, dropout=drop)
        self.cls_head = nn.Linear(emb_dim, n_classes)
        self.emb_dim = emb_dim
        self.avg_on = avg_on
        self.slice_len_steps = slice_len_steps
        self.slice_stride_steps = slice_stride_steps

    def _time_slices(self, T):
        if self.slice_len_steps is None or self.slice_stride_steps is None:
            k = 5; L = max(2, T // (k + 1)); S = max(1, (T - L) // max(1, (k - 1)))
        else:
            L, S = self.slice_len_steps, self.slice_stride_steps
        slices, s = [], 0
        while s + L <= T:
            slices.append((s, s+L)); s += S
        if not slices: slices = [(0,T)]
        return slices

    def forward(self, x):  # x: (B,1,T,C)
        z = self.backbone(x)                      # (B,Dsum,T',1)
        z = self.proj(z).squeeze(-1).transpose(1,2)  # (B,T',emb)
        z = self.attn(z)
        z = self.stcn(z.transpose(1,2)).transpose(1,2)  # (B,T',emb)
        B,Tp,Demb = z.shape
        idxs = self._time_slices(Tp)
        slice_embs, slice_logits = [], []
        for (a,b) in idxs:
            seg = z[:,a:b,:]
            e = seg.mean(dim=1)  # GAP over time
            slice_embs.append(e)
            slice_logits.append(self.cls_head(e))
        if self.avg_on == 'logits':
            logits = torch.stack(slice_logits, dim=0).mean(dim=0)
            emb = torch.stack(slice_embs, dim=0).mean(dim=0)
        else:
            emb = torch.stack(slice_embs, dim=0).mean(dim=0)
            logits = self.cls_head(emb)
        return logits, emb

# ============== Stage-1 training & eval ==============

def build_weighted_sampler(y, groups):
    y = np.asarray(y); groups = np.asarray(groups)
    class_counts = np.bincount(y, minlength=len(np.unique(y))).astype(float)
    class_w = 1.0 / class_counts[y]
    subj_vals, subj_counts = np.unique(groups, return_counts=True)
    subj_map = {s:c for s,c in zip(subj_vals, subj_counts)}
    subj_w = np.array([1.0/subj_map[g] for g in groups], dtype=float)
    w = class_w * subj_w
    w = w / w.mean()
    w_t = torch.from_numpy(w).float()
    sampler = WeightedRandomSampler(weights=w_t, num_samples=len(w_t), replacement=True)
    return sampler

class Stage1Trainer:
    def __init__(self, device=DEVICE, lr=LR_STAGE1, wd=WD_STAGE1, label_smoothing=0.05, max_norm=2.0):
        self.device = device; self.lr = lr; self.wd = wd
        self.label_smoothing = label_smoothing; self.max_norm = max_norm
    def _crit(self, n_classes, class_weights=None):
        return nn.CrossEntropyLoss(weight=class_weights, label_smoothing=self.label_smoothing)
    def fit(self, model, dl_train, dl_val, epochs=EPOCHS_STAGE1, class_weights=None):
        model.to(self.device)
        opt = optim.Adam(model.parameters(), lr=self.lr, weight_decay=self.wd)
        sched = optim.lr_scheduler.CosineAnnealingWarmRestarts(opt, T_0=SGDR_T0, T_mult=SGDR_Tmult)
        crit = self._crit(n_classes=model.cls_head.out_features, class_weights=class_weights)
        best_state = copy.deepcopy(model.state_dict()); best_val=-1.0; bad=0
        history = {'train_acc':[], 'val_acc':[]}
        for ep in range(1, epochs+1):
            # ---- train epoch ----
            model.train(); n_ok=0; n_tot=0
            for xb,yb,_ in dl_train:
                xb,yb = xb.to(self.device), yb.to(self.device)
                opt.zero_grad(set_to_none=True)
                logits,_ = model(xb)
                loss = crit(logits, yb)
                loss.backward(); nn.utils.clip_grad_norm_(model.parameters(), 5.0); opt.step()
                n_ok += (logits.argmax(1)==yb).sum().item(); n_tot += yb.numel()
            sched.step(ep-1 + 1e-8)
            train_acc = n_ok/max(1,n_tot)
            # ---- val ----
            val_acc = self.evaluate(model, dl_val)
            history['train_acc'].append(train_acc); history['val_acc'].append(val_acc)
            if (ep % LOG_EVERY == 0) or ep in (1, 10, 20, 50):
                cur_lr = opt.param_groups[0]['lr']
                print(f"[Stage-1] Ep {ep:3d} | train_acc={train_acc:.4f} | val_acc={val_acc:.4f} | LR={cur_lr:.5f}")
            if val_acc > best_val + 1e-4:
                best_val = val_acc; best_state = copy.deepcopy(model.state_dict()); bad = 0
            else:
                bad += 1
                if bad >= PATIENCE_STAGE1:
                    print(f"[Stage-1] Early stopping @ {ep} (best val_acc={best_val:.4f})")
                    break
        return best_state, history
    @torch.no_grad()
    def evaluate(self, model, dl):
        model.eval(); n_ok=0; n=0
        for xb,yb,_ in dl:
            xb,yb = xb.to(self.device), yb.to(self.device)
            logits,_ = model(xb)
            n_ok += (logits.argmax(1)==yb).sum().item(); n += yb.numel()
        return n/max(1,n)
    @torch.no_grad()
    def predict_logits(self, model, dl):
        model.eval(); outs=[]
        for xb,_,_ in dl:
            xb = xb.to(self.device)
            logits,_ = model(xb)
            outs.append(logits.detach().cpu().numpy())
        return np.vstack(outs)

@torch.no_grad()
def extract_embeddings(model, dl, device=DEVICE):
    model.eval().to(device)
    embs, ys = [], []
    for xb,yb,_ in dl:
        xb = xb.to(device)
        _, e = model(xb)
        embs.append(e.detach().cpu().numpy()); ys.append(yb.numpy())
    return np.vstack(embs), np.concatenate(ys)

# ============== PSD features (Welch) ==============
try:
    from scipy.signal import welch
except Exception:
    welch = None
    warnings.warn("scipy not found: PSD features disabled. Install scipy to enable.")

class PSDConfig:
    def __init__(self, fs=160.0, nperseg_sec=1.0, noverlap=0.5, bands=((1,4),(4,8),(8,14),(14,31),(31,49))):
        self.fs = fs; self.nperseg_sec = nperseg_sec; self.noverlap = noverlap; self.bands = bands

def _bandpower(f, Pxx, band):
    idx = (f >= band[0]) & (f < band[1])
    return (Pxx[idx].mean() + 1e-12)

def make_psd_features(X_np, cfg=PSDConfig()):
    if welch is None:
        raise RuntimeError("scipy.signal.welch not available.")
    N,T,C = X_np.shape
    nperseg = int(round(cfg.nperseg_sec*cfg.fs))
    noverlap = int(round(cfg.noverlap*nperseg))
    feats = []
    for i in range(N):
        x = X_np[i]
        ch_feats = []
        for c in range(C):
            f, P = welch(x[:,c], fs=cfg.fs, nperseg=nperseg, noverlap=noverlap)
            bp = [np.log(_bandpower(f, P, b)) for b in cfg.bands]
            ch_feats.extend(bp)
        feats.append(ch_feats)
    return np.asarray(feats, dtype=np.float32)

# ============== Stage-2: Tabular head (TabNet or MLP) ==============
class _MLPHead(nn.Module):
    def __init__(self, d_in, n_classes, hidden=256, drop=0.2):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_in, hidden), nn.ReLU(), nn.Dropout(drop),
            nn.Linear(hidden, hidden//2), nn.ReLU(), nn.Dropout(drop),
            nn.Linear(hidden//2, n_classes)
        )
    def forward(self, x): return self.net(x)

class TabularHead:
    def __init__(self, input_dim, n_classes, use_tabnet=USE_TABNET, device=DEVICE, tabnet_params=None):
        self.n_classes = n_classes; self.device = device
        self.use_tabnet = False; self.tabnet=None
        if use_tabnet:
            try:
                from pytorch_tabnet.tab_model import TabNetClassifier
                self.tabnet = TabNetClassifier(**(tabnet_params or dict(n_d=32, n_a=64, n_steps=3, gamma=1.5)))
                self.use_tabnet = True
            except Exception:
                warnings.warn("pytorch_tabnet not found; using MLP fallback.")
        if not self.use_tabnet:
            self.mlp = _MLPHead(input_dim, n_classes).to(self.device)
            self.opt = optim.Adam(self.mlp.parameters(), lr=1e-3, weight_decay=1e-5)
            self.crit = nn.CrossEntropyLoss()
            self.scaler_mean=None; self.scaler_std=None
    def fit(self, X, y, X_val=None, y_val=None, epochs=EPOCHS_STAGE2, patience=PATIENCE_STAGE2):
        if self.use_tabnet:
            eval_set=[(X_val, y_val)] if X_val is not None else None
            self.tabnet.fit(X, y, eval_set=eval_set, max_epochs=epochs, patience=patience,
                            batch_size=256, virtual_batch_size=128)
        else:
            self.scaler_mean = X.mean(axis=0, keepdims=True); self.scaler_std = X.std(axis=0, keepdims=True)+1e-6
            Xn = (X - self.scaler_mean)/self.scaler_std; Xn = torch.from_numpy(Xn).float().to(self.device)
            y_t = torch.from_numpy(y).long().to(self.device)
            dset = torch.utils.data.TensorDataset(Xn,y_t)
            loader = DataLoader(dset, batch_size=256, shuffle=True)
            if X_val is not None and y_val is not None:
                Xv = (X_val - self.scaler_mean)/self.scaler_std
                self.Xv_t = torch.from_numpy(Xv).float().to(self.device)
                self.yv_t = torch.from_numpy(y_val).long().to(self.device)
            else:
                self.Xv_t=None; self.yv_t=None
            best_state = copy.deepcopy(self.mlp.state_dict()); best=-1.0; bad=0
            for ep in range(1, epochs+1):
                self.mlp.train()
                for xb,yb in loader:
                    self.opt.zero_grad(set_to_none=True)
                    logits = self.mlp(xb)
                    loss = self.crit(logits, yb)
                    loss.backward(); self.opt.step()
                if self.Xv_t is not None:
                    self.mlp.eval();
                    with torch.no_grad():
                        lv = self.mlp(self.Xv_t)
                        acc = (lv.argmax(1)==self.yv_t).float().mean().item()
                    if ep % 10 == 0:
                        print(f"[Stage-2] Ep {ep:3d} | val_acc={acc:.4f}")
                    if acc > best + 1e-4:
                        best = acc; best_state = copy.deepcopy(self.mlp.state_dict()); bad=0
                    else:
                        bad += 1
                        if bad >= patience:
                            print(f"[Stage-2] Early stopping @ {ep} (best val_acc={best:.4f})"); break
            self.mlp.load_state_dict(best_state)
    def predict_logits(self, X):
        if self.use_tabnet:
            proba = self.tabnet.predict_proba(X)
            return np.log(proba + 1e-9)
        else:
            Xn = (X - self.scaler_mean)/self.scaler_std
            Xt = torch.from_numpy(Xn).float().to(self.device)
            self.mlp.eval();
            with torch.no_grad():
                logits = self.mlp(Xt).cpu().numpy()
            return logits

def blend_logits(logits1, logits2, w=BLEND_W):
    w = float(np.clip(w,0.0,1.0)); return w*logits1 + (1.0-w)*logits2

# ============== Folds JSON helpers (como tu script) ==============

def save_group_folds_json_with_indices(subject_ids_str, groups_array, n_splits, out_json_path,
                                       created_by="TwoStage_Exp", description=None):
    out_json_path = Path(out_json_path)
    unique_subjects_int = sorted(np.unique(groups_array).tolist())
    subject_ids = [f"S{sid:03d}" for sid in unique_subjects_int]
    if len(subject_ids) < n_splits:
        raise ValueError(f"n_splits={n_splits} mayor que número de sujetos={len(subject_ids)}")
    groups = np.arange(len(subject_ids))
    gkf = GroupKFold(n_splits=n_splits)
    folds = []
    fold_i = 0
    for train_idx_grp, test_idx_grp in gkf.split(groups, groups, groups):
        fold_i += 1
        train_sids = [subject_ids[int(i)] for i in train_idx_grp]
        test_sids  = [subject_ids[int(i)] for i in test_idx_grp]
        train_sids_int = [int(s[1:]) for s in train_sids]
        test_sids_int  = [int(s[1:]) for s in test_sids]
        tr_idx = np.where(np.isin(groups_array, train_sids_int))[0].tolist()
        te_idx = np.where(np.isin(groups_array, test_sids_int))[0].tolist()
        folds.append({"fold": int(fold_i), "train": train_sids, "test": test_sids, "tr_idx": tr_idx, "te_idx": te_idx})
    payload = {"created_at": datetime.now().isoformat(), "created_by": created_by,
               "description": description if description is not None else "",
               "n_splits": int(n_splits), "n_subjects": len(subject_ids),
               "subject_ids": subject_ids, "folds": folds}
    out_json_path.parent.mkdir(parents=True, exist_ok=True)
    with open(out_json_path, "w", encoding="utf-8") as f:
        json.dump(payload, f, indent=2, ensure_ascii=False)
    print(f"Folds JSON con índices guardado → {out_json_path}")
    return out_json_path


def load_group_folds_json(path_json, expected_subject_ids=None, strict_check=True):
    path_json = Path(path_json)
    if not path_json.exists():
        raise FileNotFoundError(f"No existe {path_json}")
    with open(path_json, "r", encoding="utf-8") as f:
        payload = json.load(f)
    subj_json = payload.get("subject_ids", [])
    if expected_subject_ids is not None:
        expected = sorted(list(expected_subject_ids))
        if subj_json != expected:
            msg = ("Los subject_ids del JSON no coinciden con expected_subject_ids.\n"
                   f"JSON has {len(subj_json)} subjects, expected {len(expected)}.\n"
                   f"First 10 JSON: {subj_json[:10]}\nFirst 10 expected: {expected[:10]}")
            if strict_check: raise ValueError(msg)
            else: print("WARNING: " + msg)
    return payload

# ============== Confusion + curves ==============

def plot_confusion(y_true, y_pred, classes, title, fname):
    cm = confusion_matrix(y_true, y_pred, labels=list(range(len(classes))))
    with np.errstate(invalid='ignore'):
        cm_norm = cm.astype('float') / cm.sum(axis=1, keepdims=True)
    cm_norm = np.nan_to_num(cm_norm)
    plt.figure(figsize=(6, 5))
    plt.imshow(cm_norm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title(title); plt.colorbar(fraction=0.046, pad=0.04)
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45, ha='right'); plt.yticks(tick_marks, classes)
    fmt = '.2f'; thresh = cm_norm.max() / 2.
    for i, j in itertools.product(range(cm_norm.shape[0]), range(cm_norm.shape[1])):
        plt.text(j, i, format(cm_norm[i, j], fmt), horizontalalignment="center",
                 color="white" if cm_norm[i, j] > thresh else "black")
    plt.ylabel('True label'); plt.xlabel('Predicted label')
    plt.tight_layout(); plt.savefig(fname, dpi=150, bbox_inches='tight'); plt.close()

def plot_training_curves(history, fname, title='Stage-1 Training curve'):
    plt.figure(figsize=(6,4))
    plt.plot(history['train_acc'], label='train_acc')
    plt.plot(history['val_acc'], label='val_acc')
    plt.xlabel('Epoch'); plt.ylabel('Accuracy'); plt.title(title); plt.legend()
    plt.tight_layout(); plt.savefig(fname, dpi=150); plt.close()

# ============== EXPERIMENTO PRINCIPAL ==============

def run_experiment_two_stage():
    mne.set_log_level('WARNING')
    subs = subjects_available()
    print(f"Sujetos elegibles: {len(subs)} → {subs[:10]}{'...' if len(subs)>10 else ''}")

    X, y, groups, chs = build_dataset_all(subs, scenario=CLASS_SCENARIO, window_mode=WINDOW_MODE)
    N, T, C = X.shape
    n_classes = len(np.unique(y))
    print(f"Listo para entrenar: N={N} | T={T} | C={C} | clases={n_classes} | sujetos={len(np.unique(groups))}")

    ds = EEGTrials(X, y, groups)

    # preparar JSON folds
    folds_json_path = Path(FOLDS_DIR)
    folds_json_path.parent.mkdir(parents=True, exist_ok=True)
    unique_subs = sorted(np.unique(groups).tolist())
    subject_ids_str = [f"S{s:03d}" for s in unique_subs]
    if not folds_json_path.exists():
        save_group_folds_json_with_indices(subject_ids_str, groups, n_splits=N_FOLDS,
                                           out_json_path=folds_json_path,
                                           created_by="TwoStage_Pipeline",
                                           description="Two-Stage GroupKFold")
    payload = load_group_folds_json(folds_json_path, expected_subject_ids=subject_ids_str, strict_check=False)
    folds = payload["folds"]

    global_folds = []
    blend_folds = []
    all_true = []; all_pred = []

    for f in folds:
        fold = f["fold"]
        tr_idx = np.asarray(f.get("tr_idx", []), dtype=int)
        te_idx = np.asarray(f.get("te_idx", []), dtype=int)
        if tr_idx.size == 0 or te_idx.size == 0:
            print(f"Advertencia: fold {fold} sin índices tr/te válidos. Saltando.")
            continue
        # Split de validación por sujetos
        gss = GroupShuffleSplit(n_splits=1, test_size=0.15, random_state=RANDOM_STATE)
        tr_subj_idx, va_subj_idx = next(gss.split(tr_idx, groups[tr_idx], groups[tr_idx]))
        tr_sub_idx = tr_idx[tr_subj_idx]
        va_idx     = tr_idx[va_subj_idx]

        tr_loader = DataLoader(Subset(ds, tr_sub_idx), batch_size=BATCH_SIZE,
                               sampler=build_weighted_sampler(y[tr_sub_idx], groups[tr_sub_idx]), drop_last=False)
        va_loader = DataLoader(Subset(ds, va_idx),     batch_size=BATCH_SIZE, shuffle=False, drop_last=False)
        te_loader = DataLoader(Subset(ds, te_idx),     batch_size=BATCH_SIZE, shuffle=False, drop_last=False)

        # -------- Stage-1 --------
        model = TwoStageEncoder(n_ch=C, n_classes=n_classes,
                                F1=16, D=2, k_t_list=(32,64,96), attn_heads=2,
                                stcn_k=3, stcn_dils=(1,2,4), emb_dim=128, drop=0.2,
                                slice_len_steps=None, slice_stride_steps=None, avg_on='emb').to(DEVICE)

        trainer = Stage1Trainer(device=DEVICE, lr=LR_STAGE1, wd=WD_STAGE1, label_smoothing=0.05)
        class_counts = np.bincount(y[tr_sub_idx], minlength=n_classes).astype(np.float32)
        class_counts[class_counts==0]=1.0
        class_w = class_counts.sum()/class_counts
        class_w = class_w / class_w.mean()
        class_w = torch.tensor(class_w, dtype=torch.float32, device=DEVICE)

        print(f"\n[Fold {fold}/{N_FOLDS}] Entrenando Stage-1 (encoder)...")
        best_state, hist = trainer.fit(model, tr_loader, va_loader, epochs=EPOCHS_STAGE1, class_weights=class_w)
        plot_training_curves(hist, f"stage1_curve_fold{fold}.png", title='Stage-1 Train/Val Acc')
        print(f"↳ Curva Stage-1 guardada: stage1_curve_fold{fold}.png")
        model.load_state_dict(best_state)

        # Evaluación Stage-1 pura en test
        logits_stage1 = trainer.predict_logits(model, te_loader)  # (Nte, C)
        y_te = y[te_idx]
        y_pred_stage1 = logits_stage1.argmax(1)
        acc_stage1 = (y_pred_stage1 == y_te).mean()
        print(f"[Fold {fold}] Stage-1 acc(test) = {acc_stage1:.4f}")

        # -------- Extraer embeddings + PSD para Stage-2 --------
        emb_tr, y_tr = extract_embeddings(model, tr_loader, DEVICE)
        emb_va, y_va = extract_embeddings(model, va_loader, DEVICE)
        emb_te, _    = extract_embeddings(model, te_loader, DEVICE)

        # recolecta X bruto para PSD
        def collect_numpy(dl):
            Xs, ys = [], []
            for xb,yb,_ in dl:
                Xs.append(xb.squeeze(1).numpy()); ys.append(yb.numpy())
            return np.concatenate(Xs, axis=0), np.concatenate(ys, axis=0)
        X_tr_np, _ = collect_numpy(tr_loader)
        X_va_np, _ = collect_numpy(va_loader)
        X_te_np, _ = collect_numpy(te_loader)

        if welch is None:
            print("WARNING: scipy no disponible → Stage-2 usará SOLO embeddings (sin PSD).")
            psd_tr = np.zeros((emb_tr.shape[0], 0), dtype=np.float32)
            psd_va = np.zeros((emb_va.shape[0], 0), dtype=np.float32)
            psd_te = np.zeros((emb_te.shape[0], 0), dtype=np.float32)
        else:
            psd_tr = make_psd_features(X_tr_np, PSDConfig(fs=FS))
            psd_va = make_psd_features(X_va_np, PSDConfig(fs=FS))
            psd_te = make_psd_features(X_te_np, PSDConfig(fs=FS))

        Xtr2 = np.hstack([emb_tr, psd_tr]); Xva2 = np.hstack([emb_va, psd_va]); Xte2 = np.hstack([emb_te, psd_te])

        # -------- Stage-2 (TabNet/MLP) --------
        head = TabularHead(input_dim=Xtr2.shape[1], n_classes=n_classes, use_tabnet=USE_TABNET, device=DEVICE)
        print(f"[Fold {fold}] Entrenando Stage-2 (tabular head)...")
        head.fit(Xtr2, y_tr, X_val=Xva2, y_val=y_va, epochs=EPOCHS_STAGE2, patience=PATIENCE_STAGE2)
        logits_stage2 = head.predict_logits(Xte2)

        # -------- Blend opcional --------
        logits_blend = blend_logits(logits_stage1, logits_stage2, w=BLEND_W)
        y_pred_blend = logits_blend.argmax(1)
        acc_blend = (y_pred_blend == y_te).mean()
        print(f"[Fold {fold}] BLEND acc(test) = {acc_blend:.4f} (w={BLEND_W}) | Δ vs Stage-1 = {acc_blend-acc_stage1:+.4f}")

        global_folds.append(acc_stage1)
        blend_folds.append(acc_blend)
        all_true.append(y_te); all_pred.append(y_pred_blend)

    if len(all_true)>0:
        all_true = np.concatenate(all_true); all_pred = np.concatenate(all_pred)
    else:
        all_true = np.array([], dtype=int); all_pred = np.array([], dtype=int)

    print("\n" + "="*60)
    print("RESULTADOS FINALES (Two-Stage)")
    print("="*60)
    print("Stage-1 folds:", [f"{a:.4f}" for a in global_folds])
    if global_folds:
        print(f"Stage-1 mean: {np.mean(global_folds):.4f}")
    print("BLEND folds:", [f"{a:.4f}" for a in blend_folds])
    if blend_folds:
        print(f"BLEND mean: {np.mean(blend_folds):.4f}")
        print(f"Δ(BLEND - Stage-1) mean: {np.mean(blend_folds) - np.mean(global_folds):+.4f}")

    if all_true.size>0:
        plot_confusion(all_true, all_pred, CLASS_NAMES_4C,
                       title="Confusion Matrix - Two-Stage BLEND (All Folds)",
                       fname="confusion_twostage_blend_allfolds.png")
        print("↳ Matriz de confusión guardada: confusion_twostage_blend_allfolds.png")

    return {
        "stage1_folds": global_folds,
        "blend_folds": blend_folds,
        "all_true": all_true,
        "all_pred": all_pred,
        "folds_json_path": str(FOLDS_DIR)
    }

# -------- MAIN --------
if __name__ == "__main__":
    print("🧠 INICIANDO EXPERIMENTO TWO-STAGE (Encoder+Tabular)")
    print(f"🔧 Configuración: {CLASS_SCENARIO}, {len(EXPECTED_8)} canales, {WINDOW_MODE}")
    print(f"⚙️  Stage-1: epochs={EPOCHS_STAGE1}, lr={LR_STAGE1}, SGDR T0={SGDR_T0}, Tmult={SGDR_Tmult}")
    print(f"⚙️  Stage-2: {'TabNet' if USE_TABNET else 'MLP'} | epochs={EPOCHS_STAGE2}, patience={PATIENCE_STAGE2}")
    out = run_experiment_two_stage()
    print("✔️ Done.")


🚀 Usando dispositivo: cuda
🧠 INICIANDO EXPERIMENTO TWO-STAGE (Encoder+Tabular)
🔧 Configuración: 4c, 8 canales, 3s
⚙️  Stage-1: epochs=70, lr=0.001, SGDR T0=6, Tmult=2
⚙️  Stage-2: TabNet | epochs=400, patience=40
Sujetos elegibles: 103 → [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]...


Construyendo dataset (RAW): 100%|██████████| 103/103 [00:40<00:00,  2.54it/s]


Dataset construido: N=8652 | T=480 | C=8 | clases=4 | sujetos únicos=103
Listo para entrenar: N=8652 | T=480 | C=8 | clases=4 | sujetos=103

[Fold 1/5] Entrenando Stage-1 (encoder)...
[Stage-1] Ep   1 | train_acc=0.3606 | val_acc=1.0000 | LR=0.00100
[Stage-1] Ep   5 | train_acc=0.5174 | val_acc=1.0000 | LR=0.00025
[Stage-1] Ep  10 | train_acc=0.5459 | val_acc=1.0000 | LR=0.00085
[Stage-1] Early stopping @ 13 (best val_acc=1.0000)
↳ Curva Stage-1 guardada: stage1_curve_fold1.png
[Fold 1] Stage-1 acc(test) = 0.3526




[Fold 1] Entrenando Stage-2 (tabular head)...
epoch 0  | loss: 1.57781 | val_0_accuracy: 0.30311 |  0:00:00s
epoch 1  | loss: 1.36085 | val_0_accuracy: 0.28938 |  0:00:01s
epoch 2  | loss: 1.3195  | val_0_accuracy: 0.2674  |  0:00:01s
epoch 3  | loss: 1.29776 | val_0_accuracy: 0.32234 |  0:00:01s
epoch 4  | loss: 1.28598 | val_0_accuracy: 0.35073 |  0:00:02s
epoch 5  | loss: 1.28318 | val_0_accuracy: 0.3544  |  0:00:02s
epoch 6  | loss: 1.27581 | val_0_accuracy: 0.35531 |  0:00:03s
epoch 7  | loss: 1.27677 | val_0_accuracy: 0.35348 |  0:00:03s
epoch 8  | loss: 1.26668 | val_0_accuracy: 0.37363 |  0:00:04s
epoch 9  | loss: 1.273   | val_0_accuracy: 0.34799 |  0:00:04s
epoch 10 | loss: 1.25982 | val_0_accuracy: 0.3663  |  0:00:05s
epoch 11 | loss: 1.26755 | val_0_accuracy: 0.35897 |  0:00:05s
epoch 12 | loss: 1.25274 | val_0_accuracy: 0.35256 |  0:00:06s
epoch 13 | loss: 1.25517 | val_0_accuracy: 0.37179 |  0:00:06s
epoch 14 | loss: 1.25671 | val_0_accuracy: 0.35256 |  0:00:07s
epoch 15 



[Fold 1] BLEND acc(test) = 0.3730 (w=0.5) | Δ vs Stage-1 = +0.0204

[Fold 2/5] Entrenando Stage-1 (encoder)...
[Stage-1] Ep   1 | train_acc=0.3447 | val_acc=1.0000 | LR=0.00100
[Stage-1] Ep   5 | train_acc=0.5171 | val_acc=1.0000 | LR=0.00025
[Stage-1] Ep  10 | train_acc=0.5416 | val_acc=1.0000 | LR=0.00085
[Stage-1] Early stopping @ 13 (best val_acc=1.0000)
↳ Curva Stage-1 guardada: stage1_curve_fold2.png
[Fold 2] Stage-1 acc(test) = 0.4223




[Fold 2] Entrenando Stage-2 (tabular head)...
epoch 0  | loss: 1.56595 | val_0_accuracy: 0.26557 |  0:00:00s
epoch 1  | loss: 1.37343 | val_0_accuracy: 0.25    |  0:00:00s
epoch 2  | loss: 1.31592 | val_0_accuracy: 0.27015 |  0:00:01s
epoch 3  | loss: 1.29433 | val_0_accuracy: 0.26099 |  0:00:01s
epoch 4  | loss: 1.28363 | val_0_accuracy: 0.28846 |  0:00:02s
epoch 5  | loss: 1.26352 | val_0_accuracy: 0.33974 |  0:00:02s
epoch 6  | loss: 1.26247 | val_0_accuracy: 0.3663  |  0:00:03s
epoch 7  | loss: 1.26035 | val_0_accuracy: 0.36447 |  0:00:03s
epoch 8  | loss: 1.25784 | val_0_accuracy: 0.37088 |  0:00:04s
epoch 9  | loss: 1.26009 | val_0_accuracy: 0.36172 |  0:00:04s
epoch 10 | loss: 1.2434  | val_0_accuracy: 0.37454 |  0:00:05s
epoch 11 | loss: 1.2328  | val_0_accuracy: 0.37271 |  0:00:05s
epoch 12 | loss: 1.2349  | val_0_accuracy: 0.36996 |  0:00:05s
epoch 13 | loss: 1.22338 | val_0_accuracy: 0.36813 |  0:00:06s
epoch 14 | loss: 1.21859 | val_0_accuracy: 0.36264 |  0:00:06s
epoch 15 



[Fold 2] BLEND acc(test) = 0.4331 (w=0.5) | Δ vs Stage-1 = +0.0108

[Fold 3/5] Entrenando Stage-1 (encoder)...
[Stage-1] Ep   1 | train_acc=0.3651 | val_acc=1.0000 | LR=0.00100
[Stage-1] Ep   5 | train_acc=0.5162 | val_acc=1.0000 | LR=0.00025
[Stage-1] Ep  10 | train_acc=0.5492 | val_acc=1.0000 | LR=0.00085
[Stage-1] Early stopping @ 13 (best val_acc=1.0000)
↳ Curva Stage-1 guardada: stage1_curve_fold3.png
[Fold 3] Stage-1 acc(test) = 0.3844




[Fold 3] Entrenando Stage-2 (tabular head)...
epoch 0  | loss: 1.56814 | val_0_accuracy: 0.28938 |  0:00:00s
epoch 1  | loss: 1.37503 | val_0_accuracy: 0.25641 |  0:00:00s
epoch 2  | loss: 1.32627 | val_0_accuracy: 0.25458 |  0:00:01s
epoch 3  | loss: 1.2972  | val_0_accuracy: 0.28205 |  0:00:01s
epoch 4  | loss: 1.30162 | val_0_accuracy: 0.30769 |  0:00:02s
epoch 5  | loss: 1.28396 | val_0_accuracy: 0.32967 |  0:00:02s
epoch 6  | loss: 1.27343 | val_0_accuracy: 0.33883 |  0:00:03s
epoch 7  | loss: 1.27326 | val_0_accuracy: 0.34524 |  0:00:03s
epoch 8  | loss: 1.26439 | val_0_accuracy: 0.36355 |  0:00:04s
epoch 9  | loss: 1.25808 | val_0_accuracy: 0.38553 |  0:00:04s
epoch 10 | loss: 1.25545 | val_0_accuracy: 0.35806 |  0:00:05s
epoch 11 | loss: 1.25803 | val_0_accuracy: 0.39652 |  0:00:05s
epoch 12 | loss: 1.25404 | val_0_accuracy: 0.38828 |  0:00:05s
epoch 13 | loss: 1.24225 | val_0_accuracy: 0.37821 |  0:00:06s
epoch 14 | loss: 1.24355 | val_0_accuracy: 0.39652 |  0:00:06s
epoch 15 



[Fold 3] BLEND acc(test) = 0.3889 (w=0.5) | Δ vs Stage-1 = +0.0045

[Fold 4/5] Entrenando Stage-1 (encoder)...
[Stage-1] Ep   1 | train_acc=0.3522 | val_acc=1.0000 | LR=0.00100
[Stage-1] Ep   5 | train_acc=0.5078 | val_acc=1.0000 | LR=0.00025
[Stage-1] Ep  10 | train_acc=0.5337 | val_acc=1.0000 | LR=0.00085
[Stage-1] Early stopping @ 13 (best val_acc=1.0000)
↳ Curva Stage-1 guardada: stage1_curve_fold4.png
[Fold 4] Stage-1 acc(test) = 0.3756




[Fold 4] Entrenando Stage-2 (tabular head)...
epoch 0  | loss: 1.56813 | val_0_accuracy: 0.27839 |  0:00:00s
epoch 1  | loss: 1.35981 | val_0_accuracy: 0.3141  |  0:00:00s
epoch 2  | loss: 1.32926 | val_0_accuracy: 0.32234 |  0:00:01s
epoch 3  | loss: 1.30819 | val_0_accuracy: 0.34982 |  0:00:01s
epoch 4  | loss: 1.30547 | val_0_accuracy: 0.34615 |  0:00:02s
epoch 5  | loss: 1.29807 | val_0_accuracy: 0.34799 |  0:00:02s
epoch 6  | loss: 1.28665 | val_0_accuracy: 0.3489  |  0:00:03s
epoch 7  | loss: 1.26928 | val_0_accuracy: 0.37637 |  0:00:03s
epoch 8  | loss: 1.26966 | val_0_accuracy: 0.36813 |  0:00:04s
epoch 9  | loss: 1.26733 | val_0_accuracy: 0.37088 |  0:00:04s
epoch 10 | loss: 1.26133 | val_0_accuracy: 0.35623 |  0:00:04s
epoch 11 | loss: 1.25938 | val_0_accuracy: 0.36813 |  0:00:05s
epoch 12 | loss: 1.24944 | val_0_accuracy: 0.37729 |  0:00:05s
epoch 13 | loss: 1.24751 | val_0_accuracy: 0.38736 |  0:00:06s
epoch 14 | loss: 1.25004 | val_0_accuracy: 0.38553 |  0:00:06s
epoch 15 



[Fold 4] BLEND acc(test) = 0.4024 (w=0.5) | Δ vs Stage-1 = +0.0268

[Fold 5/5] Entrenando Stage-1 (encoder)...
[Stage-1] Ep   1 | train_acc=0.3543 | val_acc=1.0000 | LR=0.00100
[Stage-1] Ep   5 | train_acc=0.5065 | val_acc=1.0000 | LR=0.00025
[Stage-1] Ep  10 | train_acc=0.5313 | val_acc=1.0000 | LR=0.00085
[Stage-1] Early stopping @ 13 (best val_acc=1.0000)
↳ Curva Stage-1 guardada: stage1_curve_fold5.png
[Fold 5] Stage-1 acc(test) = 0.4060




[Fold 5] Entrenando Stage-2 (tabular head)...
epoch 0  | loss: 1.60815 | val_0_accuracy: 0.23352 |  0:00:00s
epoch 1  | loss: 1.36743 | val_0_accuracy: 0.25641 |  0:00:00s
epoch 2  | loss: 1.3148  | val_0_accuracy: 0.28755 |  0:00:01s
epoch 3  | loss: 1.30859 | val_0_accuracy: 0.35073 |  0:00:01s
epoch 4  | loss: 1.29575 | val_0_accuracy: 0.35256 |  0:00:02s
epoch 5  | loss: 1.28065 | val_0_accuracy: 0.37454 |  0:00:02s
epoch 6  | loss: 1.27686 | val_0_accuracy: 0.36447 |  0:00:03s
epoch 7  | loss: 1.26341 | val_0_accuracy: 0.37088 |  0:00:03s
epoch 8  | loss: 1.25136 | val_0_accuracy: 0.39286 |  0:00:04s
epoch 9  | loss: 1.26418 | val_0_accuracy: 0.38187 |  0:00:04s
epoch 10 | loss: 1.26114 | val_0_accuracy: 0.39194 |  0:00:04s
epoch 11 | loss: 1.25717 | val_0_accuracy: 0.39927 |  0:00:05s
epoch 12 | loss: 1.2527  | val_0_accuracy: 0.41026 |  0:00:05s
epoch 13 | loss: 1.25438 | val_0_accuracy: 0.39927 |  0:00:06s
epoch 14 | loss: 1.23913 | val_0_accuracy: 0.413   |  0:00:06s
epoch 15 



[Fold 5] BLEND acc(test) = 0.4226 (w=0.5) | Δ vs Stage-1 = +0.0167

RESULTADOS FINALES (Two-Stage)
Stage-1 folds: ['0.3526', '0.4223', '0.3844', '0.3756', '0.4060']
Stage-1 mean: 0.3882
BLEND folds: ['0.3730', '0.4331', '0.3889', '0.4024', '0.4226']
BLEND mean: 0.4040
Δ(BLEND - Stage-1) mean: +0.0158
↳ Matriz de confusión guardada: confusion_twostage_blend_allfolds.png
✔️ Done.


In [5]:
# -*- coding: utf-8 -*-
"""
Two-Stage MI-EEG (PhysioNet) — Standalone, tuned
- Stage-1: EEGNet-parallel + MHA temporal + S-TCN + Time Slicing (por logits)
           con augments + TTA, mayor capacidad (F1=24, emb=160, heads=4).
- Stage-2: TabNet (si disponible) con hiperparámetros afinados; fallback MLP.
- Blend: selecciona w por validación y lo aplica en test.
- Mismas rutas/estructura que tu baseline, 8 canales MI, exclusiones y GroupKFold(5).

Archivos generados:
- stage1_curve_foldX.png por fold
- confusion_twostage_blend_allfolds.png
"""

import os, re, math, random, json, itertools, copy, warnings
from pathlib import Path
from datetime import datetime

import numpy as np
import pandas as pd
import mne

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset, WeightedRandomSampler

from sklearn.model_selection import GroupKFold
from sklearn.utils import check_random_state
from sklearn.metrics import confusion_matrix, classification_report

import matplotlib.pyplot as plt
from tqdm import tqdm

# =============== CONFIG ===============
PROJ = Path('..').resolve().parent
DATA_RAW = PROJ / 'data' / 'raw'
CACHE_DIR = PROJ / 'data' / 'cache'
FOLDS_DIR = PROJ / 'models' / 'folds' / 'Kfold5.json'
CACHE_DIR.mkdir(parents=True, exist_ok=True)

RANDOM_STATE = 42
torch.manual_seed(RANDOM_STATE); np.random.seed(RANDOM_STATE); random.seed(RANDOM_STATE)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🚀 Usando dispositivo: {DEVICE}")

# Escenario y ventana
CLASS_SCENARIO = '4c'
WINDOW_MODE   = '6s'   # usamos 6s y recortamos 0.5–4.5 s dentro (ver extract_trials_from_run)
FS = 160.0
N_FOLDS = 5

# Entrenamiento Stage-1
BATCH_SIZE = 64
EPOCHS_STAGE1 = 80
LR_STAGE1 = 1e-3
WD_STAGE1 = 1e-5
SGDR_T0 = 6
SGDR_Tmult = 2
PATIENCE_STAGE1 = 15
LOG_EVERY = 5

# Stage-2 (Tabular)
USE_TABNET = True  # si no está instalado, cae a MLP automáticamente
EPOCHS_STAGE2 = 400
PATIENCE_STAGE2 = 50

# Blend: se elegirá por validación entre estos pesos
BLEND_CANDIDATES = [0.3, 0.4, 0.5, 0.6]

# Normalización por época
NORM_EPOCH_ZSCORE = True

# Sujetos excluidos
EXCLUDE_SUBJECTS = {38, 88, 89, 92, 100, 104}

# Runs
MI_RUNS_LR = [4, 8, 12]
MI_RUNS_OF = [6, 10, 14]
BASELINE_RUNS_EO = [1]

# Canales (8 con FCz)
EXPECTED_8 = ['C3','C4','Cz','CP3','CP4','FC3','FC4','FCz']
CLASS_NAMES_4C = ['Left', 'Right', 'Both Fists', 'Both Feet']

# ============== UTIL CANALES/EDF (idéntico base) ==============
_re_file = re.compile(r'[Ss](\d{3}).*?[Rr](\d{2})')

def normalize_label(s: str) -> str:
    if s is None: return s
    s = s.strip()
    s = re.sub(r'[^A-Za-z0-9]', '', s)
    s = re.sub(r'([A-Za-z])0([0-9])', r'\1\2', s)
    s = re.sub(r'([A-Za-z])Z$', r'\1z', s)
    s = s.replace('fp', 'Fp').replace('FP', 'Fp')
    s = ''.join(ch.upper() if ch != 'z' else 'z' for ch in s)
    return s

def rename_channels_1010(raw: mne.io.BaseRaw):
    mapping = {}
    for ch in raw.ch_names:
        lab = normalize_label(ch)
        lab = lab[:-1] + 'z' if lab.endswith('Z') else lab
        lab = re.sub(r'([A-Z])Z$', r'\1z', lab)
        mapping[ch] = lab
    mne.rename_channels(raw.info, mapping)

def ensure_channels_order(raw: mne.io.BaseRaw, desired_channels=EXPECTED_8):
    missing = [ch for ch in desired_channels if ch not in raw.ch_names]
    if missing:
        print(f"Warning: faltan canales {missing} en archivo {getattr(raw,'filenames', [''])[0]}")
        return None
    raw.reorder_channels([ch for ch in raw.ch_names if ch in desired_channels] +
                         [ch for ch in raw.ch_names if ch not in desired_channels])
    raw.pick_channels(desired_channels, ordered=True)
    return raw

def parse_subject_run(path: Path):
    m = _re_file.search(str(path))
    if not m: return None, None
    return int(m.group(1)), int(m.group(2))

def run_kind(run_id:int):
    if run_id in MI_RUNS_LR: return 'LR'
    if run_id in MI_RUNS_OF: return 'OF'
    if run_id in BASELINE_RUNS_EO: return 'EO'
    return None

_SNR_TABLE = None
def _load_snr_table():
    global _SNR_TABLE
    if _SNR_TABLE is not None:
        return _SNR_TABLE
    csv_path = PROJ / 'reports' / 'psd_mains' / 'psd_mains_summary.csv'
    if csv_path.exists():
        try:
            df = pd.read_csv(csv_path)
            _SNR_TABLE = df
        except Exception as e:
            print(f"[SNR] No se pudo leer {csv_path}: {e}")
            _SNR_TABLE = None
    return _SNR_TABLE

def _decide_notch(subject, run, th_db=10.0):
    df = _load_snr_table()
    if df is None:  # default
        return 60.0
    row = df[(df['subject']==subject) & (df['run']==run)]
    if row.empty:
        return 60.0
    snr50 = float(row['snr50_db'].iloc[0]); snr60 = float(row['snr60_db'].iloc[0])
    if snr60 >= th_db and snr60 >= snr50: return 60.0
    if snr50 >= th_db and snr50 >  snr60: return 50.0
    return None  # no notch si no sobresale

def read_raw_edf(path: Path):
    raw = mne.io.read_raw_edf(path, preload=True, verbose=False)
    raw.pick(mne.pick_types(raw.info, eeg=True))
    rename_channels_1010(raw)
    try:
        mont = mne.channels.make_standard_montage('standard_1020')
        raw.set_montage(mont, on_missing='ignore')
    except Exception:
        pass
    if abs(raw.info['sfreq'] - FS) > 1e-6:
        raw.resample(FS, npad="auto")
    raw = ensure_channels_order(raw, EXPECTED_8)
    if raw is None:
        return None
    sid, rid = parse_subject_run(path)
    notch = _decide_notch(sid, rid)
    if notch is not None:
        raw.notch_filter(freqs=[float(notch)], picks='eeg', method='spectrum_fit', phase='zero')
    return raw

def collect_events_T1T2(raw: mne.io.BaseRaw):
    if raw.annotations is None or len(raw.annotations) == 0:
        return []
    def _norm(s): return str(s).strip().upper().replace(' ', '')
    res = []
    for onset, desc in zip(raw.annotations.onset, raw.annotations.description):
        tag = _norm(desc)
        if tag in ('T1','T2'):
            res.append((float(onset), tag))
    res.sort()
    dedup = []
    last_t1 = last_t2 = -1e9
    for t, tag in res:
        if tag == 'T1':
            if (t - last_t1) >= 0.5: dedup.append((t, tag)); last_t1 = t
        else:
            if (t - last_t2) >= 0.5: dedup.append((t, tag)); last_t2 = t
    return dedup

# ============== DATASET BUILD ==============
def subjects_available():
    subs = []
    for sdir in sorted(DATA_RAW.glob('S*')):
        if not sdir.is_dir(): continue
        try: sid = int(sdir.name[1:])
        except: continue
        if sid in EXCLUDE_SUBJECTS: continue
        any_mi = any((sdir / f"S{sid:03d}R{r:02d}.edf").exists() for r in (MI_RUNS_LR + MI_RUNS_OF))
        if any_mi: subs.append(sid)
    return subs

def extract_trials_from_run(edf_path: Path, scenario: str, window_mode: str):
    subj, run = parse_subject_run(edf_path)
    kind = run_kind(run)
    if kind not in ('LR','OF','EO'):
        return ([], [])

    raw = read_raw_edf(edf_path)
    if raw is None:
        return ([], [])

    data = raw.get_data()
    fs = raw.info['sfreq']
    assert abs(fs - FS) < 1e-6

    out = []

    if kind in ('LR','OF'):
        events = collect_events_T1T2(raw)
        if window_mode == '3s':
            rel_start, rel_end = 0.0, 3.0
        else:
            # usamos 6s pero recortamos la "ventana útil" 0.5–4.5 s
            rel_start, rel_end = 0.5, 4.5

        for onset_sec, tag in events:
            if kind == 'LR':
                if tag == 'T1': label = 'L'
                elif tag == 'T2': label = 'R'
                else: continue
            else:
                if tag == 'T1': label = 'BFISTS'
                elif tag == 'T2': label = 'BFEET'
                else: continue

            if scenario == '2c' and label not in ('L','R'): continue
            if scenario == '3c' and label not in ('L','R','BFISTS'): continue
            if scenario == '4c' and label not in ('L','R','BFISTS','BFEET'): continue

            s = int(round((raw.first_time + onset_sec + rel_start) * fs))
            e = int(round((raw.first_time + onset_sec + rel_end) * fs))
            if s < 0 or e > data.shape[1]:
                continue

            seg = data[:, s:e].T.astype(np.float32)  # (T,C)
            if NORM_EPOCH_ZSCORE:
                seg = (seg - seg.mean(axis=0, keepdims=True)) / (seg.std(axis=0, keepdims=True) + 1e-6)

            if label == 'L':       y = 0
            elif label == 'R':     y = 1
            elif label == 'BFISTS':y = 2
            elif label == 'BFEET': y = 3
            else: continue

            out.append((seg, y, subj))

    elif kind == 'EO':
        return ([], raw.ch_names)

    return out, raw.ch_names

def build_dataset_all(subjects, scenario='4c', window_mode='6s'):
    X, y, groups = [], [], []
    ch_template = None

    for s in tqdm(subjects, desc="Construyendo dataset (RAW)"):
        sdir = DATA_RAW / f"S{s:03d}"
        if not sdir.exists(): continue

        trials_L, trials_R, trials_FISTS, trials_FEET = [], [], [], []

        for r in MI_RUNS_LR:
            p = sdir / f"S{s:03d}R{r:02d}.edf"
            if not p.exists(): continue
            outs, chs = extract_trials_from_run(p, scenario, window_mode)
            if ch_template is None and chs: ch_template = chs
            for seg, lab, _ in outs:
                if lab == 0: trials_L.append(seg)
                elif lab == 1: trials_R.append(seg)

        for r in MI_RUNS_OF:
            p = sdir / f"S{s:03d}R{r:02d}.edf"
            if not p.exists(): continue
            outs, chs = extract_trials_from_run(p, scenario, window_mode)
            if ch_template is None and chs: ch_template = chs
            for seg, lab, _ in outs:
                if lab == 2: trials_FISTS.append(seg)
                elif lab == 3: trials_FEET.append(seg)

        need_per_class = 21
        def pick(trials, n, rng):
            if len(trials) < n:
                idx = rng.choice(len(trials), size=n, replace=True)
                return [trials[i] for i in idx]
            rng.shuffle(trials)
            return trials[:n]

        rng = check_random_state(RANDOM_STATE + s)
        if len(trials_L)==0 or len(trials_R)==0 or len(trials_FISTS)==0 or len(trials_FEET)==0:
            continue

        Lp  = pick(trials_L,     need_per_class, rng)
        Rp  = pick(trials_R,     need_per_class, rng)
        FIp = pick(trials_FISTS, need_per_class, rng)
        FEp = pick(trials_FEET,  need_per_class, rng)

        pack = [(Lp, 0), (Rp, 1), (FIp, 2), (FEp, 3)]
        for segs, lab in pack:
            for seg in segs:
                X.append(seg); y.append(lab); groups.append(s)

    X = np.stack(X, axis=0)
    y = np.asarray(y, dtype=np.int64)
    groups = np.asarray(groups, dtype=np.int64)

    n, T, C = X.shape
    n_classes = len(np.unique(y))
    print(f"Dataset construido: N={n} | T={T} | C={C} | clases={n_classes} | sujetos únicos={len(np.unique(groups))}")
    return X, y, groups, ch_template

# ============== Dataset y utils torch ==============
class EEGTrials(Dataset):
    def __init__(self, X, y, groups):
        self.X = X.astype(np.float32)
        self.y = y.astype(np.int64)
        self.g = groups.astype(np.int64)
    def __len__(self): return self.X.shape[0]
    def __getitem__(self, idx):
        x = self.X[idx]
        x = np.expand_dims(x, 0)                 # (1, T, C)
        return torch.from_numpy(x), torch.tensor(self.y[idx]), torch.tensor(self.g[idx])

@torch.no_grad()
def apply_max_norm(layers, max_value=2.0, p=2.0):
    for layer in layers:
        if hasattr(layer, 'weight') and layer.weight is not None:
            w = layer.weight.data
            norms = w.view(w.size(0), -1).norm(p=p, dim=1, keepdim=True)
            desired = torch.clamp(norms, max=max_value)
            w.view(w.size(0), -1).mul_(desired / (1e-8 + norms))

# ============== Augments + TTA ==============
def do_time_jitter(x, max_ms=50, fs=160.0):
    max_shift = int(round(max_ms/1000.0 * fs))
    if max_shift <= 0: return x
    B,_,T,C = x.shape
    shifts = torch.randint(low=-max_shift, high=max_shift+1, size=(B,), device=x.device)
    out = torch.empty_like(x)
    for i,s in enumerate(shifts):
        if s==0: out[i] = x[i]; continue
        if s>0:
            out[i,:,s:,:] = x[i,:,:T-s,:]
            out[i,:,:s,:] = 0
        else:
            s = -s
            out[i,:,:T-s,:] = x[i,:,s:,:]
            out[i,:,T-s:,:] = 0
    return out

def do_gaussian_noise(x, sigma=0.01):
    if sigma<=0: return x
    return x + sigma*torch.randn_like(x)

def mixup_batch(x, y, n_classes, alpha=0.2):
    if alpha<=0:
        y_onehot = torch.nn.functional.one_hot(y, num_classes=n_classes).float()
        return x, y_onehot, 1.0
    lam = np.random.beta(alpha, alpha)
    perm = torch.randperm(x.size(0), device=x.device)
    x_mix = lam*x + (1-lam)*x[perm]
    y_a = torch.nn.functional.one_hot(y, num_classes=n_classes).float()
    y_b = y_a[perm]
    y_mix = lam*y_a + (1-lam)*y_b
    return x_mix, y_mix, lam

@torch.no_grad()
def _tta_logits(model, xb, n=5):
    outs = []
    for _ in range(n):
        xj = do_time_jitter(xb, max_ms=25, fs=FS)
        outs.append(model(xj)[0])
    return torch.stack(outs, 0).mean(0)

# ============== Stage-1: Encoder ==============
class MHABlock(nn.Module):
    def __init__(self, dim: int, num_heads: int = 2, dropout: float = 0.1):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, batch_first=True)
        self.drop = nn.Dropout(dropout)
        self.ffn = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, dim * 4), nn.GELU(), nn.Dropout(dropout),
            nn.Linear(dim * 4, dim), nn.Dropout(dropout),
        )
    def forward(self, x):
        h = self.norm(x)
        y, _ = self.attn(h, h, h, need_weights=False)
        x = x + self.drop(y)
        x = x + self.ffn(x)
        return x

class DSConv1d(nn.Module):
    def __init__(self, in_ch: int, out_ch: int, k: int = 3, dilation: int = 1, dropout: float = 0.1):
        super().__init__()
        pad = (k - 1) // 2 * dilation
        self.dw = nn.Conv1d(in_ch, in_ch, kernel_size=k, groups=in_ch, padding=pad, dilation=dilation, bias=False)
        self.pw = nn.Conv1d(in_ch, out_ch, kernel_size=1, bias=False)
        self.bn = nn.BatchNorm1d(out_ch)
        self.act = nn.ELU(); self.drop = nn.Dropout(dropout)
    def forward(self, x):
        x = self.dw(x); x = self.pw(x); x = self.bn(x); x = self.act(x); x = self.drop(x); return x

class STCN(nn.Module):
    def __init__(self, dim: int, k: int = 3, dilations=(1,2,4,8), dropout: float = 0.1):
        super().__init__()
        self.blocks = nn.ModuleList([
            nn.Sequential(
                DSConv1d(dim, dim, k=k, dilation=d, dropout=dropout),
                nn.Conv1d(dim, dim, kernel_size=1, bias=False),
                nn.BatchNorm1d(dim), nn.ELU(), nn.Dropout(dropout)
            ) for d in dilations
        ])
    def forward(self, x):  # (B,D,T)
        for b in self.blocks:
            res = x; x = b(x); x = x + res
        return x

class EEGNetParallel(nn.Module):
    def __init__(self, n_ch: int, F1: int = 24, D: int = 2, k_t_list=(32,64,96), pool1_t=4, k_sep=16, pool2_t=4, drop=0.2):
        super().__init__()
        self.n_ch = n_ch
        self.branches = nn.ModuleList()
        for k_t in k_t_list:
            self.branches.append(nn.Sequential(
                nn.Conv2d(1, F1, kernel_size=(k_t, 1), padding=(k_t // 2, 0), bias=False),
                nn.BatchNorm2d(F1), nn.ELU(),
                nn.Conv2d(F1, F1*D, kernel_size=(1, n_ch), groups=F1, bias=False),
                nn.BatchNorm2d(F1*D), nn.ELU(),
                nn.AvgPool2d(kernel_size=(pool1_t,1), stride=(pool1_t,1)),
                nn.Dropout(drop),
                nn.Conv2d(F1*D, F1*D, kernel_size=(k_sep,1), groups=F1*D, padding=(k_sep//2,0), bias=False),
                nn.Conv2d(F1*D, F1*D, kernel_size=(1,1), bias=False),
                nn.BatchNorm2d(F1*D), nn.ELU(),
                nn.AvgPool2d(kernel_size=(pool2_t,1), stride=(pool2_t,1)),
                nn.Dropout(drop),
            ))
        self.out_dim = len(k_t_list) * (F1*D)
    def forward(self, x):
        feats = [b(x) for b in self.branches]  # each: (B,F2,T',1)
        return torch.cat(feats, dim=1)         # (B,Dsum,T',1)

class TwoStageEncoder(nn.Module):
    def __init__(self, n_ch: int, n_classes: int,
                 F1=24, D=2, k_t_list=(32,64,96), attn_heads=4,
                 stcn_k=3, stcn_dils=(1,2,4,8), emb_dim=160, drop=0.2,
                 slice_len_steps=8, slice_stride_steps=4, avg_on='logits'):
        super().__init__()
        self.backbone = EEGNetParallel(n_ch=n_ch, F1=F1, D=D, k_t_list=k_t_list, drop=drop)
        self.proj = nn.Conv2d(self.backbone.out_dim, emb_dim, kernel_size=(1,1), bias=False)
        self.attn = MHABlock(dim=emb_dim, num_heads=attn_heads, dropout=drop)
        self.stcn = STCN(dim=emb_dim, k=stcn_k, dilations=stcn_dils, dropout=drop)
        self.cls_head = nn.Linear(emb_dim, n_classes)
        self.emb_dim = emb_dim
        self.avg_on = avg_on
        self.slice_len_steps = slice_len_steps
        self.slice_stride_steps = slice_stride_steps

    def _time_slices(self, T):
        L, S = self.slice_len_steps, self.slice_stride_steps
        slices, s = [], 0
        while s + L <= T:
            slices.append((s, s+L)); s += S
        if not slices: slices = [(0,T)]
        return slices

    def forward(self, x):  # x: (B,1,T,C)
        z = self.backbone(x)                          # (B,Dsum,T',1)
        z = self.proj(z).squeeze(-1).transpose(1,2)  # (B,T',emb)
        z = self.attn(z)
        z = self.stcn(z.transpose(1,2)).transpose(1,2)  # (B,T',emb)
        B,Tp,Demb = z.shape
        idxs = self._time_slices(Tp)
        slice_embs, slice_logits = [], []
        for (a,b) in idxs:
            seg = z[:,a:b,:]
            e = seg.mean(dim=1)  # GAP temporal
            slice_embs.append(e)
            slice_logits.append(self.cls_head(e))
        if self.avg_on == 'logits':
            logits = torch.stack(slice_logits, dim=0).mean(dim=0)
            emb = torch.stack(slice_embs, dim=0).mean(dim=0)
        else:
            emb = torch.stack(slice_embs, dim=0).mean(dim=0)
            logits = self.cls_head(emb)
        return logits, emb

# ============== Stage-1 Trainer ==============
def build_weighted_sampler(y, groups):
    y = np.asarray(y); groups = np.asarray(groups)
    class_counts = np.bincount(y, minlength=len(np.unique(y))).astype(float)
    class_w = 1.0 / class_counts[y]
    subj_vals, subj_counts = np.unique(groups, return_counts=True)
    subj_map = {s:c for s,c in zip(subj_vals, subj_counts)}
    subj_w = np.array([1.0/subj_map[g] for g in groups], dtype=float)
    w = class_w * subj_w
    w = w / w.mean()
    w_t = torch.from_numpy(w).float()
    sampler = WeightedRandomSampler(weights=w_t, num_samples=len(w_t), replacement=True)
    return sampler

class Stage1Trainer:
    def __init__(self, device=DEVICE, lr=LR_STAGE1, wd=WD_STAGE1, label_smoothing=0.05):
        self.device = device; self.lr = lr; self.wd = wd
        self.label_smoothing = label_smoothing
    def _crit(self, n_classes, class_weights=None):
        return nn.CrossEntropyLoss(weight=class_weights, label_smoothing=self.label_smoothing)
    def fit(self, model, dl_train, dl_val, epochs=EPOCHS_STAGE1, class_weights=None):
        model.to(self.device)
        opt = optim.Adam(model.parameters(), lr=self.lr, weight_decay=self.wd)
        sched = optim.lr_scheduler.CosineAnnealingWarmRestarts(opt, T_0=SGDR_T0, T_mult=SGDR_Tmult)
        crit = self._crit(n_classes=model.cls_head.out_features, class_weights=class_weights)
        best_state = copy.deepcopy(model.state_dict()); best_val=-1.0; bad=0
        history = {'train_acc':[], 'val_acc':[]}
        for ep in range(1, epochs+1):
            # ---- train ----
            model.train(); n_ok=0; n_tot=0
            for xb,yb,_ in dl_train:
                xb,yb = xb.to(self.device), yb.to(self.device)
                # Augments
                xb = do_time_jitter(xb, max_ms=50, fs=FS)
                xb = do_gaussian_noise(xb, sigma=0.01)
                opt.zero_grad(set_to_none=True)
                logits,_ = model(xb)
                loss = crit(logits, yb)
                loss.backward()
                nn.utils.clip_grad_norm_(model.parameters(), 5.0)
                opt.step()
                n_ok += (logits.argmax(1)==yb).sum().item(); n_tot += yb.numel()
            sched.step(ep-1 + 1e-8)
            train_acc = n_ok/max(1,n_tot)
            # Max-norm ligero en la cabeza
            apply_max_norm([model.cls_head], max_value=2.0, p=2.0)
            # ---- val (con TTA) ----
            val_acc = self.evaluate(model, dl_val)
            history['train_acc'].append(train_acc); history['val_acc'].append(val_acc)
            if (ep % LOG_EVERY == 0) or ep in (1, 10, 20, 50, 80):
                cur_lr = opt.param_groups[0]['lr']
                print(f"[Stage-1] Ep {ep:3d} | train_acc={train_acc:.4f} | val_acc={val_acc:.4f} | LR={cur_lr:.5f}")
            if val_acc > best_val + 1e-4:
                best_val = val_acc; best_state = copy.deepcopy(model.state_dict()); bad = 0
            else:
                bad += 1
                if bad >= PATIENCE_STAGE1:
                    print(f"[Stage-1] Early stopping @ {ep} (best val_acc={best_val:.4f})")
                    break
        return best_state, history
    @torch.no_grad()
    def evaluate(self, model, dl):
        model.eval(); n_ok=0; n=0
        for xb,yb,_ in dl:
            xb,yb = xb.to(self.device), yb.to(self.device)
            logits = _tta_logits(model, xb, n=5)
            n_ok += (logits.argmax(1)==yb).sum().item(); n += yb.numel()
        return n/max(1,n)
    @torch.no_grad()
    def predict_logits(self, model, dl):
        model.eval(); outs=[]
        for xb,_,_ in dl:
            xb = xb.to(self.device)
            outs.append(_tta_logits(model, xb, n=5).detach().cpu().numpy())
        return np.vstack(outs)

# ============== PSD features (Welch) ==============
try:
    from scipy.signal import welch
except Exception:
    welch = None
    warnings.warn("scipy not found: PSD features disabled. Install scipy to enable.")

class PSDConfig:
    def __init__(self, fs=160.0, nperseg_sec=1.0, noverlap=0.5, bands=((1,4),(4,8),(8,14),(14,31),(31,49))):
        self.fs = fs; self.nperseg_sec = nperseg_sec; self.noverlap = noverlap; self.bands = bands

def _bandpower(f, Pxx, band):
    idx = (f >= band[0]) & (f < band[1])
    return (Pxx[idx].mean() + 1e-12)

def make_psd_features(X_np, cfg=PSDConfig()):
    if welch is None:
        raise RuntimeError("scipy.signal.welch not available.")
    N,T,C = X_np.shape
    nperseg = int(round(cfg.nperseg_sec*cfg.fs))
    noverlap = int(round(cfg.noverlap*nperseg))
    feats = []
    for i in range(N):
        x = X_np[i]
        ch_feats = []
        for c in range(C):
            f, P = welch(x[:,c], fs=cfg.fs, nperseg=nperseg, noverlap=noverlap)
            # log-bandpowers + tres ratios simples
            bps = [np.log(_bandpower(f, P, b)) for b in cfg.bands]  # delta, theta, alpha, beta, gamma
            ratios = [bps[2]-bps[3], bps[1]-bps[2], bps[0]-bps[2]]  # alpha-beta, theta-alpha, delta-alpha
            ch_feats.extend(bps + ratios)
        feats.append(ch_feats)
    return np.asarray(feats, dtype=np.float32)

# ============== Tabular Head (TabNet/MLP) ==============
class _MLPHead(nn.Module):
    def __init__(self, d_in, n_classes, hidden=512, drop=0.3):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_in, hidden), nn.ReLU(), nn.Dropout(drop),
            nn.Linear(hidden, hidden//2), nn.ReLU(), nn.Dropout(drop),
            nn.Linear(hidden//2, n_classes)
        )
    def forward(self, x): return self.net(x)

class TabularHead:
    def __init__(self, input_dim, n_classes, use_tabnet=USE_TABNET, device=DEVICE, tabnet_params=None):
        self.n_classes = n_classes; self.device = device
        self.use_tabnet = False; self.tabnet=None
        if use_tabnet:
            try:
                from pytorch_tabnet.tab_model import TabNetClassifier
                params = dict(n_d=32, n_a=32, n_steps=4, gamma=1.3,
                              lambda_sparse=1e-4, momentum=0.02, clip_value=2.0)
                if tabnet_params: params.update(tabnet_params)
                self.tabnet = TabNetClassifier(**params)
                self.use_tabnet = True
            except Exception:
                warnings.warn("pytorch_tabnet not found; using MLP fallback.")
        if not self.use_tabnet:
            self.mlp = _MLPHead(input_dim, n_classes).to(self.device)
            self.opt = optim.Adam(self.mlp.parameters(), lr=1e-3, weight_decay=1e-5)
            self.crit = nn.CrossEntropyLoss()
            self.scaler_mean=None; self.scaler_std=None
    def fit(self, X, y, X_val=None, y_val=None, epochs=EPOCHS_STAGE2, patience=PATIENCE_STAGE2):
        if self.use_tabnet:
            eval_set=[(X_val, y_val)] if X_val is not None else None
            self.tabnet.fit(X, y, eval_set=eval_set, max_epochs=epochs, patience=patience,
                            batch_size=256, virtual_batch_size=128)
        else:
            self.scaler_mean = X.mean(axis=0, keepdims=True); self.scaler_std = X.std(axis=0, keepdims=True)+1e-6
            Xn = (X - self.scaler_mean)/self.scaler_std; Xn = torch.from_numpy(Xn).float().to(self.device)
            y_t = torch.from_numpy(y).long().to(self.device)
            dset = torch.utils.data.TensorDataset(Xn,y_t)
            loader = DataLoader(dset, batch_size=256, shuffle=True)
            if X_val is not None and y_val is not None:
                Xv = (X_val - self.scaler_mean)/self.scaler_std
                self.Xv_t = torch.from_numpy(Xv).float().to(self.device)
                self.yv_t = torch.from_numpy(y_val).long().to(self.device)
            else:
                self.Xv_t=None; self.yv_t=None
            best_state = copy.deepcopy(self.mlp.state_dict()); best=-1.0; bad=0
            for ep in range(1, epochs+1):
                self.mlp.train()
                for xb,yb in loader:
                    self.opt.zero_grad(set_to_none=True)
                    logits = self.mlp(xb)
                    loss = self.crit(logits, yb)
                    loss.backward(); self.opt.step()
                if self.Xv_t is not None:
                    self.mlp.eval();
                    with torch.no_grad():
                        lv = self.mlp(self.Xv_t)
                        acc = (lv.argmax(1)==self.yv_t).float().mean().item()
                    if ep % 10 == 0:
                        print(f"[Stage-2] Ep {ep:3d} | val_acc={acc:.4f}")
                    if acc > best + 1e-4:
                        best = acc; best_state = copy.deepcopy(self.mlp.state_dict()); bad=0
                    else:
                        bad += 1
                        if bad >= patience:
                            print(f"[Stage-2] Early stopping @ {ep} (best val_acc={best:.4f})"); break
            self.mlp.load_state_dict(best_state)
    def predict_logits(self, X):
        if self.use_tabnet:
            proba = self.tabnet.predict_proba(X)
            return np.log(proba + 1e-9)
        else:
            Xn = (X - self.scaler_mean)/self.scaler_std
            Xt = torch.from_numpy(Xn).float().to(self.device)
            self.mlp.eval();
            with torch.no_grad():
                logits = self.mlp(Xt).cpu().numpy()
            return logits

def blend_logits(logits1, logits2, w):
    w = float(np.clip(w,0.0,1.0)); return w*logits1 + (1.0-w)*logits2

# ============== Folds JSON helpers ==============
def save_group_folds_json_with_indices(subject_ids_str, groups_array, n_splits, out_json_path,
                                       created_by="TwoStage_Exp", description=None):
    out_json_path = Path(out_json_path)
    unique_subjects_int = sorted(np.unique(groups_array).tolist())
    subject_ids = [f"S{sid:03d}" for sid in unique_subjects_int]
    if len(subject_ids) < n_splits:
        raise ValueError(f"n_splits={n_splits} mayor que número de sujetos={len(subject_ids)}")
    groups = np.arange(len(subject_ids))
    gkf = GroupKFold(n_splits=n_splits)
    folds = []
    fold_i = 0
    for train_idx_grp, test_idx_grp in gkf.split(groups, groups, groups):
        fold_i += 1
        train_sids = [subject_ids[int(i)] for i in train_idx_grp]
        test_sids  = [subject_ids[int(i)] for i in test_idx_grp]
        train_sids_int = [int(s[1:]) for s in train_sids]
        test_sids_int  = [int(s[1:]) for s in test_sids]
        tr_idx = np.where(np.isin(groups_array, train_sids_int))[0].tolist()
        te_idx = np.where(np.isin(groups_array, test_sids_int))[0].tolist()
        folds.append({"fold": int(fold_i), "train": train_sids, "test": test_sids, "tr_idx": tr_idx, "te_idx": te_idx})
    payload = {"created_at": datetime.now().isoformat(), "created_by": created_by,
               "description": description if description is not None else "",
               "n_splits": int(n_splits), "n_subjects": len(subject_ids),
               "subject_ids": subject_ids, "folds": folds}
    out_json_path.parent.mkdir(parents=True, exist_ok=True)
    with open(out_json_path, "w", encoding="utf-8") as f:
        json.dump(payload, f, indent=2, ensure_ascii=False)
    print(f"Folds JSON con índices guardado → {out_json_path}")
    return out_json_path

def load_group_folds_json(path_json, expected_subject_ids=None, strict_check=True):
    path_json = Path(path_json)
    if not path_json.exists():
        raise FileNotFoundError(f"No existe {path_json}")
    with open(path_json, "r", encoding="utf-8") as f:
        payload = json.load(f)
    subj_json = payload.get("subject_ids", [])
    if expected_subject_ids is not None:
        expected = sorted(list(expected_subject_ids))
        if subj_json != expected:
            msg = ("Los subject_ids del JSON no coinciden con expected_subject_ids.\n"
                   f"JSON has {len(subj_json)} subjects, expected {len(expected)}.\n"
                   f"First 10 JSON: {subj_json[:10]}\nFirst 10 expected: {expected[:10]}")
            if strict_check: raise ValueError(msg)
            else: print("WARNING: " + msg)
    return payload

# ============== Confusion + curves ==============
def plot_confusion(y_true, y_pred, classes, title, fname):
    cm = confusion_matrix(y_true, y_pred, labels=list(range(len(classes))))
    with np.errstate(invalid='ignore'):
        cm_norm = cm.astype('float') / cm.sum(axis=1, keepdims=True)
    cm_norm = np.nan_to_num(cm_norm)
    plt.figure(figsize=(6, 5))
    plt.imshow(cm_norm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title(title); plt.colorbar(fraction=0.046, pad=0.04)
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45, ha='right'); plt.yticks(tick_marks, classes)
    fmt = '.2f'; thresh = cm_norm.max() / 2.
    for i, j in itertools.product(range(cm_norm.shape[0]), range(cm_norm.shape[1])):
        plt.text(j, i, format(cm_norm[i, j], fmt), horizontalalignment="center",
                 color="white" if cm_norm[i, j] > thresh else "black")
    plt.ylabel('True label'); plt.xlabel('Predicted label')
    plt.tight_layout(); plt.savefig(fname, dpi=150, bbox_inches='tight'); plt.close()

def plot_training_curves(history, fname, title='Stage-1 Training curve'):
    plt.figure(figsize=(6,4))
    plt.plot(history['train_acc'], label='train_acc')
    plt.plot(history['val_acc'], label='val_acc')
    plt.xlabel('Epoch'); plt.ylabel('Accuracy'); plt.title(title); plt.legend()
    plt.tight_layout(); plt.savefig(fname, dpi=150); plt.close()

# ============== EXPERIMENTO PRINCIPAL ==============
def run_experiment_two_stage():
    mne.set_log_level('WARNING')
    subs = subjects_available()
    print(f"Sujetos elegibles: {len(subs)} → {subs[:10]}{'...' if len(subs)>10 else ''}")

    X, y, groups, chs = build_dataset_all(subs, scenario=CLASS_SCENARIO, window_mode=WINDOW_MODE)
    N, T, C = X.shape
    n_classes = len(np.unique(y))
    print(f"Listo para entrenar: N={N} | T={T} | C={C} | clases={n_classes} | sujetos={len(np.unique(groups))}")

    ds = EEGTrials(X, y, groups)

    # preparar JSON folds
    folds_json_path = Path(FOLDS_DIR)
    folds_json_path.parent.mkdir(parents=True, exist_ok=True)
    unique_subs = sorted(np.unique(groups).tolist())
    subject_ids_str = [f"S{s:03d}" for s in unique_subs]
    if not folds_json_path.exists():
        save_group_folds_json_with_indices(subject_ids_str, groups, n_splits=N_FOLDS,
                                           out_json_path=folds_json_path,
                                           created_by="TwoStage_Pipeline_Tuned",
                                           description="Two-Stage GroupKFold")
    payload = load_group_folds_json(folds_json_path, expected_subject_ids=subject_ids_str, strict_check=False)
    folds = payload["folds"]

    stage1_folds = []
    blend_folds  = []
    all_true = []; all_pred = []

    for f in folds:
        fold = f["fold"]
        tr_idx = np.asarray(f.get("tr_idx", []), dtype=int)
        te_idx = np.asarray(f.get("te_idx", []), dtype=int)
        if tr_idx.size == 0 or te_idx.size == 0:
            print(f"Advertencia: fold {fold} sin índices tr/te válidos. Saltando.")
            continue

        # ======= B) VALIDACIÓN INTERNA POR SUJETOS CON GroupKFold =======
        n_groups_tr = len(np.unique(groups[tr_idx]))
        if n_groups_tr < 2:
            raise RuntimeError(f"[Fold {fold}] Muy pocos sujetos en train para validación interna (n_groups_tr={n_groups_tr}).")
        n_splits_inner = min(5, n_groups_tr)  # hasta 5 splits; si hay menos sujetos, se ajusta
        gkf_in = GroupKFold(n_splits=n_splits_inner)
        inner_splits = list(gkf_in.split(tr_idx, y[tr_idx], groups[tr_idx]))
        tr_pos, va_pos = inner_splits[0]  # puedes elegir otro índice si prefieres
        tr_sub_idx = tr_idx[tr_pos]
        va_idx     = tr_idx[va_pos]

        # ======= A) PRINTS + ASSERTS DE SANIDAD (no-overlap por sujeto) =======
        tr_sub_ids = set(groups[tr_sub_idx].tolist())
        va_sub_ids = set(groups[va_idx].tolist())
        te_sub_ids = set(groups[te_idx].tolist())

        print(f"[Fold {fold}] sujetos train={len(tr_sub_ids)}, val={len(va_sub_ids)}, test={len(te_sub_ids)}")
        print(f"  train subs: {sorted(list(tr_sub_ids))[:10]}{'...' if len(tr_sub_ids)>10 else ''}")
        print(f"  val   subs: {sorted(list(va_sub_ids))[:10]}{'...' if len(va_sub_ids)>10 else ''}")
        print(f"  test  subs: {sorted(list(te_sub_ids))[:10]}{'...' if len(te_sub_ids)>10 else ''}")

        assert tr_sub_ids.isdisjoint(va_sub_ids), "♨️ fuga: sujeto aparece en train y val"
        assert tr_sub_ids.isdisjoint(te_sub_ids), "♨️ fuga: sujeto aparece en train y test"
        assert va_sub_ids.isdisjoint(te_sub_ids), "♨️ fuga: sujeto aparece en val y test"

        print(f"[Fold {fold}] n_train={len(tr_sub_idx)} | n_val={len(va_idx)} | n_test={len(te_idx)}")
        vals, cnts = np.unique(y[va_idx], return_counts=True)
        print(f"[Fold {fold}] val class dist: " + str(dict(zip(vals.tolist(), cnts.tolist()))))

        # -------- DataLoaders --------
        tr_loader = DataLoader(
            Subset(ds, tr_sub_idx),
            batch_size=BATCH_SIZE,
            sampler=build_weighted_sampler(y[tr_sub_idx], groups[tr_sub_idx]),
            drop_last=False
        )
        va_loader = DataLoader(Subset(ds, va_idx), batch_size=BATCH_SIZE, shuffle=False, drop_last=False)
        te_loader = DataLoader(Subset(ds, te_idx), batch_size=BATCH_SIZE, shuffle=False, drop_last=False)

        # -------- Stage-1 --------
        model = TwoStageEncoder(n_ch=C, n_classes=n_classes,
                                F1=24, D=2, k_t_list=(32,64,96), attn_heads=4,
                                stcn_k=3, stcn_dils=(1,2,4,8), emb_dim=160, drop=0.2,
                                slice_len_steps=8, slice_stride_steps=4, avg_on='logits').to(DEVICE)

        trainer = Stage1Trainer(device=DEVICE, lr=LR_STAGE1, wd=WD_STAGE1, label_smoothing=0.05)
        # class weights inversos por clase
        class_counts = np.bincount(y[tr_sub_idx], minlength=n_classes).astype(np.float32)
        class_counts[class_counts==0]=1.0
        class_w = class_counts.sum()/class_counts
        class_w = class_w / class_w.mean()
        class_w = torch.tensor(class_w, dtype=torch.float32, device=DEVICE)

        print(f"\n[Fold {fold}/{N_FOLDS}] Entrenando Stage-1 (encoder)...")
        best_state, hist = trainer.fit(model, tr_loader, va_loader, epochs=EPOCHS_STAGE1, class_weights=class_w)
        plot_training_curves(hist, f"stage1_curve_fold{fold}.png", title='Stage-1 Train/Val Acc')
        print(f"↳ Curva Stage-1 guardada: stage1_curve_fold{fold}.png")
        model.load_state_dict(best_state)

        # logits Stage-1 en validación y test (con TTA)
        logits_stage1_val = trainer.predict_logits(model, va_loader)
        logits_stage1_te  = trainer.predict_logits(model, te_loader)

        y_va = y[va_idx]
        y_te = y[te_idx]
        y_pred_stage1 = logits_stage1_te.argmax(1)
        acc_stage1 = (y_pred_stage1 == y_te).mean()
        print(f"[Fold {fold}] Stage-1 acc(test) = {acc_stage1:.4f}")

        # -------- Extraer embeddings + PSD --------
        @torch.no_grad()
        def extract_embeddings(model, dl, device=DEVICE):
            model.eval().to(device)
            embs, ys = [], []
            for xb,yb,_ in dl:
                xb = xb.to(device)
                _, e = model(xb)
                embs.append(e.detach().cpu().numpy()); ys.append(yb.numpy())
            return np.vstack(embs), np.concatenate(ys)

        emb_tr, y_tr = extract_embeddings(model, tr_loader, DEVICE)
        emb_va, _    = extract_embeddings(model, va_loader, DEVICE)
        emb_te, _    = extract_embeddings(model, te_loader, DEVICE)

        # recolecta X bruto para PSD
        def collect_numpy(dl):
            Xs, ys = [], []
            for xb,yb,_ in dl:
                Xs.append(xb.squeeze(1).numpy()); ys.append(yb.numpy())
            return np.concatenate(Xs, axis=0), np.concatenate(ys, axis=0)
        X_tr_np, _ = collect_numpy(tr_loader)
        X_va_np, _ = collect_numpy(va_loader)
        X_te_np, _ = collect_numpy(te_loader)

        if welch is None:
            print("WARNING: scipy no disponible → Stage-2 usará SOLO embeddings (sin PSD).")
            psd_tr = np.zeros((emb_tr.shape[0], 0), dtype=np.float32)
            psd_va = np.zeros((emb_va.shape[0], 0), dtype=np.float32)
            psd_te = np.zeros((emb_te.shape[0], 0), dtype=np.float32)
        else:
            psd_tr = make_psd_features(X_tr_np, PSDConfig(fs=FS))
            psd_va = make_psd_features(X_va_np, PSDConfig(fs=FS))
            psd_te = make_psd_features(X_te_np, PSDConfig(fs=FS))

        Xtr2 = np.hstack([emb_tr, psd_tr])
        Xva2 = np.hstack([emb_va, psd_va])
        Xte2 = np.hstack([emb_te, psd_te])

        # -------- Stage-2 --------
        head = TabularHead(input_dim=Xtr2.shape[1], n_classes=n_classes, use_tabnet=USE_TABNET, device=DEVICE,
                           tabnet_params=dict(n_d=32, n_a=32, n_steps=4, gamma=1.3,
                                              lambda_sparse=1e-4, momentum=0.02, clip_value=2.0))
        print(f"[Fold {fold}] Entrenando Stage-2 (tabular head)...")
        head.fit(Xtr2, y_tr, X_val=Xva2, y_val=y_va, epochs=EPOCHS_STAGE2, patience=PATIENCE_STAGE2)
        logits_stage2_val = head.predict_logits(Xva2)
        logits_stage2_te  = head.predict_logits(Xte2)

        # -------- Buscar mejor w en validación --------
        best_w, best_val_acc = None, -1.0
        for w in BLEND_CANDIDATES:
            val_logits = blend_logits(logits_stage1_val, logits_stage2_val, w=w)
            acc = (val_logits.argmax(1) == y_va).mean()
            if acc > best_val_acc:
                best_val_acc = acc; best_w = w
        print(f"[Fold {fold}] Mejor w (val) = {best_w} | val_acc_blend = {best_val_acc:.4f}")

        # -------- Evaluar en test con ese w --------
        logits_blend_te = blend_logits(logits_stage1_te, logits_stage2_te, w=best_w)
        y_pred_blend = logits_blend_te.argmax(1)
        acc_blend = (y_pred_blend == y_te).mean()
        print(f"[Fold {fold}] BLEND acc(test) = {acc_blend:.4f} (w={best_w}) | Δ vs Stage-1 = {acc_blend-acc_stage1:+.4f}")

        stage1_folds.append(acc_stage1)
        blend_folds.append(acc_blend)
        all_true.append(y_te); all_pred.append(y_pred_blend)

    if len(all_true)>0:
        all_true = np.concatenate(all_true); all_pred = np.concatenate(all_pred)
    else:
        all_true = np.array([], dtype=int); all_pred = np.array([], dtype=int)

    print("\n" + "="*60)
    print("RESULTADOS FINALES (Two-Stage, tuned)")
    print("="*60)
    print("Stage-1 folds:", [f"{a:.4f}" for a in stage1_folds])
    if stage1_folds:
        print(f"Stage-1 mean: {np.mean(stage1_folds):.4f}")
    print("BLEND folds:", [f"{a:.4f}" for a in blend_folds])
    if blend_folds:
        print(f"BLEND mean: {np.mean(blend_folds):.4f}")
        print(f"Δ(BLEND - Stage-1) mean: {np.mean(blend_folds) - np.mean(stage1_folds):+.4f}")

    if all_true.size>0:
        plot_confusion(all_true, all_pred, CLASS_NAMES_4C,
                       title="Confusion Matrix - Two-Stage BLEND (All Folds)",
                       fname="confusion_twostage_blend_allfolds.png")
        print("↳ Matriz de confusión guardada: confusion_twostage_blend_allfolds.png")

    return {
        "stage1_folds": stage1_folds,
        "blend_folds": blend_folds,
        "all_true": all_true,
        "all_pred": all_pred,
        "folds_json_path": str(FOLDS_DIR)
    }

# -------- MAIN --------
if __name__ == "__main__":
    print("🧠 INICIANDO EXPERIMENTO TWO-STAGE (Encoder+Tabular, tuned)")
    print(f"🔧 Configuración: {CLASS_SCENARIO}, {len(EXPECTED_8)} canales, {WINDOW_MODE}")
    print(f"⚙️  Stage-1: epochs={EPOCHS_STAGE1}, lr={LR_STAGE1}, SGDR T0={SGDR_T0}, Tmult={SGDR_Tmult}, augments+TTA ✓")
    print(f"⚙️  Stage-2: {'TabNet' if USE_TABNET else 'MLP'} | epochs={EPOCHS_STAGE2}, patience={PATIENCE_STAGE2}")
    out = run_experiment_two_stage()
    print("✔️ Done.")


🚀 Usando dispositivo: cuda
🧠 INICIANDO EXPERIMENTO TWO-STAGE (Encoder+Tabular, tuned)
🔧 Configuración: 4c, 8 canales, 6s
⚙️  Stage-1: epochs=80, lr=0.001, SGDR T0=6, Tmult=2, augments+TTA ✓
⚙️  Stage-2: TabNet | epochs=400, patience=50
Sujetos elegibles: 103 → [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]...


Construyendo dataset (RAW):   0%|          | 0/103 [00:00<?, ?it/s]

Construyendo dataset (RAW): 100%|██████████| 103/103 [00:36<00:00,  2.80it/s]


Dataset construido: N=8652 | T=640 | C=8 | clases=4 | sujetos únicos=103
Listo para entrenar: N=8652 | T=640 | C=8 | clases=4 | sujetos=103
[Fold 1] sujetos train=65, val=17, test=21
  train subs: [1, 4, 5, 6, 7, 10, 11, 12, 14, 16]...
  val   subs: [2, 9, 15, 21, 27, 34, 41, 47, 53, 60]...
  test  subs: [3, 8, 13, 18, 23, 28, 33, 39, 44, 49]...
[Fold 1] n_train=5460 | n_val=1428 | n_test=1764
[Fold 1] val class dist: {0: 357, 1: 357, 2: 357, 3: 357}

[Fold 1/5] Entrenando Stage-1 (encoder)...
[Stage-1] Ep   1 | train_acc=0.3837 | val_acc=1.0000 | LR=0.00100
[Stage-1] Ep   5 | train_acc=0.5099 | val_acc=1.0000 | LR=0.00025
[Stage-1] Ep  10 | train_acc=0.5209 | val_acc=1.0000 | LR=0.00085
[Stage-1] Ep  15 | train_acc=0.6016 | val_acc=1.0000 | LR=0.00025
[Stage-1] Early stopping @ 16 (best val_acc=1.0000)
↳ Curva Stage-1 guardada: stage1_curve_fold1.png
[Fold 1] Stage-1 acc(test) = 0.3764




[Fold 1] Entrenando Stage-2 (tabular head)...
epoch 0  | loss: 1.72381 | val_0_accuracy: 0.2521  |  0:00:00s
epoch 1  | loss: 1.39826 | val_0_accuracy: 0.28992 |  0:00:01s
epoch 2  | loss: 1.34501 | val_0_accuracy: 0.30462 |  0:00:01s
epoch 3  | loss: 1.3227  | val_0_accuracy: 0.29202 |  0:00:02s
epoch 4  | loss: 1.3081  | val_0_accuracy: 0.32633 |  0:00:02s
epoch 5  | loss: 1.3049  | val_0_accuracy: 0.34314 |  0:00:03s
epoch 6  | loss: 1.30153 | val_0_accuracy: 0.37535 |  0:00:03s
epoch 7  | loss: 1.303   | val_0_accuracy: 0.38375 |  0:00:04s
epoch 8  | loss: 1.28378 | val_0_accuracy: 0.37185 |  0:00:04s
epoch 9  | loss: 1.27604 | val_0_accuracy: 0.39216 |  0:00:05s
epoch 10 | loss: 1.27567 | val_0_accuracy: 0.39216 |  0:00:06s
epoch 11 | loss: 1.2724  | val_0_accuracy: 0.40126 |  0:00:06s
epoch 12 | loss: 1.27616 | val_0_accuracy: 0.39636 |  0:00:07s
epoch 13 | loss: 1.27536 | val_0_accuracy: 0.37815 |  0:00:07s
epoch 14 | loss: 1.272   | val_0_accuracy: 0.39636 |  0:00:08s
epoch 15 



[Fold 1] Mejor w (val) = 0.4 | val_acc_blend = 0.4286
[Fold 1] BLEND acc(test) = 0.3622 (w=0.4) | Δ vs Stage-1 = -0.0142
[Fold 2] sujetos train=65, val=17, test=21
  train subs: [1, 4, 5, 6, 8, 10, 11, 13, 14, 16]...
  val   subs: [3, 9, 15, 21, 28, 34, 41, 47, 54, 60]...
  test  subs: [2, 7, 12, 17, 22, 27, 32, 37, 43, 48]...
[Fold 2] n_train=5460 | n_val=1428 | n_test=1764
[Fold 2] val class dist: {0: 357, 1: 357, 2: 357, 3: 357}

[Fold 2/5] Entrenando Stage-1 (encoder)...
[Stage-1] Ep   1 | train_acc=0.3537 | val_acc=1.0000 | LR=0.00100
[Stage-1] Ep   5 | train_acc=0.4716 | val_acc=1.0000 | LR=0.00025
[Stage-1] Ep  10 | train_acc=0.5007 | val_acc=1.0000 | LR=0.00085
[Stage-1] Ep  15 | train_acc=0.5570 | val_acc=1.0000 | LR=0.00025
[Stage-1] Early stopping @ 16 (best val_acc=1.0000)
↳ Curva Stage-1 guardada: stage1_curve_fold2.png
[Fold 2] Stage-1 acc(test) = 0.4019




[Fold 2] Entrenando Stage-2 (tabular head)...
epoch 0  | loss: 1.77574 | val_0_accuracy: 0.2507  |  0:00:00s
epoch 1  | loss: 1.4198  | val_0_accuracy: 0.29342 |  0:00:01s
epoch 2  | loss: 1.36164 | val_0_accuracy: 0.2423  |  0:00:01s
epoch 3  | loss: 1.34384 | val_0_accuracy: 0.29132 |  0:00:02s
epoch 4  | loss: 1.33698 | val_0_accuracy: 0.31653 |  0:00:02s
epoch 5  | loss: 1.33277 | val_0_accuracy: 0.34524 |  0:00:03s
epoch 6  | loss: 1.31691 | val_0_accuracy: 0.32003 |  0:00:03s
epoch 7  | loss: 1.32441 | val_0_accuracy: 0.30952 |  0:00:04s
epoch 8  | loss: 1.31943 | val_0_accuracy: 0.34804 |  0:00:04s
epoch 9  | loss: 1.31435 | val_0_accuracy: 0.34174 |  0:00:05s
epoch 10 | loss: 1.31102 | val_0_accuracy: 0.35154 |  0:00:06s
epoch 11 | loss: 1.30184 | val_0_accuracy: 0.35224 |  0:00:06s
epoch 12 | loss: 1.30103 | val_0_accuracy: 0.33754 |  0:00:07s
epoch 13 | loss: 1.30214 | val_0_accuracy: 0.35714 |  0:00:07s
epoch 14 | loss: 1.30027 | val_0_accuracy: 0.35224 |  0:00:08s
epoch 15 



[Fold 2] Mejor w (val) = 0.3 | val_acc_blend = 0.3873
[Fold 2] BLEND acc(test) = 0.3963 (w=0.3) | Δ vs Stage-1 = -0.0057
[Fold 3] sujetos train=65, val=17, test=21
  train subs: [2, 4, 5, 7, 8, 10, 12, 13, 14, 17]...
  val   subs: [3, 9, 15, 22, 28, 34, 41, 48, 54, 60]...
  test  subs: [1, 6, 11, 16, 21, 26, 31, 36, 42, 47]...
[Fold 3] n_train=5460 | n_val=1428 | n_test=1764
[Fold 3] val class dist: {0: 357, 1: 357, 2: 357, 3: 357}

[Fold 3/5] Entrenando Stage-1 (encoder)...
[Stage-1] Ep   1 | train_acc=0.3560 | val_acc=1.0000 | LR=0.00100
[Stage-1] Ep   5 | train_acc=0.4813 | val_acc=1.0000 | LR=0.00025
[Stage-1] Ep  10 | train_acc=0.5071 | val_acc=1.0000 | LR=0.00085
[Stage-1] Ep  15 | train_acc=0.5885 | val_acc=1.0000 | LR=0.00025
[Stage-1] Early stopping @ 16 (best val_acc=1.0000)
↳ Curva Stage-1 guardada: stage1_curve_fold3.png
[Fold 3] Stage-1 acc(test) = 0.3810




[Fold 3] Entrenando Stage-2 (tabular head)...
epoch 0  | loss: 1.73112 | val_0_accuracy: 0.23039 |  0:00:00s
epoch 1  | loss: 1.41738 | val_0_accuracy: 0.2591  |  0:00:01s
epoch 2  | loss: 1.36009 | val_0_accuracy: 0.30182 |  0:00:01s
epoch 3  | loss: 1.33996 | val_0_accuracy: 0.35854 |  0:00:02s
epoch 4  | loss: 1.32274 | val_0_accuracy: 0.34454 |  0:00:02s
epoch 5  | loss: 1.32055 | val_0_accuracy: 0.37115 |  0:00:03s
epoch 6  | loss: 1.31153 | val_0_accuracy: 0.38655 |  0:00:03s
epoch 7  | loss: 1.30019 | val_0_accuracy: 0.39986 |  0:00:04s
epoch 8  | loss: 1.29564 | val_0_accuracy: 0.38866 |  0:00:04s
epoch 9  | loss: 1.29559 | val_0_accuracy: 0.39566 |  0:00:05s
epoch 10 | loss: 1.3034  | val_0_accuracy: 0.36905 |  0:00:06s
epoch 11 | loss: 1.29401 | val_0_accuracy: 0.41387 |  0:00:06s
epoch 12 | loss: 1.28695 | val_0_accuracy: 0.40336 |  0:00:07s
epoch 13 | loss: 1.29179 | val_0_accuracy: 0.40336 |  0:00:07s
epoch 14 | loss: 1.28562 | val_0_accuracy: 0.40476 |  0:00:08s
epoch 15 



[Fold 3] Mejor w (val) = 0.3 | val_acc_blend = 0.4328
[Fold 3] BLEND acc(test) = 0.3968 (w=0.3) | Δ vs Stage-1 = +0.0159
[Fold 4] sujetos train=66, val=17, test=20
  train subs: [1, 2, 4, 6, 7, 8, 11, 12, 13, 14]...
  val   subs: [3, 9, 16, 22, 28, 34, 42, 48, 54, 60]...
  test  subs: [5, 10, 15, 20, 25, 30, 35, 41, 46, 51]...
[Fold 4] n_train=5544 | n_val=1428 | n_test=1680
[Fold 4] val class dist: {0: 357, 1: 357, 2: 357, 3: 357}

[Fold 4/5] Entrenando Stage-1 (encoder)...
[Stage-1] Ep   1 | train_acc=0.3575 | val_acc=1.0000 | LR=0.00100
[Stage-1] Ep   5 | train_acc=0.4829 | val_acc=1.0000 | LR=0.00025
[Stage-1] Ep  10 | train_acc=0.5090 | val_acc=1.0000 | LR=0.00085
[Stage-1] Ep  15 | train_acc=0.5785 | val_acc=1.0000 | LR=0.00025
[Stage-1] Early stopping @ 16 (best val_acc=1.0000)
↳ Curva Stage-1 guardada: stage1_curve_fold4.png
[Fold 4] Stage-1 acc(test) = 0.3726




[Fold 4] Entrenando Stage-2 (tabular head)...
epoch 0  | loss: 1.74597 | val_0_accuracy: 0.26891 |  0:00:00s
epoch 1  | loss: 1.40348 | val_0_accuracy: 0.28711 |  0:00:01s
epoch 2  | loss: 1.34584 | val_0_accuracy: 0.34314 |  0:00:01s
epoch 3  | loss: 1.31728 | val_0_accuracy: 0.34244 |  0:00:02s
epoch 4  | loss: 1.31294 | val_0_accuracy: 0.36625 |  0:00:02s
epoch 5  | loss: 1.30271 | val_0_accuracy: 0.37325 |  0:00:03s
epoch 6  | loss: 1.30412 | val_0_accuracy: 0.36765 |  0:00:03s
epoch 7  | loss: 1.30057 | val_0_accuracy: 0.38025 |  0:00:04s
epoch 8  | loss: 1.29688 | val_0_accuracy: 0.40196 |  0:00:04s
epoch 9  | loss: 1.29128 | val_0_accuracy: 0.40826 |  0:00:05s
epoch 10 | loss: 1.28585 | val_0_accuracy: 0.39286 |  0:00:05s
epoch 11 | loss: 1.27981 | val_0_accuracy: 0.38796 |  0:00:06s
epoch 12 | loss: 1.27985 | val_0_accuracy: 0.35574 |  0:00:06s
epoch 13 | loss: 1.27713 | val_0_accuracy: 0.38725 |  0:00:07s
epoch 14 | loss: 1.28152 | val_0_accuracy: 0.40056 |  0:00:07s
epoch 15 



[Fold 4] Mejor w (val) = 0.3 | val_acc_blend = 0.4083
[Fold 4] BLEND acc(test) = 0.3881 (w=0.3) | Δ vs Stage-1 = +0.0155
[Fold 5] sujetos train=66, val=17, test=20
  train subs: [1, 2, 5, 6, 7, 8, 11, 12, 13, 15]...
  val   subs: [3, 10, 16, 22, 28, 35, 42, 48, 54, 61]...
  test  subs: [4, 9, 14, 19, 24, 29, 34, 40, 45, 50]...
[Fold 5] n_train=5544 | n_val=1428 | n_test=1680
[Fold 5] val class dist: {0: 357, 1: 357, 2: 357, 3: 357}

[Fold 5/5] Entrenando Stage-1 (encoder)...
[Stage-1] Ep   1 | train_acc=0.3505 | val_acc=1.0000 | LR=0.00100
[Stage-1] Ep   5 | train_acc=0.4697 | val_acc=1.0000 | LR=0.00025
[Stage-1] Ep  10 | train_acc=0.5023 | val_acc=1.0000 | LR=0.00085
[Stage-1] Ep  15 | train_acc=0.5804 | val_acc=1.0000 | LR=0.00025
[Stage-1] Early stopping @ 16 (best val_acc=1.0000)
↳ Curva Stage-1 guardada: stage1_curve_fold5.png
[Fold 5] Stage-1 acc(test) = 0.4065




[Fold 5] Entrenando Stage-2 (tabular head)...
epoch 0  | loss: 1.71778 | val_0_accuracy: 0.27381 |  0:00:00s
epoch 1  | loss: 1.42375 | val_0_accuracy: 0.28431 |  0:00:01s
epoch 2  | loss: 1.3542  | val_0_accuracy: 0.30392 |  0:00:01s
epoch 3  | loss: 1.33522 | val_0_accuracy: 0.32773 |  0:00:02s
epoch 4  | loss: 1.32554 | val_0_accuracy: 0.33263 |  0:00:02s
epoch 5  | loss: 1.32158 | val_0_accuracy: 0.35854 |  0:00:03s
epoch 6  | loss: 1.3097  | val_0_accuracy: 0.34244 |  0:00:03s
epoch 7  | loss: 1.30535 | val_0_accuracy: 0.37465 |  0:00:04s
epoch 8  | loss: 1.29372 | val_0_accuracy: 0.37535 |  0:00:05s
epoch 9  | loss: 1.29652 | val_0_accuracy: 0.39146 |  0:00:05s
epoch 10 | loss: 1.28764 | val_0_accuracy: 0.38936 |  0:00:06s
epoch 11 | loss: 1.28833 | val_0_accuracy: 0.38025 |  0:00:06s
epoch 12 | loss: 1.29755 | val_0_accuracy: 0.38655 |  0:00:07s
epoch 13 | loss: 1.28871 | val_0_accuracy: 0.38725 |  0:00:07s
epoch 14 | loss: 1.29046 | val_0_accuracy: 0.37325 |  0:00:08s
epoch 15 



[Fold 5] Mejor w (val) = 0.5 | val_acc_blend = 0.4111
[Fold 5] BLEND acc(test) = 0.4071 (w=0.5) | Δ vs Stage-1 = +0.0006

RESULTADOS FINALES (Two-Stage, tuned)
Stage-1 folds: ['0.3764', '0.4019', '0.3810', '0.3726', '0.4065']
Stage-1 mean: 0.3877
BLEND folds: ['0.3622', '0.3963', '0.3968', '0.3881', '0.4071']
BLEND mean: 0.3901
Δ(BLEND - Stage-1) mean: +0.0024
↳ Matriz de confusión guardada: confusion_twostage_blend_allfolds.png
✔️ Done.
