In [3]:
# -*- coding: utf-8 -*-
# Conformer-MI (v2) — Balanceo previo (21 eventos/clase/sujeto) ANTES de crops
# Cambios:
# - Balanceo (down/up) por sujeto/clase a 21 EVENTOS en bruto (no croppeados)
# - Luego se generan los CROPS a partir de esos eventos balanceados
# - Ventanas: (0.5–3.5), (1.0–4.0), (1.5–4.5) → ~3 s @160 Hz
# - Regularización ↑: dropout=0.3, sd_prob=0.2, weight_decay=1e-4, GLOBAL_PATIENCE=10
# - TTA en test
# - Augments train: time jitter + amplitude scaling + channel dropout
# - Band-pass augment opcional (apagado por defecto), notch adaptativo como antes

import os, re, math, json, copy, random, itertools
from pathlib import Path
from datetime import datetime

import numpy as np
import pandas as pd
import mne

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset, WeightedRandomSampler

from sklearn.model_selection import (GroupKFold, GroupShuffleSplit, StratifiedKFold,
                                     StratifiedShuffleSplit)
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.utils import check_random_state

import matplotlib.pyplot as plt
from tqdm import tqdm

# =========================
# RUTAS (del usuario)
# =========================
PROJ = Path('..').resolve().parent
DATA_RAW = PROJ / 'data' / 'raw'
CACHE_DIR = PROJ / 'data' / 'cache'
FOLDS_DIR = PROJ / 'models' / 'folds' / 'Kfold5.json'
CACHE_DIR.mkdir(parents=True, exist_ok=True)

# =========================
# CONFIG GLOBAL
# =========================
RANDOM_STATE = 42
torch.manual_seed(RANDOM_STATE); np.random.seed(RANDOM_STATE); random.seed(RANDOM_STATE)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🚀 Usando dispositivo: {DEVICE}")

FS = 160.0
N_FOLDS = 5
CLASS_NAMES_4C = ['Left', 'Right', 'Both Fists', 'Both Feet']
N_CLASSES = 4

# === CROPS (DESPUÉS del balanceo de eventos)
CROPS = [(0.0, 3.0)]  # ~3.0 s cada uno

# === Augments (train)
USE_BANDPASS_AUG = True      # <- por defecto SIN band-pass; ponlo True si quieres probar
BANDPASS_PROB = 0.25          # prob. de band-pass cuando USE_BANDPASS_AUG=True
AMP_SCALE_RANGE = (0.8, 1.2)  # escalado por canal
CHANNEL_DROP_PROB = 0.1       # prob. de "apagar" un canal completo en train

# Entrenamiento global
BATCH_SIZE = 32
EPOCHS_GLOBAL = 100
LR_INIT = 1e-3
WARMUP_EPOCHS = 5
GLOBAL_VAL_SPLIT = 0.15
GLOBAL_PATIENCE = 10          # más estricto para evitar overfitting
LOG_EVERY = 5

# Fine-tuning por sujeto
DO_SUBJECT_FT = True
CALIB_CV_FOLDS = 4
FT_EPOCHS = 15
FT_HEAD_LR = 1e-3
FT_BASE_LR = 5e-5
FT_L2SP = 1e-4
FT_PATIENCE = 5
FT_VAL_RATIO = 0.2

# Sujetos/runs
EXCLUDE_SUBJECTS = {38, 88, 89, 92, 100, 104}
MI_RUNS_LR = [4, 8, 12]     # Left/Right
MI_RUNS_OF = [6, 10, 14]    # Both Fists / Both Feet
BASELINE_RUNS_EO = [1]

# Canales esperados
EXPECTED_8 = ['C3','C4','Cz','CP3','CP4','FC3','FC4','FCz']

DEBUG_WINDOWS = True           # <--- pon a False cuando termines de depurar
DEBUG_MAX_PRINTS = 50   
# =========================
# UTILIDADES CANALES / EDF
# =========================
def normalize_label(s: str) -> str:
    if s is None: return s
    s = s.strip()
    s = re.sub(r'[^A-Za-z0-9]', '', s)
    s = re.sub(r'([A-Za-z])0([0-9])', r'\1\2', s)
    s = re.sub(r'([A-Za-z])Z$', r'\1z', s)
    s = s.replace('fp', 'Fp').replace('FP', 'Fp')
    s = ''.join(ch.upper() if ch != 'z' else 'z' for ch in s)
    return s

def rename_channels_1010(raw: mne.io.BaseRaw):
    mapping = {}
    for ch in raw.ch_names:
        lab = normalize_label(ch)
        lab = lab[:-1] + 'z' if lab.endswith('Z') else lab
        lab = re.sub(r'([A-Z])Z$', r'\1z', lab)
        mapping[ch] = lab
    mne.rename_channels(raw.info, mapping)

def ensure_channels_order(raw: mne.io.BaseRaw, desired_channels=EXPECTED_8):
    missing = [ch for ch in desired_channels if ch not in raw.ch_names]
    if missing:
        print(f"[WARN] faltan canales {missing} en archivo {getattr(raw,'filenames', [''])[0]}")
        return None
    raw.reorder_channels([ch for ch in raw.ch_names if ch in desired_channels] +
                         [ch for ch in raw.ch_names if ch not in desired_channels])
    raw.pick_channels(desired_channels, ordered=True)
    return raw

_re_file = re.compile(r'[Ss](\d{3}).*?[Rr](\d{2})')
def parse_subject_run(path: Path):
    m = _re_file.search(str(path))
    if not m: return None, None
    return int(m.group(1)), int(m.group(2))

def run_kind(run_id:int):
    if run_id in MI_RUNS_LR: return 'LR'
    if run_id in MI_RUNS_OF: return 'OF'
    if run_id in BASELINE_RUNS_EO: return 'EO'
    return None

# --- Notch según tabla SNR opcional ---
_SNR_TABLE = None
def _load_snr_table():
    global _SNR_TABLE
    if _SNR_TABLE is not None:
        return _SNR_TABLE
    csv_path = PROJ / 'reports' / 'psd_mains' / 'psd_mains_summary.csv'
    if csv_path.exists():
        try:
            _SNR_TABLE = pd.read_csv(csv_path)
        except:
            _SNR_TABLE = None
    return _SNR_TABLE

def _decide_notch(subject, run, th_db=10.0):
    df = _load_snr_table()
    if df is None:
        return 60.0
    row = df[(df['subject']==subject) & (df['run']==run)]
    if row.empty:
        return 60.0
    snr50 = float(row['snr50_db'].iloc[0]); snr60 = float(row['snr60_db'].iloc[0])
    if snr60 >= th_db and snr60 >= snr50: return 60.0
    if snr50 >= th_db and snr50 >  snr60: return 50.0
    return None

def read_raw_edf(path: Path):
    raw = mne.io.read_raw_edf(path, preload=True, verbose=False)
    raw.pick(mne.pick_types(raw.info, eeg=True))
    rename_channels_1010(raw)
    try:
        mont = mne.channels.make_standard_montage('standard_1020')
        raw.set_montage(mont, on_missing='ignore')
    except Exception:
        pass
    if abs(raw.info['sfreq'] - FS) > 1e-6:
        raw.resample(FS, npad="auto")
    raw = ensure_channels_order(raw, EXPECTED_8)
    if raw is None:
        return None
    sid, rid = parse_subject_run(path)
    notch = _decide_notch(sid, rid)
    if notch is not None:
        raw.notch_filter(freqs=[float(notch)], picks='eeg', method='spectrum_fit', phase='zero')
    return raw

def collect_events_T1T2(raw: mne.io.BaseRaw):
    """Extrae eventos T1/T2 y aplica deduplicación mínima (>=0.5s)."""
    if raw.annotations is None or len(raw.annotations) == 0:
        return []
    def _norm(s): return str(s).strip().upper().replace(' ', '')

    res = []
    for onset, desc in zip(raw.annotations.onset, raw.annotations.description):
        tag = _norm(desc)
        if tag in ('T1','T2'):
            res.append((float(onset), tag))
    res.sort()

    dedup = []
    last_t1 = last_t2 = -1e9
    for t, tag in res:
        if tag == 'T1':
            if (t - last_t1) >= 0.5:
                dedup.append((t, tag)); last_t1 = t
        else:
            if (t - last_t2) >= 0.5:
                dedup.append((t, tag)); last_t2 = t
    return dedup

# =========================
# DISCOVERY DE SUJETOS
# =========================
def subjects_available():
    subs = []
    for sdir in sorted(DATA_RAW.glob('S*')):
        if not sdir.is_dir(): continue
        try: sid = int(sdir.name[1:])
        except: continue
        if sid in EXCLUDE_SUBJECTS: continue
        any_mi = any((sdir / f"S{sid:03d}R{r:02d}.edf").exists() for r in (MI_RUNS_LR + MI_RUNS_OF))
        if any_mi: subs.append(sid)
    return subs

# =========================
# BALANCEO POR EVENTOS (ANTES de CROPS)
# =========================
def _label_from_tag(kind: str, tag: str):
    if kind == 'LR':
        return 0 if tag == 'T1' else 1 if tag == 'T2' else None
    elif kind == 'OF':
        return 2 if tag == 'T1' else 3 if tag == 'T2' else None
    return None

def gather_all_events(subjects):
    """Junta TODOS los eventos T1/T2 por sujeto, SIN crops."""
    per_subj = {}  # subj -> {0:[...],1:[...],2:[...],3:[...]}
    for s in tqdm(subjects, desc="Escaneando eventos (sin crops)"):
        sdir = DATA_RAW / f"S{s:03d}"
        if not sdir.exists(): continue
        per_subj[s] = {0:[], 1:[], 2:[], 3:[]}
        for r in (MI_RUNS_LR + MI_RUNS_OF + BASELINE_RUNS_EO):
            p = sdir / f"S{s:03d}R{r:02d}.edf"
            if not p.exists(): continue
            kind = run_kind(r)
            if kind not in ('LR','OF'): 
                continue
            raw = read_raw_edf(p)
            if raw is None: 
                continue
            events = collect_events_T1T2(raw)
            for onset_sec, tag in events:
                lab = _label_from_tag(kind, tag)
                if lab is None: 
                    continue
                per_subj[s][lab].append({
                    "path": p,
                    "onset": float(onset_sec),
                    "label": int(lab),
                    "subject": int(s),
                    "run": int(r)
                })
    return per_subj

def balance_events_per_subject(per_subj, need_per_class=21, rng_seed=RANDOM_STATE):
    """Para cada sujeto y clase, deja EXACTAMENTE 'need_per_class' eventos (down/up)."""
    rng = check_random_state(rng_seed)
    kept = {}
    for s, cls_dict in per_subj.items():
        kept[s] = {}
        skip_subject = False
        for c in range(4):
            events = cls_dict.get(c, [])
            n0 = len(events)
            if n0 == 0:
                skip_subject = True
                break
            if n0 == need_per_class:
                sel = events
            elif n0 > need_per_class:
                idx = rng.choice(n0, size=need_per_class, replace=False)
                sel = [events[i] for i in idx]
            else:
                idx = rng.choice(n0, size=need_per_class, replace=True)
                sel = [events[i] for i in idx]
            kept[s][c] = sel
        if skip_subject:
            kept[s] = None
            print(f"[SKIP] S{s:03d} (alguna clase sin eventos)")
    kept = {s:cls for s,cls in kept.items() if cls is not None}
    return kept

# =========================
# APLICAR CROPS a los eventos balanceados
# =========================
# === debug global ===
DEBUG_WINDOWS = True           # <--- pon a False cuando termines de depurar
DEBUG_MAX_PRINTS = 50          # límite para no spamear

def segments_from_balanced_events(kept_balanced, crops=CROPS, fs=FS):
    """Genera segmentos (T,C) a partir de los eventos ya balanceados (lee EDF una vez por archivo)."""
    X, y, groups = [], [], []
    ch_template = None
    cache_raw = {}  # path -> raw

    # contadores de diagnóstico
    dbg_total_events = 0
    dbg_total_crops_ok = 0
    dbg_skips_oob = 0       # fuera de rango (out-of-bounds)
    dbg_skips_short = 0     # longitud inesperada (< 1 muestra)
    dbg_skips_nan = 0       # datos NaN o inf
    dbg_prints = 0

    try:
        for s, cls_dict in tqdm(kept_balanced.items(), desc="Generando segmentos (con crops)"):
            for c in range(4):
                ev_list = cls_dict[c]
                for ev in ev_list:
                    dbg_total_events += 1
                    p = ev["path"]; onset = ev["onset"]; sid = ev["subject"]; rid = ev["run"]
                    if p not in cache_raw:
                        raw = read_raw_edf(p)
                        if raw is None:
                            if DEBUG_WINDOWS and dbg_prints < DEBUG_MAX_PRINTS:
                                print(f"[DEBUG] No se pudo leer RAW: {p}")
                                dbg_prints += 1
                            continue
                        cache_raw[p] = raw
                    raw = cache_raw[p]
                    data = raw.get_data()  # (C, T_total)
                    T_total = data.shape[1]
                    if ch_template is None:
                        ch_template = raw.ch_names

                    for rel_start, rel_end in crops:
                        s_samp = int(round((raw.first_time + onset + rel_start) * fs))
                        e_samp = int(round((raw.first_time + onset + rel_end)   * fs))
                        exp_len = int(round((rel_end - rel_start) * fs))

                        # chequeo de rango
                        if s_samp < 0 or e_samp > T_total:
                            dbg_skips_oob += 1
                            if DEBUG_WINDOWS and dbg_prints < DEBUG_MAX_PRINTS:
                                print(f"[SKIP OOB] S{sid:03d} R{rid:02d} C{c} "
                                      f"onset={onset:.3f}s win=({rel_start:.2f},{rel_end:.2f}) "
                                      f"s={s_samp} e={e_samp} T_total={T_total}")
                                dbg_prints += 1
                            continue

                        # extrae segmento
                        seg = data[:, s_samp:e_samp].T.astype(np.float32)  # (T, C)

                        # chequeo de longitud
                        if seg.shape[0] != exp_len or seg.shape[0] <= 0:
                            dbg_skips_short += 1
                            if DEBUG_WINDOWS and dbg_prints < DEBUG_MAX_PRINTS:
                                print(f"[SKIP LEN] S{sid:03d} R{rid:02d} C{c} "
                                      f"esperado_T={exp_len} got_T={seg.shape[0]} "
                                      f"s={s_samp} e={e_samp}")
                                dbg_prints += 1
                            continue

                        # chequeo NaNs/infs
                        if not np.isfinite(seg).all():
                            dbg_skips_nan += 1
                            if DEBUG_WINDOWS and dbg_prints < DEBUG_MAX_PRINTS:
                                print(f"[SKIP NaN] S{sid:03d} R{rid:02d} C{c} "
                                      f"s={s_samp} e={e_samp}")
                                dbg_prints += 1
                            continue

                        # OK: añadimos
                        X.append(seg); y.append(c); groups.append(sid)
                        dbg_total_crops_ok += 1

                        if DEBUG_WINDOWS and dbg_prints < DEBUG_MAX_PRINTS:
                            print(f"[OK] S{sid:03d} R{rid:02d} C{c} onset={onset:.3f} "
                                  f"win=({rel_start:.2f},{rel_end:.2f}) s={s_samp} e={e_samp} "
                                  f"T={seg.shape[0]} C={seg.shape[1]}")
                            dbg_prints += 1
    finally:
        cache_raw.clear()

    if DEBUG_WINDOWS:
        print("=== DEBUG WINDOWS SUMMARY ===")
        print(f"Eventos brutos: {dbg_total_events}")
        print(f"Crops OK:      {dbg_total_crops_ok}")
        print(f"Skips OOB:     {dbg_skips_oob}")
        print(f"Skips LEN:     {dbg_skips_short}")
        print(f"Skips NaN:     {dbg_skips_nan}")

    X = np.stack(X, axis=0) if len(X)>0 else np.zeros((0,1,1), np.float32)
    y = np.asarray(y, dtype=np.int64)
    groups = np.asarray(groups, dtype=np.int64)
    n, T, C = X.shape
    print(f"Dataset construido: N={n} | T={T} | C={C} | sujetos únicos={len(np.unique(groups))}")
    return X, y, groups, ch_template

# =========================
# NORMALIZACIÓN POR FOLD
# =========================
def compute_fold_norm_stats(X_tr):
    mu = X_tr.mean(axis=(0,1), keepdims=True)
    sigma = X_tr.std(axis=(0,1), keepdims=True) + 1e-6
    return mu, sigma

def apply_norm_stats(X, mu, sigma):
    return (X - mu) / sigma

# =========================
# AUGMENTS (utils)
# =========================
def bandpass_augment(x_np, fs=160.0, p=BANDPASS_PROB):
    if not USE_BANDPASS_AUG or random.random() > p:
        return x_np
    x = x_np.copy()
    try:
        x_mne = mne.filter.filter_data(x.T, sfreq=fs, l_freq=8., h_freq=30., verbose='ERROR').T
        return x_mne.astype(np.float32)
    except Exception:
        return x_np

def amplitude_scale_per_channel(x_np, low=AMP_SCALE_RANGE[0], high=AMP_SCALE_RANGE[1]):
    scales = np.random.uniform(low, high, size=(x_np.shape[1],)).astype(np.float32)  # C
    return (x_np * scales[np.newaxis, :]).astype(np.float32)

def channel_dropout(x_np, p=CHANNEL_DROP_PROB):
    if p <= 0: return x_np
    x = x_np.copy()
    C = x.shape[1]
    mask = (np.random.rand(C) > p).astype(np.float32)  # 0 -> canal apagado
    # evitar caso todos-cero
    if mask.sum() == 0:
        mask[np.random.randint(0, C)] = 1.0
    return (x * mask[np.newaxis, :]).astype(np.float32)

def do_time_jitter(xb, max_ms=30, fs=160.0):
    max_shift = int(round(max_ms/1000.0 * fs))
    if max_shift <= 0: return xb
    B,_,T,C = xb.shape
    shifts = torch.randint(low=-max_shift, high=max_shift+1, size=(B,), device=xb.device)
    out = torch.empty_like(xb)
    for i,s in enumerate(shifts):
        if s==0: out[i] = xb[i]; continue
        if s>0:
            out[i,:,s:,:] = xb[i,:,:T-s,:]
            out[i,:,:s,:] = 0
        else:
            s = -s
            out[i,:,:T-s,:] = xb[i,:,s:,:]
            out[i,:,T-s:,:] = 0
    return out

# =========================
# DATASET TORCH
# =========================
class EEGTrials(Dataset):
    def __init__(self, X, y, groups, fs=160.0, train_mode=False):
        self.X = X.astype(np.float32)
        self.y = y.astype(np.int64)
        self.g = groups.astype(np.int64)
        self.fs = fs
        self.train_mode = train_mode
    def __len__(self): return self.X.shape[0]
    def __getitem__(self, idx):
        x = self.X[idx]   # (T,C)
        if self.train_mode:
            # augment opcional band-pass (por defecto off)
            x = bandpass_augment(x, fs=self.fs)
            # robustez inter-sujeto
            x = amplitude_scale_per_channel(x, *AMP_SCALE_RANGE)
            x = channel_dropout(x, p=CHANNEL_DROP_PROB)
        x = np.expand_dims(x, 0)  # (1,T,C)
        return torch.from_numpy(x), torch.tensor(self.y[idx]), torch.tensor(self.g[idx])

# =========================
# MODELO: Conformer-MI
# =========================
class SinusoidalPositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=4000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer('pe', pe)
    def forward(self, x):
        L = x.size(1)
        return x + self.pe[:L].unsqueeze(0)

class ConvModule(nn.Module):
    def __init__(self, d_model, dw_kernel=15):
        super().__init__()
        self.ln = nn.LayerNorm(d_model)
        self.pw1 = nn.Conv1d(d_model, 2*d_model, kernel_size=1)
        self.dw = nn.Conv1d(d_model, d_model, kernel_size=dw_kernel, padding=dw_kernel//2, groups=d_model)
        self.bn = nn.BatchNorm1d(d_model)
        self.act = nn.SiLU()
        self.pw2 = nn.Conv1d(d_model, d_model, kernel_size=1)
    def forward(self, x):
        y = self.ln(x)
        y = y.transpose(1,2)
        y = self.pw1(y)
        a, b = y.chunk(2, dim=1)
        y = a * torch.sigmoid(b)
        y = self.dw(y)
        y = self.bn(y)
        y = self.act(y)
        y = self.pw2(y)
        y = y.transpose(1,2)
        return y

class ConformerBlock(nn.Module):
    def __init__(self, d_model, n_heads=2, dropout=0.3, dw_kernel=15, ffn_mult=2.0, sd_prob=0.2):
        super().__init__()
        self.sd_prob = sd_prob
        self.ln_attn = nn.LayerNorm(d_model)
        self.mha = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
        self.drop1 = nn.Dropout(dropout)
        self.conv = ConvModule(d_model, dw_kernel=dw_kernel)
        self.drop2 = nn.Dropout(dropout)
        self.ln_ffn = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, int(ffn_mult*d_model)),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(int(ffn_mult*d_model), d_model),
        )
        self.drop3 = nn.Dropout(dropout)

    def _stochastic_residual(self, x, y, training):
        if not training or self.sd_prob <= 0.0:
            return x + y
        keep = torch.rand(x.size(0), device=x.device) > self.sd_prob
        keep = keep.float().view(-1,1,1)
        return x + y * keep

    def forward(self, x):
        y = self.ln_attn(x)
        y, _ = self.mha(y, y, y, need_weights=False)
        x = self._stochastic_residual(x, self.drop1(y), self.training)
        y = self.conv(x)
        x = self._stochastic_residual(x, self.drop2(y), self.training)
        y = self.ffn(self.ln_ffn(x))
        x = self._stochastic_residual(x, self.drop3(y), self.training)
        return x

class ConformerMI(nn.Module):
    def __init__(self, n_ch=8, n_classes=4, d_model=64, n_heads=2, n_blocks=3,
                 patch_len=6, dropout=0.3, sd_prob=0.2, dw_kernel=15):
        super().__init__()
        self.n_ch = n_ch
        self.n_classes = n_classes
        self.d_model = d_model
        self.patch_len = patch_len

        # Stem: proyección espacial a d_model
        self.spatial_pw = nn.Conv1d(n_ch, d_model, kernel_size=1, bias=False)
        self.spatial_bn = nn.BatchNorm1d(d_model)
        self.spatial_act = nn.GELU()

        # Patchify temporal
        self.patchify = nn.Conv1d(d_model, d_model, kernel_size=patch_len, stride=patch_len, bias=False)

        self.pos = SinusoidalPositionalEncoding(d_model, max_len=4000)
        self.blocks = nn.ModuleList([
            ConformerBlock(d_model, n_heads=n_heads, dropout=dropout, dw_kernel=dw_kernel,
                           ffn_mult=2.0, sd_prob=sd_prob*(i+1)/n_blocks)
            for i in range(n_blocks)
        ])
        self.ln_out = nn.LayerNorm(d_model)
        self.head = nn.Sequential(
            nn.Linear(d_model, 128),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(128, n_classes)
        )

    def forward_features(self, x):
        B,_,T,C = x.shape
        x = x.squeeze(1).permute(0,2,1)      # (B,C,T)
        x = self.spatial_pw(x)               # (B,D,T)
        x = self.spatial_bn(x); x = self.spatial_act(x)
        x = self.patchify(x)                 # (B,D,L)
        x = x.transpose(1,2)                 # (B,L,D)
        x = self.pos(x)
        for blk in self.blocks:
            x = blk(x)
        x = self.ln_out(x)
        return x

    def forward(self, x):
        x = self.forward_features(x)
        x = x.mean(dim=1)
        logits = self.head(x)
        return logits

# =========================
# ENTRENAMIENTO/EVAL
# =========================
def build_weighted_sampler(y, groups):
    y = np.asarray(y); groups = np.asarray(groups)
    class_counts = np.bincount(y, minlength=len(np.unique(y))).astype(float)
    class_w = 1.0 / class_counts[y]
    subj_vals, subj_counts = np.unique(groups, return_counts=True)
    subj_map = {s:c for s,c in zip(subj_vals, subj_counts)}
    subj_w = np.array([1.0/subj_map[g] for g in groups], dtype=float)
    w = class_w * subj_w
    w = w / w.mean()
    return WeightedRandomSampler(weights=torch.from_numpy(w).float(), num_samples=len(w), replacement=True)

def make_class_weight_tensor(y_indices, n_classes):
    counts = np.bincount(y_indices, minlength=n_classes).astype(np.float32)
    counts[counts == 0] = 1.0
    w = counts.sum() / counts
    w = w / w.mean()
    return torch.tensor(w, dtype=torch.float32, device=DEVICE)

@torch.no_grad()
def evaluate(model, loader, use_tta=False):
    model.eval()
    y_true, y_pred = [], []
    for xb, yb, _ in loader:
        xb = xb.to(DEVICE)
        if use_tta:
            outs = []
            for _ in range(3):
                outs.append(model(do_time_jitter(xb, max_ms=20, fs=FS)))
            logits = torch.stack(outs, dim=0).mean(dim=0)
        else:
            logits = model(xb)
        pred = logits.argmax(dim=1).cpu().numpy().tolist()
        y_pred.extend(pred); y_true.extend(yb.numpy().tolist())
    y_true = np.asarray(y_true, dtype=int); y_pred = np.asarray(y_pred, dtype=int)
    return y_true, y_pred, float((y_true==y_pred).mean())

def plot_confusion(y_true, y_pred, classes, title, fname):
    cm = confusion_matrix(y_true, y_pred, labels=list(range(len(classes))))
    with np.errstate(invalid='ignore'):
        cm_norm = cm.astype('float') / cm.sum(axis=1, keepdims=True)
    cm_norm = np.nan_to_num(cm_norm)
    plt.figure(figsize=(6,5))
    plt.imshow(cm_norm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title(title); plt.colorbar(fraction=0.046, pad=0.04)
    ticks = np.arange(len(classes))
    plt.xticks(ticks, classes, rotation=45, ha='right'); plt.yticks(ticks, classes)
    fmt = '.2f'; thresh = cm_norm.max()/2.
    for i,j in itertools.product(range(cm_norm.shape[0]), range(cm_norm.shape[1])):
        plt.text(j, i, format(cm_norm[i,j], fmt), ha="center",
                 color="white" if cm_norm[i,j] > thresh else "black")
    plt.ylabel('True'); plt.xlabel('Pred')
    plt.tight_layout(); plt.savefig(fname, dpi=150, bbox_inches='tight'); plt.close()

def cosine_with_warmup(epoch, total_epochs, base_lr=LR_INIT, warmup=WARMUP_EPOCHS, min_lr=1e-5):
    if epoch < warmup:
        return base_lr * (epoch+1) / warmup
    progress = (epoch - warmup) / max(1, (total_epochs - warmup))
    return min_lr + (base_lr - min_lr) * 0.5 * (1 + math.cos(math.pi * progress))

def train_one_epoch(model, loader, optimizer, criterion, epoch, total_epochs):
    model.train()
    for xb, yb, _ in loader:
        xb, yb = xb.to(DEVICE), yb.to(DEVICE)
        xb = do_time_jitter(xb, max_ms=30, fs=FS)
        logits = model(xb)
        loss = criterion(logits, yb)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

def print_report(y_true, y_pred, classes):
    print(classification_report(y_true, y_pred, target_names=classes, digits=4))

# =========================
# FINE-TUNING POR SUJETO (CV adaptativo)
# =========================
def _freeze_all(m): 
    for p in m.parameters(): p.requires_grad = False

def _unfreeze_head(m):
    for p in m.head.parameters(): p.requires_grad = True

def _unfreeze_lastblock_head(m):
    _unfreeze_head(m)
    for p in m.blocks[-1].parameters(): p.requires_grad = True
    for p in m.ln_out.parameters(): p.requires_grad = True

def l2sp_reg(train_params, ref_params):
    reg = 0.0
    for p_cur, p_ref in zip(train_params, ref_params):
        reg = reg + torch.sum((p_cur - p_ref)**2)
    return reg

def _adaptive_cv_splits(Xs, ys, base_splits=CALIB_CV_FOLDS):
    ys_np = np.asarray(ys)
    counts = np.bincount(ys_np, minlength=N_CLASSES)
    min_class = counts[counts>0].min() if counts.sum() > 0 else 0
    splits = min(base_splits, max(2, int(min_class)))
    if splits < 2:
        sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=RANDOM_STATE)
        return sss.split(Xs, ys_np)
    try:
        skf = StratifiedKFold(n_splits=splits, shuffle=True, random_state=RANDOM_STATE)
        list(skf.split(Xs, ys_np))
        return skf.split(Xs, ys_np)
    except Exception:
        sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=RANDOM_STATE)
        return sss.split(Xs, ys_np)

def subject_cv_finetune_predict_progressive(model_global, Xs, ys, n_classes=4):
    y_true_full = np.empty_like(ys); y_pred_full = np.empty_like(ys)
    for tr_idx, te_idx in _adaptive_cv_splits(Xs, ys, base_splits=CALIB_CV_FOLDS):
        Xcal, ycal = Xs[tr_idx], ys[tr_idx]
        Xho,  yho  = Xs[te_idx], ys[te_idx]

        ds_te = EEGTrials(Xho,  yho,  np.zeros_like(yho),  train_mode=False)
        dl_te = DataLoader(ds_te, batch_size=64, shuffle=False)

        if len(np.unique(ycal)) >= 2 and len(ycal) >= 5:
            sss = StratifiedShuffleSplit(n_splits=1, test_size=FT_VAL_RATIO, random_state=RANDOM_STATE)
            (tr2, va2), = sss.split(Xcal, ycal)
        else:
            idx = np.arange(len(ycal))
            np.random.shuffle(idx)
            cut = max(1, int(0.85*len(idx)))
            tr2, va2 = idx[:cut], idx[cut:]

        ds_tr2 = EEGTrials(Xcal[tr2], ycal[tr2], np.zeros_like(ycal[tr2]), train_mode=True)
        ds_va2 = EEGTrials(Xcal[va2], ycal[va2], np.zeros_like(ycal[va2]), train_mode=False)
        dl_tr2 = DataLoader(ds_tr2, batch_size=32, shuffle=True)
        dl_va2 = DataLoader(ds_va2, batch_size=64, shuffle=False)

        # HEAD only
        m = copy.deepcopy(model_global).to(DEVICE)
        _freeze_all(m); _unfreeze_head(m)
        opt = optim.Adam(m.head.parameters(), lr=FT_HEAD_LR)
        crit = nn.CrossEntropyLoss()
        best = copy.deepcopy(m.state_dict()); best_va = 1e9; bad=0
        for _ in range(FT_EPOCHS):
            m.train()
            for xb, yb, _ in dl_tr2:
                xb, yb = xb.to(DEVICE), yb.to(DEVICE)
                opt.zero_grad(set_to_none=True)
                loss = crit(m(xb), yb)
                loss.backward(); opt.step()
            # val
            m.eval(); val_loss=0.0; n=0
            with torch.no_grad():
                for xb, yb, _ in dl_va2:
                    xb, yb = xb.to(DEVICE), yb.to(DEVICE)
                    val_loss += crit(m(xb), yb).item()*xb.size(0); n+=xb.size(0)
            val_loss/=max(1,n)
            if val_loss < best_va-1e-7:
                best_va = val_loss; best=copy.deepcopy(m.state_dict()); bad=0
            else:
                bad+=1
                if bad>=FT_PATIENCE: break
        m.load_state_dict(best)

        # LAST BLOCK + HEAD con L2-SP
        _freeze_all(m); _unfreeze_lastblock_head(m)
        train_params = list(m.blocks[-1].parameters()) + list(m.ln_out.parameters()) + list(m.head.parameters())
        ref = [p.detach().clone() for p in train_params]
        opt = optim.Adam([
            {"params": list(m.blocks[-1].parameters()) + list(m.ln_out.parameters()), "lr": FT_BASE_LR},
            {"params": m.head.parameters(), "lr": FT_HEAD_LR},
        ])
        crit = nn.CrossEntropyLoss()
        best = copy.deepcopy(m.state_dict()); best_va = 1e9; bad=0
        for _ in range(FT_EPOCHS):
            m.train()
            for xb, yb, _ in dl_tr2:
                xb, yb = xb.to(DEVICE), yb.to(DEVICE)
                opt.zero_grad(set_to_none=True)
                loss = crit(m(xb), yb) + FT_L2SP * l2sp_reg(train_params, ref)
                loss.backward(); opt.step()
            m.eval(); val_loss=0.0; n=0
            with torch.no_grad():
                for xb, yb, _ in dl_va2:
                    xb, yb = xb.to(DEVICE), yb.to(DEVICE)
                    val_loss += crit(m(xb), yb).item()*xb.size(0); n+=xb.size(0)
            val_loss/=max(1,n)
            if val_loss < best_va-1e-7:
                best_va = val_loss; best=copy.deepcopy(m.state_dict()); bad=0
            else:
                bad+=1
                if bad>=FT_PATIENCE: break
        m.load_state_dict(best)

        with torch.no_grad():
            m.eval()
            preds=[]
            for xb, _, _ in dl_te:
                xb = xb.to(DEVICE)
                logits = m(xb)
                preds.extend(logits.argmax(1).cpu().numpy().tolist())
        y_pred_full[te_idx] = np.asarray(preds); y_true_full[te_idx] = yho
    return y_true_full, y_pred_full

# =========================
# FOLDS JSON helpers
# =========================
def save_group_folds_json_with_indices(subject_ids_str, groups_array, n_splits, out_json_path,
                                       created_by="ConformerMI_v2_balFirst", description=None):
    out_json_path = Path(out_json_path)
    unique_subjects_int = sorted(np.unique(groups_array).tolist())
    subject_ids = [f"S{sid:03d}" for sid in unique_subjects_int]
    if len(subject_ids) < n_splits:
        raise ValueError(f"n_splits={n_splits} mayor que número de sujetos={len(subject_ids)}")
    groups = np.arange(len(subject_ids))
    gkf = GroupKFold(n_splits=n_splits)
    folds = []
    fold_i = 0
    for train_idx_grp, test_idx_grp in gkf.split(groups, groups, groups):
        fold_i += 1
        train_sids = [subject_ids[int(i)] for i in train_idx_grp]
        test_sids  = [subject_ids[int(i)] for i in test_idx_grp]
        train_sids_int = [int(s[1:]) for s in train_sids]
        test_sids_int  = [int(s[1:]) for s in test_sids]
        tr_idx = np.where(np.isin(groups_array, train_sids_int))[0].tolist()
        te_idx = np.where(np.isin(groups_array, test_sids_int))[0].tolist()
        folds.append({
            "fold": int(fold_i),
            "train": train_sids,
            "test": test_sids,
            "tr_idx": tr_idx,
            "te_idx": te_idx
        })
    payload = {
        "created_at": datetime.now().isoformat(),
        "created_by": created_by,
        "description": description if description is not None else "",
        "n_splits": int(n_splits),
        "n_subjects": len(subject_ids),
        "subject_ids": subject_ids,
        "folds": folds
    }
    out_json_path.parent.mkdir(parents=True, exist_ok=True)
    with open(out_json_path, "w", encoding="utf-8") as f:
        json.dump(payload, f, indent=2, ensure_ascii=False)
    print(f"Folds JSON con índices guardado → {out_json_path}")
    return out_json_path

def load_group_folds_json(path_json, expected_subject_ids=None, strict_check=True):
    path_json = Path(path_json)
    if not path_json.exists():
        raise FileNotFoundError(f"No existe {path_json}")
    with open(path_json, "r", encoding="utf-8") as f:
        payload = json.load(f)
    subj_json = payload.get("subject_ids", [])
    if expected_subject_ids is not None:
        expected = sorted(list(expected_subject_ids))
        if subj_json != expected:
            msg = ("Los subject_ids del JSON no coinciden con expected_subject_ids.\n"
                   f"JSON has {len(subj_json)} subjects, expected {len(expected)}.\n"
                   f"First 10 JSON: {subj_json[:10]}\nFirst 10 expected: {expected[:10]}")
            if strict_check:
                raise ValueError(msg)
            else:
                print("WARNING: " + msg)
    return payload

# =========================
# EXPERIMENTO
# =========================
def run_experiment(save_folds_json=True, folds_json_path=FOLDS_DIR,
                   folds_json_description="GroupKFold folds for ConformerMI v2 (balance-first)"):
    mne.set_log_level('WARNING')

    subs = subjects_available()
    print(f"Sujetos elegibles: {len(subs)} → {subs[:10]}{'...' if len(subs)>10 else ''}")

    # 1) Recolectar eventos brutos T1/T2 (sin crops)
    per_subj = gather_all_events(subs)

    # 2) Balancear a 21 eventos por clase/sujeto
    kept_bal = balance_events_per_subject(per_subj, need_per_class=21, rng_seed=RANDOM_STATE)

    # 3) Generar segmentos aplicando CROPS SOLO sobre los eventos balanceados
    X, y, groups, chs = segments_from_balanced_events(kept_bal, crops=CROPS, fs=FS)
    N, T, C = X.shape
    print(f"Listo para entrenar: N={N} | T={T} | C={C} | clases={len(np.unique(y))} | sujetos={len(np.unique(groups))}")

    # preparar JSON folds
    if folds_json_path is None:
        folds_json_path = Path("folds") / f"group_folds_{N_FOLDS}splits.json"
    else:
        folds_json_path = Path(folds_json_path)
    folds_json_path.parent.mkdir(parents=True, exist_ok=True)

    unique_subs = sorted(np.unique(groups).tolist())
    subject_ids_str = [f"S{s:03d}" for s in unique_subs]

    if not folds_json_path.exists():
        if not save_folds_json:
            raise FileNotFoundError(f"Folds JSON no encontrado en {folds_json_path} y save_folds_json=False.")
        save_group_folds_json_with_indices(subject_ids_str, groups, n_splits=N_FOLDS,
                                           out_json_path=folds_json_path,
                                           created_by="ConformerMI_v2_balFirst",
                                           description=folds_json_description)
    payload = load_group_folds_json(folds_json_path, expected_subject_ids=subject_ids_str, strict_check=False)
    folds = payload["folds"]

    global_folds = []
    ft_prog_folds = []
    all_true = []
    all_pred = []

    for f in folds:
        fold = f["fold"]
        tr_idx = np.asarray(f.get("tr_idx", []), dtype=int)
        te_idx = np.asarray(f.get("te_idx", []), dtype=int)
        if tr_idx.size == 0 or te_idx.size == 0:
            print(f"[WARN] fold {fold} sin índices tr/te válidos. Saltando.")
            continue

        # split de validación por sujetos
        gss = GroupShuffleSplit(n_splits=1, test_size=GLOBAL_VAL_SPLIT, random_state=RANDOM_STATE)
        tr_subj_idx, va_subj_idx = next(gss.split(tr_idx, groups[tr_idx], groups[tr_idx]))
        tr_sub_idx = tr_idx[tr_subj_idx]
        va_idx     = tr_idx[va_subj_idx]

        # Normalización por fold (usar sólo train)
        mu, sigma = compute_fold_norm_stats(X[tr_sub_idx])
        Xn = apply_norm_stats(X, mu, sigma)

        ds_tr = EEGTrials(Xn[tr_sub_idx], y[tr_sub_idx], groups[tr_sub_idx], fs=FS, train_mode=True)
        ds_va = EEGTrials(Xn[va_idx],     y[va_idx],     groups[va_idx],     fs=FS, train_mode=False)
        ds_te = EEGTrials(Xn[te_idx],     y[te_idx],     groups[te_idx],     fs=FS, train_mode=False)

        sampler = build_weighted_sampler(y[tr_sub_idx], groups[tr_sub_idx])
        dl_tr = DataLoader(ds_tr, batch_size=BATCH_SIZE, sampler=sampler, drop_last=False)
        dl_va = DataLoader(ds_va, batch_size=BATCH_SIZE, shuffle=False, drop_last=False)
        dl_te = DataLoader(ds_te, batch_size=BATCH_SIZE, shuffle=False, drop_last=False)

        # modelo
        model = ConformerMI(n_ch=C, n_classes=N_CLASSES,
                            d_model=64, n_heads=2, n_blocks=3,
                            patch_len=6, dropout=0.3, sd_prob=0.2, dw_kernel=15).to(DEVICE)

        # CE ponderada (además del sampler)
        class_weights = make_class_weight_tensor(y[tr_sub_idx], n_classes=N_CLASSES)
        crit = nn.CrossEntropyLoss(weight=class_weights)

        opt = optim.Adam(model.parameters(), lr=LR_INIT, weight_decay=1e-4)

        best_state = copy.deepcopy(model.state_dict())
        best_val = -1.0; bad = 0

        print(f"\n[Fold {fold}/{N_FOLDS}] Entrenando ConformerMI v2 (balance-first)..."
              f" (train={len(tr_sub_idx)} | val={len(va_idx)} | test={len(te_idx)})")

        for epoch in range(EPOCHS_GLOBAL):
            for pg in opt.param_groups:
                pg['lr'] = cosine_with_warmup(epoch, EPOCHS_GLOBAL, base_lr=LR_INIT, warmup=WARMUP_EPOCHS)
            train_one_epoch(model, dl_tr, opt, crit, epoch, EPOCHS_GLOBAL)

            _, _, tr_acc = evaluate(model, dl_tr, use_tta=False)
            _, _, va_acc = evaluate(model, dl_va, use_tta=False)

            if (epoch+1) % LOG_EVERY == 0 or epoch in (0, 4, 9, 19, 49, 79):
                cur_lr = opt.param_groups[0]['lr']
                print(f"  Ep {epoch+1:3d} | train_acc={tr_acc:.4f} | val_acc={va_acc:.4f} | LR={cur_lr:.5f}")

            if va_acc > best_val + 1e-4:
                best_val = va_acc; best_state = copy.deepcopy(model.state_dict()); bad = 0
            else:
                bad += 1
                if bad >= GLOBAL_PATIENCE:
                    print(f"  Early stopping en época {epoch+1} (mejor val_acc={best_val:.4f})")
                    break

        model.load_state_dict(best_state)

        # evaluación global (con TTA activado)
        y_true, y_pred, acc_global = evaluate(model, dl_te, use_tta=True)
        global_folds.append(acc_global); all_true.append(y_true); all_pred.append(y_pred)
        print(f"[Fold {fold}/{N_FOLDS}] Global acc={acc_global:.4f}")
        print_report(y_true, y_pred, CLASS_NAMES_4C)

        # ---------- Fine-tuning por sujeto (CV adaptativo) ----------
        acc_ft = np.nan
        if DO_SUBJECT_FT:
            X_te, y_te, g_te = Xn[te_idx], y[te_idx], groups[te_idx]
            y_true_ft_all, y_pred_ft_all = [], []
            used_subjects = 0
            for sid in np.unique(g_te):
                idx = np.where(g_te == sid)[0]
                Xs, ys = X_te[idx], y_te[idx]
                if len(np.unique(ys)) < 2:
                    continue
                y_true_subj, y_pred_subj = subject_cv_finetune_predict_progressive(model, Xs, ys, n_classes=N_CLASSES)
                y_true_ft_all.append(y_true_subj); y_pred_ft_all.append(y_pred_subj); used_subjects += 1
            if len(y_true_ft_all) > 0:
                y_true_ft_all = np.concatenate(y_true_ft_all)
                y_pred_ft_all = np.concatenate(y_pred_ft_all)
                acc_ft = (y_true_ft_all == y_pred_ft_all).mean()
                print(f"  Fine-tuning (por sujeto, CV adaptativo) acc={acc_ft:.4f} | sujetos={used_subjects}")
                print(f"  Δ(FT-Global) = {acc_ft - acc_global:+.4f}")
            else:
                print("  Fine-tuning no ejecutado (sujeto(s) con muestras/clases insuficientes).")
        ft_prog_folds.append(acc_ft)

    # ---------- resultados finales ----------
    if len(all_true) > 0:
        all_true = np.concatenate(all_true)
        all_pred = np.concatenate(all_pred)
    else:
        all_true = np.array([], dtype=int)
        all_pred = np.array([], dtype=int)

    print("\n" + "="*60)
    print("RESULTADOS FINALES")
    print("="*60)
    print("Global folds:", [f"{a:.4f}" for a in global_folds])
    if len(global_folds) > 0:
        print(f"Global mean: {np.mean(global_folds):.4f}")

    print("Fine-tune folds:", [("nan" if (a is None or np.isnan(a)) else f"{a:.4f}") for a in ft_prog_folds])
    if len(ft_prog_folds) > 0:
        print(f"Fine-tune mean: {np.nanmean(ft_prog_folds):.4f}")
        print(f"Δ(FT-Global) mean: {np.nanmean(ft_prog_folds) - np.mean(global_folds):+.4f}")

    if all_true.size > 0:
        plot_confusion(all_true, all_pred, CLASS_NAMES_4C,
                       title="Confusion Matrix - ConformerMI v2 (Global, all folds, balance-first)",
                       fname="confusion_conformer_v2_balancefirst_global_allfolds.png")
        print("↳ Matriz de confusión guardada: confusion_conformer_v2_balancefirst_global_allfolds.png")

    return {
        "global_folds": global_folds,
        "ft_prog_folds": ft_prog_folds,
        "all_true": all_true,
        "all_pred": all_pred,
        "folds_json_path": str(folds_json_path)
    }

# ---------- MAIN ----------
if __name__ == "__main__":
    print("🧠 INICIANDO EXPERIMENTO Conformer-MI v2 — balance-first (4 clases, 8 canales)")
    print(f"⚙️  GLOBAL: epochs={EPOCHS_GLOBAL}, lr={LR_INIT}, warmup={WARMUP_EPOCHS}, batch={BATCH_SIZE}")
    print(f"🔧 Augments: bandpass={'ON' if USE_BANDPASS_AUG else 'OFF'}, amp_scale={AMP_SCALE_RANGE}, ch_drop={CHANNEL_DROP_PROB}")
    if DO_SUBJECT_FT:
        print(f"🔧 FT por sujeto (adaptativo): epochs={FT_EPOCHS}, head_lr={FT_HEAD_LR}, base_lr={FT_BASE_LR}, L2SP={FT_L2SP}")
    run_experiment()


🚀 Usando dispositivo: cuda
🧠 INICIANDO EXPERIMENTO Conformer-MI v2 — balance-first (4 clases, 8 canales)
⚙️  GLOBAL: epochs=100, lr=0.001, warmup=5, batch=32
🔧 Augments: bandpass=ON, amp_scale=(0.8, 1.2), ch_drop=0.1
🔧 FT por sujeto (adaptativo): epochs=15, head_lr=0.001, base_lr=5e-05, L2SP=0.0001
Sujetos elegibles: 103 → [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]...


Escaneando eventos (sin crops):   0%|          | 0/103 [00:00<?, ?it/s]

Escaneando eventos (sin crops): 100%|██████████| 103/103 [00:39<00:00,  2.64it/s]
Generando segmentos (con crops):   1%|          | 1/103 [00:00<00:14,  6.98it/s]

[OK] S001 R08 C0 onset=120.400 win=(0.00,3.00) s=19264 e=19744 T=480 C=8
[OK] S001 R08 C0 onset=20.800 win=(0.00,3.00) s=3328 e=3808 T=480 C=8
[OK] S001 R04 C0 onset=12.500 win=(0.00,3.00) s=2000 e=2480 T=480 C=8
[OK] S001 R08 C0 onset=4.200 win=(0.00,3.00) s=672 e=1152 T=480 C=8
[OK] S001 R12 C0 onset=29.100 win=(0.00,3.00) s=4656 e=5136 T=480 C=8
[OK] S001 R08 C0 onset=78.900 win=(0.00,3.00) s=12624 e=13104 T=480 C=8
[OK] S001 R04 C0 onset=20.800 win=(0.00,3.00) s=3328 e=3808 T=480 C=8
[OK] S001 R08 C0 onset=87.200 win=(0.00,3.00) s=13952 e=14432 T=480 C=8
[OK] S001 R04 C0 onset=87.200 win=(0.00,3.00) s=13952 e=14432 T=480 C=8
[OK] S001 R04 C0 onset=45.700 win=(0.00,3.00) s=7312 e=7792 T=480 C=8
[OK] S001 R08 C0 onset=54.000 win=(0.00,3.00) s=8640 e=9120 T=480 C=8
[OK] S001 R12 C0 onset=70.600 win=(0.00,3.00) s=11296 e=11776 T=480 C=8
[OK] S001 R04 C0 onset=62.300 win=(0.00,3.00) s=9968 e=10448 T=480 C=8
[OK] S001 R04 C0 onset=78.900 win=(0.00,3.00) s=12624 e=13104 T=480 C=8
[OK] S00

Generando segmentos (con crops): 100%|██████████| 103/103 [00:39<00:00,  2.59it/s]


=== DEBUG WINDOWS SUMMARY ===
Eventos brutos: 8652
Crops OK:      8652
Skips OOB:     0
Skips LEN:     0
Skips NaN:     0
Dataset construido: N=8652 | T=480 | C=8 | sujetos únicos=103
Listo para entrenar: N=8652 | T=480 | C=8 | clases=4 | sujetos=103

[Fold 1/5] Entrenando ConformerMI v2 (balance-first)... (train=5796 | val=1092 | test=1764)
  Ep   1 | train_acc=0.3092 | val_acc=0.3159 | LR=0.00020
  Ep   5 | train_acc=0.4382 | val_acc=0.4103 | LR=0.00100
  Ep  10 | train_acc=0.4750 | val_acc=0.4167 | LR=0.00100
  Ep  15 | train_acc=0.4762 | val_acc=0.4167 | LR=0.00098
  Ep  20 | train_acc=0.5300 | val_acc=0.4460 | LR=0.00095
  Ep  25 | train_acc=0.5064 | val_acc=0.4332 | LR=0.00091
  Ep  30 | train_acc=0.5342 | val_acc=0.4451 | LR=0.00085
  Ep  35 | train_acc=0.5437 | val_acc=0.4341 | LR=0.00079
  Early stopping en época 36 (mejor val_acc=0.4597)
[Fold 1/5] Global acc=0.4461
              precision    recall  f1-score   support

        Left     0.4753    0.5669    0.5171       441
  

In [None]:
# -*- coding: utf-8 -*-
# CNN + Transformer (Conformer-lite) para MI
# - MISMO pipeline que tu EEGNet (6 s ventana, augments, mixup, cutout OF, SGDR, TTA=5, FT progresivo L2-SP)
# - Solo cambia el modelo

import os, re, math, random, json, itertools, copy
from pathlib import Path
from datetime import datetime

import numpy as np
import pandas as pd
import mne

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset, WeightedRandomSampler

from sklearn.model_selection import GroupKFold, StratifiedKFold, StratifiedShuffleSplit, GroupShuffleSplit
from sklearn.utils import check_random_state
from sklearn.metrics import confusion_matrix, classification_report

import matplotlib.pyplot as plt
from tqdm import tqdm

# =========================
# CONFIGURACIÓN GENERAL
# =========================
PROJ = Path('..').resolve().parent
DATA_RAW = PROJ / 'data' / 'raw'
CACHE_DIR = PROJ / 'data' / 'cache'
FOLDS_DIR = PROJ / 'models' / 'folds' / 'Kfold5.json'
CACHE_DIR.mkdir(parents=True, exist_ok=True)

# Dispositivo y semilla
RANDOM_STATE = 42
torch.manual_seed(RANDOM_STATE); np.random.seed(RANDOM_STATE); random.seed(RANDOM_STATE)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🚀 Usando dispositivo: {DEVICE}")

# Escenario y ventana
CLASS_SCENARIO = '4c'
WINDOW_MODE = '6s'   # 6 s → ~960 muestras @160 Hz
FS = 160.0
N_FOLDS = 5

# Entrenamiento global
BATCH_SIZE = 64
EPOCHS_GLOBAL = 100
LR_INIT = 1e-2
SGDR_T0 = 6
SGDR_Tmult = 2

# Validación/ES global
GLOBAL_VAL_SPLIT = 0.15
GLOBAL_PATIENCE  = 10
LOG_EVERY        = 5

# Fine-tuning por sujeto (protocolo robusto)
CALIB_CV_FOLDS = 4
FT_EPOCHS = 30
FT_BASE_LR = 5e-5
FT_HEAD_LR = 1e-3
FT_L2SP = 1e-4
FT_PATIENCE = 5
FT_VAL_RATIO = 0.2

# Normalización por época canal-a-canal (z-score)
NORM_EPOCH_ZSCORE = True

# Sujetos excluidos
EXCLUDE_SUBJECTS = {38, 88, 89, 92, 100, 104}

# Runs
MI_RUNS_LR = [4, 8, 12]
MI_RUNS_OF = [6, 10, 14]
BASELINE_RUNS_EO = [1]

# Canales (8 con FCz)
EXPECTED_8 = ['C3','C4','Cz','CP3','CP4','FC3','FC4','FCz']

CLASS_NAMES_4C = ['Left', 'Right', 'Both Fists', 'Both Feet']

# =========================
# UTILIDADES DE CANALES / EDF
# =========================
def normalize_label(s: str) -> str:
    if s is None: return s
    s = s.strip()
    s = re.sub(r'[^A-Za-z0-9]', '', s)
    s = re.sub(r'([A-Za-z])0([0-9])', r'\1\2', s)
    s = re.sub(r'([A-Za-z])Z$', r'\1z', s)
    s = s.replace('fp', 'Fp').replace('FP', 'Fp')
    s = ''.join(ch.upper() if ch != 'z' else 'z' for ch in s)
    return s

def rename_channels_1010(raw: mne.io.BaseRaw):
    mapping = {}
    for ch in raw.ch_names:
        lab = normalize_label(ch)
        lab = lab[:-1] + 'z' if lab.endswith('Z') else lab
        lab = re.sub(r'([A-Z])Z$', r'\1z', lab)
        mapping[ch] = lab
    mne.rename_channels(raw.info, mapping)

def ensure_channels_order(raw: mne.io.BaseRaw, desired_channels=EXPECTED_8):
    missing = [ch for ch in desired_channels if ch not in raw.ch_names]
    if missing:
        print(f"Warning: faltan canales {missing} en archivo {getattr(raw,'filenames', [''])[0]}")
        return None
    raw.reorder_channels([ch for ch in raw.ch_names if ch in desired_channels] +
                         [ch for ch in raw.ch_names if ch not in desired_channels])
    raw.pick_channels(desired_channels, ordered=True)
    return raw

# =========================
# LECTURA DE EDF y EVENTOS
# =========================
_re_file = re.compile(r'[Ss](\d{3}).*?[Rr](\d{2})')
def parse_subject_run(path: Path):
    m = _re_file.search(str(path))
    if not m: return None, None
    return int(m.group(1)), int(m.group(2))

def run_kind(run_id:int):
    if run_id in MI_RUNS_LR: return 'LR'
    if run_id in MI_RUNS_OF: return 'OF'
    if run_id in BASELINE_RUNS_EO: return 'EO'
    return None

# --- Notch adaptativo: lee CSV si existe (subject, run, snr50_db, snr60_db)
_SNR_TABLE = None
def _load_snr_table():
    global _SNR_TABLE
    if _SNR_TABLE is not None:
        return _SNR_TABLE
    csv_path = PROJ / 'reports' / 'psd_mains' / 'psd_mains_summary.csv'
    if csv_path.exists():
        try:
            df = pd.read_csv(csv_path)
            _SNR_TABLE = df
        except Exception as e:
            print(f"[SNR] No se pudo leer {csv_path}: {e}")
            _SNR_TABLE = None
    return _SNR_TABLE

def _decide_notch(subject, run, th_db=10.0):
    df = _load_snr_table()
    if df is None:  # default
        return 60.0
    row = df[(df['subject']==subject) & (df['run']==run)]
    if row.empty:
        return 60.0
    snr50 = float(row['snr50_db'].iloc[0]); snr60 = float(row['snr60_db'].iloc[0])
    if snr60 >= th_db and snr60 >= snr50: return 60.0
    if snr50 >= th_db and snr50 >  snr60: return 50.0
    return None  # no notch si no sobresale

def read_raw_edf(path: Path):
    raw = mne.io.read_raw_edf(path, preload=True, verbose=False)
    raw.pick(mne.pick_types(raw.info, eeg=True))
    rename_channels_1010(raw)
    try:
        mont = mne.channels.make_standard_montage('standard_1020')
        raw.set_montage(mont, on_missing='ignore')
    except Exception:
        pass
    if abs(raw.info['sfreq'] - FS) > 1e-6:
        raw.resample(FS, npad="auto")
    raw = ensure_channels_order(raw, EXPECTED_8)
    if raw is None:
        return None

    # Notch adaptativo (por SNR); si no hay CSV, por defecto 60 Hz
    sid, rid = parse_subject_run(path)
    notch = _decide_notch(sid, rid)
    if notch is not None:
        raw.notch_filter(freqs=[float(notch)], picks='eeg', method='spectrum_fit', phase='zero')
    return raw

def collect_events_T1T2(raw: mne.io.BaseRaw):
    if raw.annotations is None or len(raw.annotations) == 0:
        return []
    def _norm(s): return str(s).strip().upper().replace(' ', '')
    res = []
    for onset, desc in zip(raw.annotations.onset, raw.annotations.description):
        tag = _norm(desc)
        if tag in ('T1','T2'):
            res.append((float(onset), tag))
    res.sort()
    dedup = []
    last_t1 = last_t2 = -1e9
    for t, tag in res:
        if tag == 'T1':
            if (t - last_t1) >= 0.5: dedup.append((t, tag)); last_t1 = t
        else:
            if (t - last_t2) >= 0.5: dedup.append((t, tag)); last_t2 = t
    return dedup

# =========================
# CONSTRUCCIÓN DE DATASETS (ventana 6s)
# =========================
def subjects_available():
    subs = []
    for sdir in sorted(DATA_RAW.glob('S*')):
        if not sdir.is_dir(): continue
        try: sid = int(sdir.name[1:])
        except: continue
        if sid in EXCLUDE_SUBJECTS: continue
        any_mi = any((sdir / f"S{sid:03d}R{r:02d}.edf").exists() for r in (MI_RUNS_LR + MI_RUNS_OF))
        if any_mi: subs.append(sid)
    return subs

def extract_trials_from_run(edf_path: Path, scenario: str, window_mode: str):
    subj, run = parse_subject_run(edf_path)
    kind = run_kind(run)
    if kind not in ('LR','OF','EO'):
        return ([], [])

    raw = read_raw_edf(edf_path)
    if raw is None:
        return ([], [])

    data = raw.get_data()
    fs = raw.info['sfreq']
    assert abs(fs - FS) < 1e-6

    out = []
    if kind in ('LR','OF'):
        events = collect_events_T1T2(raw)
        # 6 s: de -1 a +5 s respecto al onset → T=960
        if window_mode == '3s':
            rel_start, rel_end = 0.0, 3.0
        else:
            rel_start, rel_end = -1.0, 5.0

        for onset_sec, tag in events:
            if kind == 'LR':
                if tag == 'T1': label = 'L'
                elif tag == 'T2': label = 'R'
                else: continue
            else:
                if tag == 'T1': label = 'BFISTS'
                elif tag == 'T2': label = 'BFEET'
                else: continue

            if scenario == '2c' and label not in ('L','R'): continue
            if scenario == '3c' and label not in ('L','R','BFISTS'): continue
            if scenario == '4c' and label not in ('L','R','BFISTS','BFEET'): continue

            s = int(round((raw.first_time + onset_sec + rel_start) * fs))
            e = int(round((raw.first_time + onset_sec + rel_end) * fs))
            if s < 0 or e > data.shape[1]:
                continue

            seg = data[:, s:e].T.astype(np.float32)
            if NORM_EPOCH_ZSCORE:
                seg = (seg - seg.mean(axis=0, keepdims=True)) / (seg.std(axis=0, keepdims=True) + 1e-6)

            if label == 'L':       y = 0
            elif label == 'R':     y = 1
            elif label == 'BFISTS':y = 2
            elif label == 'BFEET': y = 3
            else: continue

            out.append((seg, y, subj))

    elif kind == 'EO':
        return ([], raw.ch_names)

    return out, raw.ch_names

def build_dataset_all(subjects, scenario='4c', window_mode='6s'):
    X, y, groups = [], [], []
    ch_template = None

    for s in tqdm(subjects, desc="Construyendo dataset (RAW)"):
        sdir = DATA_RAW / f"S{s:03d}"
        if not sdir.exists(): continue

        trials_L, trials_R, trials_FISTS, trials_FEET = [], [], [], []

        for r in MI_RUNS_LR:
            p = sdir / f"S{s:03d}R{r:02d}.edf"
            if not p.exists(): continue
            outs, chs = extract_trials_from_run(p, scenario, window_mode)
            if ch_template is None and chs: ch_template = chs
            for seg, lab, _ in outs:
                if lab == 0: trials_L.append(seg)
                elif lab == 1: trials_R.append(seg)

        for r in MI_RUNS_OF:
            p = sdir / f"S{s:03d}R{r:02d}.edf"
            if not p.exists(): continue
            outs, chs = extract_trials_from_run(p, scenario, window_mode)
            if ch_template is None and chs: ch_template = chs
            for seg, lab, _ in outs:
                if lab == 2: trials_FISTS.append(seg)
                elif lab == 3: trials_FEET.append(seg)

        need_per_class = 21
        def pick(trials, n, rng):
            if len(trials) < n:
                idx = rng.choice(len(trials), size=n, replace=True)
                return [trials[i] for i in idx]
            rng.shuffle(trials)
            return trials[:n]

        rng = check_random_state(RANDOM_STATE + s)
        if len(trials_L)==0 or len(trials_R)==0 or len(trials_FISTS)==0 or len(trials_FEET)==0:
            continue

        Lp  = pick(trials_L,     need_per_class, rng)
        Rp  = pick(trials_R,     need_per_class, rng)
        FIp = pick(trials_FISTS, need_per_class, rng)
        FEp = pick(trials_FEET,  need_per_class, rng)

        pack = [(Lp, 0), (Rp, 1), (FIp, 2), (FEp, 3)]
        for segs, lab in pack:
            for seg in segs:
                X.append(seg); y.append(lab); groups.append(s)

    X = np.stack(X, axis=0)
    y = np.asarray(y, dtype=np.int64)
    groups = np.asarray(groups, dtype=np.int64)

    n, T, C = X.shape
    n_classes = len(np.unique(y))
    print(f"Dataset construido: N={n} | T={T} | C={C} | clases={n_classes} | sujetos únicos={len(np.unique(groups))}")
    return X, y, groups, ch_template

# =========================
# AUGMENTS (solo train)
# =========================
def do_time_jitter(x, max_ms=50, fs=160.0):
    # x: (B,1,T,C) torch.float
    max_shift = int(round(max_ms/1000.0 * fs))
    if max_shift <= 0: return x
    B,_,T,C = x.shape
    shifts = torch.randint(low=-max_shift, high=max_shift+1, size=(B,), device=x.device)
    out = torch.empty_like(x)
    for i,s in enumerate(shifts):
        if s==0: out[i] = x[i]; continue
        if s>0:
            out[i,:,s:,:] = x[i,:,:T-s,:]
            out[i,:,:s,:] = 0
        else:
            s = -s
            out[i,:,:T-s,:] = x[i,:,s:,:]
            out[i,:,T-s:,:] = 0
    return out

def do_gaussian_noise(x, sigma=0.01):
    if sigma<=0: return x
    return x + sigma*torch.randn_like(x)

def do_temporal_cutout_masked(x, y, classes_mask={2,3}, min_ms=30, max_ms=60, fs=160.0):
    # Aplica cutout solo si y pertenece a classes_mask (OF)
    B,_,T,C = x.shape
    if B == 0: return x
    Lmin = int(round(min_ms/1000.0*fs))
    Lmax = int(round(max_ms/1000.0*fs))
    if Lmin<=0 or Lmax<=0 or Lmin>Lmax: return x
    out = x.clone()
    y_np = y.detach().cpu().numpy()
    for i in range(B):
        if int(y_np[i]) not in classes_mask: 
            continue
        L = random.randint(Lmin, Lmax)
        if L>=T: continue
        s = random.randint(0, T-L)
        out[i,:,s:s+L,:] = 0
    return out

def mixup_batch(x, y, n_classes, alpha=0.2):
    if alpha<=0:
        y_onehot = torch.nn.functional.one_hot(y, num_classes=n_classes).float()
        return x, y_onehot, 1.0
    lam = np.random.beta(alpha, alpha)
    perm = torch.randperm(x.size(0), device=x.device)
    x_mix = lam*x + (1-lam)*x[perm]
    y_a = torch.nn.functional.one_hot(y, num_classes=n_classes).float()
    y_b = y_a[perm]
    y_mix = lam*y_a + (1-lam)*y_b
    return x_mix, y_mix, lam

class WeightedSoftCrossEntropy(nn.Module):
    """
    Cross-entropy suave (targets en probas, p.ej. mixup) con pesos por clase.
    """
    def __init__(self, class_weights=None, label_smoothing=0.0):
        super().__init__()
        self.register_buffer('w', None if class_weights is None else class_weights.clone().float())
        self.ls = float(label_smoothing)
    def forward(self, logits, target_probs):
        if self.ls > 0:
            K = logits.size(1)
            target_probs = (1-self.ls)*target_probs + self.ls*(1.0/K)
        logp = torch.log_softmax(logits, dim=1)  # (B,C)
        loss_per_class = -(target_probs * logp)  # (B,C)
        if self.w is not None:
            loss_per_class = loss_per_class * self.w.unsqueeze(0)
        loss = loss_per_class.sum(dim=1).mean()
        return loss

# =========================
# MODELO: CNN + Transformer (Conformer-lite)
# =========================
class SinPosEnc(nn.Module):
    def __init__(self, d_model, max_len=4000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer('pe', pe)
    def forward(self, x):  # x: (B,L,D)
        L = x.size(1)
        return x + self.pe[:L].unsqueeze(0)

class ConvModule1D(nn.Module):
    def __init__(self, d_model, dw_kernel=15, dropout=0.3):
        super().__init__()
        self.ln = nn.LayerNorm(d_model)
        self.pw1 = nn.Conv1d(d_model, 2*d_model, kernel_size=1)
        self.dw  = nn.Conv1d(d_model, d_model, kernel_size=dw_kernel, padding=dw_kernel//2, groups=d_model)
        self.bn  = nn.BatchNorm1d(d_model)
        self.act = nn.SiLU()
        self.pw2 = nn.Conv1d(d_model, d_model, kernel_size=1)
        self.drop = nn.Dropout(dropout)
    def forward(self, x):  # (B,L,D)
        y = self.ln(x).transpose(1,2)          # (B,D,L)
        y = self.pw1(y)
        a,b = y.chunk(2, dim=1)
        y = a * torch.sigmoid(b)
        y = self.dw(y)
        y = self.bn(y); y = self.act(y)
        y = self.pw2(y)
        y = self.drop(y.transpose(1,2))        # (B,L,D)
        return y

class ConformerLiteBlock(nn.Module):
    def __init__(self, d_model=48, n_heads=2, dropout=0.3, dw_kernel=15, ffn_mult=2.0, sd_prob=0.2):
        super().__init__()
        self.sd_prob = sd_prob
        self.ln_attn = nn.LayerNorm(d_model)
        self.mha = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
        self.drop1 = nn.Dropout(dropout)
        self.conv = ConvModule1D(d_model, dw_kernel=dw_kernel, dropout=dropout)
        self.drop2 = nn.Dropout(dropout)
        self.ln_ffn = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, int(ffn_mult*d_model)),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(int(ffn_mult*d_model), d_model),
        )
        self.drop3 = nn.Dropout(dropout)

    def _res(self, x, y, training):
        if not training or self.sd_prob <= 0.0:
            return x + y
        keep = (torch.rand(x.size(0), device=x.device) > self.sd_prob).float().view(-1,1,1)
        return x + y * keep

    def forward(self, x):  # (B,L,D)
        y = self.ln_attn(x)
        y, _ = self.mha(y, y, y, need_weights=False)
        x = self._res(x, self.drop1(y), self.training)
        y = self.conv(x)
        x = self._res(x, self.drop2(y), self.training)
        y = self.ffn(self.ln_ffn(x))
        x = self._res(x, self.drop3(y), self.training)
        return x

class CNNTransMI(nn.Module):
    """
    Entrada: (B, 1, T, C) con T=960,C=8 (6s @160Hz)
    Pipeline:
      - Stem temporal (CNN) ligero
      - Proyección espacial 1x1 (C->D)
      - Patchify temporal (stride=patch_len)
      - N bloques Conformer-lite
      - AvgPool temporal + MLP
    """
    def __init__(self, n_ch=8, n_classes=4,
                 d_model=48, n_heads=2, n_blocks=3,
                 patch_len=8, dropout=0.3, sd_prob=0.2, dw_kernel=15):
        super().__init__()
        self.n_ch = n_ch
        self.d_model = d_model
        self.patch_len = patch_len

        # Stem temporal ligero (sobre T; trabaja canal-independiente)
        self.temporal_stem = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=(15,1), padding=(7,0), bias=False),
            nn.BatchNorm2d(8),
            nn.GELU(),
            nn.Conv2d(8, 8, kernel_size=(15,1), padding=(7,0), bias=False),
            nn.BatchNorm2d(8),
            nn.GELU()
        )
        # Reordenar y proyección espacial a D
        self.spatial_pw = nn.Conv1d(n_ch*8, d_model, kernel_size=1, bias=False)
        self.spatial_bn = nn.BatchNorm1d(d_model)
        self.spatial_act = nn.GELU()

        # Patchify temporal (stride=patch_len)
        self.patchify = nn.Conv1d(d_model, d_model, kernel_size=patch_len, stride=patch_len, bias=False)

        self.pos = SinPosEnc(d_model, max_len=4000)
        self.blocks = nn.ModuleList([
            ConformerLiteBlock(d_model, n_heads=n_heads, dropout=dropout,
                               dw_kernel=dw_kernel, ffn_mult=2.0, sd_prob=sd_prob*(i+1)/n_blocks)
            for i in range(n_blocks)
        ])
        self.ln_out = nn.LayerNorm(d_model)
        self.head = nn.Sequential(
            nn.Linear(d_model, 128),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(128, n_classes)
        )

    def forward(self, x):  # x: (B,1,T,C)
        B,_,T,C = x.shape
        z = self.temporal_stem(x)                 # (B,8,T,C)
        z = z.permute(0,3,2,1).contiguous()       # (B,C,T,8)
        z = z.view(B, C*8, T)                     # (B,C*8,T)
        z = self.spatial_pw(z)                    # (B,D,T)
        z = self.spatial_bn(z); z = self.spatial_act(z)
        z = self.patchify(z)                      # (B,D,L)
        z = z.transpose(1,2)                      # (B,L,D)
        z = self.pos(z)
        for blk in self.blocks:
            z = blk(z)
        z = self.ln_out(z)
        z = z.mean(dim=1)                         # (B,D)
        return self.head(z)

# =========================
# TORCH DATASET
# =========================
class EEGTrials(Dataset):
    def __init__(self, X, y, groups):
        self.X = X.astype(np.float32)
        self.y = y.astype(np.int64)
        self.g = groups.astype(np.int64)
    def __len__(self): return self.X.shape[0]
    def __getitem__(self, idx):
        x = self.X[idx]
        x = np.expand_dims(x, 0)                 # (1, T, C)
        return torch.from_numpy(x), torch.tensor(self.y[idx]), torch.tensor(self.g[idx])

# =========================
# UTIL: MAX-NORM (para estabilizar FT)
# =========================
@torch.no_grad()
def apply_max_norm(model, max_value=2.0, p=2.0):
    layers = []
    # aplicarlo a capas densas finales; opcional a convs
    if hasattr(model, 'head'):
        for m in model.head:
            if isinstance(m, nn.Linear):
                layers.append(m)
    for layer in layers:
        if hasattr(layer, 'weight') and layer.weight is not None:
            w = layer.weight.data
            norms = w.view(w.size(0), -1).norm(p=p, dim=1, keepdim=True)
            desired = torch.clamp(norms, max=max_value)
            w.view(w.size(0), -1).mul_(desired / (1e-8 + norms))

# =========================
# ENTRENAMIENTO / EVALUACIÓN
# =========================
def build_weighted_sampler(y, groups):
    # pesos inversos por clase y por sujeto → promueve balance en cada batch
    y = np.asarray(y); groups = np.asarray(groups)
    class_counts = np.bincount(y, minlength=len(np.unique(y))).astype(float)
    class_w = 1.0 / class_counts[y]
    subj_vals, subj_counts = np.unique(groups, return_counts=True)
    subj_map = {s:c for s,c in zip(subj_vals, subj_counts)}
    subj_w = np.array([1.0/subj_map[g] for g in groups], dtype=float)
    w = class_w * subj_w
    w = w / w.mean()
    w_t = torch.from_numpy(w).float()
    sampler = WeightedRandomSampler(weights=w_t, num_samples=len(w_t), replacement=True)
    return sampler

def make_class_weight_tensor(y_indices, n_classes, boost_bfists=1.20):
    # pesos inversos por clase, normalizados; clase 2 (*Both Fists*) con boost
    counts = np.bincount(y_indices, minlength=n_classes).astype(np.float32)
    counts[counts == 0] = 1.0
    w = counts.sum() / counts
    w = w / w.mean()
    w[2] *= boost_bfists
    w = w / w.mean()
    return torch.tensor(w, dtype=torch.float32, device=DEVICE)

def train_epoch(model, loader, opt, criterion_soft, n_classes,
                do_aug=True, fs=160.0, maxnorm=None):
    model.train()
    for xb, yb, _ in loader:
        xb, yb = xb.to(DEVICE), yb.to(DEVICE)

        # ---- AUGMENTS (orden: jitter → noise → cutout focalizado → mixup) ----
        if do_aug:
            xb = do_time_jitter(xb, max_ms=50, fs=fs)
            xb = do_gaussian_noise(xb, sigma=0.01)
            xb = do_temporal_cutout_masked(xb, yb, classes_mask={2,3}, min_ms=30, max_ms=60, fs=fs)
            xb, yt, _ = mixup_batch(xb, yb, n_classes=n_classes, alpha=0.2)
            logits = model(xb)
            loss = criterion_soft(logits, yt)
        else:
            logits = model(xb)
            yt = torch.nn.functional.one_hot(yb, num_classes=n_classes).float()
            loss = criterion_soft(logits, yt)

        opt.zero_grad(set_to_none=True)
        loss.backward()
        opt.step()
        if maxnorm is not None:
            apply_max_norm(model, max_value=maxnorm, p=2.0)

@torch.no_grad()
def _predict_tta(model, xb, n=5, fs=160.0):
    outs = []
    for _ in range(n):
        xj = do_time_jitter(xb, max_ms=25, fs=fs)
        outs.append(model(xj))
    return torch.stack(outs, dim=0).mean(dim=0)

@torch.no_grad()
def evaluate_with_preds(model, loader, use_tta=True, tta_n=5):
    model.eval()
    y_true, y_pred = [], []
    for xb, yb, _ in loader:
        xb = xb.to(DEVICE)
        if use_tta:
            logits = _predict_tta(model, xb, n=tta_n, fs=FS)
        else:
            logits = model(xb)
        pred = logits.argmax(dim=1).cpu().numpy().tolist()
        y_pred.extend(pred)
        y_true.extend(yb.numpy().tolist())
    y_true = np.asarray(y_true, dtype=int)
    y_pred = np.asarray(y_pred, dtype=int)
    acc = (y_true == y_pred).mean()
    return y_true, y_pred, float(acc)

def plot_confusion(y_true, y_pred, classes, title, fname):
    cm = confusion_matrix(y_true, y_pred, labels=list(range(len(classes))))
    with np.errstate(invalid='ignore'):
        cm_norm = cm.astype('float') / cm.sum(axis=1, keepdims=True)
    cm_norm = np.nan_to_num(cm_norm)

    plt.figure(figsize=(6, 5))
    plt.imshow(cm_norm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title(title)
    plt.colorbar(fraction=0.046, pad=0.04)
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45, ha='right')
    plt.yticks(tick_marks, classes)

    fmt = '.2f'
    thresh = cm_norm.max() / 2.
    for i, j in itertools.product(range(cm_norm.shape[0]), range(cm_norm.shape[1])):
        plt.text(j, i, format(cm_norm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm_norm[i, j] > thresh else "black")

    plt.ylabel('True label'); plt.xlabel('Predicted label')
    plt.tight_layout(); plt.savefig(fname, dpi=150, bbox_inches='tight'); plt.close()

def plot_training_curves(history, fname):
    plt.figure(figsize=(6,4))
    plt.plot(history['train_acc'], label='train_acc')
    plt.plot(history['val_acc'], label='val_acc')
    plt.xlabel('Epoch'); plt.ylabel('Accuracy')
    plt.title('Training curve'); plt.legend()
    plt.tight_layout(); plt.savefig(fname, dpi=150); plt.close()

def print_report(y_true, y_pred, classes):
    print(classification_report(y_true, y_pred, target_names=classes, digits=4))

# =========================
# FINE-TUNING PROGRESIVO por sujeto (con validación interna)
# =========================
def _freeze_for_mode_cnntrans(model, mode):
    # Congela todo y libera según el modo
    for p in model.parameters(): p.requires_grad = False

    if mode == 'out':
        # Solo la última capa
        for p in model.head[-1].parameters(): p.requires_grad = True
    elif mode == 'head':
        for m in model.head:
            if isinstance(m, nn.Linear):
                for p in m.parameters(): p.requires_grad = True
    elif mode == 'spatial+head':
        # Último bloque + ln_out + head
        for p in model.blocks[-1].parameters(): p.requires_grad = True
        for p in model.ln_out.parameters(): p.requires_grad = True
        for m in model.head:
            if isinstance(m, nn.Linear):
                for p in m.parameters(): p.requires_grad = True
    else:
        raise ValueError(mode)

def _param_groups_cnntrans(model, mode):
    if mode == 'out':
        return list(model.head[-1].parameters())
    elif mode == 'head':
        params = []
        for m in model.head:
            if isinstance(m, nn.Linear):
                params += list(m.parameters())
        return params
    elif mode == 'spatial+head':
        base = list(model.blocks[-1].parameters()) + list(model.ln_out.parameters())
        head = []
        for m in model.head:
            if isinstance(m, nn.Linear):
                head += list(m.parameters())
        return base + head
    else:
        raise ValueError(mode)

def _class_weights(y_np, n_classes):
    counts = np.bincount(y_np, minlength=n_classes).astype(np.float32)
    counts[counts == 0] = 1.0
    weights = counts.sum() / counts
    weights = weights / weights.mean()
    return torch.tensor(weights, dtype=torch.float32, device=DEVICE)

def _train_one_mode(model, X_cal, y_cal, n_classes, mode,
                    epochs=FT_EPOCHS, batch_size=32,
                    head_lr=FT_HEAD_LR, base_lr=FT_BASE_LR,
                    l2sp_lambda=FT_L2SP, patience=FT_PATIENCE, val_ratio=FT_VAL_RATIO):
    sss = StratifiedShuffleSplit(n_splits=1, test_size=val_ratio, random_state=RANDOM_STATE)
    (tr_idx, va_idx), = sss.split(X_cal, y_cal)
    Xtr, ytr = X_cal[tr_idx], y_cal[tr_idx]
    Xva, yva = X_cal[va_idx], y_cal[va_idx]

    ds_tr = torch.utils.data.TensorDataset(
        torch.from_numpy(Xtr).float().unsqueeze(1),
        torch.from_numpy(ytr).long()
    )
    ds_va = torch.utils.data.TensorDataset(
        torch.from_numpy(Xva).float().unsqueeze(1),
        torch.from_numpy(yva).long()
    )
    dl_tr = DataLoader(ds_tr, batch_size=batch_size, shuffle=True, drop_last=False)
    dl_va = DataLoader(ds_va, batch_size=batch_size, shuffle=False, drop_last=False)

    _freeze_for_mode_cnntrans(model, mode)

    if mode == 'spatial+head':
        base_params = list(model.blocks[-1].parameters()) + list(model.ln_out.parameters())
        head_params = []
        for m in model.head:
            if isinstance(m, nn.Linear):
                head_params += list(m.parameters())
        train_params = base_params + head_params
        opt = optim.Adam([
            {"params": base_params, "lr": base_lr},
            {"params": head_params, "lr": head_lr},
        ])
    else:
        train_params = _param_groups_cnntrans(model, mode)
        opt = optim.Adam(train_params, lr=head_lr)

    ref = [p.detach().clone().to(p.device) for p in train_params]
    class_w = _class_weights(ytr, n_classes)
    crit = nn.CrossEntropyLoss(weight=class_w)

    best_state = copy.deepcopy(model.state_dict())
    best_val = float('inf'); bad = 0

    for _ in range(epochs):
        model.train()
        for xb, yb in dl_tr:
            xb, yb = xb.to(DEVICE), yb.to(DEVICE)
            opt.zero_grad(set_to_none=True)
            logits = model(xb)
            loss = crit(logits, yb)
            reg = 0.0
            for p_cur, p_ref in zip(train_params, ref):
                reg = reg + torch.sum((p_cur - p_ref)**2)
            loss = loss + l2sp_lambda * reg
            loss.backward(); opt.step()
            apply_max_norm(model, max_value=2.0, p=2.0)

        model.eval()
        with torch.no_grad():
            val_loss = 0.0; nval = 0
            for xb, yb in dl_va:
                xb, yb = xb.to(DEVICE), yb.to(DEVICE)
                logits = model(xb)
                loss = crit(logits, yb)
                val_loss += loss.item() * xb.size(0); nval += xb.size(0)
            val_loss /= max(1, nval)

        if val_loss + 1e-7 < best_val:
            best_val = val_loss; bad = 0
            best_state = copy.deepcopy(model.state_dict())
        else:
            bad += 1
            if bad >= patience: break

    model.load_state_dict(best_state)
    return model

@torch.no_grad()
def predict_numpy(model, X_np, device):
    model.eval()
    xb = torch.from_numpy(X_np).float().unsqueeze(1).to(device)
    logits = model(xb)
    return logits.argmax(dim=1).cpu().numpy()

def subject_cv_finetune_predict_progressive(model_global, Xs, ys, device,
                                            n_splits=CALIB_CV_FOLDS, n_classes=4):
    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=RANDOM_STATE)
    y_true_full = np.empty_like(ys); y_pred_full = np.empty_like(ys)

    for tr_idx, te_idx in skf.split(Xs, ys):
        Xcal, ycal = Xs[tr_idx], ys[tr_idx]
        Xho,  yho  = Xs[te_idx], ys[te_idx]

        m_out = copy.deepcopy(model_global)
        _train_one_mode(m_out, Xcal, ycal, n_classes, mode='out',
                        epochs=FT_EPOCHS, head_lr=FT_HEAD_LR, l2sp_lambda=FT_L2SP,
                        patience=FT_PATIENCE, val_ratio=FT_VAL_RATIO)
        yhat_out = predict_numpy(m_out, Xho, device); acc_out = (yhat_out == yho).mean()

        m_head = copy.deepcopy(model_global)
        _train_one_mode(m_head, Xcal, ycal, n_classes, mode='head',
                        epochs=FT_EPOCHS, head_lr=FT_HEAD_LR, l2sp_lambda=FT_L2SP,
                        patience=FT_PATIENCE, val_ratio=FT_VAL_RATIO)
        yhat_head = predict_numpy(m_head, Xho, device); acc_head = (yhat_head == yho).mean()

        m_sp = copy.deepcopy(model_global)
        _train_one_mode(m_sp, Xcal, ycal, n_classes, mode='spatial+head',
                        epochs=FT_EPOCHS, head_lr=FT_HEAD_LR, base_lr=FT_BASE_LR,
                        l2sp_lambda=FT_L2SP, patience=FT_PATIENCE, val_ratio=FT_VAL_RATIO)
        yhat_sp = predict_numpy(m_sp, Xho, device); acc_sp = (yhat_sp == yho).mean()

        best_idx = np.argmax([acc_out, acc_head, acc_sp])
        yhat_best = [yhat_out, yhat_head, yhat_sp][best_idx]

        y_true_full[te_idx] = yho; y_pred_full[te_idx] = yhat_best

    return y_true_full, y_pred_full

# =========================
# FOLDS JSON helpers
# =========================
def save_group_folds_json_with_indices(subject_ids_str, groups_array, n_splits, out_json_path,
                                       created_by="cnntrans_experiment", description=None):
    out_json_path = Path(out_json_path)
    unique_subjects_int = sorted(np.unique(groups_array).tolist())
    subject_ids = [f"S{sid:03d}" for sid in unique_subjects_int]

    if len(subject_ids) < n_splits:
        raise ValueError(f"n_splits={n_splits} mayor que número de sujetos={len(subject_ids)}")

    groups = np.arange(len(subject_ids))
    gkf = GroupKFold(n_splits=n_splits)

    folds = []
    fold_i = 0
    for train_idx_grp, test_idx_grp in gkf.split(groups, groups, groups):
        fold_i += 1
        train_sids = [subject_ids[int(i)] for i in train_idx_grp]
        test_sids  = [subject_ids[int(i)] for i in test_idx_grp]

        train_sids_int = [int(s[1:]) for s in train_sids]
        test_sids_int  = [int(s[1:]) for s in test_sids]

        tr_idx = np.where(np.isin(groups_array, train_sids_int))[0].tolist()
        te_idx = np.where(np.isin(groups_array, test_sids_int))[0].tolist()

        folds.append({
            "fold": int(fold_i),
            "train": train_sids,
            "test": test_sids,
            "tr_idx": tr_idx,
            "te_idx": te_idx
        })

    payload = {
        "created_at": datetime.now().isoformat(),
        "created_by": created_by,
        "description": description if description is not None else "",
        "n_splits": int(n_splits),
        "n_subjects": len(subject_ids),
        "subject_ids": subject_ids,
        "folds": folds
    }

    out_json_path.parent.mkdir(parents=True, exist_ok=True)
    with open(out_json_path, "w", encoding="utf-8") as f:
        json.dump(payload, f, indent=2, ensure_ascii=False)

    print(f"Folds JSON con índices guardado → {out_json_path}")
    return out_json_path

def load_group_folds_json(path_json, expected_subject_ids=None, strict_check=True):
    path_json = Path(path_json)
    if not path_json.exists():
        raise FileNotFoundError(f"No existe {path_json}")
    with open(path_json, "r", encoding="utf-8") as f:
        payload = json.load(f)

    subj_json = payload.get("subject_ids", [])
    if expected_subject_ids is not None:
        expected = sorted(list(expected_subject_ids))
        if subj_json != expected:
            msg = ("Los subject_ids del JSON no coinciden con expected_subject_ids.\n"
                   f"JSON has {len(subj_json)} subjects, expected {len(expected)}.\n"
                   f"First 10 JSON: {subj_json[:10]}\nFirst 10 expected: {expected[:10]}")
            if strict_check:
                raise ValueError(msg)
            else:
                print("WARNING: " + msg)
    return payload

# =========================
# EXPERIMENTO
# =========================
def run_experiment(save_folds_json=True, folds_json_path=FOLDS_DIR, folds_json_description="GroupKFold folds for CNN+Transformer"):
    mne.set_log_level('WARNING')

    subs = subjects_available()
    print(f"Sujetos elegibles: {len(subs)} → {subs[:10]}{'...' if len(subs)>10 else ''}")

    X, y, groups, chs = build_dataset_all(subs, scenario=CLASS_SCENARIO, window_mode=WINDOW_MODE)
    N, T, C = X.shape
    n_classes = len(np.unique(y))
    print(f"Listo para entrenar: N={N} | T={T} | C={C} | clases={n_classes} | sujetos={len(np.unique(groups))}")

    ds = EEGTrials(X, y, groups)

    # preparar JSON folds
    if folds_json_path is None:
        folds_json_path = Path("folds") / f"group_folds_{N_FOLDS}splits.json"
    else:
        folds_json_path = Path(folds_json_path)
    folds_json_path.parent.mkdir(parents=True, exist_ok=True)

    unique_subs = sorted(np.unique(groups).tolist())
    subject_ids_str = [f"S{s:03d}" for s in unique_subs]

    if not folds_json_path.exists():
        if not save_folds_json:
            raise FileNotFoundError(f"Folds JSON no encontrado en {folds_json_path} y save_folds_json=False.")
        save_group_folds_json_with_indices(subject_ids_str, groups, n_splits=N_FOLDS,
                                           out_json_path=folds_json_path,
                                           created_by="CNNTransMI",
                                           description=folds_json_description)

    payload = load_group_folds_json(folds_json_path, expected_subject_ids=subject_ids_str, strict_check=False)
    folds = payload["folds"]

    # bucle por folds
    global_folds = []
    ft_prog_folds = []
    all_true = []
    all_pred = []

    for f in folds:
        fold = f["fold"]
        tr_idx = np.asarray(f.get("tr_idx", []), dtype=int)
        te_idx = np.asarray(f.get("te_idx", []), dtype=int)

        if tr_idx.size == 0 or te_idx.size == 0:
            print(f"Advertencia: fold {fold} sin índices tr/te válidos. Saltando.")
            continue

        # ===== Split de validación por sujetos =====
        gss = GroupShuffleSplit(n_splits=1, test_size=GLOBAL_VAL_SPLIT, random_state=RANDOM_STATE)
        tr_subj_idx, va_subj_idx = next(gss.split(tr_idx, groups[tr_idx], groups[tr_idx]))
        tr_sub_idx = tr_idx[tr_subj_idx]
        va_idx     = tr_idx[va_subj_idx]

        # Sampler balanceado para train
        sampler = build_weighted_sampler(y[tr_sub_idx], groups[tr_sub_idx])

        tr_loader = DataLoader(Subset(ds, tr_sub_idx), batch_size=BATCH_SIZE, sampler=sampler, drop_last=False)
        va_loader = DataLoader(Subset(ds, va_idx),     batch_size=BATCH_SIZE, shuffle=False, drop_last=False)
        te_loader = DataLoader(Subset(ds, te_idx),     batch_size=BATCH_SIZE, shuffle=False, drop_last=False)

        # ===== CNN + Transformer =====
        model = CNNTransMI(
            n_ch=C, n_classes=n_classes,
            d_model=48, n_heads=2, n_blocks=3,
            patch_len=8, dropout=0.3, sd_prob=0.2, dw_kernel=15
        ).to(DEVICE)

        # Opt y scheduler SGDR
        opt = optim.Adam(model.parameters(), lr=LR_INIT)
        scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(opt, T_0=SGDR_T0, T_mult=SGDR_Tmult)

        # Métrica rápida
        def _acc(loader):
            return evaluate_with_preds(model, loader, use_tta=True, tta_n=5)[2]

        # Historia para curvas
        history = {'train_acc': [], 'val_acc': []}

        # ---- Criterio suave con ponderación por clase (+20% BFISTS) ----
        class_weights = make_class_weight_tensor(y[tr_sub_idx], n_classes, boost_bfists=1.20)
        criterion_soft = WeightedSoftCrossEntropy(class_weights, label_smoothing=0.05)

        # ===== Entrenamiento global =====
        print(f"\n[Fold {fold}/{N_FOLDS}] Entrenando modelo global..."
              f" (n_train={len(tr_sub_idx)} | n_val={len(va_idx)} | n_test={len(te_idx)})")

        best_state = copy.deepcopy(model.state_dict())
        best_val = -1.0
        bad = 0

        for epoch in range(1, EPOCHS_GLOBAL + 1):
            train_epoch(model, tr_loader, opt, criterion_soft, n_classes=n_classes,
                        do_aug=True, fs=FS, maxnorm=2.0)
            scheduler.step(epoch-1 + 1e-8)  # tick suave

            # eval
            tr_acc = _acc(tr_loader)
            va_acc = _acc(va_loader)
            history['train_acc'].append(tr_acc)
            history['val_acc'].append(va_acc)

            if (epoch % LOG_EVERY == 0) or epoch in (1, 10, 20, 50, 100):
                cur_lr = opt.param_groups[0]['lr']
                print(f"  Época {epoch:3d} | train_acc={tr_acc:.4f} | val_acc={va_acc:.4f} | LR={cur_lr:.5f}")

            if va_acc > best_val + 1e-4:
                best_val = va_acc
                best_state = copy.deepcopy(model.state_dict())
                bad = 0
            else:
                bad += 1
                if bad >= GLOBAL_PATIENCE:
                    print(f"  Early stopping en época {epoch} (mejor val_acc={best_val:.4f})")
                    break

        # guardar curva de entrenamiento
        curve_path = f"training_curve_fold{fold}.png"
        plot_training_curves(history, curve_path)
        print(f"↳ Curva de entrenamiento guardada: {curve_path}")

        # cargar mejor estado antes de evaluar en test
        model.load_state_dict(best_state)

        # ===== Evaluación global (inter-sujeto puro) =====
        y_true, y_pred, acc_global = evaluate_with_preds(model, te_loader, use_tta=True, tta_n=5)
        global_folds.append(acc_global)
        all_true.append(y_true); all_pred.append(y_pred)

        print(f"[Fold {fold}/{N_FOLDS}] Global acc={acc_global:.4f}")
        print_report(y_true, y_pred, CLASS_NAMES_4C)

        # ---------- Fine-tuning PROGRESIVO por sujeto con 4-fold CV ----------
        X_te, y_te, g_te = X[te_idx], y[te_idx], groups[te_idx]

        y_true_ft_all, y_pred_ft_all = [], []
        used_subjects = 0
        for sid in np.unique(g_te):
            idx = np.where(g_te == sid)[0]
            Xs, ys = X_te[idx], y_te[idx]

            if len(ys) < CALIB_CV_FOLDS or len(np.unique(ys)) < 2:
                continue

            y_true_subj, y_pred_subj = subject_cv_finetune_predict_progressive(
                model, Xs, ys, DEVICE, n_splits=CALIB_CV_FOLDS, n_classes=n_classes
            )
            y_true_ft_all.append(y_true_subj)
            y_pred_ft_all.append(y_pred_subj)
            used_subjects += 1

        if len(y_true_ft_all) > 0:
            y_true_ft_all = np.concatenate(y_true_ft_all)
            y_pred_ft_all = np.concatenate(y_pred_ft_all)
            acc_ft = (y_true_ft_all == y_pred_ft_all).mean()
            print(f"  Fine-tuning PROGRESIVO (por sujeto, {CALIB_CV_FOLDS}-fold CV) acc={acc_ft:.4f} | sujetos={used_subjects}")
            print(f"  Δ(FT-Global) = {acc_ft - acc_global:+.4f}")
        else:
            acc_ft = np.nan
            print("  Fine-tuning PROGRESIVO no ejecutado (sujeto(s) con muestras insuficientes).")

        ft_prog_folds.append(acc_ft)

    # ---------- resultados finales ----------
    if len(all_true) > 0:
        all_true = np.concatenate(all_true)
        all_pred = np.concatenate(all_pred)
    else:
        all_true = np.array([], dtype=int)
        all_pred = np.array([], dtype=int)

    print("\n" + "="*60)
    print("RESULTADOS FINALES")
    print("="*60)
    print("Global folds:", [f"{a:.4f}" for a in global_folds])
    if len(global_folds) > 0:
        print(f"Global mean: {np.mean(global_folds):.4f}")

    print("Fine-tune PROGRESIVO folds:", [("nan" if (a is None or np.isnan(a)) else f"{a:.4f}") for a in ft_prog_folds])
    if len(ft_prog_folds) > 0:
        print(f"Fine-tune PROGRESIVO mean: {np.nanmean(ft_prog_folds):.4f}")
        print(f"Δ(FT-Global) mean: {np.nanmean(ft_prog_folds) - np.mean(global_folds):+.4f}")

    if all_true.size > 0:
        plot_confusion(all_true, all_pred, CLASS_NAMES_4C,
                       title="Confusion Matrix - Global Model (All Folds)",
                       fname="confusion_global_allfolds.png")
        print("\n↳ Matriz de confusión guardada: confusion_global_allfolds.png")

    return {
        "global_folds": global_folds,
        "ft_prog_folds": ft_prog_folds,
        "all_true": all_true,
        "all_pred": all_pred,
        "folds_json_path": str(folds_json_path)
    }

# ---------- MAIN ----------
if __name__ == "__main__":
    print("🧠 INICIANDO EXPERIMENTO CNN + Transformer (Conformer-lite) + FINE-TUNING PROGRESIVO")
    print(f"🔧 Configuración: {CLASS_SCENARIO}, {len(EXPECTED_8)} canales, {WINDOW_MODE}")
    print(f"⚙️  FT: epochs={FT_EPOCHS}, base_lr={FT_BASE_LR}, head_lr={FT_HEAD_LR}, L2SP={FT_L2SP}, patience={FT_PATIENCE}, CV={CALIB_CV_FOLDS}")
    run_experiment()


In [1]:
# -*- coding: utf-8 -*-
# CNN+Transformer v2 — mismo pipeline que tu EEGNet (6s), con mejoras:
# - Optimizer: AdamW (lr=3e-3, weight_decay=1e-3) + SGDR
# - Modelo: stem temporal ligero + depthwise espacial (estilo EEGNet) + PW→D
# - Patchify: patch_len=12 (menos tokens), n_blocks=2
# - Regularización: dropout=0.4, sd_prob=0.30, label_smoothing=0.10, ChannelDropout
# - CLS token + attention pooling
# - Mantiene: mixup, cutout OF, TTA, sampler por sujeto/clase, FT progresivo por sujeto

import os, re, math, random, json, itertools, copy
from pathlib import Path
from datetime import datetime

import numpy as np
import pandas as pd
import mne

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset, WeightedRandomSampler

from sklearn.model_selection import GroupKFold, StratifiedKFold, StratifiedShuffleSplit, GroupShuffleSplit
from sklearn.utils import check_random_state
from sklearn.metrics import confusion_matrix, classification_report

import matplotlib.pyplot as plt
from tqdm import tqdm

# =========================
# CONFIGURACIÓN GENERAL
# =========================
PROJ = Path('..').resolve().parent
DATA_RAW = PROJ / 'data' / 'raw'
CACHE_DIR = PROJ / 'data' / 'cache'
FOLDS_DIR = PROJ / 'models' / 'folds' / 'Kfold5.json'
CACHE_DIR.mkdir(parents=True, exist_ok=True)

# Dispositivo y semilla
RANDOM_STATE = 42
torch.manual_seed(RANDOM_STATE); np.random.seed(RANDOM_STATE); random.seed(RANDOM_STATE)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🚀 Usando dispositivo: {DEVICE}")

# Escenario y ventana
CLASS_SCENARIO = '4c'
WINDOW_MODE = '6s'   # 6 segundos (-1 a 5) como en tu EEGNet
FS = 160.0
N_FOLDS = 5

# Entrenamiento global
BATCH_SIZE = 64
EPOCHS_GLOBAL = 100
LR_INIT = 3e-3            # <<< cambio
SGDR_T0 = 6
SGDR_Tmult = 2
WEIGHT_DECAY = 1e-3       # <<< cambio

# Validación/ES global
GLOBAL_VAL_SPLIT = 0.15
GLOBAL_PATIENCE  = 10
LOG_EVERY        = 5

# Fine-tuning por sujeto (protocolo robusto)
CALIB_CV_FOLDS = 4
FT_EPOCHS = 30
FT_BASE_LR = 5e-5
FT_HEAD_LR = 1e-3
FT_L2SP = 1e-4
FT_PATIENCE = 5
FT_VAL_RATIO = 0.2

# Normalización por época canal-a-canal (z-score)
NORM_EPOCH_ZSCORE = True

# Sujetos excluidos
EXCLUDE_SUBJECTS = {38, 88, 89, 92, 100, 104}

# Runs
MI_RUNS_LR = [4, 8, 12]
MI_RUNS_OF = [6, 10, 14]
BASELINE_RUNS_EO = [1]

# Canales (8 con FCz)
EXPECTED_8 = ['C3','C4','Cz','CP3','CP4','FC3','FC4','FCz']

# =========================
# UTILIDADES DE CANALES
# =========================
def normalize_label(s: str) -> str:
    if s is None: return s
    s = s.strip()
    s = re.sub(r'[^A-Za-z0-9]', '', s)
    s = re.sub(r'([A-Za-z])0([0-9])', r'\1\2', s)
    s = re.sub(r'([A-Za-z])Z$', r'\1z', s)
    s = s.replace('fp', 'Fp').replace('FP', 'Fp')
    s = ''.join(ch.upper() if ch != 'z' else 'z' for ch in s)
    return s

def rename_channels_1010(raw: mne.io.BaseRaw):
    mapping = {}
    for ch in raw.ch_names:
        lab = normalize_label(ch)
        lab = lab[:-1] + 'z' if lab.endswith('Z') else lab
        lab = re.sub(r'([A-Z])Z$', r'\1z', lab)
        mapping[ch] = lab
    mne.rename_channels(raw.info, mapping)

def ensure_channels_order(raw: mne.io.BaseRaw, desired_channels=EXPECTED_8):
    missing = [ch for ch in desired_channels if ch not in raw.ch_names]
    if missing:
        print(f"Warning: faltan canales {missing} en archivo {getattr(raw,'filenames', [''])[0]}")
        return None
    raw.reorder_channels([ch for ch in raw.ch_names if ch in desired_channels] +
                         [ch for ch in raw.ch_names if ch not in desired_channels])
    raw.pick_channels(desired_channels, ordered=True)
    return raw

# =========================
# LECTURA DE EDF y EVENTOS
# =========================
_re_file = re.compile(r'[Ss](\d{3}).*?[Rr](\d{2})')

def parse_subject_run(path: Path):
    m = _re_file.search(str(path))
    if not m: return None, None
    return int(m.group(1)), int(m.group(2))

def run_kind(run_id:int):
    if run_id in MI_RUNS_LR: return 'LR'
    if run_id in MI_RUNS_OF: return 'OF'
    if run_id in BASELINE_RUNS_EO: return 'EO'
    return None

# --- Notch adaptativo: lee CSV si existe (subject, run, snr50_db, snr60_db)
_SNR_TABLE = None
def _load_snr_table():
    global _SNR_TABLE
    if _SNR_TABLE is not None:
        return _SNR_TABLE
    csv_path = PROJ / 'reports' / 'psd_mains' / 'psd_mains_summary.csv'
    if csv_path.exists():
        try:
            df = pd.read_csv(csv_path)
            _SNR_TABLE = df
        except Exception as e:
            print(f"[SNR] No se pudo leer {csv_path}: {e}")
            _SNR_TABLE = None
    return _SNR_TABLE

def _decide_notch(subject, run, th_db=10.0):
    df = _load_snr_table()
    if df is None:  # default
        return 60.0
    row = df[(df['subject']==subject) & (df['run']==run)]
    if row.empty:
        return 60.0
    snr50 = float(row['snr50_db'].iloc[0]); snr60 = float(row['snr60_db'].iloc[0])
    if snr60 >= th_db and snr60 >= snr50: return 60.0
    if snr50 >= th_db and snr50 >  snr60: return 50.0
    return None  # no notch si no sobresale

def read_raw_edf(path: Path):
    raw = mne.io.read_raw_edf(path, preload=True, verbose=False)
    raw.pick(mne.pick_types(raw.info, eeg=True))
    rename_channels_1010(raw)
    try:
        mont = mne.channels.make_standard_montage('standard_1020')
        raw.set_montage(mont, on_missing='ignore')
    except Exception:
        pass
    if abs(raw.info['sfreq'] - FS) > 1e-6:
        raw.resample(FS, npad="auto")
    raw = ensure_channels_order(raw, EXPECTED_8)
    if raw is None:
        return None

    sid, rid = parse_subject_run(path)
    notch = _decide_notch(sid, rid)
    if notch is not None:
        raw.notch_filter(freqs=[float(notch)], picks='eeg', method='spectrum_fit', phase='zero')

    return raw

def collect_events_T1T2(raw: mne.io.BaseRaw):
    if raw.annotations is None or len(raw.annotations) == 0:
        return []
    def _norm(s): return str(s).strip().upper().replace(' ', '')
    res = []
    for onset, desc in zip(raw.annotations.onset, raw.annotations.description):
        tag = _norm(desc)
        if tag in ('T1','T2'):
            res.append((float(onset), tag))
    res.sort()
    dedup = []
    last_t1 = last_t2 = -1e9
    for t, tag in res:
        if tag == 'T1':
            if (t - last_t1) >= 0.5: dedup.append((t, tag)); last_t1 = t
        else:
            if (t - last_t2) >= 0.5: dedup.append((t, tag)); last_t2 = t
    return dedup

# =========================
# CONSTRUCCIÓN DE DATASETS
# =========================
def subjects_available():
    subs = []
    for sdir in sorted(DATA_RAW.glob('S*')):
        if not sdir.is_dir(): continue
        try: sid = int(sdir.name[1:])
        except: continue
        if sid in EXCLUDE_SUBJECTS: continue
        any_mi = any((sdir / f"S{sid:03d}R{r:02d}.edf").exists() for r in (MI_RUNS_LR + MI_RUNS_OF))
        if any_mi: subs.append(sid)
    return subs

def extract_trials_from_run(edf_path: Path, scenario: str, window_mode: str):
    subj, run = parse_subject_run(edf_path)
    kind = run_kind(run)
    if kind not in ('LR','OF','EO'):
        return ([], [])

    raw = read_raw_edf(edf_path)
    if raw is None:
        return ([], [])

    data = raw.get_data()
    fs = raw.info['sfreq']
    assert abs(fs - FS) < 1e-6

    out = []

    if kind in ('LR','OF'):
        events = collect_events_T1T2(raw)
        if window_mode == '3s':
            rel_start, rel_end = 0.0, 3.0
        else:
            rel_start, rel_end = -1.0, 5.0   # 6s

        for onset_sec, tag in events:
            if kind == 'LR':
                if tag == 'T1': label = 'L'
                elif tag == 'T2': label = 'R'
                else: continue
            else:
                if tag == 'T1': label = 'BFISTS'
                elif tag == 'T2': label = 'BFEET'
                else: continue

            if scenario == '2c' and label not in ('L','R'): continue
            if scenario == '3c' and label not in ('L','R','BFISTS'): continue
            if scenario == '4c' and label not in ('L','R','BFISTS','BFEET'): continue

            s = int(round((raw.first_time + onset_sec + rel_start) * fs))
            e = int(round((raw.first_time + onset_sec + rel_end) * fs))
            if s < 0 or e > data.shape[1]:
                continue

            seg = data[:, s:e].T.astype(np.float32)
            if NORM_EPOCH_ZSCORE:
                seg = (seg - seg.mean(axis=0, keepdims=True)) / (seg.std(axis=0, keepdims=True) + 1e-6)

            if label == 'L':       y = 0
            elif label == 'R':     y = 1
            elif label == 'BFISTS':y = 2
            elif label == 'BFEET': y = 3
            else: continue

            out.append((seg, y, subj))

    elif kind == 'EO':
        return ([], raw.ch_names)

    return out, raw.ch_names

def build_dataset_all(subjects, scenario='4c', window_mode='6s'):
    X, y, groups = [], [], []
    ch_template = None

    for s in tqdm(subjects, desc="Construyendo dataset (RAW)"):
        sdir = DATA_RAW / f"S{s:03d}"
        if not sdir.exists(): continue

        trials_L, trials_R, trials_FISTS, trials_FEET = [], [], [], []

        for r in MI_RUNS_LR:
            p = sdir / f"S{s:03d}R{r:02d}.edf"
            if not p.exists(): continue
            outs, chs = extract_trials_from_run(p, scenario, window_mode)
            if ch_template is None and chs: ch_template = chs
            for seg, lab, _ in outs:
                if lab == 0: trials_L.append(seg)
                elif lab == 1: trials_R.append(seg)

        for r in MI_RUNS_OF:
            p = sdir / f"S{s:03d}R{r:02d}.edf"
            if not p.exists(): continue
            outs, chs = extract_trials_from_run(p, scenario, window_mode)
            if ch_template is None and chs: ch_template = chs
            for seg, lab, _ in outs:
                if lab == 2: trials_FISTS.append(seg)
                elif lab == 3: trials_FEET.append(seg)

        need_per_class = 21
        def pick(trials, n, rng):
            if len(trials) < n:
                idx = rng.choice(len(trials), size=n, replace=True)
                return [trials[i] for i in idx]
            rng.shuffle(trials)
            return trials[:n]

        rng = check_random_state(RANDOM_STATE + s)
        if len(trials_L)==0 or len(trials_R)==0 or len(trials_FISTS)==0 or len(trials_FEET)==0:
            continue

        Lp  = pick(trials_L,     need_per_class, rng)
        Rp  = pick(trials_R,     need_per_class, rng)
        FIp = pick(trials_FISTS, need_per_class, rng)
        FEp = pick(trials_FEET,  need_per_class, rng)

        pack = [(Lp, 0), (Rp, 1), (FIp, 2), (FEp, 3)]
        for segs, lab in pack:
            for seg in segs:
                X.append(seg); y.append(lab); groups.append(s)

    X = np.stack(X, axis=0)
    y = np.asarray(y, dtype=np.int64)
    groups = np.asarray(groups, dtype=np.int64)

    n, T, C = X.shape
    n_classes = len(np.unique(y))
    print(f"Dataset construido: N={n} | T={T} | C={C} | clases={n_classes} | sujetos únicos={len(np.unique(groups))}")
    return X, y, groups, ch_template

# =========================
# AUGMENTS (solo train)
# =========================
def do_time_jitter(x, max_ms=50, fs=160.0):
    B,_,T,C = x.shape
    max_shift = int(round(max_ms/1000.0 * fs))
    if max_shift <= 0: return x
    shifts = torch.randint(low=-max_shift, high=max_shift+1, size=(B,), device=x.device)
    out = torch.empty_like(x)
    for i,s in enumerate(shifts):
        if s==0: out[i] = x[i]; continue
        if s>0:
            out[i,:,s:,:] = x[i,:,:T-s,:]
            out[i,:,:s,:] = 0
        else:
            s = -s
            out[i,:,:T-s,:] = x[i,:,s:,:]
            out[i,:,T-s:,:] = 0
    return out

def do_gaussian_noise(x, sigma=0.01):
    if sigma<=0: return x
    return x + sigma*torch.randn_like(x)

def do_temporal_cutout_masked(x, y, classes_mask={2,3}, min_ms=30, max_ms=60, fs=160.0):
    B,_,T,C = x.shape
    Lmin = int(round(min_ms/1000.0*fs))
    Lmax = int(round(max_ms/1000.0*fs))
    if Lmin<=0 or Lmax<=0 or Lmin>Lmax: return x
    out = x.clone()
    y_np = y.detach().cpu().numpy()
    for i in range(B):
        if int(y_np[i]) not in classes_mask: 
            continue
        L = random.randint(Lmin, Lmax)
        if L>=T: continue
        s = random.randint(0, T-L)
        out[i,:,s:s+L,:] = 0
    return out

def mixup_batch(x, y, n_classes, alpha=0.2):
    if alpha<=0:
        y_onehot = torch.nn.functional.one_hot(y, num_classes=n_classes).float()
        return x, y_onehot, 1.0
    lam = np.random.beta(alpha, alpha)
    perm = torch.randperm(x.size(0), device=x.device)
    x_mix = lam*x + (1-lam)*x[perm]
    y_a = torch.nn.functional.one_hot(y, num_classes=n_classes).float()
    y_b = y_a[perm]
    y_mix = lam*y_a + (1-lam)*y_b
    return x_mix, y_mix, lam

class WeightedSoftCrossEntropy(nn.Module):
    def __init__(self, class_weights=None, label_smoothing=0.0):
        super().__init__()
        self.register_buffer('w', None if class_weights is None else class_weights.clone().float())
        self.ls = float(label_smoothing)
    def forward(self, logits, target_probs):
        if self.ls > 0:
            K = logits.size(1)
            target_probs = (1-self.ls)*target_probs + self.ls*(1.0/K)
        logp = torch.log_softmax(logits, dim=1)
        loss_per_class = -(target_probs * logp)
        if self.w is not None:
            loss_per_class = loss_per_class * self.w.unsqueeze(0)
        loss = loss_per_class.sum(dim=1).mean()
        return loss

# =========================
# DATASET TORCH
# =========================
class EEGTrials(Dataset):
    def __init__(self, X, y, groups):
        self.X = X.astype(np.float32)
        self.y = y.astype(np.int64)
        self.g = groups.astype(np.int64)
    def __len__(self): return self.X.shape[0]
    def __getitem__(self, idx):
        x = self.X[idx]
        x = np.expand_dims(x, 0)     # (1, T, C)
        return torch.from_numpy(x), torch.tensor(self.y[idx]), torch.tensor(self.g[idx])

CLASS_NAMES_4C = ['Left', 'Right', 'Both Fists', 'Both Feet']

# =========================
# UTIL: MAX-NORM (opcional tras step)
# =========================
@torch.no_grad()
def apply_max_norm(model, max_value=2.0, p=2.0):
    layers = []
    if hasattr(model, 'head') and isinstance(model.head, nn.Sequential):
        for mod in model.head:
            if hasattr(mod, 'weight'): layers.append(mod)
    for layer in layers:
        if hasattr(layer, 'weight') and layer.weight is not None:
            w = layer.weight.data
            norms = w.view(w.size(0), -1).norm(p=p, dim=1, keepdim=True)
            desired = torch.clamp(norms, max=max_value)
            w.view(w.size(0), -1).mul_(desired / (1e-8 + norms))

# =========================
# MODELO: CNN + Transformer v2
# =========================
class ChannelDropout(nn.Module):
    def __init__(self, p=0.1): super().__init__(); self.p = p
    def forward(self, x):  # (B,1,T,C)
        if not self.training or self.p <= 0: return x
        B,_,T,C = x.shape
        mask = (torch.rand(B,1,1,C, device=x.device) > self.p).float()
        return x * mask

class SinPosEnc(nn.Module):
    def __init__(self, d_model, max_len=4000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(pos * div); pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer('pe', pe)
    def forward(self, x):  # (B,L,D)
        L = x.size(1); return x + self.pe[:L].unsqueeze(0)

class ConvModule1D(nn.Module):
    def __init__(self, d_model, dw_kernel=15, dropout=0.4):
        super().__init__()
        self.ln  = nn.LayerNorm(d_model)
        self.pw1 = nn.Conv1d(d_model, 2*d_model, kernel_size=1)
        self.dw  = nn.Conv1d(d_model, d_model, kernel_size=dw_kernel, padding=dw_kernel//2, groups=d_model)
        self.bn  = nn.BatchNorm1d(d_model)
        self.act = nn.SiLU()
        self.pw2 = nn.Conv1d(d_model, d_model, kernel_size=1)
        self.drop= nn.Dropout(dropout)
    def forward(self, x):  # (B,L,D)
        y = self.ln(x).transpose(1,2)         # (B,D,L)
        y = self.pw1(y); a,b = y.chunk(2, dim=1); y = a * torch.sigmoid(b)
        y = self.dw(y); y = self.bn(y); y = self.act(y); y = self.pw2(y)
        return self.drop(y.transpose(1,2))

class ConformerLiteBlock(nn.Module):
    def __init__(self, d_model=48, n_heads=2, dropout=0.4, dw_kernel=15, ffn_mult=2.0, sd_prob=0.30):
        super().__init__()
        self.sd_prob = sd_prob
        self.ln_attn = nn.LayerNorm(d_model)
        self.mha     = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
        self.drop1   = nn.Dropout(dropout)
        self.conv    = ConvModule1D(d_model, dw_kernel=dw_kernel, dropout=dropout)
        self.drop2   = nn.Dropout(dropout)
        self.ln_ffn  = nn.LayerNorm(d_model)
        self.ffn     = nn.Sequential(
            nn.Linear(d_model, int(ffn_mult*d_model)),
            nn.GELU(), nn.Dropout(dropout),
            nn.Linear(int(ffn_mult*d_model), d_model),
        )
        self.drop3   = nn.Dropout(dropout)

    def _res(self, x, y, training):
        if not training or self.sd_prob <= 0.0: return x + y
        keep = (torch.rand(x.size(0), device=x.device) > self.sd_prob).float().view(-1,1,1)
        return x + y * keep

    def forward(self, x):            # (B,L,D)
        q = self.ln_attn(x)
        y, _ = self.mha(q, q, q, need_weights=False)
        x = self._res(x, self.drop1(y), self.training)
        y = self.conv(x)
        x = self._res(x, self.drop2(y), self.training)
        y = self.ffn(self.ln_ffn(x))
        x = self._res(x, self.drop3(y), self.training)
        return x

class CNNTransMIv2(nn.Module):
    def __init__(self, n_ch=8, n_classes=4,
                 d_model=48, n_heads=2, n_blocks=2,
                 patch_len=12, dropout=0.4, sd_prob=0.30, dw_kernel=15):
        super().__init__()
        self.n_ch = n_ch
        self.d_model = d_model
        self.patch_len = patch_len
        self.chdrop = ChannelDropout(p=0.1)

        # Stem temporal ligero (conv solo en T)
        self.temporal_stem = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=(15,1), padding=(7,0), bias=False),
            nn.BatchNorm2d(8), nn.GELU(),
            nn.Conv2d(8, 8, kernel_size=(15,1), padding=(7,0), bias=False),
            nn.BatchNorm2d(8), nn.GELU()
        )

        # Spatial estilo EEGNet: depthwise (1×C) por mapa
        self.spatial_dw = nn.Conv2d(8, 8, kernel_size=(1, n_ch), groups=8, bias=False)
        self.spatial_bn = nn.BatchNorm2d(8)
        self.spatial_act= nn.GELU()

        # Proyección a D con 1x1 sobre canales
        self.proj_pw = nn.Conv1d(8, d_model, kernel_size=1, bias=False)
        self.proj_bn = nn.BatchNorm1d(d_model)
        self.proj_act= nn.GELU()

        # Patchify temporal
        self.patchify = nn.Conv1d(d_model, d_model, kernel_size=patch_len, stride=patch_len, bias=False)

        # CLS token
        self.cls = nn.Parameter(torch.zeros(1, 1, d_model))
        nn.init.trunc_normal_(self.cls, std=0.02)

        self.pos = SinPosEnc(d_model, max_len=4000)
        self.blocks = nn.ModuleList([
            ConformerLiteBlock(d_model, n_heads=n_heads, dropout=dropout,
                               dw_kernel=dw_kernel, ffn_mult=2.0, sd_prob=sd_prob*(i+1)/max(1,n_blocks))
            for i in range(n_blocks)
        ])
        self.ln_out = nn.LayerNorm(d_model)

        # Attention pooling
        self.attn_pool = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, d_model),
            nn.Tanh(),
            nn.Linear(d_model, 1)
        )

        self.head = nn.Sequential(
            nn.Linear(d_model, 128), nn.GELU(), nn.Dropout(dropout),
            nn.Linear(128, n_classes)
        )

    def forward(self, x):  # (B,1,T,C)
        B,_,T,C = x.shape
        x = self.chdrop(x)

        z = self.temporal_stem(x)               # (B,8,T,C)
        z = self.spatial_dw(z)                  # (B,8,T,1)
        z = self.spatial_bn(z); z = self.spatial_act(z)
        z = z.squeeze(3).transpose(1,2)         # (B,T,8)
        z = z.transpose(1,2)                    # (B,8,T)
        z = self.proj_pw(z)                     # (B,D,T)
        z = self.proj_bn(z); z = self.proj_act(z)

        z = self.patchify(z)                    # (B,D,L)
        z = z.transpose(1,2)                    # (B,L,D)

        cls = self.cls.expand(B, 1, -1)         # (B,1,D)
        z = torch.cat([cls, z], dim=1)          # (B,1+L,D)
        z = self.pos(z)

        for blk in self.blocks:
            z = blk(z)

        z = self.ln_out(z)                      # (B,1+L,D)

        att = self.attn_pool(z)                 # (B,1+L,1)
        att = torch.softmax(att, dim=1)
        z = (z * att).sum(dim=1)                # (B,D)

        return self.head(z)

# =========================
# ENTRENAMIENTO / EVALUACIÓN
# =========================
def build_weighted_sampler(y, groups):
    y = np.asarray(y); groups = np.asarray(groups)
    class_counts = np.bincount(y, minlength=len(np.unique(y))).astype(float)
    class_w = 1.0 / class_counts[y]
    subj_vals, subj_counts = np.unique(groups, return_counts=True)
    subj_map = {s:c for s,c in zip(subj_vals, subj_counts)}
    subj_w = np.array([1.0/subj_map[g] for g in groups], dtype=float)
    w = class_w * subj_w
    w = w / w.mean()
    w_t = torch.from_numpy(w).float()
    sampler = WeightedRandomSampler(weights=w_t, num_samples=len(w_t), replacement=True)
    return sampler

def make_class_weight_tensor(y_indices, n_classes, boost_bfists=1.20):
    counts = np.bincount(y_indices, minlength=n_classes).astype(np.float32)
    counts[counts == 0] = 1.0
    w = counts.sum() / counts
    w = w / w.mean()
    w[2] *= boost_bfists  # Both Fists
    w = w / w.mean()
    return torch.tensor(w, dtype=torch.float32, device=DEVICE)

def train_epoch(model, loader, opt, criterion_soft, n_classes,
                do_aug=True, fs=160.0, maxnorm=None):
    model.train()
    for xb, yb, _ in loader:
        xb, yb = xb.to(DEVICE), yb.to(DEVICE)

        if do_aug:
            xb = do_time_jitter(xb, max_ms=50, fs=fs)
            xb = do_gaussian_noise(xb, sigma=0.01)
            xb = do_temporal_cutout_masked(xb, yb, classes_mask={2,3}, min_ms=30, max_ms=60, fs=fs)
            xb, yt, _ = mixup_batch(xb, yb, n_classes=n_classes, alpha=0.2)
            logits = model(xb)
            loss = criterion_soft(logits, yt)
        else:
            logits = model(xb)
            yt = torch.nn.functional.one_hot(yb, num_classes=n_classes).float()
            loss = criterion_soft(logits, yt)

        opt.zero_grad(set_to_none=True)
        loss.backward()
        opt.step()
        if maxnorm is not None:
            apply_max_norm(model, max_value=maxnorm, p=2.0)

@torch.no_grad()
def _predict_tta(model, xb, n=5, fs=160.0):
    outs = []
    for _ in range(n):
        xj = do_time_jitter(xb, max_ms=25, fs=fs)
        outs.append(model(xj))
    return torch.stack(outs, dim=0).mean(dim=0)

@torch.no_grad()
def evaluate_with_preds(model, loader, use_tta=True, tta_n=5):
    model.eval()
    y_true, y_pred = [], []
    for xb, yb, _ in loader:
        xb = xb.to(DEVICE)
        logits = _predict_tta(model, xb, n=tta_n, fs=FS) if use_tta else model(xb)
        pred = logits.argmax(dim=1).cpu().numpy().tolist()
        y_pred.extend(pred); y_true.extend(yb.numpy().tolist())
    y_true = np.asarray(y_true, dtype=int); y_pred = np.asarray(y_pred, dtype=int)
    acc = (y_true == y_pred).mean()
    return y_true, y_pred, float(acc)

def plot_confusion(y_true, y_pred, classes, title, fname):
    cm = confusion_matrix(y_true, y_pred, labels=list(range(len(classes))))
    with np.errstate(invalid='ignore'):
        cm_norm = cm.astype('float') / cm.sum(axis=1, keepdims=True)
    cm_norm = np.nan_to_num(cm_norm)
    plt.figure(figsize=(6,5))
    plt.imshow(cm_norm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title(title); plt.colorbar(fraction=0.046, pad=0.04)
    ticks = np.arange(len(classes))
    plt.xticks(ticks, classes, rotation=45, ha='right'); plt.yticks(ticks, classes)
    fmt = '.2f'; thresh = cm_norm.max()/2.
    for i,j in itertools.product(range(cm_norm.shape[0]), range(cm_norm.shape[1])):
        plt.text(j, i, format(cm_norm[i,j], fmt), ha="center",
                 color="white" if cm_norm[i,j] > thresh else "black")
    plt.ylabel('True'); plt.xlabel('Pred')
    plt.tight_layout(); plt.savefig(fname, dpi=150, bbox_inches='tight'); plt.close()

def plot_training_curves(history, fname):
    plt.figure(figsize=(6,4))
    plt.plot(history['train_acc'], label='train_acc')
    plt.plot(history['val_acc'], label='val_acc')
    plt.xlabel('Epoch'); plt.ylabel('Accuracy')
    plt.title('Training curve'); plt.legend()
    plt.tight_layout(); plt.savefig(fname, dpi=150); plt.close()

def print_report(y_true, y_pred, classes):
    print(classification_report(y_true, y_pred, target_names=classes, digits=4))

# =========================
# FINE-TUNING PROGRESIVO por sujeto
# =========================
def _param_groups(model, mode):
    if mode == 'out':
        train = list(model.head.parameters())
    elif mode == 'head':
        train = list(model.head.parameters())
    elif mode == 'spatial+head':
        train = (list(model.temporal_stem.parameters()) +
                 list(model.spatial_dw.parameters()) +
                 list(model.spatial_bn.parameters()) +
                 list(model.proj_pw.parameters()) +
                 list(model.proj_bn.parameters()) +
                 list(model.head.parameters()))
    else:
        raise ValueError(mode)
    return train

def _freeze_for_mode(model, mode):
    for p in model.parameters(): p.requires_grad = False
    if mode == 'out' or mode == 'head':
        for p in model.head.parameters(): p.requires_grad = True
    elif mode == 'spatial+head':
        for mod in [model.temporal_stem, model.spatial_dw, model.spatial_bn,
                    model.proj_pw, model.proj_bn, model.head]:
            for p in mod.parameters(): p.requires_grad = True

def _class_weights(y_np, n_classes):
    counts = np.bincount(y_np, minlength=n_classes).astype(np.float32)
    counts[counts == 0] = 1.0
    weights = counts.sum() / counts
    weights = weights / weights.mean()
    return torch.tensor(weights, dtype=torch.float32, device=DEVICE)

def _train_one_mode(model, X_cal, y_cal, n_classes, mode,
                    epochs=FT_EPOCHS, batch_size=32,
                    head_lr=FT_HEAD_LR, base_lr=FT_BASE_LR,
                    l2sp_lambda=FT_L2SP, patience=FT_PATIENCE, val_ratio=FT_VAL_RATIO):
    sss = StratifiedShuffleSplit(n_splits=1, test_size=val_ratio, random_state=RANDOM_STATE)
    (tr_idx, va_idx), = sss.split(X_cal, y_cal)
    Xtr, ytr = X_cal[tr_idx], y_cal[tr_idx]
    Xva, yva = X_cal[va_idx], y_cal[va_idx]

    ds_tr = torch.utils.data.TensorDataset(
        torch.from_numpy(Xtr).float().unsqueeze(1),
        torch.from_numpy(ytr).long()
    )
    ds_va = torch.utils.data.TensorDataset(
        torch.from_numpy(Xva).float().unsqueeze(1),
        torch.from_numpy(yva).long()
    )
    dl_tr = DataLoader(ds_tr, batch_size=batch_size, shuffle=True, drop_last=False)
    dl_va = DataLoader(ds_va, batch_size=batch_size, shuffle=False, drop_last=False)

    _freeze_for_mode(model, mode)

    if mode == 'spatial+head':
        base_params = (list(model.temporal_stem.parameters()) +
                       list(model.spatial_dw.parameters()) +
                       list(model.spatial_bn.parameters()) +
                       list(model.proj_pw.parameters()) +
                       list(model.proj_bn.parameters()))
        head_params = list(model.head.parameters())
        train_params = base_params + head_params
        opt = optim.AdamW([
            {"params": base_params, "lr": base_lr, "weight_decay": WEIGHT_DECAY},
            {"params": head_params, "lr": head_lr, "weight_decay": WEIGHT_DECAY},
        ])
    else:
        train_params = _param_groups(model, mode)
        opt = optim.AdamW(train_params, lr=head_lr, weight_decay=WEIGHT_DECAY)

    ref = [p.detach().clone().to(p.device) for p in train_params]
    class_w = _class_weights(ytr, n_classes)
    crit = nn.CrossEntropyLoss(weight=class_w)

    best_state = copy.deepcopy(model.state_dict())
    best_val = float('inf'); bad = 0

    for _ in range(epochs):
        model.train()
        for xb, yb in dl_tr:
            xb, yb = xb.to(DEVICE), yb.to(DEVICE)
            opt.zero_grad(set_to_none=True)
            logits = model(xb)
            loss = crit(logits, yb)
            reg = 0.0
            for p_cur, p_ref in zip(train_params, ref):
                reg = reg + torch.sum((p_cur - p_ref)**2)
            loss = loss + l2sp_lambda * reg
            loss.backward(); opt.step()
            apply_max_norm(model, max_value=2.0, p=2.0)

        model.eval()
        with torch.no_grad():
            val_loss = 0.0; nval = 0
            for xb, yb in dl_va:
                xb, yb = xb.to(DEVICE), yb.to(DEVICE)
                logits = model(xb)
                loss = crit(logits, yb)
                val_loss += loss.item() * xb.size(0); nval += xb.size(0)
            val_loss /= max(1, nval)

        if val_loss + 1e-7 < best_val:
            best_val = val_loss; bad = 0
            best_state = copy.deepcopy(model.state_dict())
        else:
            bad += 1
            if bad >= patience: break

    model.load_state_dict(best_state)
    return model

@torch.no_grad()
def predict_numpy(model, X_np, device):
    model.eval()
    xb = torch.from_numpy(X_np).float().unsqueeze(1).to(device)
    logits = model(xb)
    return logits.argmax(dim=1).cpu().numpy()

def subject_cv_finetune_predict_progressive(model_global, Xs, ys, device,
                                            n_splits=CALIB_CV_FOLDS, n_classes=4):
    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=RANDOM_STATE)
    y_true_full = np.empty_like(ys); y_pred_full = np.empty_like(ys)

    for tr_idx, te_idx in skf.split(Xs, ys):
        Xcal, ycal = Xs[tr_idx], ys[tr_idx]
        Xho,  yho  = Xs[te_idx], ys[te_idx]

        m_out = copy.deepcopy(model_global)
        _train_one_mode(m_out, Xcal, ycal, n_classes, mode='out',
                        epochs=FT_EPOCHS, head_lr=FT_HEAD_LR, l2sp_lambda=FT_L2SP,
                        patience=FT_PATIENCE, val_ratio=FT_VAL_RATIO)
        yhat_out = predict_numpy(m_out, Xho, device)

        m_head = copy.deepcopy(model_global)
        _train_one_mode(m_head, Xcal, ycal, n_classes, mode='head',
                        epochs=FT_EPOCHS, head_lr=FT_HEAD_LR, l2sp_lambda=FT_L2SP,
                        patience=FT_PATIENCE, val_ratio=FT_VAL_RATIO)
        yhat_head = predict_numpy(m_head, Xho, device)

        m_sp = copy.deepcopy(model_global)
        _train_one_mode(m_sp, Xcal, ycal, n_classes, mode='spatial+head',
                        epochs=FT_EPOCHS, head_lr=FT_HEAD_LR, base_lr=FT_BASE_LR,
                        l2sp_lambda=FT_L2SP, patience=FT_PATIENCE, val_ratio=FT_VAL_RATIO)
        yhat_sp = predict_numpy(m_sp, Xho, device)

        scores = [ (yhat_out==yho).mean(), (yhat_head==yho).mean(), (yhat_sp==yho).mean() ]
        best_idx = int(np.argmax(scores))
        yhat_best = [yhat_out, yhat_head, yhat_sp][best_idx]

        y_true_full[te_idx] = yho; y_pred_full[te_idx] = yhat_best

    return y_true_full, y_pred_full

# =========================
# FOLDS JSON helpers
# =========================
def save_group_folds_json_with_indices(subject_ids_str, groups_array, n_splits, out_json_path,
                                       created_by="CNNTransMIv2", description=None):
    out_json_path = Path(out_json_path)
    unique_subjects_int = sorted(np.unique(groups_array).tolist())
    subject_ids = [f"S{sid:03d}" for sid in unique_subjects_int]

    if len(subject_ids) < n_splits:
        raise ValueError(f"n_splits={n_splits} mayor que número de sujetos={len(subject_ids)}")

    groups = np.arange(len(subject_ids))
    gkf = GroupKFold(n_splits=n_splits)

    folds = []
    fold_i = 0
    for train_idx_grp, test_idx_grp in gkf.split(groups, groups, groups):
        fold_i += 1
        train_sids = [subject_ids[int(i)] for i in train_idx_grp]
        test_sids  = [subject_ids[int(i)] for i in test_idx_grp]

        train_sids_int = [int(s[1:]) for s in train_sids]
        test_sids_int  = [int(s[1:]) for s in test_sids]

        tr_idx = np.where(np.isin(groups_array, train_sids_int))[0].tolist()
        te_idx = np.where(np.isin(groups_array, test_sids_int))[0].tolist()

        folds.append({
            "fold": int(fold_i),
            "train": train_sids,
            "test": test_sids,
            "tr_idx": tr_idx,
            "te_idx": te_idx
        })

    payload = {
        "created_at": datetime.now().isoformat(),
        "created_by": created_by,
        "description": description if description is not None else "",
        "n_splits": int(n_splits),
        "n_subjects": len(subject_ids),
        "subject_ids": subject_ids,
        "folds": folds
    }

    out_json_path.parent.mkdir(parents=True, exist_ok=True)
    with open(out_json_path, "w", encoding="utf-8") as f:
        json.dump(payload, f, indent=2, ensure_ascii=False)

    print(f"Folds JSON con índices guardado → {out_json_path}")
    return out_json_path

def load_group_folds_json(path_json, expected_subject_ids=None, strict_check=True):
    path_json = Path(path_json)
    if not path_json.exists():
        raise FileNotFoundError(f"No existe {path_json}")
    with open(path_json, "r", encoding="utf-8") as f:
        payload = json.load(f)

    subj_json = payload.get("subject_ids", [])
    if expected_subject_ids is not None:
        expected = sorted(list(expected_subject_ids))
        if subj_json != expected:
            msg = ("Los subject_ids del JSON no coinciden con expected_subject_ids.\n"
                   f"JSON has {len(subj_json)} subjects, expected {len(expected)}.\n"
                   f"First 10 JSON: {subj_json[:10]}\nFirst 10 expected: {expected[:10]}")
            if strict_check:
                raise ValueError(msg)
            else:
                print("WARNING: " + msg)
    return payload

# =========================
# EXPERIMENTO
# =========================
def run_experiment(save_folds_json=True, folds_json_path=FOLDS_DIR, folds_json_description="GroupKFold folds for CNNTransMIv2"):
    mne.set_log_level('WARNING')

    subs = subjects_available()
    print(f"Sujetos elegibles: {len(subs)} → {subs[:10]}{'...' if len(subs)>10 else ''}")

    X, y, groups, chs = build_dataset_all(subs, scenario=CLASS_SCENARIO, window_mode=WINDOW_MODE)
    N, T, C = X.shape
    n_classes = len(np.unique(y))
    print(f"Listo para entrenar: N={N} | T={T} | C={C} | clases={n_classes} | sujetos={len(np.unique(groups))}")

    ds = EEGTrials(X, y, groups)

    # preparar JSON folds
    if folds_json_path is None:
        folds_json_path = Path("folds") / f"group_folds_{N_FOLDS}splits.json"
    else:
        folds_json_path = Path(folds_json_path)
    folds_json_path.parent.mkdir(parents=True, exist_ok=True)

    unique_subs = sorted(np.unique(groups).tolist())
    subject_ids_str = [f"S{s:03d}" for s in unique_subs]

    if not folds_json_path.exists():
        if not save_folds_json:
            raise FileNotFoundError(f"Folds JSON no encontrado en {folds_json_path} y save_folds_json=False.")
        save_group_folds_json_with_indices(subject_ids_str, groups, n_splits=N_FOLDS,
                                           out_json_path=folds_json_path,
                                           created_by="CNNTransMIv2",
                                           description=folds_json_description)

    payload = load_group_folds_json(folds_json_path, expected_subject_ids=subject_ids_str, strict_check=False)
    folds = payload["folds"]

    # bucle por folds
    global_folds = []
    ft_prog_folds = []
    all_true = []
    all_pred = []

    for f in folds:
        fold = f["fold"]
        tr_idx = np.asarray(f.get("tr_idx", []), dtype=int)
        te_idx = np.asarray(f.get("te_idx", []), dtype=int)

        if tr_idx.size == 0 or te_idx.size == 0:
            print(f"Advertencia: fold {fold} sin índices tr/te válidos. Saltando.")
            continue

        # ===== Split de validación por sujetos =====
        gss = GroupShuffleSplit(n_splits=1, test_size=GLOBAL_VAL_SPLIT, random_state=RANDOM_STATE)
        tr_subj_idx, va_subj_idx = next(gss.split(tr_idx, groups[tr_idx], groups[tr_idx]))
        tr_sub_idx = tr_idx[tr_subj_idx]
        va_idx     = tr_idx[va_subj_idx]

        # Sampler balanceado para train
        sampler = build_weighted_sampler(y[tr_sub_idx], groups[tr_sub_idx])

        tr_loader = DataLoader(Subset(ds, tr_sub_idx), batch_size=BATCH_SIZE, sampler=sampler, drop_last=False)
        va_loader = DataLoader(Subset(ds, va_idx),     batch_size=BATCH_SIZE, shuffle=False, drop_last=False)
        te_loader = DataLoader(Subset(ds, te_idx),     batch_size=BATCH_SIZE, shuffle=False, drop_last=False)

        # ===== CNN + Transformer v2 =====
        model = CNNTransMIv2(
            n_ch=C, n_classes=n_classes,
            d_model=48, n_heads=2, n_blocks=2,
            patch_len=12, dropout=0.4, sd_prob=0.30, dw_kernel=15
        ).to(DEVICE)

        # Opt y scheduler SGDR (AdamW + WD)
        opt = optim.AdamW(model.parameters(), lr=LR_INIT, weight_decay=WEIGHT_DECAY)
        scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(opt, T_0=SGDR_T0, T_mult=SGDR_Tmult)

        # Métrica rápida
        def _acc(loader):
            return evaluate_with_preds(model, loader, use_tta=True, tta_n=5)[2]

        # Historia para curvas
        history = {'train_acc': [], 'val_acc': []}

        # Criterio suave con ponderación por clase (+20% BFISTS) y LS=0.10
        class_weights = make_class_weight_tensor(y[tr_sub_idx], n_classes, boost_bfists=1.20)
        criterion_soft = WeightedSoftCrossEntropy(class_weights, label_smoothing=0.10)

        # ===== Entrenamiento global =====
        print(f"\n[Fold {fold}/{N_FOLDS}] Entrenando modelo global..."
              f" (n_train={len(tr_sub_idx)} | n_val={len(va_idx)} | n_test={len(te_idx)})")

        best_state = copy.deepcopy(model.state_dict())
        best_val = -1.0
        bad = 0

        for epoch in range(1, EPOCHS_GLOBAL + 1):
            train_epoch(model, tr_loader, opt, criterion_soft, n_classes=n_classes,
                        do_aug=True, fs=FS, maxnorm=2.0)
            scheduler.step(epoch-1 + 1e-8)

            tr_acc = _acc(tr_loader)
            va_acc = _acc(va_loader)
            history['train_acc'].append(tr_acc)
            history['val_acc'].append(va_acc)

            if (epoch % LOG_EVERY == 0) or epoch in (1, 10, 20, 50, 100):
                cur_lr = opt.param_groups[0]['lr']
                print(f"  Época {epoch:3d} | train_acc={tr_acc:.4f} | val_acc={va_acc:.4f} | LR={cur_lr:.5f}")

            if va_acc > best_val + 1e-4:
                best_val = va_acc; best_state = copy.deepcopy(model.state_dict()); bad = 0
            else:
                bad += 1
                if bad >= GLOBAL_PATIENCE:
                    print(f"  Early stopping en época {epoch} (mejor val_acc={best_val:.4f})")
                    break

        curve_path = f"training_curve_fold{fold}.png"
        plot_training_curves(history, curve_path)
        print(f"↳ Curva de entrenamiento guardada: {curve_path}")

        model.load_state_dict(best_state)

        # ===== Evaluación global (inter-sujeto puro) =====
        y_true, y_pred, acc_global = evaluate_with_preds(model, te_loader, use_tta=True, tta_n=5)
        global_folds.append(acc_global)
        all_true.append(y_true); all_pred.append(y_pred)

        print(f"[Fold {fold}/{N_FOLDS}] Global acc={acc_global:.4f}")
        print_report(y_true, y_pred, CLASS_NAMES_4C)

        # ---------- Fine-tuning PROGRESIVO por sujeto con 4-fold CV ----------
        X_te, y_te, g_te = X[te_idx], y[te_idx], groups[te_idx]

        y_true_ft_all, y_pred_ft_all = [], []
        used_subjects = 0
        for sid in np.unique(g_te):
            idx = np.where(g_te == sid)[0]
            Xs, ys = X_te[idx], y_te[idx]

            if len(ys) < CALIB_CV_FOLDS or len(np.unique(ys)) < 2:
                continue

            y_true_subj, y_pred_subj = subject_cv_finetune_predict_progressive(
                model, Xs, ys, DEVICE, n_splits=CALIB_CV_FOLDS, n_classes=n_classes
            )
            y_true_ft_all.append(y_true_subj)
            y_pred_ft_all.append(y_pred_subj)
            used_subjects += 1

        if len(y_true_ft_all) > 0:
            y_true_ft_all = np.concatenate(y_true_ft_all)
            y_pred_ft_all = np.concatenate(y_pred_ft_all)
            acc_ft = (y_true_ft_all == y_pred_ft_all).mean()
            print(f"  Fine-tuning PROGRESIVO (por sujeto, {CALIB_CV_FOLDS}-fold CV) acc={acc_ft:.4f} | sujetos={used_subjects}")
            print(f"  Δ(FT-Global) = {acc_ft - acc_global:+.4f}")
        else:
            acc_ft = np.nan
            print("  Fine-tuning PROGRESIVO no ejecutado (sujeto(s) con muestras insuficientes).")

        ft_prog_folds.append(acc_ft)

    # ---------- resultados finales ----------
    if len(all_true) > 0:
        all_true = np.concatenate(all_true)
        all_pred = np.concatenate(all_pred)
    else:
        all_true = np.array([], dtype=int)
        all_pred = np.array([], dtype=int)

    print("\n" + "="*60)
    print("RESULTADOS FINALES")
    print("="*60)
    print("Global folds:", [f"{a:.4f}" for a in global_folds])
    if len(global_folds) > 0:
        print(f"Global mean: {np.mean(global_folds):.4f}")

    print("Fine-tune PROGRESIVO folds:", [("nan" if (a is None or np.isnan(a)) else f"{a:.4f}") for a in ft_prog_folds])
    if len(ft_prog_folds) > 0:
        print(f"Fine-tune PROGRESIVO mean: {np.nanmean(ft_prog_folds):.4f}")
        print(f"Δ(FT-Global) mean: {np.nanmean(ft_prog_folds) - np.mean(global_folds):+.4f}")

    if all_true.size > 0:
        plot_confusion(all_true, all_pred, CLASS_NAMES_4C,
                       title="Confusion Matrix - CNNTransMIv2 (Global, all folds)",
                       fname="confusion_cnntransmi_v2_global_allfolds.png")
        print("\n↳ Matriz de confusión guardada: confusion_cnntransmi_v2_global_allfolds.png")

    return {
        "global_folds": global_folds,
        "ft_prog_folds": ft_prog_folds,
        "all_true": all_true,
        "all_pred": all_pred,
        "folds_json_path": str(folds_json_path)
    }

# ---------- MAIN ----------
if __name__ == "__main__":
    print("🧠 INICIANDO EXPERIMENTO CNN+Transformer v2 (6s, 8 canales)")
    print(f"🔧 Configuración: {CLASS_SCENARIO}, {len(EXPECTED_8)} canales, {WINDOW_MODE}")
    print(f"⚙️  FT: epochs={FT_EPOCHS}, base_lr={FT_BASE_LR}, head_lr={FT_HEAD_LR}, L2SP={FT_L2SP}, patience={FT_PATIENCE}, CV={CALIB_CV_FOLDS}")
    run_experiment()


🚀 Usando dispositivo: cuda
🧠 INICIANDO EXPERIMENTO CNN+Transformer v2 (6s, 8 canales)
🔧 Configuración: 4c, 8 canales, 6s
⚙️  FT: epochs=30, base_lr=5e-05, head_lr=0.001, L2SP=0.0001, patience=5, CV=4
Sujetos elegibles: 103 → [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]...


Construyendo dataset (RAW): 100%|██████████| 103/103 [00:48<00:00,  2.14it/s]


Dataset construido: N=8652 | T=960 | C=8 | clases=4 | sujetos únicos=103
Listo para entrenar: N=8652 | T=960 | C=8 | clases=4 | sujetos=103

[Fold 1/5] Entrenando modelo global... (n_train=5796 | n_val=1092 | n_test=1764)
  Época   1 | train_acc=0.2692 | val_acc=0.2564 | LR=0.00300
  Época   5 | train_acc=0.2985 | val_acc=0.2830 | LR=0.00075
  Época  10 | train_acc=0.4665 | val_acc=0.4249 | LR=0.00256
  Época  15 | train_acc=0.5204 | val_acc=0.4368 | LR=0.00075
  Época  20 | train_acc=0.4971 | val_acc=0.4377 | LR=0.00299
  Época  25 | train_acc=0.5311 | val_acc=0.4670 | LR=0.00256
  Época  30 | train_acc=0.5711 | val_acc=0.4808 | LR=0.00170
  Época  35 | train_acc=0.5863 | val_acc=0.4670 | LR=0.00075
  Época  40 | train_acc=0.5908 | val_acc=0.4799 | LR=0.00011
  Early stopping en época 40 (mejor val_acc=0.4808)
↳ Curva de entrenamiento guardada: training_curve_fold1.png
[Fold 1/5] Global acc=0.4751
              precision    recall  f1-score   support

        Left     0.6065    0.5102

In [2]:
# -*- coding: utf-8 -*-
# EEGNet + Sliding Windows por evento + Eval agrupada por EVENTO (promedio de logits entre crops)
# - Excluye sujetos: {38,88,89,92,100,104}
# - Solo 8 canales: ['C3','C4','Cz','CP3','CP4','FC3','FC4','FCz']
# - Sliding policies configurables
# - SGDR + augments
# - FT por sujeto (opcional)

import os, re, math, random, json, itertools, copy
from pathlib import Path
from datetime import datetime

import numpy as np
import pandas as pd
import mne

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset, WeightedRandomSampler

from sklearn.model_selection import GroupKFold, GroupShuffleSplit, StratifiedShuffleSplit, StratifiedKFold
from sklearn.utils import check_random_state
from sklearn.metrics import confusion_matrix, classification_report

import matplotlib.pyplot as plt
from tqdm import tqdm

# =========================
# CONFIGURACIÓN GENERAL
# =========================
PROJ = Path('..').resolve().parent
DATA_RAW = PROJ / 'data' / 'raw'
CACHE_DIR = PROJ / 'data' / 'cache'
FOLDS_DIR = PROJ / 'models' / 'folds' / 'Kfold5.json'
CACHE_DIR.mkdir(parents=True, exist_ok=True)

RANDOM_STATE = 42
torch.manual_seed(RANDOM_STATE); np.random.seed(RANDOM_STATE); random.seed(RANDOM_STATE)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🚀 Usando dispositivo: {DEVICE}")

FS = 160.0
N_FOLDS = 5
N_CLASSES = 4
CLASS_NAMES_4C = ['Left', 'Right', 'Both Fists', 'Both Feet']

# ===== ENTRENAMIENTO GLOBAL =====
BATCH_SIZE    = 64
EPOCHS_GLOBAL = 100
LR_INIT       = 1e-2
SGDR_T0       = 6
SGDR_Tmult    = 2
GLOBAL_VAL_SPLIT = 0.15
GLOBAL_PATIENCE  = 10
LOG_EVERY        = 5

# ===== FINE-TUNING POR SUJETO (opcional) =====
DO_SUBJECT_FT = True
CALIB_CV_FOLDS = 4
FT_EPOCHS   = 25
FT_BASE_LR  = 5e-5
FT_HEAD_LR  = 1e-3
FT_L2SP     = 1e-4
FT_PATIENCE = 5
FT_VAL_RATIO= 0.2

# ===== Normalización por epoch (z-score por canal) =====
NORM_EPOCH_ZSCORE = True

# ===== Sujetos / Runs / Canales =====
EXCLUDE_SUBJECTS = {38, 88, 89, 92, 100, 104}
MI_RUNS_LR = [4, 8, 12]
MI_RUNS_OF = [6, 10, 14]
BASELINE_RUNS_EO = [1]
EXPECTED_8 = ['C3','C4','Cz','CP3','CP4','FC3','FC4','FCz']

# ===== Sliding windows =====
DEBUG_WINDOWS = True  # logs de debug (resumen)
# Elige set base: "3s_base" (0–3, 0.5–3.5, 1–4) o "6s_base" (-1–5, -0.5–5.5, 0–6)
WINDOW_SET = "3s_base"
# Política de train (agrega más ventanas para robustecer)
# - "standard": usa mismo set en train y eval
# - "train_precue": añade ventana con precue solo en train
# - "train_aggressive": stride 0.25s solo en train (eval se queda con 3)
WINDOW_POLICY = "standard"

# =========================
# UTILIDADES CANALES / EDF
# =========================
def normalize_label(s: str) -> str:
    if s is None: return s
    s = s.strip()
    s = re.sub(r'[^A-Za-z0-9]', '', s)
    s = re.sub(r'([A-Za-z])0([0-9])', r'\1\2', s)
    s = re.sub(r'([A-Za-z])Z$', r'\1z', s)
    s = s.replace('fp', 'Fp').replace('FP', 'Fp')
    s = ''.join(ch.upper() if ch != 'z' else 'z' for ch in s)
    return s

def rename_channels_1010(raw: mne.io.BaseRaw):
    mapping = {}
    for ch in raw.ch_names:
        lab = normalize_label(ch)
        lab = lab[:-1] + 'z' if lab.endswith('Z') else lab
        lab = re.sub(r'([A-Z])Z$', r'\1z', lab)
        mapping[ch] = lab
    mne.rename_channels(raw.info, mapping)

def ensure_channels_order(raw: mne.io.BaseRaw, desired_channels=EXPECTED_8):
    missing = [ch for ch in desired_channels if ch not in raw.ch_names]
    if missing:
        print(f"[WARN] faltan canales {missing} en archivo {getattr(raw,'filenames', [''])[0]}")
        return None
    raw.reorder_channels([ch for ch in raw.ch_names if ch in desired_channels] +
                         [ch for ch in raw.ch_names if ch not in desired_channels])
    raw.pick_channels(desired_channels, ordered=True)
    return raw

_re_file = re.compile(r'[Ss](\d{3}).*?[Rr](\d{2})')
def parse_subject_run(path: Path):
    m = _re_file.search(str(path))
    if not m: return None, None
    return int(m.group(1)), int(m.group(2))

def run_kind(run_id:int):
    if run_id in MI_RUNS_LR: return 'LR'
    if run_id in MI_RUNS_OF: return 'OF'
    if run_id in BASELINE_RUNS_EO: return 'EO'
    return None

# --- Notch adaptativo opcional (si hay CSV de SNR) ---
_SNR_TABLE = None
def _load_snr_table():
    global _SNR_TABLE
    if _SNR_TABLE is not None:
        return _SNR_TABLE
    csv_path = PROJ / 'reports' / 'psd_mains' / 'psd_mains_summary.csv'
    if csv_path.exists():
        try:
            _SNR_TABLE = pd.read_csv(csv_path)
        except:
            _SNR_TABLE = None
    return _SNR_TABLE

def _decide_notch(subject, run, th_db=10.0):
    df = _load_snr_table()
    if df is None:
        return 60.0
    row = df[(df['subject']==subject) & (df['run']==run)]
    if row.empty:
        return 60.0
    snr50 = float(row['snr50_db'].iloc[0]); snr60 = float(row['snr60_db'].iloc[0])
    if snr60 >= th_db and snr60 >= snr50: return 60.0
    if snr50 >= th_db and snr50 >  snr60: return 50.0
    return None

def read_raw_edf(path: Path):
    raw = mne.io.read_raw_edf(path, preload=True, verbose=False)
    raw.pick(mne.pick_types(raw.info, eeg=True))
    rename_channels_1010(raw)
    try:
        mont = mne.channels.make_standard_montage('standard_1020')
        raw.set_montage(mont, on_missing='ignore')
    except Exception:
        pass
    if abs(raw.info['sfreq'] - FS) > 1e-6:
        raw.resample(FS, npad="auto")
    raw = ensure_channels_order(raw, EXPECTED_8)
    if raw is None:
        return None
    sid, rid = parse_subject_run(path)
    notch = _decide_notch(sid, rid)
    if notch is not None:
        raw.notch_filter(freqs=[float(notch)], picks='eeg', method='spectrum_fit', phase='zero')
    return raw

def collect_events_T1T2(raw: mne.io.BaseRaw):
    """Extrae T1/T2 y dedup mínimo (>=0.5s)."""
    if raw.annotations is None or len(raw.annotations) == 0:
        return []
    def _norm(s): return str(s).strip().upper().replace(' ', '')
    res = []
    for onset, desc in zip(raw.annotations.onset, raw.annotations.description):
        tag = _norm(desc)
        if tag in ('T1','T2'):
            res.append((float(onset), tag))
    res.sort()
    dedup = []
    last_t1 = last_t2 = -1e9
    for t, tag in res:
        if tag == 'T1':
            if (t - last_t1) >= 0.5: dedup.append((t, tag)); last_t1 = t
        else:
            if (t - last_t2) >= 0.5: dedup.append((t, tag)); last_t2 = t
    return dedup

# =========================
# DISCOVERY DE SUJETOS
# =========================
def subjects_available():
    subs = []
    for sdir in sorted(DATA_RAW.glob('S*')):
        if not sdir.is_dir(): continue
        try: sid = int(sdir.name[1:])
        except: continue
        if sid in EXCLUDE_SUBJECTS: continue
        any_mi = any((sdir / f"S{sid:03d}R{r:02d}.edf").exists() for r in (MI_RUNS_LR + MI_RUNS_OF))
        if any_mi: subs.append(sid)
    return subs

# =========================
# BALANCEO por EVENTOS (ANTES de crops)
# =========================
def _label_from_tag(kind: str, tag: str):
    if kind == 'LR':
        return 0 if tag == 'T1' else 1 if tag == 'T2' else None
    elif kind == 'OF':
        return 2 if tag == 'T1' else 3 if tag == 'T2' else None
    return None

def gather_all_events(subjects):
    per_subj = {}
    for s in tqdm(subjects, desc="Escaneando eventos (sin crops)"):
        sdir = DATA_RAW / f"S{s:03d}"
        if not sdir.exists(): continue
        per_subj[s] = {0:[], 1:[], 2:[], 3:[]}
        for r in (MI_RUNS_LR + MI_RUNS_OF + BASELINE_RUNS_EO):
            p = sdir / f"S{s:03d}R{r:02d}.edf"
            if not p.exists(): continue
            kind = run_kind(r)
            if kind not in ('LR','OF'):
                continue
            raw = read_raw_edf(p)
            if raw is None: 
                continue
            events = collect_events_T1T2(raw)
            for onset_sec, tag in events:
                lab = _label_from_tag(kind, tag)
                if lab is None: 
                    continue
                per_subj[s][lab].append({
                    "path": p,
                    "onset": float(onset_sec),
                    "label": int(lab),
                    "subject": int(s),
                    "run": int(r)
                })
    return per_subj

def balance_events_per_subject(per_subj, need_per_class=21, rng_seed=RANDOM_STATE):
    rng = check_random_state(rng_seed)
    kept = {}
    for s, cls_dict in per_subj.items():
        kept[s] = {}
        skip_subject = False
        for c in range(4):
            events = cls_dict.get(c, [])
            n0 = len(events)
            if n0 == 0:
                skip_subject = True
                break
            if n0 == need_per_class:
                sel = events
            elif n0 > need_per_class:
                idx = rng.choice(n0, size=need_per_class, replace=False)
                sel = [events[i] for i in idx]
            else:
                idx = rng.choice(n0, size=need_per_class, replace=True)
                sel = [events[i] for i in idx]
            kept[s][c] = sel
        if skip_subject:
            kept[s] = None
            print(f"[SKIP] S{s:03d} (alguna clase sin eventos)")
    kept = {s:cls for s,cls in kept.items() if cls is not None}
    return kept

# =========================
# SLIDING WINDOWS (train & eval)
# =========================
def base_windows(set_name="3s_base"):
    if set_name == "3s_base":
        # 3 s cada una dentro de [0,4] (típico PhysioNet 2009)
        return [(0.0, 3.0), (0.5, 3.5), (1.0, 4.0)]
    elif set_name == "6s_base":
        # 6 s alrededor del evento (usa con cuidado por OOB según run)
        return [(-1.0, 5.0), (-0.5, 5.5), (0.0, 6.0)]
    else:
        raise ValueError(set_name)

def get_crops_for_policy(window_set="3s_base", policy="standard"):
    crops_eval = base_windows(window_set)
    if policy == "standard":
        crops_train = crops_eval
    elif policy == "train_precue":
        if window_set == "3s_base":
            # añade -0.5 a 2.5 solo en train
            crops_train = [(-0.5, 2.5)] + crops_eval
        else:
            crops_train = [(-1.5, 4.5)] + crops_eval
    elif policy == "train_aggressive":
        if window_set == "3s_base":
            starts = [0.0, 0.25, 0.5, 0.75, 1.0]  # 5 ventanas (3 s con stride 0.25 s)
            crops_train = [(s, s+3.0) for s in starts]
        else:
            # 6 s con stride 0.5 s (3 ventanas ya cubren, mantenemos igual para no inflar mucho)
            crops_train = crops_eval
    else:
        raise ValueError(policy)
    return crops_train, crops_eval

CROPS_TRAIN, CROPS_EVAL = get_crops_for_policy(WINDOW_SET, WINDOW_POLICY)
print(f"🎯 Sliding windows → TRAIN={CROPS_TRAIN} | EVAL={CROPS_EVAL}")

# =========================
# APLICAR CROPS a eventos balanceados (con event_id)
# =========================
def segments_from_balanced_events_with_ids(kept_balanced, crops, fs=FS):
    X, y, groups = [], [], []
    event_ids = []
    ch_template = None
    cache_raw = {}
    # debug
    dbg_total_events = 0
    dbg_total_crops_ok = 0
    dbg_skips_oob = 0
    dbg_skips_short = 0
    dbg_skips_nan = 0

    next_event_id = 0
    try:
        for s, cls_dict in tqdm(kept_balanced.items(), desc="Generando segmentos (con crops + event_ids)"):
            for c in range(4):
                for ev in cls_dict[c]:
                    dbg_total_events += 1
                    p = ev["path"]; onset = ev["onset"]; sid = ev["subject"]; rid = ev["run"]
                    if p not in cache_raw:
                        raw = read_raw_edf(p)
                        if raw is None: 
                            continue
                        cache_raw[p] = raw
                    raw = cache_raw[p]
                    data = raw.get_data()
                    T_total = data.shape[1]
                    if ch_template is None:
                        ch_template = raw.ch_names

                    eid = next_event_id
                    next_event_id += 1

                    for rel_start, rel_end in crops:
                        s_samp = int(round((raw.first_time + onset + rel_start) * fs))
                        e_samp = int(round((raw.first_time + onset + rel_end)   * fs))
                        exp_len = int(round((rel_end - rel_start) * fs))
                        if s_samp < 0 or e_samp > T_total:
                            dbg_skips_oob += 1
                            continue
                        seg = data[:, s_samp:e_samp].T.astype(np.float32)
                        if seg.shape[0] != exp_len or seg.shape[0] <= 0:
                            dbg_skips_short += 1
                            continue
                        if not np.isfinite(seg).all():
                            dbg_skips_nan += 1
                            continue
                        # normalización por epoch (opcional)
                        if NORM_EPOCH_ZSCORE:
                            seg = (seg - seg.mean(axis=0, keepdims=True)) / (seg.std(axis=0, keepdims=True) + 1e-6)

                        X.append(seg); y.append(c); groups.append(s); event_ids.append(eid)
                        dbg_total_crops_ok += 1
    finally:
        cache_raw.clear()

    if DEBUG_WINDOWS:
        print("=== DEBUG WINDOWS SUMMARY ===")
        print(f"Eventos brutos: {dbg_total_events}")
        print(f"Crops OK:      {dbg_total_crops_ok}")
        print(f"Skips OOB:     {dbg_skips_oob}")
        print(f"Skips LEN:     {dbg_skips_short}")
        print(f"Skips NaN:     {dbg_skips_nan}")

    X = np.stack(X, axis=0) if len(X)>0 else np.zeros((0,1,1), np.float32)
    y = np.asarray(y, dtype=np.int64)
    groups = np.asarray(groups, dtype=np.int64)
    event_ids = np.asarray(event_ids, dtype=np.int64)

    n, T, C = X.shape
    print(f"Dataset construido: N={n} | T={T} | C={C} | sujetos únicos={len(np.unique(groups))} | eventos únicos={len(np.unique(event_ids))}")
    return X, y, groups, event_ids, ch_template

# =========================
# DATASET TORCH
# =========================
class EEGTrials(Dataset):
    def __init__(self, X, y, groups):
        self.X = X.astype(np.float32)
        self.y = y.astype(np.int64)
        self.g = groups.astype(np.int64)
    def __len__(self): return self.X.shape[0]
    def __getitem__(self, idx):
        x = self.X[idx]  # (T,C)
        x = np.expand_dims(x, 0)  # (1,T,C)
        return torch.from_numpy(x), torch.tensor(self.y[idx]), torch.tensor(self.g[idx])

# =========================
# AUGMENTS (solo train)
# =========================
def do_time_jitter(x, max_ms=50, fs=160.0):
    B,_,T,C = x.shape
    max_shift = int(round(max_ms/1000.0 * fs))
    if max_shift <= 0: return x
    shifts = torch.randint(low=-max_shift, high=max_shift+1, size=(B,), device=x.device)
    out = torch.empty_like(x)
    for i,s in enumerate(shifts):
        if s==0: out[i] = x[i]; continue
        if s>0:
            out[i,:,s:,:] = x[i,:,:T-s,:]
            out[i,:,:s,:] = 0
        else:
            s = -s
            out[i,:,:T-s,:] = x[i,:,s:,:]
            out[i,:,T-s:,:] = 0
    return out

def do_gaussian_noise(x, sigma=0.01):
    if sigma<=0: return x
    return x + sigma*torch.randn_like(x)

def do_temporal_cutout_masked(x, y, classes_mask={2,3}, min_ms=30, max_ms=60, fs=160.0):
    B,_,T,C = x.shape
    Lmin = int(round(min_ms/1000.0*fs))
    Lmax = int(round(max_ms/1000.0*fs))
    if Lmin<=0 or Lmax<=0 or Lmin>Lmax: return x
    out = x.clone()
    y_np = y.detach().cpu().numpy()
    for i in range(B):
        if int(y_np[i]) not in classes_mask:
            continue
        L = random.randint(Lmin, Lmax)
        if L>=T: continue
        s = random.randint(0, T-L)
        out[i,:,s:s+L,:] = 0
    return out

def mixup_batch(x, y, n_classes, alpha=0.2):
    if alpha<=0:
        y_onehot = torch.nn.functional.one_hot(y, num_classes=n_classes).float()
        return x, y_onehot, 1.0
    lam = np.random.beta(alpha, alpha)
    perm = torch.randperm(x.size(0), device=x.device)
    x_mix = lam*x + (1-lam)*x[perm]
    y_a = torch.nn.functional.one_hot(y, num_classes=n_classes).float()
    y_b = y_a[perm]
    y_mix = lam*y_a + (1-lam)*y_b
    return x_mix, y_mix, lam

class WeightedSoftCrossEntropy(nn.Module):
    def __init__(self, class_weights=None, label_smoothing=0.0):
        super().__init__()
        self.register_buffer('w', None if class_weights is None else class_weights.clone().float())
        self.ls = float(label_smoothing)
    def forward(self, logits, target_probs):
        if self.ls > 0:
            K = logits.size(1)
            target_probs = (1-self.ls)*target_probs + self.ls*(1.0/K)
        logp = torch.log_softmax(logits, dim=1)
        loss_per_class = -(target_probs * logp)
        if self.w is not None:
            loss_per_class = loss_per_class * self.w.unsqueeze(0)
        loss = loss_per_class.sum(dim=1).mean()
        return loss

# =========================
# EEGNet (adaptado a (B,1,T,C))
# =========================
class ChannelDropout(nn.Module):
    def __init__(self, p=0.1):
        super().__init__()
        self.p = p
    def forward(self, x):
        if not self.training or self.p<=0: return x
        B,_,T,C = x.shape
        mask = (torch.rand(B,1,1,C, device=x.device) > self.p).float()
        return x * mask

@torch.no_grad()
def apply_max_norm(model, max_value=2.0, p=2.0):
    layers = []
    if hasattr(model, 'conv_depthwise'): layers.append(model.conv_depthwise)
    if hasattr(model, 'conv_sep_point'): layers.append(model.conv_sep_point)
    if hasattr(model, 'fc'):             layers.append(model.fc)
    if hasattr(model, 'out'):            layers.append(model.out)
    for layer in layers:
        if hasattr(layer, 'weight') and layer.weight is not None:
            w = layer.weight.data
            norms = w.view(w.size(0), -1).norm(p=p, dim=1, keepdim=True)
            desired = torch.clamp(norms, max=max_value)
            w.view(w.size(0), -1).mul_(desired / (1e-8 + norms))

class EEGNet(nn.Module):
    def __init__(self, n_ch: int, n_classes: int,
                 F1: int = 24, D: int = 2, kernel_t: int = 64, k_sep: int = 16,
                 pool1_t: int = 4, pool2_t: int = 6,
                 drop1_p: float = 0.35, drop2_p: float = 0.6,
                 chdrop_p: float = 0.1):
        super().__init__()
        self.n_ch = n_ch
        self.n_classes = n_classes
        self.F1 = F1
        self.D = D
        self.F2 = F1 * D
        self.kernel_t = kernel_t
        self.k_sep = k_sep
        self.pool1_t = pool1_t
        self.pool2_t = pool2_t

        self.chdrop = ChannelDropout(p=chdrop_p)

        self.conv_temporal = nn.Conv2d(1, F1, kernel_size=(kernel_t, 1),
                                       padding=(kernel_t // 2, 0), bias=False)
        self.bn1 = nn.BatchNorm2d(F1, momentum=0.99, eps=1e-3)
        self.act = nn.ELU()

        self.conv_depthwise = nn.Conv2d(F1, self.F2, kernel_size=(1, n_ch),
                                        groups=F1, bias=False)
        self.bn2 = nn.BatchNorm2d(self.F2, momentum=0.99, eps=1e-3)
        self.pool1 = nn.AvgPool2d(kernel_size=(pool1_t, 1), stride=(pool1_t, 1))
        self.drop1 = nn.Dropout(drop1_p)

        self.conv_sep_depth = nn.Conv2d(self.F2, self.F2, kernel_size=(k_sep, 1),
                                        groups=self.F2, padding=(k_sep // 2, 0), bias=False)
        self.conv_sep_point = nn.Conv2d(self.F2, self.F2, kernel_size=(1, 1), bias=False)
        self.bn3 = nn.BatchNorm2d(self.F2, momentum=0.99, eps=1e-3)
        self.pool2 = nn.AvgPool2d(kernel_size=(pool2_t, 1), stride=(pool2_t, 1))
        self.drop2 = nn.Dropout(drop2_p)

        self.flatten = nn.Flatten()
        self.fc  = None
        self.out = None
        self._T_in = None

    def _build_head(self, T_in: int, device: torch.device):
        T1 = T_in // self.pool1_t
        T2 = T1 // self.pool2_t
        feat_dim = self.F2 * T2 * 1
        self.fc  = nn.Linear(feat_dim, 128, bias=True).to(device)
        self.out = nn.Linear(128, self.n_classes, bias=True).to(device)
        self._T_in = T_in

    def ensure_head(self, T_in: int, device: torch.device):
        if (self.fc is None) or (self.out is None) or (self._T_in != T_in):
            self._build_head(T_in, device)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, _, T, C = x.shape
        self.ensure_head(T, x.device)

        x = self.chdrop(x)

        z = self.conv_temporal(x)
        z = self.bn1(z); z = self.act(z)

        z = self.conv_depthwise(z)
        z = self.bn2(z); z = self.act(z)
        z = self.pool1(z)
        z = self.drop1(z)

        z = self.conv_sep_depth(z)
        z = self.conv_sep_point(z)
        z = self.bn3(z); z = self.act(z)
        z = self.pool2(z)
        z = self.drop2(z)

        z = self.flatten(z)
        z = self.fc(z); z = self.act(z)
        z = self.out(z)
        return z

# =========================
# SAMPLER / PESOS
# =========================
def build_weighted_sampler(y, groups):
    y = np.asarray(y); groups = np.asarray(groups)
    class_counts = np.bincount(y, minlength=len(np.unique(y))).astype(float)
    class_w = 1.0 / class_counts[y]
    subj_vals, subj_counts = np.unique(groups, return_counts=True)
    subj_map = {s:c for s,c in zip(subj_vals, subj_counts)}
    subj_w = np.array([1.0/subj_map[g] for g in groups], dtype=float)
    w = class_w * subj_w
    w = w / w.mean()
    w_t = torch.from_numpy(w).float()
    sampler = WeightedRandomSampler(weights=w_t, num_samples=len(w_t), replacement=True)
    return sampler

def make_class_weight_tensor(y_indices, n_classes, boost_bfists=1.20):
    counts = np.bincount(y_indices, minlength=n_classes).astype(np.float32)
    counts[counts == 0] = 1.0
    w = counts.sum() / counts
    w = w / w.mean()
    w[2] *= boost_bfists
    w = w / w.mean()
    return torch.tensor(w, dtype=torch.float32, device=DEVICE)

# =========================
# TRAIN / EVAL
# =========================
def train_epoch(model, loader, opt, criterion_soft, n_classes, fs=160.0, maxnorm=2.0):
    model.train()
    for xb, yb, _ in loader:
        xb, yb = xb.to(DEVICE), yb.to(DEVICE)
        xb = do_time_jitter(xb, max_ms=50, fs=fs)
        xb = do_gaussian_noise(xb, sigma=0.01)
        xb = do_temporal_cutout_masked(xb, yb, classes_mask={2,3}, min_ms=30, max_ms=60, fs=fs)
        xb, yt, _ = mixup_batch(xb, yb, n_classes=n_classes, alpha=0.2)
        logits = model(xb)
        loss = criterion_soft(logits, yt)
        opt.zero_grad(set_to_none=True)
        loss.backward()
        opt.step()
        apply_max_norm(model, max_value=maxnorm, p=2.0)

@torch.no_grad()
def evaluate_grouped_by_event(model, loader, event_ids_np, use_tta_jitter=True, tta_jitter_n=3):
    model.eval()
    all_logits, all_eids, y_all = [], [], []
    idx_offset = 0
    for xb, yb, _ in loader:
        B = xb.size(0)
        xb = xb.to(DEVICE)
        if use_tta_jitter:
            outs = []
            for _ in range(tta_jitter_n):
                outs.append(model(do_time_jitter(xb, max_ms=20, fs=FS)))
            logits = torch.stack(outs, dim=0).mean(dim=0)
        else:
            logits = model(xb)
        all_logits.append(logits.cpu().numpy())
        all_eids.append(event_ids_np[idx_offset: idx_offset + B].copy())
        y_all.append(yb.numpy())
        idx_offset += B
    all_logits = np.concatenate(all_logits, axis=0)
    all_eids   = np.concatenate(all_eids, axis=0)
    y_all      = np.concatenate(y_all, axis=0)

    uniq_eids = np.unique(all_eids)
    logits_ev, ytrue_ev = [], []
    for eid in uniq_eids:
        idx = (all_eids == eid)
        logits_ev.append(all_logits[idx].mean(axis=0))
        ytrue_ev.append(int(y_all[idx][0]))
    logits_ev = np.stack(logits_ev, axis=0)
    ytrue_ev  = np.asarray(ytrue_ev, dtype=np.int64)
    ypred_ev  = logits_ev.argmax(axis=1)
    acc = float((ypred_ev == ytrue_ev).mean())
    return ytrue_ev, ypred_ev, acc

def plot_confusion(y_true, y_pred, classes, title, fname):
    cm = confusion_matrix(y_true, y_pred, labels=list(range(len(classes))))
    with np.errstate(invalid='ignore'):
        cm_norm = cm.astype('float') / cm.sum(axis=1, keepdims=True)
    cm_norm = np.nan_to_num(cm_norm)
    plt.figure(figsize=(6,5))
    plt.imshow(cm_norm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title(title); plt.colorbar(fraction=0.046, pad=0.04)
    ticks = np.arange(len(classes))
    plt.xticks(ticks, classes, rotation=45, ha='right'); plt.yticks(ticks, classes)
    fmt = '.2f'; thresh = cm_norm.max()/2.
    for i,j in itertools.product(range(cm_norm.shape[0]), range(cm_norm.shape[1])):
        plt.text(j, i, format(cm_norm[i,j], fmt), ha="center",
                 color="white" if cm_norm[i,j] > thresh else "black")
    plt.ylabel('True'); plt.xlabel('Pred')
    plt.tight_layout(); plt.savefig(fname, dpi=150, bbox_inches='tight'); plt.close()

def print_report(y_true, y_pred, classes):
    print(classification_report(y_true, y_pred, target_names=classes, digits=4))

# =========================
# FT POR SUJETO (opcional, con agrupación por evento)
# =========================
def _freeze_for_ft(model):
    for p in model.parameters(): p.requires_grad = False
    for p in model.conv_depthwise.parameters(): p.requires_grad = True
    for p in model.bn2.parameters():           p.requires_grad = True
    for p in model.conv_sep_depth.parameters():p.requires_grad = True
    for p in model.conv_sep_point.parameters():p.requires_grad = True
    for p in model.bn3.parameters():           p.requires_grad = True
    for p in model.fc.parameters():            p.requires_grad = True
    for p in model.out.parameters():           p.requires_grad = True

def subject_cv_finetune_predict_grouped(model_global, Xs, ys, eids, n_splits=CALIB_CV_FOLDS, n_classes=4):
    """
    CV interno en el sujeto; evalúa agrupando por event_id del sujeto.
    Retorna y_true_event, y_pred_event.
    """
    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=RANDOM_STATE)
    y_true_all, y_pred_all = [], []
    for tr_idx, te_idx in skf.split(Xs, ys):
        Xcal, ycal, ecal = Xs[tr_idx], ys[tr_idx], eids[tr_idx]
        Xho,  yho,  eho  = Xs[te_idx], ys[te_idx], eids[te_idx]

        # split val
        sss = StratifiedShuffleSplit(n_splits=1, test_size=FT_VAL_RATIO, random_state=RANDOM_STATE)
        (tr2, va2), = sss.split(Xcal, ycal)
        Xtr, ytr = Xcal[tr2], ycal[tr2]
        Xva, yva = Xcal[va2], ycal[va2]

        # loaders
        tr_ds = torch.utils.data.TensorDataset(torch.from_numpy(Xtr).float().unsqueeze(1),
                                               torch.from_numpy(ytr).long())
        va_ds = torch.utils.data.TensorDataset(torch.from_numpy(Xva).float().unsqueeze(1),
                                               torch.from_numpy(yva).long())
        te_ds = torch.utils.data.TensorDataset(torch.from_numpy(Xho).float().unsqueeze(1),
                                               torch.from_numpy(yho).long())
        tr_dl = DataLoader(tr_ds, batch_size=32, shuffle=True)
        va_dl = DataLoader(va_ds, batch_size=64, shuffle=False)
        te_dl = DataLoader(te_ds, batch_size=64, shuffle=False)

        m = copy.deepcopy(model_global).to(DEVICE)
        _freeze_for_ft(m)

        base_params = (list(m.conv_depthwise.parameters()) +
                       list(m.bn2.parameters()) +
                       list(m.conv_sep_depth.parameters()) +
                       list(m.conv_sep_point.parameters()) +
                       list(m.bn3.parameters()))
        head_params = list(m.fc.parameters()) + list(m.out.parameters())

        opt = optim.Adam([
            {"params": base_params, "lr": FT_BASE_LR},
            {"params": head_params, "lr": FT_HEAD_LR},
        ])
        ref = [p.detach().clone() for p in base_params + head_params]
        class_w = make_class_weight_tensor(ytr, n_classes, boost_bfists=1.0)  # sin boost extra
        crit = nn.CrossEntropyLoss(weight=class_w)

        best_state = copy.deepcopy(m.state_dict())
        best_val = float('inf'); bad = 0

        for _ in range(FT_EPOCHS):
            m.train()
            for xb, yb in tr_dl:
                xb, yb = xb.to(DEVICE), yb.to(DEVICE)
                opt.zero_grad(set_to_none=True)
                logits = m(xb)
                loss = crit(logits, yb)
                # L2-SP
                tr_params = base_params + head_params
                reg = 0.0
                for p_cur, p_ref in zip(tr_params, ref):
                    reg = reg + torch.sum((p_cur - p_ref)**2)
                loss = loss + FT_L2SP * reg
                loss.backward(); opt.step()
                apply_max_norm(m, max_value=2.0, p=2.0)

            # val
            m.eval()
            val_loss = 0.0; n = 0
            with torch.no_grad():
                for xb, yb in va_dl:
                    xb, yb = xb.to(DEVICE), yb.to(DEVICE)
                    logits = m(xb)
                    loss = crit(logits, yb)
                    val_loss += loss.item()*xb.size(0); n += xb.size(0)
            val_loss /= max(1,n)
            if val_loss + 1e-7 < best_val:
                best_val = val_loss; bad = 0
                best_state = copy.deepcopy(m.state_dict())
            else:
                bad += 1
                if bad >= FT_PATIENCE: break

        m.load_state_dict(best_state)

        # pred en HOLD-OUT del sujeto, agrupado por event_id
        m.eval()
        logits_all = []
        with torch.no_grad():
            for xb, _ in te_dl:
                xb = xb.to(DEVICE)
                logits_all.append(m(xb).cpu().numpy())
        logits_all = np.concatenate(logits_all, axis=0)  # (Nh, C)
        # agrupa por eid
        uniq = np.unique(eho)
        for eid in uniq:
            idx = (eho == eid)
            log_ev = logits_all[idx].mean(axis=0)
            y_true_all.append(int(yho[idx][0]))
            y_pred_all.append(int(np.argmax(log_ev)))
    return np.asarray(y_true_all, dtype=np.int64), np.asarray(y_pred_all, dtype=np.int64)

# =========================
# FOLDS JSON helpers
# =========================
def save_group_folds_json_with_indices(subject_ids_str, groups_array, n_splits, out_json_path,
                                       created_by="EEGNet_Sliding", description=None):
    out_json_path = Path(out_json_path)
    unique_subjects_int = sorted(np.unique(groups_array).tolist())
    subject_ids = [f"S{sid:03d}" for sid in unique_subjects_int]
    if len(subject_ids) < n_splits:
        raise ValueError(f"n_splits={n_splits} mayor que número de sujetos={len(subject_ids)}")
    groups = np.arange(len(subject_ids))
    gkf = GroupKFold(n_splits=n_splits)
    folds = []
    fold_i = 0
    for train_idx_grp, test_idx_grp in gkf.split(groups, groups, groups):
        fold_i += 1
        train_sids = [subject_ids[int(i)] for i in train_idx_grp]
        test_sids  = [subject_ids[int(i)] for i in test_idx_grp]
        train_sids_int = [int(s[1:]) for s in train_sids]
        test_sids_int  = [int(s[1:]) for s in test_sids]
        tr_idx = np.where(np.isin(groups_array, train_sids_int))[0].tolist()
        te_idx = np.where(np.isin(groups_array, test_sids_int))[0].tolist()
        folds.append({
            "fold": int(fold_i),
            "train": train_sids,
            "test": test_sids,
            "tr_idx": tr_idx,
            "te_idx": te_idx
        })
    payload = {
        "created_at": datetime.now().isoformat(),
        "created_by": created_by,
        "description": description if description is not None else "",
        "n_splits": int(n_splits),
        "n_subjects": len(subject_ids),
        "subject_ids": subject_ids,
        "folds": folds
    }
    out_json_path.parent.mkdir(parents=True, exist_ok=True)
    with open(out_json_path, "w", encoding="utf-8") as f:
        json.dump(payload, f, indent=2, ensure_ascii=False)
    print(f"Folds JSON con índices guardado → {out_json_path}")
    return out_json_path

def load_group_folds_json(path_json, expected_subject_ids=None, strict_check=True):
    path_json = Path(path_json)
    if not path_json.exists():
        raise FileNotFoundError(f"No existe {path_json}")
    with open(path_json, "r", encoding="utf-8") as f:
        payload = json.load(f)
    subj_json = payload.get("subject_ids", [])
    if expected_subject_ids is not None:
        expected = sorted(list(expected_subject_ids))
        if subj_json != expected:
            msg = ("Los subject_ids del JSON no coinciden con expected_subject_ids.\n"
                   f"JSON has {len(subj_json)} subjects, expected {len(expected)}.\n"
                   f"First 10 JSON: {subj_json[:10]}\nFirst 10 expected: {expected[:10]}")
            if strict_check:
                raise ValueError(msg)
            else:
                print("WARNING: " + msg)
    return payload

# =========================
# EXPERIMENTO
# =========================
def run_experiment(save_folds_json=True, folds_json_path=FOLDS_DIR,
                   folds_json_description="GroupKFold folds (sliding windows + event eval)"):
    mne.set_log_level('WARNING')

    subs = subjects_available()
    print(f"Sujetos elegibles: {len(subs)} → {subs[:10]}{'...' if len(subs)>10 else ''}")

    # 1) Recolectar eventos (sin crops)
    per_subj = gather_all_events(subs)

    # 2) Balancear por sujeto/clase (21 eventos)
    kept_bal = balance_events_per_subject(per_subj, need_per_class=21, rng_seed=RANDOM_STATE)

    # 3) Generar segmentos con ventanas (TRAIN policy) + IDs de evento
    X, y, groups, eids, chs = segments_from_balanced_events_with_ids(kept_bal, crops=CROPS_TRAIN, fs=FS)
    N, T, C = X.shape
    print(f"Listo para entrenar: N={N} | T={T} | C={C} | clases={len(np.unique(y))} | sujetos={len(np.unique(groups))}")

    # Folds por sujetos
    folds_json_path = Path(folds_json_path)
    folds_json_path.parent.mkdir(parents=True, exist_ok=True)
    unique_subs = sorted(np.unique(groups).tolist())
    subject_ids_str = [f"S{s:03d}" for s in unique_subs]

    if not folds_json_path.exists():
        if not save_folds_json:
            raise FileNotFoundError(f"Folds JSON no encontrado en {folds_json_path} y save_folds_json=False.")
        save_group_folds_json_with_indices(subject_ids_str, groups, n_splits=N_FOLDS,
                                           out_json_path=folds_json_path,
                                           created_by="EEGNet_Sliding",
                                           description=folds_json_description)
    payload = load_group_folds_json(folds_json_path, expected_subject_ids=subject_ids_str, strict_check=False)
    folds = payload["folds"]

    global_folds = []
    ft_folds = []
    all_true_ev = []
    all_pred_ev = []

    for f in folds:
        fold = f["fold"]
        tr_idx = np.asarray(f.get("tr_idx", []), dtype=int)
        te_idx = np.asarray(f.get("te_idx", []), dtype=int)
        if tr_idx.size == 0 or te_idx.size == 0:
            print(f"[WARN] fold {fold} sin índices tr/te válidos. Saltando.")
            continue

        # split de validación por sujetos (sobre train)
        gss = GroupShuffleSplit(n_splits=1, test_size=GLOBAL_VAL_SPLIT, random_state=RANDOM_STATE)
        tr_subj_idx, va_subj_idx = next(gss.split(tr_idx, groups[tr_idx], groups[tr_idx]))
        tr_sub_idx = tr_idx[tr_subj_idx]
        va_idx     = tr_idx[va_subj_idx]

        # Dataset y loaders
        ds_tr = EEGTrials(X[tr_sub_idx], y[tr_sub_idx], groups[tr_sub_idx])
        ds_va = EEGTrials(X[va_idx],     y[va_idx],     groups[va_idx])
        ds_te = EEGTrials(X[te_idx],     y[te_idx],     groups[te_idx])

        sampler = build_weighted_sampler(y[tr_sub_idx], groups[tr_sub_idx])
        tr_loader = DataLoader(ds_tr, batch_size=BATCH_SIZE, sampler=sampler, drop_last=False)
        va_loader = DataLoader(ds_va, batch_size=BATCH_SIZE, shuffle=False, drop_last=False)
        te_loader = DataLoader(ds_te, batch_size=BATCH_SIZE, shuffle=False, drop_last=False)

        # Modelo
        model = EEGNet(n_ch=C, n_classes=N_CLASSES,
                       F1=24, D=2, kernel_t=64, k_sep=16,
                       pool1_t=4, pool2_t=6, drop1_p=0.35, drop2_p=0.6,
                       chdrop_p=0.1).to(DEVICE)

        # Opt + SGDR
        opt = optim.Adam(model.parameters(), lr=LR_INIT)
        scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(opt, T_0=SGDR_T0, T_mult=SGDR_Tmult)

        # Criterio suave con ponderación (+20% Both Fists)
        class_weights = make_class_weight_tensor(y[tr_sub_idx], N_CLASSES, boost_bfists=1.20)
        criterion_soft = WeightedSoftCrossEntropy(class_weights, label_smoothing=0.05)

        best_state = copy.deepcopy(model.state_dict())
        best_val_ev_acc = -1.0
        bad = 0

        print(f"\n[Fold {fold}/{N_FOLDS}] Entrenando (train={len(tr_sub_idx)} | val={len(va_idx)} | test={len(te_idx)})")
        for epoch in range(1, EPOCHS_GLOBAL+1):
            train_epoch(model, tr_loader, opt, criterion_soft, n_classes=N_CLASSES, fs=FS, maxnorm=2.0)
            scheduler.step(epoch-1 + 1e-8)

            # Eval por evento en VALIDACIÓN
            y_true_va, y_pred_va, va_acc_ev = evaluate_grouped_by_event(
                model, va_loader, event_ids_np=eids[va_idx], use_tta_jitter=True, tta_jitter_n=3
            )

            if (epoch % LOG_EVERY == 0) or epoch in (1, 10, 20, 50, 100):
                cur_lr = opt.param_groups[0]['lr']
                print(f"  Época {epoch:3d} | val_acc(ev)={va_acc_ev:.4f} | LR={cur_lr:.5f}")

            if va_acc_ev > best_val_ev_acc + 1e-4:
                best_val_ev_acc = va_acc_ev; bad = 0
                best_state = copy.deepcopy(model.state_dict())
            else:
                bad += 1
                if bad >= GLOBAL_PATIENCE:
                    print(f"  Early stopping @ {epoch} (mejor val_acc(ev)={best_val_ev_acc:.4f})")
                    break

        # Cargar mejor estado
        model.load_state_dict(best_state)

        # EVALUACIÓN GLOBAL por EVENTO en TEST
        y_true_te, y_pred_te, acc_ev = evaluate_grouped_by_event(
            model, te_loader, event_ids_np=eids[te_idx], use_tta_jitter=True, tta_jitter_n=5
        )
        global_folds.append(acc_ev)
        all_true_ev.append(y_true_te); all_pred_ev.append(y_pred_te)
        print(f"[Fold {fold}/{N_FOLDS}] Global acc (por evento) = {acc_ev:.4f}")
        print_report(y_true_te, y_pred_te, CLASS_NAMES_4C)

        # ---------- FT POR SUJETO (opcional) ----------
        acc_ft = np.nan
        if DO_SUBJECT_FT:
            X_te, y_te, g_te, e_te = X[te_idx], y[te_idx], groups[te_idx], eids[te_idx]
            y_true_ft_all, y_pred_ft_all = [], []
            used_subjects = 0
            for sid in np.unique(g_te):
                idx = np.where(g_te == sid)[0]
                Xs, ys, es = X_te[idx], y_te[idx], e_te[idx]
                if len(np.unique(ys)) < 2 or len(ys) < CALIB_CV_FOLDS:
                    continue
                yt, yp = subject_cv_finetune_predict_grouped(model, Xs, ys, es, n_splits=CALIB_CV_FOLDS, n_classes=N_CLASSES)
                y_true_ft_all.append(yt); y_pred_ft_all.append(yp); used_subjects += 1
            if len(y_true_ft_all) > 0:
                y_true_ft_all = np.concatenate(y_true_ft_all)
                y_pred_ft_all = np.concatenate(y_pred_ft_all)
                acc_ft = (y_true_ft_all == y_pred_ft_all).mean()
                print(f"  Fine-tuning (por sujeto, agrupado ev) acc={acc_ft:.4f} | Δ={acc_ft-acc_ev:+.4f} | sujetos={used_subjects}")
            else:
                print("  Fine-tuning no ejecutado (sujeto(s) con muestras/clases insuficientes).")
        ft_folds.append(acc_ft)

    # ---------- resultados finales ----------
    if len(all_true_ev) > 0:
        all_true_ev = np.concatenate(all_true_ev)
        all_pred_ev = np.concatenate(all_pred_ev)
    else:
        all_true_ev = np.array([], dtype=int)
        all_pred_ev = np.array([], dtype=int)

    print("\n" + "="*60)
    print("RESULTADOS FINALES (por EVENTO)")
    print("="*60)
    print("Global folds:", [f"{a:.4f}" for a in global_folds])
    if len(global_folds) > 0:
        print(f"Global mean: {np.mean(global_folds):.4f}")

    print("Fine-tune folds:", [("nan" if (a is None or np.isnan(a)) else f"{a:.4f}") for a in ft_folds])
    if len(ft_folds) > 0:
        print(f"Fine-tune mean: {np.nanmean(ft_folds):.4f}")
        print(f"Δ(FT-Global) mean: {np.nanmean(ft_folds) - np.mean(global_folds):+.4f}")

    if all_true_ev.size > 0:
        fname = "confusion_eegnet_sliding_event_eval_allfolds.png"
        plot_confusion(all_true_ev, all_pred_ev, CLASS_NAMES_4C,
                       title="Confusion Matrix (Event-level, ALL FOLDS)",
                       fname=fname)
        print(f"↳ Matriz de confusión guardada: {fname}")

    return {
        "global_folds": global_folds,
        "ft_folds": ft_folds,
        "all_true_ev": all_true_ev,
        "all_pred_ev": all_pred_ev,
        "folds_json_path": str(folds_json_path)
    }

# ---------- MAIN ----------
if __name__ == "__main__":
    print("🧠 INICIANDO EEGNet + Sliding Windows + Event-level Eval")
    print(f"🔧 WINDOW_SET={WINDOW_SET} | POLICY={WINDOW_POLICY}")
    print(f"⚙️  GLOBAL: epochs={EPOCHS_GLOBAL}, lr={LR_INIT}, batch={BATCH_SIZE}, SGDR(T0={SGDR_T0},Tmult={SGDR_Tmult})")
    print(f"🧪 FT por sujeto: {'ON' if DO_SUBJECT_FT else 'OFF'} "
          f"(epochs={FT_EPOCHS}, base_lr={FT_BASE_LR}, head_lr={FT_HEAD_LR}, L2SP={FT_L2SP})")
    run_experiment()


🚀 Usando dispositivo: cuda
🎯 Sliding windows → TRAIN=[(0.0, 3.0), (0.5, 3.5), (1.0, 4.0)] | EVAL=[(0.0, 3.0), (0.5, 3.5), (1.0, 4.0)]
🧠 INICIANDO EEGNet + Sliding Windows + Event-level Eval
🔧 WINDOW_SET=3s_base | POLICY=standard
⚙️  GLOBAL: epochs=100, lr=0.01, batch=64, SGDR(T0=6,Tmult=2)
🧪 FT por sujeto: ON (epochs=25, base_lr=5e-05, head_lr=0.001, L2SP=0.0001)
Sujetos elegibles: 103 → [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]...


Escaneando eventos (sin crops):   1%|          | 1/103 [00:00<00:14,  7.15it/s]

Escaneando eventos (sin crops): 100%|██████████| 103/103 [00:37<00:00,  2.72it/s]
Generando segmentos (con crops + event_ids): 100%|██████████| 103/103 [00:43<00:00,  2.37it/s]


=== DEBUG WINDOWS SUMMARY ===
Eventos brutos: 8652
Crops OK:      25956
Skips OOB:     0
Skips LEN:     0
Skips NaN:     0
Dataset construido: N=25956 | T=480 | C=8 | sujetos únicos=103 | eventos únicos=8652
Listo para entrenar: N=25956 | T=480 | C=8 | clases=4 | sujetos=103

[Fold 1/5] Entrenando (train=5712 | val=1176 | test=1764)
  Época   1 | val_acc(ev)=0.3571 | LR=0.01000
  Época   5 | val_acc(ev)=0.3929 | LR=0.00250
  Época  10 | val_acc(ev)=0.3801 | LR=0.00854
  Época  15 | val_acc(ev)=0.3878 | LR=0.00250
  Época  20 | val_acc(ev)=0.3903 | LR=0.00996
  Early stopping @ 23 (mejor val_acc(ev)=0.4082)
[Fold 1/5] Global acc (por evento) = 0.3214
              precision    recall  f1-score   support

        Left     0.4898    0.1633    0.2449       147
       Right     0.5135    0.1293    0.2065       147
  Both Fists     0.3022    0.7483    0.4305       147
   Both Feet     0.2609    0.2449    0.2526       147

    accuracy                         0.3214       588
   macro avg    

In [None]:
# -*- coding: utf-8 -*-
# EEGNet + Sliding Windows por evento
# Mejoras:
#  - RandomEventSampler: 1 crop aleatorio por evento/época (combatir overfitting)
#  - Curvas train(ev) vs val(ev) por época y guardado PNG
#  - Tag de crops "eval" para que train-curve use SOLO los eval-crops (sin augs)
#  - SpecAug (band masking) opcional
#  - Hooks para consistency regularization entre ventanas del mismo evento (OFF por defecto)

import os, re, math, random, json, itertools, copy
from pathlib import Path
from datetime import datetime

import numpy as np
import pandas as pd
import mne

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Sampler

from sklearn.model_selection import GroupKFold, GroupShuffleSplit, StratifiedKFold, StratifiedShuffleSplit
from sklearn.utils import check_random_state
from sklearn.metrics import confusion_matrix, classification_report

import matplotlib.pyplot as plt
from tqdm import tqdm

# =========================
# CONFIG
# =========================
PROJ = Path('..').resolve().parent
DATA_RAW = PROJ / 'data' / 'raw'
CACHE_DIR = PROJ / 'data' / 'cache'
FOLDS_JSON = PROJ / 'models' / 'folds' / 'Kfold5.json'
FIGS_DIR = PROJ / 'reports' / 'figs'
CACHE_DIR.mkdir(parents=True, exist_ok=True)
FIGS_DIR.mkdir(parents=True, exist_ok=True)

RANDOM_STATE = 42
torch.manual_seed(RANDOM_STATE); np.random.seed(RANDOM_STATE); random.seed(RANDOM_STATE)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🚀 Dispositivo: {DEVICE}")

FS = 160.0
N_FOLDS = 5
N_CLASSES = 4
CLASS_NAMES_4C = ['Left', 'Right', 'Both Fists', 'Both Feet']

# Entrenamiento global
BATCH_SIZE    = 64
EPOCHS_GLOBAL = 100
LR_INIT       = 1e-2
SGDR_T0       = 6
SGDR_Tmult    = 2
GLOBAL_VAL_SPLIT = 0.17
GLOBAL_PATIENCE  = 10
LOG_EVERY        = 5

# Fine-tuning por sujeto
DO_SUBJECT_FT = True
CALIB_CV_FOLDS = 4
FT_EPOCHS   = 25
FT_BASE_LR  = 5e-5
FT_HEAD_LR  = 1e-3
FT_L2SP     = 1e-4
FT_PATIENCE = 5
FT_VAL_RATIO= 0.2

# Normalización por epoch (z-score por canal)
NORM_EPOCH_ZSCORE = True

# Sujetos / Runs / Canales
EXCLUDE_SUBJECTS = {38, 88, 89, 92, 100, 104}
MI_RUNS_LR = [4, 8, 12]
MI_RUNS_OF = [6, 10, 14]
BASELINE_RUNS_EO = [1]
EXPECTED_8 = ['C3','C4','Cz','CP3','CP4','FC3','FC4','FCz']

# Sliding windows
WINDOW_SET   = "3s_base"   # "3s_base" o "6s_base"
WINDOW_POLICY= "train_aggressive"  # "standard" | "train_precue" | "train_aggressive"
DEBUG_WINDOWS = True

# Random per-event crop sampling
RANDOM_PER_EVENT = True   # clave para combatir overfitting
K_PER_EVENT      = 1      # 1 crop aleatorio por evento y época

# SpecAug (band masking) opcional
USE_SPECAUG = True
SPECAUG_PROB = 0.35
SPECAUG_BW_HZ = (2.0, 6.0)   # ancho de banda a enmascarar

# Consistency regularization entre dos crops del mismo evento (apagado por defecto)
USE_CONSISTENCY = False
CONS_WEIGHT = 0.1

# =========================
# UTILIDADES CANALES / EDF
# =========================
def normalize_label(s: str) -> str:
    if s is None: return s
    s = s.strip()
    s = re.sub(r'[^A-Za-z0-9]', '', s)
    s = re.sub(r'([A-Za-z])0([0-9])', r'\1\2', s)
    s = re.sub(r'([A-Za-z])Z$', r'\1z', s)
    s = s.replace('fp', 'Fp').replace('FP', 'Fp')
    s = ''.join(ch.upper() if ch != 'z' else 'z' for ch in s)
    return s

def rename_channels_1010(raw: mne.io.BaseRaw):
    mapping = {}
    for ch in raw.ch_names:
        lab = normalize_label(ch)
        lab = lab[:-1] + 'z' if lab.endswith('Z') else lab
        lab = re.sub(r'([A-Z])Z$', r'\1z', lab)
        mapping[ch] = lab
    mne.rename_channels(raw.info, mapping)

def ensure_channels_order(raw: mne.io.BaseRaw, desired_channels=EXPECTED_8):
    missing = [ch for ch in desired_channels if ch not in raw.ch_names]
    if missing:
        print(f"[WARN] faltan canales {missing} en archivo {getattr(raw,'filenames', [''])[0]}")
        return None
    raw.reorder_channels([ch for ch in raw.ch_names if ch in desired_channels] +
                         [ch for ch in raw.ch_names if ch not in desired_channels])
    raw.pick_channels(desired_channels, ordered=True)
    return raw

_re_file = re.compile(r'[Ss](\d{3}).*?[Rr](\d{2})')
def parse_subject_run(path: Path):
    m = _re_file.search(str(path))
    if not m: return None, None
    return int(m.group(1)), int(m.group(2))

def run_kind(run_id:int):
    if run_id in MI_RUNS_LR: return 'LR'
    if run_id in MI_RUNS_OF: return 'OF'
    if run_id in BASELINE_RUNS_EO: return 'EO'
    return None

def read_raw_edf(path: Path):
    raw = mne.io.read_raw_edf(path, preload=True, verbose=False)
    raw.pick(mne.pick_types(raw.info, eeg=True))
    rename_channels_1010(raw)
    try:
        mont = mne.channels.make_standard_montage('standard_1020')
        raw.set_montage(mont, on_missing='ignore')
    except Exception:
        pass
    if abs(raw.info['sfreq'] - FS) > 1e-6:
        raw.resample(FS, npad="auto")
    raw = ensure_channels_order(raw, EXPECTED_8)
    return raw

def collect_events_T1T2(raw: mne.io.BaseRaw):
    if raw.annotations is None or len(raw.annotations) == 0:
        return []
    def _norm(s): return str(s).strip().upper().replace(' ', '')
    res = []
    for onset, desc in zip(raw.annotations.onset, raw.annotations.description):
        tag = _norm(desc)
        if tag in ('T1','T2'):
            res.append((float(onset), tag))
    res.sort()
    # dedup simple
    out = []
    last = {'T1': -1e9, 'T2': -1e9}
    for t, tag in res:
        if t - last[tag] >= 0.5:
            out.append((t, tag))
            last[tag] = t
    return out

# =========================
# DISCOVERY SUJETOS
# =========================
def subjects_available():
    subs = []
    for sdir in sorted(DATA_RAW.glob('S*')):
        if not sdir.is_dir(): continue
        try: sid = int(sdir.name[1:])
        except: continue
        if sid in EXCLUDE_SUBJECTS: continue
        any_mi = any((sdir / f"S{sid:03d}R{r:02d}.edf").exists() for r in (MI_RUNS_LR + MI_RUNS_OF))
        if any_mi: subs.append(sid)
    return subs

# =========================
# LABELS / BALANCE por evento
# =========================
def _label_from_tag(kind: str, tag: str):
    if kind == 'LR':
        return 0 if tag == 'T1' else 1 if tag == 'T2' else None
    elif kind == 'OF':
        return 2 if tag == 'T1' else 3 if tag == 'T2' else None
    return None

def gather_all_events(subjects):
    per_subj = {}
    for s in tqdm(subjects, desc="Escaneando eventos (sin crops)"):
        sdir = DATA_RAW / f"S{s:03d}"
        if not sdir.exists(): continue
        per_subj[s] = {0:[], 1:[], 2:[], 3:[]}
        for r in (MI_RUNS_LR + MI_RUNS_OF + BASELINE_RUNS_EO):
            p = sdir / f"S{s:03d}R{r:02d}.edf"
            if not p.exists(): continue
            kind = run_kind(r)
            if kind not in ('LR','OF'):
                continue
            raw = read_raw_edf(p)
            if raw is None: 
                continue
            events = collect_events_T1T2(raw)
            for onset_sec, tag in events:
                lab = _label_from_tag(kind, tag)
                if lab is None: continue
                per_subj[s][lab].append({
                    "path": p,
                    "onset": float(onset_sec),
                    "label": int(lab),
                    "subject": int(s),
                    "run": int(r)
                })
    return per_subj

def balance_events_per_subject(per_subj, need_per_class=21, rng_seed=RANDOM_STATE):
    rng = check_random_state(rng_seed)
    kept = {}
    for s, cls_dict in per_subj.items():
        kept[s] = {}
        skip = False
        for c in range(4):
            events = cls_dict.get(c, [])
            n0 = len(events)
            if n0 == 0:
                skip = True
                break
            if n0 == need_per_class:
                sel = events
            elif n0 > need_per_class:
                idx = rng.choice(n0, size=need_per_class, replace=False)
                sel = [events[i] for i in idx]
            else:
                idx = rng.choice(n0, size=need_per_class, replace=True)
                sel = [events[i] for i in idx]
            kept[s][c] = sel
        if skip:
            kept[s] = None
            print(f"[SKIP] S{s:03d} (alguna clase sin eventos)")
    return {s:cls for s,cls in kept.items() if cls is not None}

# =========================
# SLIDING WINDOWS
# =========================
def base_windows(set_name="3s_base"):
    if set_name == "3s_base":
        return [(0.0, 3.0), (0.5, 3.5), (1.0, 4.0)]
    elif set_name == "6s_base":
        return [(-1.0, 5.0), (-0.5, 5.5), (0.0, 6.0)]
    else:
        raise ValueError(set_name)

def get_crops_for_policy(window_set="3s_base", policy="standard"):
    crops_eval = base_windows(window_set)
    if policy == "standard":
        crops_train = crops_eval
    elif policy == "train_precue":
        crops_train = [(-0.5, 2.5)] + crops_eval if window_set == "3s_base" else [(-1.5, 4.5)] + crops_eval
    elif policy == "train_aggressive":
        if window_set == "3s_base":
            starts = [0.0, 0.25, 0.5, 0.75, 1.0]
            crops_train = [(s, s+3.0) for s in starts]
        else:
            crops_train = crops_eval
    else:
        raise ValueError(policy)
    return crops_train, crops_eval

CROPS_TRAIN, CROPS_EVAL = get_crops_for_policy(WINDOW_SET, WINDOW_POLICY)
print(f"🎯 Sliding windows → TRAIN={CROPS_TRAIN} | EVAL={CROPS_EVAL}")

# =========================
# CROPS (con etiqueta is_eval) + event_id
# =========================
def segments_from_balanced_events_with_flags(kept_balanced, crops_train, crops_eval, fs=FS):
    """
    Genera TODOS los crops de crops_train ∪ crops_eval y marca cuáles son de eval.
    """
    all_crops = list({c for c in (crops_train + crops_eval)})  # unión sin duplicados
    X, y, groups, eids, is_eval = [], [], [], [], []
    ch_template = None
    cache_raw = {}

    dbg_total_events = 0
    dbg_total_crops_ok = 0
    dbg_skips_oob = dbg_skips_short = dbg_skips_nan = 0

    next_event_id = 0
    try:
        for s, cls_dict in tqdm(kept_balanced.items(), desc="Generando segmentos (train∪eval crops)"):
            for c in range(4):
                for ev in cls_dict[c]:
                    dbg_total_events += 1
                    p = ev["path"]; onset = ev["onset"]; sid = ev["subject"]; rid = ev["run"]
                    if p not in cache_raw:
                        raw = read_raw_edf(p)
                        if raw is None: continue
                        cache_raw[p] = raw
                    raw = cache_raw[p]
                    data = raw.get_data()
                    T_total = data.shape[1]
                    if ch_template is None:
                        ch_template = raw.ch_names
                    eid = next_event_id; next_event_id += 1

                    for rel_start, rel_end in all_crops:
                        s_samp = int(round((raw.first_time + onset + rel_start) * fs))
                        e_samp = int(round((raw.first_time + onset + rel_end)   * fs))
                        exp_len = int(round((rel_end - rel_start) * fs))
                        if s_samp < 0 or e_samp > T_total:
                            dbg_skips_oob += 1; continue
                        seg = data[:, s_samp:e_samp].T.astype(np.float32)
                        if seg.shape[0] != exp_len or seg.shape[0] <= 0:
                            dbg_skips_short += 1; continue
                        if not np.isfinite(seg).all():
                            dbg_skips_nan += 1; continue

                        if NORM_EPOCH_ZSCORE:
                            seg = (seg - seg.mean(axis=0, keepdims=True)) / (seg.std(axis=0, keepdims=True) + 1e-6)

                        X.append(seg); y.append(c); groups.append(s); eids.append(eid)
                        is_eval.append(1 if (rel_start, rel_end) in crops_eval else 0)
                        dbg_total_crops_ok += 1
    finally:
        cache_raw.clear()

    if DEBUG_WINDOWS:
        print("=== DEBUG WINDOWS SUMMARY ===")
        print(f"Eventos brutos: {dbg_total_events}")
        print(f"Crops OK:      {dbg_total_crops_ok}")
        print(f"Skips OOB:     {dbg_skips_oob}")
        print(f"Skips LEN:     {dbg_skips_short}")
        print(f"Skips NaN:     {dbg_skips_nan}")

    X = np.stack(X, axis=0) if len(X)>0 else np.zeros((0,1,1), np.float32)
    y = np.asarray(y, dtype=np.int64)
    groups = np.asarray(groups, dtype=np.int64)
    eids = np.asarray(eids, dtype=np.int64)
    is_eval = np.asarray(is_eval, dtype=np.int8)

    n, T, C = X.shape
    print(f"Dataset construido: N={n} | T={T} | C={C} | sujetos únicos={len(np.unique(groups))} | eventos únicos={len(np.unique(eids))}")
    return X, y, groups, eids, is_eval, ch_template

# =========================
# DATASET / SAMPLERS
# =========================
class EEGTrials(Dataset):
    def __init__(self, X, y, groups, eids, is_eval):
        self.X = X.astype(np.float32)
        self.y = y.astype(np.int64)
        self.g = groups.astype(np.int64)
        self.e = eids.astype(np.int64)
        self.is_eval = is_eval.astype(np.int8)
    def __len__(self): return self.X.shape[0]
    def __getitem__(self, idx):
        x = self.X[idx]  # (T,C)
        x = np.expand_dims(x, 0)  # (1,T,C)
        return torch.from_numpy(x), torch.tensor(self.y[idx]), torch.tensor(self.g[idx]), torch.tensor(self.e[idx]), torch.tensor(self.is_eval[idx])

class RandomEventSampler(Sampler):
    """
    Devuelve exactamente K_PER_EVENT índices por event_id en cada época (aleatorios).
    Útil para no sobre-contar eventos cuando hay muchos crops por evento.
    """
    def __init__(self, eids, is_train_mask, k_per_event=1, generator=None):
        self.eids = np.asarray(eids)
        self.is_train = np.asarray(is_train_mask).astype(bool)
        self.k = int(k_per_event)
        self.gen = generator
        # agrupar indices de train por eid
        self.by_e = {}
        for i, (eid, flag) in enumerate(zip(self.eids, self.is_train)):
            if not flag: continue
            self.by_e.setdefault(int(eid), []).append(i)
        self.events = sorted(self.by_e.keys())
        self.num_samples = len(self.events) * self.k
    def __len__(self): return self.num_samples
    def __iter__(self):
        g = torch.Generator()
        if self.gen is not None: g = self.gen
        for eid in self.events:
            idxs = self.by_e[eid]
            if len(idxs) <= self.k:
                choice = np.random.choice(idxs, size=self.k, replace=True)
            else:
                choice = np.random.choice(idxs, size=self.k, replace=False)
            for c in choice:
                yield int(c)

# =========================
# AUGMENTS
# =========================
def do_time_jitter(x, max_ms=50, fs=160.0):
    B,_,T,C = x.shape
    max_shift = int(round(max_ms/1000.0 * fs))
    if max_shift <= 0: return x
    shifts = torch.randint(low=-max_shift, high=max_shift+1, size=(B,), device=x.device)
    out = torch.empty_like(x)
    for i,s in enumerate(shifts):
        if s==0: out[i] = x[i]; continue
        if s>0:
            out[i,:,s:,:] = x[i,:,:T-s,:]
            out[i,:,:s,:] = 0
        else:
            s = -s
            out[i,:,:T-s,:] = x[i,:,s:,:]
            out[i,:,T-s:,:] = 0
    return out

def do_gaussian_noise(x, sigma=0.01):
    if sigma<=0: return x
    return x + sigma*torch.randn_like(x)

def do_temporal_cutout_masked(x, y, classes_mask={2,3}, min_ms=30, max_ms=60, fs=160.0):
    B,_,T,C = x.shape
    Lmin = int(round(min_ms/1000.0*fs))
    Lmax = int(round(max_ms/1000.0*fs))
    if Lmin<=0 or Lmax<=0 or Lmin>Lmax: return x
    out = x.clone()
    y_np = y.detach().cpu().numpy()
    for i in range(B):
        if int(y_np[i]) not in classes_mask:
            continue
        L = random.randint(Lmin, Lmax)
        if L>=T: continue
        s = random.randint(0, T-L)
        out[i,:,s:s+L,:] = 0
    return out

def specaug_band_mask(x, fs=160.0, prob=0.35, bw_hz=(2.0,6.0)):
    """Enmascara una banda de frecuencia con filtro notch pobre (en tiempo) → aproximación barata."""
    if prob <= 0: return x
    if random.random() > prob: return x
    B,_,T,C = x.shape
    # simula una banda bloqueada multiplicando por envolvente sinusoidal (barato, evita FFT)
    bw = random.uniform(bw_hz[0], bw_hz[1])  # ancho
    f0 = random.uniform(6.0, 30.0-bw)        # centro
    t = torch.linspace(0, (T-1)/fs, T, device=x.device).view(1,1,T,1)
    w1 = 2*math.pi*(f0 - bw/2.0); w2 = 2*math.pi*(f0 + bw/2.0)
    mask = 1.0 - 0.35*(torch.sin(w1*t)**2 + torch.sin(w2*t)**2)/2.0
    return x * mask

def mixup_batch(x, y, n_classes, alpha=0.2):
    if alpha<=0:
        y_onehot = torch.nn.functional.one_hot(y, num_classes=n_classes).float()
        return x, y_onehot, 1.0
    lam = np.random.beta(alpha, alpha)
    perm = torch.randperm(x.size(0), device=x.device)
    x_mix = lam*x + (1-lam)*x[perm]
    y_a = torch.nn.functional.one_hot(y, num_classes=n_classes).float()
    y_b = y_a[perm]
    y_mix = lam*y_a + (1-lam)*y_b
    return x_mix, y_mix, lam

class WeightedSoftCrossEntropy(nn.Module):
    def __init__(self, class_weights=None, label_smoothing=0.0):
        super().__init__()
        self.register_buffer('w', None if class_weights is None else class_weights.clone().float())
        self.ls = float(label_smoothing)
    def forward(self, logits, target_probs):
        if self.ls > 0:
            K = logits.size(1)
            target_probs = (1-self.ls)*target_probs + self.ls*(1.0/K)
        logp = torch.log_softmax(logits, dim=1)
        loss_per_class = -(target_probs * logp)
        if self.w is not None:
            loss_per_class = loss_per_class * self.w.unsqueeze(0)
        loss = loss_per_class.sum(dim=1).mean()
        return loss

# =========================
# EEGNet
# =========================
class ChannelDropout(nn.Module):
    def __init__(self, p=0.1): super().__init__(); self.p = p
    def forward(self, x):
        if not self.training or self.p<=0: return x
        B,_,T,C = x.shape
        mask = (torch.rand(B,1,1,C, device=x.device) > self.p).float()
        return x * mask

@torch.no_grad()
def apply_max_norm(model, max_value=2.0, p=2.0):
    layers = []
    for name in ['conv_depthwise','conv_sep_point','fc','out']:
        if hasattr(model, name): layers.append(getattr(model, name))
    for layer in layers:
        if hasattr(layer, 'weight') and layer.weight is not None:
            w = layer.weight.data
            norms = w.view(w.size(0), -1).norm(p=p, dim=1, keepdim=True)
            desired = torch.clamp(norms, max=max_value)
            w.view(w.size(0), -1).mul_(desired / (1e-8 + norms))

class EEGNet(nn.Module):
    def __init__(self, n_ch: int, n_classes: int,
                 F1: int = 24, D: int = 2, kernel_t: int = 64, k_sep: int = 16,
                 pool1_t: int = 4, pool2_t: int = 6,
                 drop1_p: float = 0.35, drop2_p: float = 0.6,
                 chdrop_p: float = 0.1):
        super().__init__()
        self.n_ch = n_ch; self.n_classes = n_classes
        self.F1 = F1; self.D = D; self.F2 = F1 * D
        self.kernel_t = kernel_t; self.k_sep = k_sep
        self.pool1_t = pool1_t; self.pool2_t = pool2_t
        self.chdrop = ChannelDropout(p=chdrop_p)

        self.conv_temporal = nn.Conv2d(1, F1, kernel_size=(kernel_t, 1),
                                       padding=(kernel_t // 2, 0), bias=False)
        self.bn1 = nn.BatchNorm2d(F1, momentum=0.99, eps=1e-3)
        self.act = nn.ELU()

        self.conv_depthwise = nn.Conv2d(F1, self.F2, kernel_size=(1, n_ch),
                                        groups=F1, bias=False)
        self.bn2 = nn.BatchNorm2d(self.F2, momentum=0.99, eps=1e-3)
        self.pool1 = nn.AvgPool2d(kernel_size=(pool1_t, 1), stride=(pool1_t, 1))
        self.drop1 = nn.Dropout(drop1_p)

        self.conv_sep_depth = nn.Conv2d(self.F2, self.F2, kernel_size=(k_sep, 1),
                                        groups=self.F2, padding=(k_sep // 2, 0), bias=False)
        self.conv_sep_point = nn.Conv2d(self.F2, self.F2, kernel_size=(1, 1), bias=False)
        self.bn3 = nn.BatchNorm2d(self.F2, momentum=0.99, eps=1e-3)
        self.pool2 = nn.AvgPool2d(kernel_size=(pool2_t, 1), stride=(pool2_t, 1))
        self.drop2 = nn.Dropout(drop2_p)

        self.flatten = nn.Flatten()
        self.fc  = None; self.out = None; self._T_in = None

    def _build_head(self, T_in: int, device: torch.device):
        T1 = T_in // self.pool1_t
        T2 = T1 // self.pool2_t
        feat_dim = self.F2 * T2 * 1
        self.fc  = nn.Linear(feat_dim, 128, bias=True).to(device)
        self.out = nn.Linear(128, self.n_classes, bias=True).to(device)
        self._T_in = T_in

    def ensure_head(self, T_in: int, device: torch.device):
        if (self.fc is None) or (self.out is None) or (self._T_in != T_in):
            self._build_head(T_in, device)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, _, T, C = x.shape
        self.ensure_head(T, x.device)
        x = self.chdrop(x)
        z = self.conv_temporal(x); z = self.bn1(z); z = self.act(z)
        z = self.conv_depthwise(z); z = self.bn2(z); z = self.act(z)
        z = self.pool1(z); z = self.drop1(z)
        z = self.conv_sep_depth(z); z = self.conv_sep_point(z)
        z = self.bn3(z); z = self.act(z)
        z = self.pool2(z); z = self.drop2(z)
        z = self.flatten(z)
        z = self.fc(z); z = self.act(z)
        z = self.out(z)
        return z

# =========================
# PESOS CLASE
# =========================
def make_class_weight_tensor(y_indices, n_classes, boost_bfists=1.20):
    counts = np.bincount(y_indices, minlength=n_classes).astype(np.float32)
    counts[counts == 0] = 1.0
    w = counts.sum() / counts
    w = w / w.mean()
    w[2] *= boost_bfists
    w = w / w.mean()
    return torch.tensor(w, dtype=torch.float32, device=DEVICE)

# =========================
# EVAL AGRUPADA por EVENTO
# =========================
@torch.no_grad()
def evaluate_grouped_by_event(model, loader, eids_full, use_tta_jitter=True, tta_jitter_n=3):
    model.eval()
    all_logits, all_eids, y_all = [], [], []
    idx_offset = 0
    for xb, yb, _, _, _ in loader:
        B = xb.size(0)
        xb = xb.to(DEVICE)
        if use_tta_jitter:
            outs = []
            for _ in range(tta_jitter_n):
                outs.append(model(do_time_jitter(xb, max_ms=20, fs=FS)))
            logits = torch.stack(outs, dim=0).mean(dim=0)
        else:
            logits = model(xb)
        all_logits.append(logits.cpu().numpy())
        all_eids.append(eids_full[idx_offset: idx_offset + B].copy())
        y_all.append(yb.numpy())
        idx_offset += B
    all_logits = np.concatenate(all_logits, axis=0)
    all_eids   = np.concatenate(all_eids, axis=0)
    y_all      = np.concatenate(y_all, axis=0)

    uniq_eids = np.unique(all_eids)
    logits_ev, ytrue_ev = [], []
    for eid in uniq_eids:
        idx = (all_eids == eid)
        logits_ev.append(all_logits[idx].mean(axis=0))
        ytrue_ev.append(int(y_all[idx][0]))
    logits_ev = np.stack(logits_ev, axis=0)
    ytrue_ev  = np.asarray(ytrue_ev, dtype=np.int64)
    ypred_ev  = logits_ev.argmax(axis=1)
    acc = float((ypred_ev == ytrue_ev).mean())
    return ytrue_ev, ypred_ev, acc

def plot_confusion(y_true, y_pred, classes, title, fname):
    cm = confusion_matrix(y_true, y_pred, labels=list(range(len(classes))))
    with np.errstate(invalid='ignore'):
        cm_norm = cm.astype('float') / cm.sum(axis=1, keepdims=True)
    cm_norm = np.nan_to_num(cm_norm)
    plt.figure(figsize=(6,5))
    plt.imshow(cm_norm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title(title); plt.colorbar(fraction=0.046, pad=0.04)
    ticks = np.arange(len(classes))
    plt.xticks(ticks, classes, rotation=45, ha='right'); plt.yticks(ticks, classes)
    fmt = '.2f'; thresh = cm_norm.max()/2.
    for i in range(cm_norm.shape[0]),:
        pass
    for i in range(cm_norm.shape[0]):
        for j in range(cm_norm.shape[1]):
            plt.text(j, i, format(cm_norm[i,j], '.2f'), ha="center",
                     color="white" if cm_norm[i,j] > thresh else "black")
    plt.ylabel('True'); plt.xlabel('Pred')
    plt.tight_layout(); plt.savefig(fname, dpi=150, bbox_inches='tight'); plt.close()

def print_report(y_true, y_pred, classes):
    print(classification_report(y_true, y_pred, target_names=classes, digits=4))

# =========================
# TRAIN LOOP
# =========================
def train_epoch(model, loader, opt, criterion_soft, n_classes):
    model.train()
    for xb, yb, _, _, _ in loader:
        xb, yb = xb.to(DEVICE), yb.to(DEVICE)
        xb = do_time_jitter(xb, max_ms=50, fs=FS)
        xb = do_gaussian_noise(xb, sigma=0.01)
        xb = do_temporal_cutout_masked(xb, yb, classes_mask={2,3}, min_ms=30, max_ms=60, fs=FS)
        if USE_SPECAUG:
            xb = specaug_band_mask(xb, fs=FS, prob=SPECAUG_PROB, bw_hz=SPECAUG_BW_HZ)
        xb, yt, _ = mixup_batch(xb, yb, n_classes=n_classes, alpha=0.2)

        logits = model(xb)
        loss = criterion_soft(logits, yt)

        # (opcional) consistencia entre dos views del mismo batch
        if USE_CONSISTENCY:
            with torch.no_grad():
                xb2 = do_time_jitter(xb.clone(), max_ms=35, fs=FS)
            logits2 = model(xb2)
            p1 = torch.log_softmax(logits, dim=1)
            p2 = torch.softmax(logits2, dim=1)
            kl = torch.sum(torch.exp(p1) * (p1 - torch.log(p2 + 1e-8)), dim=1).mean()
            loss = loss + CONS_WEIGHT * kl

        opt.zero_grad(set_to_none=True)
        loss.backward()
        opt.step()
        apply_max_norm(model, max_value=2.0, p=2.0)

# =========================
# FT por sujeto
# =========================
def _freeze_for_ft(model):
    for p in model.parameters(): p.requires_grad = False
    for p in model.conv_depthwise.parameters(): p.requires_grad = True
    for p in model.bn2.parameters():           p.requires_grad = True
    for p in model.conv_sep_depth.parameters():p.requires_grad = True
    for p in model.conv_sep_point.parameters():p.requires_grad = True
    for p in model.bn3.parameters():           p.requires_grad = True
    for p in model.fc.parameters():            p.requires_grad = True
    for p in model.out.parameters():           p.requires_grad = True

def subject_cv_finetune_predict_grouped(model_global, Xs, ys, eids, n_splits=CALIB_CV_FOLDS, n_classes=4):
    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=RANDOM_STATE)
    y_true_all, y_pred_all = [], []
    for tr_idx, te_idx in skf.split(Xs, ys):
        Xcal, ycal, ecal = Xs[tr_idx], ys[tr_idx], eids[tr_idx]
        Xho,  yho,  eho  = Xs[te_idx], ys[te_idx], eids[te_idx]
        sss = StratifiedShuffleSplit(n_splits=1, test_size=FT_VAL_RATIO, random_state=RANDOM_STATE)
        (tr2, va2), = sss.split(Xcal, ycal)
        Xtr, ytr = Xcal[tr2], ycal[tr2]
        Xva, yva = Xcal[va2], ycal[va2]

        tr_ds = torch.utils.data.TensorDataset(torch.from_numpy(Xtr).float().unsqueeze(1),
                                               torch.from_numpy(ytr).long())
        va_ds = torch.utils.data.TensorDataset(torch.from_numpy(Xva).float().unsqueeze(1),
                                               torch.from_numpy(yva).long())
        te_ds = torch.utils.data.TensorDataset(torch.from_numpy(Xho).float().unsqueeze(1),
                                               torch.from_numpy(yho).long())
        tr_dl = DataLoader(tr_ds, batch_size=32, shuffle=True)
        va_dl = DataLoader(va_ds, batch_size=64, shuffle=False)
        te_dl = DataLoader(te_ds, batch_size=64, shuffle=False)

        m = copy.deepcopy(model_global).to(DEVICE)
        _freeze_for_ft(m)

        base_params = (list(m.conv_depthwise.parameters()) +
                       list(m.bn2.parameters()) +
                       list(m.conv_sep_depth.parameters()) +
                       list(m.conv_sep_point.parameters()) +
                       list(m.bn3.parameters()))
        head_params = list(m.fc.parameters()) + list(m.out.parameters())

        opt = optim.Adam([
            {"params": base_params, "lr": FT_BASE_LR},
            {"params": head_params, "lr": FT_HEAD_LR},
        ])
        ref = [p.detach().clone() for p in base_params + head_params]
        class_w = make_class_weight_tensor(ytr, n_classes, boost_bfists=1.0)
        crit = nn.CrossEntropyLoss(weight=class_w)

        best_state = copy.deepcopy(m.state_dict())
        best_val = float('inf'); bad = 0

        for _ in range(FT_EPOCHS):
            m.train()
            for xb, yb in tr_dl:
                xb, yb = xb.to(DEVICE), yb.to(DEVICE)
                opt.zero_grad(set_to_none=True)
                logits = m(xb)
                loss = crit(logits, yb)
                reg = 0.0
                tr_params = base_params + head_params
                for p_cur, p_ref in zip(tr_params, ref):
                    reg = reg + torch.sum((p_cur - p_ref)**2)
                loss = loss + FT_L2SP * reg
                loss.backward(); opt.step()
                apply_max_norm(m, max_value=2.0, p=2.0)

            # val
            m.eval()
            val_loss = 0.0; n = 0
            with torch.no_grad():
                for xb, yb in va_dl:
                    xb, yb = xb.to(DEVICE), yb.to(DEVICE)
                    logits = m(xb)
                    loss = crit(logits, yb)
                    val_loss += loss.item()*xb.size(0); n += xb.size(0)
            val_loss /= max(1,n)
            if val_loss + 1e-7 < best_val:
                best_val = val_loss; bad = 0
                best_state = copy.deepcopy(m.state_dict())
            else:
                bad += 1
                if bad >= FT_PATIENCE: break

        m.load_state_dict(best_state)

        # pred en HOLD-OUT del sujeto, agrupado por event_id
        m.eval()
        logits_all = []
        with torch.no_grad():
            for xb, _ in te_dl:
                xb = xb.to(DEVICE)
                logits_all.append(m(xb).cpu().numpy())
        logits_all = np.concatenate(logits_all, axis=0)
        uniq = np.unique(eho)
        for eid in uniq:
            idx = (eho == eid)
            log_ev = logits_all[idx].mean(axis=0)
            y_true_all.append(int(yho[idx][0]))
            y_pred_all.append(int(np.argmax(log_ev)))
    return np.asarray(y_true_all, dtype=np.int64), np.asarray(y_pred_all, dtype=np.int64)

# =========================
# FOLDS JSON helpers
# =========================
def save_group_folds_json_with_indices(subject_ids_str, groups_array, n_splits, out_json_path,
                                       created_by="EEGNet_Sliding+", description=None):
    out_json_path = Path(out_json_path)
    unique_subjects_int = sorted(np.unique(groups_array).tolist())
    subject_ids = [f"S{sid:03d}" for sid in unique_subjects_int]
    if len(subject_ids) < n_splits:
        raise ValueError(f"n_splits={n_splits} mayor que número de sujetos={len(subject_ids)}")
    groups = np.arange(len(subject_ids))
    gkf = GroupKFold(n_splits=n_splits)
    folds = []
    fold_i = 0
    for train_idx_grp, test_idx_grp in gkf.split(groups, groups, groups):
        fold_i += 1
        train_sids = [subject_ids[int(i)] for i in train_idx_grp]
        test_sids  = [subject_ids[int(i)] for i in test_idx_grp]
        train_sids_int = [int(s[1:]) for s in train_sids]
        test_sids_int  = [int(s[1:]) for s in test_sids]
        tr_idx = np.where(np.isin(groups_array, train_sids_int))[0].tolist()
        te_idx = np.where(np.isin(groups_array, test_sids_int))[0].tolist()
        folds.append({
            "fold": int(fold_i),
            "train": train_sids,
            "test": test_sids,
            "tr_idx": tr_idx,
            "te_idx": te_idx
        })
    payload = {
        "created_at": datetime.now().isoformat(),
        "created_by": created_by,
        "description": description if description is not None else "",
        "n_splits": int(n_splits),
        "n_subjects": len(subject_ids),
        "subject_ids": subject_ids,
        "folds": folds
    }
    out_json_path.parent.mkdir(parents=True, exist_ok=True)
    with open(out_json_path, "w", encoding="utf-8") as f:
        json.dump(payload, f, indent=2, ensure_ascii=False)
    print(f"Folds JSON con índices guardado → {out_json_path}")
    return out_json_path

def load_group_folds_json(path_json, expected_subject_ids=None, strict_check=True):
    path_json = Path(path_json)
    if not path_json.exists():
        raise FileNotFoundError(f"No existe {path_json}")
    with open(path_json, "r", encoding="utf-8") as f:
        payload = json.load(f)
    subj_json = payload.get("subject_ids", [])
    if expected_subject_ids is not None:
        expected = sorted(list(expected_subject_ids))
        if subj_json != expected:
            msg = ("Los subject_ids del JSON no coinciden con expected_subject_ids.\n"
                   f"JSON has {len(subj_json)} subjects, expected {len(expected)}.\n"
                   f"First 10 JSON: {subj_json[:10]}\nFirst 10 expected: {expected[:10]}")
            if strict_check:
                raise ValueError(msg)
            else:
                print("WARNING: " + msg)
    return payload

# =========================
# EXPERIMENTO
# =========================
def run_experiment():
    mne.set_log_level('WARNING')

    subs = subjects_available()
    print(f"Sujetos elegibles: {len(subs)}")

    per_subj = gather_all_events(subs)
    kept_bal = balance_events_per_subject(per_subj, need_per_class=21, rng_seed=RANDOM_STATE)

    X, y, groups, eids, is_eval, chs = segments_from_balanced_events_with_flags(
        kept_bal, crops_train=CROPS_TRAIN, crops_eval=CROPS_EVAL, fs=FS
    )
    N, T, C = X.shape
    print(f"Listo para entrenar: N={N} | T={T} | C={C} | clases={len(np.unique(y))} | sujetos={len(np.unique(groups))}")

    # Folds
    FOLDS_JSON.parent.mkdir(parents=True, exist_ok=True)
    unique_subs = sorted(np.unique(groups).tolist())
    subject_ids_str = [f"S{s:03d}" for s in unique_subs]
    if not FOLDS_JSON.exists():
        save_group_folds_json_with_indices(subject_ids_str, groups, n_splits=N_FOLDS,
                                           out_json_path=FOLDS_JSON,
                                           created_by="EEGNet_Sliding+",
                                           description="RandomEventSampler + eval-crops")
    payload = load_group_folds_json(FOLDS_JSON, expected_subject_ids=subject_ids_str, strict_check=False)
    folds = payload["folds"]

    global_folds = []
    ft_folds = []
    all_true_ev = []
    all_pred_ev = []

    for f in folds:
        fold = f["fold"]
        tr_idx = np.asarray(f.get("tr_idx", []), dtype=int)
        te_idx = np.asarray(f.get("te_idx", []), dtype=int)
        # split val por sujeto
        gss = GroupShuffleSplit(n_splits=1, test_size=GLOBAL_VAL_SPLIT, random_state=RANDOM_STATE)
        tr_subj_idx, va_subj_idx = next(gss.split(tr_idx, groups[tr_idx], groups[tr_idx]))
        tr_sub_idx = tr_idx[tr_subj_idx]
        va_idx     = tr_idx[va_subj_idx]

        # Máscaras para loaders:
        # - TRAIN: todos los crops de los sujetos de train (is_eval==0 si policy produce más crops; da igual).
        #   pero sampler RandomEventSampler cogerá 1 crop random por eid/época
        train_mask = np.isin(np.arange(N), tr_sub_idx)
        val_mask   = np.isin(np.arange(N), va_idx)
        test_mask  = np.isin(np.arange(N), te_idx)

        # DataSets (comparten backing arrays)
        ds = EEGTrials(X, y, groups, eids, is_eval)

        # Sampler: uno por evento (sobre train_mask)
        if RANDOM_PER_EVENT:
            sampler = RandomEventSampler(eids, is_train_mask=train_mask, k_per_event=K_PER_EVENT)
            tr_loader = DataLoader(ds, batch_size=BATCH_SIZE, sampler=sampler, drop_last=False)
        else:
            idx_tr = np.where(train_mask)[0]
            tr_loader = DataLoader(torch.utils.data.Subset(ds, idx_tr),
                                   batch_size=BATCH_SIZE, shuffle=True, drop_last=False)

        # VALIDACIÓN y TEST → usar SOLO crops marcados como eval (is_eval==1)
        idx_va = np.where(val_mask & (is_eval==1))[0]
        idx_te = np.where(test_mask & (is_eval==1))[0]
        va_loader = DataLoader(torch.utils.data.Subset(ds, idx_va), batch_size=BATCH_SIZE, shuffle=False)
        te_loader = DataLoader(torch.utils.data.Subset(ds, idx_te), batch_size=BATCH_SIZE, shuffle=False)

        # TRAIN-EVAL (para curva): también usar SOLO eval-crops de TRAIN (sin augs)
        idx_tr_eval = np.where(train_mask & (is_eval==1))[0]
        tr_eval_loader = DataLoader(torch.utils.data.Subset(ds, idx_tr_eval), batch_size=BATCH_SIZE, shuffle=False)

        # Modelo + Opt
        model = EEGNet(n_ch=C, n_classes=N_CLASSES,
                       F1=24, D=2, kernel_t=64, k_sep=16,
                       pool1_t=4, pool2_t=6, drop1_p=0.35, drop2_p=0.6,
                       chdrop_p=0.1).to(DEVICE)
        opt = optim.Adam(model.parameters(), lr=LR_INIT)
        scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(opt, T_0=SGDR_T0, T_mult=SGDR_Tmult)

        class_weights = make_class_weight_tensor(y[tr_sub_idx], N_CLASSES, boost_bfists=1.20)
        criterion_soft = WeightedSoftCrossEntropy(class_weights, label_smoothing=0.05)

        best_state = copy.deepcopy(model.state_dict())
        best_val_ev_acc = -1.0
        bad = 0

        # Curvas
        hist_train_ev = []
        hist_val_ev   = []
        hist_lr       = []

        print(f"\n[Fold {fold}/{N_FOLDS}] Entrenando (train_ev≈{len(np.unique(eids[tr_sub_idx]))} | val_ev≈{len(np.unique(eids[va_idx]))} | test_ev≈{len(np.unique(eids[te_idx]))})")
        for epoch in range(1, EPOCHS_GLOBAL+1):
            train_epoch(model, tr_loader, opt, criterion_soft, n_classes=N_CLASSES)
            scheduler.step(epoch-1 + 1e-8)

            # Train-EVAL (por evento, sin augs, usando eval-crops del split de TRAIN)
            y_tr_ev_t, y_tr_ev_p, tr_acc_ev = evaluate_grouped_by_event(
                model, tr_eval_loader, eids_full=eids[idx_tr_eval], use_tta_jitter=False, tta_jitter_n=1
            )
            # Val-EVAL (por evento)
            y_va_ev_t, y_va_ev_p, va_acc_ev = evaluate_grouped_by_event(
                model, va_loader, eids_full=eids[idx_va], use_tta_jitter=True, tta_jitter_n=3
            )
            hist_train_ev.append(tr_acc_ev)
            hist_val_ev.append(va_acc_ev)
            hist_lr.append(opt.param_groups[0]['lr'])

            if (epoch % LOG_EVERY == 0) or epoch in (1, 10, 20, 50, 100):
                print(f"  Época {epoch:3d} | train_acc(ev)={tr_acc_ev:.4f} | val_acc(ev)={va_acc_ev:.4f} | LR={opt.param_groups[0]['lr']:.5f}")

            if va_acc_ev > best_val_ev_acc + 1e-4:
                best_val_ev_acc = va_acc_ev; bad = 0
                best_state = copy.deepcopy(model.state_dict())
            else:
                bad += 1
                if bad >= GLOBAL_PATIENCE:
                    print(f"  Early stopping @ {epoch} (mejor val_acc(ev)={best_val_ev_acc:.4f})")
                    break

        # Guardar curva
        epochs_r = range(1, len(hist_val_ev)+1)
        plt.figure(figsize=(6.2,4.2))
        plt.plot(epochs_r, hist_train_ev, label='Train (event)')
        plt.plot(epochs_r, hist_val_ev,   label='Val (event)')
        plt.xlabel('Epoch'); plt.ylabel('Accuracy'); plt.title(f'Fold {fold} - Train vs Val (Event)')
        plt.legend(); plt.grid(True, alpha=0.3)
        fig_path = FIGS_DIR / f"training_curve_event_fold{fold}.png"
        plt.tight_layout(); plt.savefig(fig_path, dpi=150); plt.close()
        print(f"↳ Curva train/val guardada: {fig_path.name}")

        # Cargar mejor
        model.load_state_dict(best_state)

        # Test por evento
        y_true_te, y_pred_te, acc_ev = evaluate_grouped_by_event(
            model, te_loader, eids_full=eids[idx_te], use_tta_jitter=True, tta_jitter_n=5
        )
        global_folds.append(acc_ev)
        all_true_ev.append(y_true_te); all_pred_ev.append(y_pred_te)
        print(f"[Fold {fold}/{N_FOLDS}] Global acc (por evento) = {acc_ev:.4f}")
        print_report(y_true_te, y_pred_te, CLASS_NAMES_4C)

        # ---------- FT POR SUJETO ----------
        acc_ft = np.nan
        if DO_SUBJECT_FT:
            # usa SOLO eval-crops del sujeto para evaluación justa
            X_te, y_te, g_te, e_te = X[idx_te], y[idx_te], groups[idx_te], eids[idx_te]
            y_true_ft_all, y_pred_ft_all = [], []
            used_subjects = 0
            for sid in np.unique(g_te):
                idxs = np.where(g_te == sid)[0]
                Xs, ys, es = X_te[idxs], y_te[idxs], e_te[idxs]
                if len(np.unique(ys)) < 2 or len(ys) < CALIB_CV_FOLDS:
                    continue
                yt, yp = subject_cv_finetune_predict_grouped(model, Xs, ys, es, n_splits=CALIB_CV_FOLDS, n_classes=N_CLASSES)
                y_true_ft_all.append(yt); y_pred_ft_all.append(yp); used_subjects += 1
            if len(y_true_ft_all) > 0:
                y_true_ft_all = np.concatenate(y_true_ft_all)
                y_pred_ft_all = np.concatenate(y_pred_ft_all)
                acc_ft = (y_true_ft_all == y_pred_ft_all).mean()
                print(f"  Fine-tuning (por sujeto, agrup. ev) acc={acc_ft:.4f} | Δ={acc_ft-acc_ev:+.4f} | sujetos={used_subjects}")
            else:
                print("  Fine-tuning no ejecutado (muestras/clases insuficientes).")
        ft_folds.append(acc_ft)

    # ---------- resultados finales ----------
    if len(all_true_ev) > 0:
        all_true_ev = np.concatenate(all_true_ev)
        all_pred_ev = np.concatenate(all_pred_ev)
    else:
        all_true_ev = np.array([], dtype=int)
        all_pred_ev = np.array([], dtype=int)

    print("\n" + "="*60)
    print("RESULTADOS FINALES (por EVENTO)")
    print("="*60)
    print("Global folds:", [f"{a:.4f}" for a in global_folds])
    if len(global_folds) > 0:
        print(f"Global mean: {np.mean(global_folds):.4f}")

    print("Fine-tune folds:", [("nan" if (a is None or np.isnan(a)) else f"{a:.4f}") for a in ft_folds])
    if len(ft_folds) > 0:
        print(f"Fine-tune mean: {np.nanmean(ft_folds):.4f}")
        print(f"Δ(FT-Global) mean: {np.nanmean(ft_folds) - np.mean(global_folds):+.4f}")

    if all_true_ev.size > 0:
        fname = FIGS_DIR / "confusion_eegnet_sliding_event_eval_allfolds.png"
        plot_confusion(all_true_ev, all_pred_ev, CLASS_NAMES_4C,
                       title="Confusion Matrix (Event-level, ALL FOLDS)",
                       fname=str(fname))
        print(f"↳ Matriz de confusión guardada: {fname.name}")

    return {
        "global_folds": global_folds,
        "ft_folds": ft_folds,
        "all_true_ev": all_true_ev,
        "all_pred_ev": all_pred_ev,
        "folds_json_path": str(FOLDS_JSON)
    }

# ---------- MAIN ----------
if __name__ == "__main__":
    print("🧠 EEGNet + Sliding Windows (+RandomEventSampler) + Event-level Eval")
    print(f"🔧 WINDOW_SET={WINDOW_SET} | POLICY={WINDOW_POLICY} | RANDOM_PER_EVENT={RANDOM_PER_EVENT} (k={K_PER_EVENT})")
    print(f"⚙️  GLOBAL: epochs={EPOCHS_GLOBAL}, lr={LR_INIT}, batch={BATCH_SIZE}, SGDR(T0={SGDR_T0},Tmult={SGDR_Tmult})")
    print(f"🎛️ Augs: SpecAug={'ON' if USE_SPECAUG else 'OFF'}, Mixup=0.2, TimeJitter=50ms, Cutout(OF)")
    print(f"🧪 FT por sujeto: {'ON' if DO_SUBJECT_FT else 'OFF'} "
          f"(epochs={FT_EPOCHS}, base_lr={FT_BASE_LR}, head_lr={FT_HEAD_LR}, L2SP={FT_L2SP})")
    run_experiment()


🚀 Dispositivo: cuda
🎯 Sliding windows → TRAIN=[(0.0, 3.0), (0.25, 3.25), (0.5, 3.5), (0.75, 3.75), (1.0, 4.0)] | EVAL=[(0.0, 3.0), (0.5, 3.5), (1.0, 4.0)]
🧠 EEGNet + Sliding Windows (+RandomEventSampler) + Event-level Eval
🔧 WINDOW_SET=3s_base | POLICY=train_aggressive | RANDOM_PER_EVENT=True (k=1)
⚙️  GLOBAL: epochs=100, lr=0.01, batch=64, SGDR(T0=6,Tmult=2)
🎛️ Augs: SpecAug=ON, Mixup=0.2, TimeJitter=50ms, Cutout(OF)
🧪 FT por sujeto: ON (epochs=25, base_lr=5e-05, head_lr=0.001, L2SP=0.0001)
Sujetos elegibles: 103


Escaneando eventos (sin crops):   1%|          | 1/103 [00:00<00:13,  7.76it/s]

Escaneando eventos (sin crops): 100%|██████████| 103/103 [00:12<00:00,  8.34it/s]
Generando segmentos (train∪eval crops): 100%|██████████| 103/103 [00:16<00:00,  6.37it/s]


=== DEBUG WINDOWS SUMMARY ===
Eventos brutos: 8652
Crops OK:      43260
Skips OOB:     0
Skips LEN:     0
Skips NaN:     0
Dataset construido: N=43260 | T=480 | C=8 | sujetos únicos=103 | eventos únicos=8652
Listo para entrenar: N=43260 | T=480 | C=8 | clases=4 | sujetos=103

[Fold 1/5] Entrenando (train_ev≈1122 | val_ev≈272 | test_ev≈378)
  Época   1 | train_acc(ev)=0.3182 | val_acc(ev)=0.3272 | LR=0.01000
  Época   5 | train_acc(ev)=0.3939 | val_acc(ev)=0.4412 | LR=0.00250
  Época  10 | train_acc(ev)=0.4100 | val_acc(ev)=0.4118 | LR=0.00854
  Época  15 | train_acc(ev)=0.4198 | val_acc(ev)=0.4632 | LR=0.00250
  Época  20 | train_acc(ev)=0.4242 | val_acc(ev)=0.4632 | LR=0.00996
  Época  25 | train_acc(ev)=0.4385 | val_acc(ev)=0.4375 | LR=0.00854
  Early stopping @ 27 (mejor val_acc(ev)=0.4743)
↳ Curva train/val guardada: training_curve_event_fold1.png
[Fold 1/5] Global acc (por evento) = 0.3360
              precision    recall  f1-score   support

        Left     0.0000    0.0000    

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


  Fine-tuning (por sujeto, agrup. ev) acc=0.5238 | Δ=+0.1878 | sujetos=21

[Fold 2/5] Entrenando (train_ev≈1123 | val_ev≈272 | test_ev≈378)
  Época   1 | train_acc(ev)=0.3134 | val_acc(ev)=0.3162 | LR=0.01000
  Época   5 | train_acc(ev)=0.3963 | val_acc(ev)=0.4081 | LR=0.00250
  Época  10 | train_acc(ev)=0.4078 | val_acc(ev)=0.3897 | LR=0.00854
  Época  15 | train_acc(ev)=0.4310 | val_acc(ev)=0.4338 | LR=0.00250
  Época  20 | train_acc(ev)=0.4256 | val_acc(ev)=0.4228 | LR=0.00996
  Época  25 | train_acc(ev)=0.3989 | val_acc(ev)=0.4191 | LR=0.00854
  Early stopping @ 29 (mejor val_acc(ev)=0.4596)
↳ Curva train/val guardada: training_curve_event_fold2.png
[Fold 2/5] Global acc (por evento) = 0.4370
              precision    recall  f1-score   support

        Left     0.5893    0.3929    0.4714        84
       Right     0.9462    0.4505    0.6104       273
  Both Fists     0.0000    0.0000    0.0000         0
   Both Feet     0.0000    0.0000    0.0000         0

    accuracy          

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


  Fine-tuning (por sujeto, agrup. ev) acc=0.6585 | Δ=+0.2216 | sujetos=21

[Fold 3/5] Entrenando (train_ev≈1123 | val_ev≈272 | test_ev≈357)
  Época   1 | train_acc(ev)=0.2848 | val_acc(ev)=0.2873 | LR=0.01000
  Época   5 | train_acc(ev)=0.2414 | val_acc(ev)=0.2313 | LR=0.00250
  Época  10 | train_acc(ev)=0.2450 | val_acc(ev)=0.2313 | LR=0.00854
  Early stopping @ 11 (mejor val_acc(ev)=0.2873)
↳ Curva train/val guardada: training_curve_event_fold3.png
[Fold 3/5] Global acc (por evento) = 0.5966
              precision    recall  f1-score   support

        Left     1.0000    0.5966    0.7474       357
       Right     0.0000    0.0000    0.0000         0
  Both Fists     0.0000    0.0000    0.0000         0
   Both Feet     0.0000    0.0000    0.0000         0

    accuracy                         0.5966       357
   macro avg     0.2500    0.1492    0.1868       357
weighted avg     1.0000    0.5966    0.7474       357

  Fine-tuning no ejecutado (muestras/clases insuficientes).


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



[Fold 4/5] Entrenando (train_ev≈1139 | val_ev≈272 | test_ev≈340)
  Época   1 | train_acc(ev)=0.2467 | val_acc(ev)=0.2169 | LR=0.01000
  Época   5 | train_acc(ev)=0.2862 | val_acc(ev)=0.2941 | LR=0.00250
  Época  10 | train_acc(ev)=0.2888 | val_acc(ev)=0.2757 | LR=0.00854
  Época  15 | train_acc(ev)=0.2994 | val_acc(ev)=0.3309 | LR=0.00250
  Época  20 | train_acc(ev)=0.3310 | val_acc(ev)=0.3309 | LR=0.00996
  Época  25 | train_acc(ev)=0.3512 | val_acc(ev)=0.3272 | LR=0.00854
  Early stopping @ 27 (mejor val_acc(ev)=0.3566)
↳ Curva train/val guardada: training_curve_event_fold4.png
[Fold 4/5] Global acc (por evento) = 0.7324
              precision    recall  f1-score   support

        Left     0.0000    0.0000    0.0000         0
       Right     0.0000    0.0000    0.0000         0
  Both Fists     0.0000    0.0000    0.0000         0
   Both Feet     1.0000    0.7324    0.8455       340

    accuracy                         0.7324       340
   macro avg     0.2500    0.1831    0.211

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


  Época   1 | train_acc(ev)=0.2827 | val_acc(ev)=0.2794 | LR=0.01000
  Época   5 | train_acc(ev)=0.2906 | val_acc(ev)=0.2794 | LR=0.00250
  Época  10 | train_acc(ev)=0.3108 | val_acc(ev)=0.3640 | LR=0.00854
  Época  15 | train_acc(ev)=0.3327 | val_acc(ev)=0.3860 | LR=0.00250
  Época  20 | train_acc(ev)=0.4083 | val_acc(ev)=0.4412 | LR=0.00996
  Época  25 | train_acc(ev)=0.4688 | val_acc(ev)=0.4743 | LR=0.00854
  Época  30 | train_acc(ev)=0.4539 | val_acc(ev)=0.4706 | LR=0.00565
  Época  35 | train_acc(ev)=0.4688 | val_acc(ev)=0.4706 | LR=0.00250
  Early stopping @ 35 (mejor val_acc(ev)=0.4743)
↳ Curva train/val guardada: training_curve_event_fold5.png
[Fold 5/5] Global acc (por evento) = 0.2750
              precision    recall  f1-score   support

        Left     0.0000    0.0000    0.0000         0
       Right     0.0000    0.0000    0.0000         0
  Both Fists     0.8393    0.1808    0.2975       260
   Both Feet     0.4262    0.5200    0.4685       100

    accuracy            

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


  Fine-tuning (por sujeto, agrup. ev) acc=0.5750 | Δ=+0.3000 | sujetos=20

RESULTADOS FINALES (por EVENTO)
Global folds: ['0.3360', '0.4370', '0.5966', '0.7324', '0.2750']
Global mean: 0.4754
Fine-tune folds: ['0.5238', '0.6585', 'nan', 'nan', '0.5750']
Fine-tune mean: 0.5858
Δ(FT-Global) mean: +0.1104
↳ Matriz de confusión guardada: confusion_eegnet_sliding_event_eval_allfolds.png


In [9]:
# -*- coding: utf-8 -*-
# CNN + Transformer con sliding windows + evaluación por evento
# - Lectura EDF (PhysioNet 2009 MI), 8 canales principales, sujetos excluidos
# - Ventanas con stride fijo (sin fuga entre eventos)
# - Entrenamiento por ventanas; evaluación global AGRUPADA por evento (voto)
# - Curvas train/val por fold
# - Scheduler Cosine + warmup; mixup + label smoothing + class weights
# - Fine-tuning por sujeto (CV por evento) tras evaluar el modelo global

import os, re, math, random, json, copy, itertools
from pathlib import Path
from datetime import datetime
from collections import Counter, defaultdict

import numpy as np
import pandas as pd
import mne

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset, WeightedRandomSampler

from sklearn.model_selection import GroupKFold, GroupShuffleSplit, StratifiedKFold
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.utils import check_random_state

import matplotlib.pyplot as plt
from tqdm import tqdm

# =========================
# CONFIG
# =========================
PROJ = Path('..').resolve().parent
DATA_RAW = PROJ / 'data' / 'raw'
CACHE_DIR = PROJ / 'data' / 'cache'
FOLDS_JSON = PROJ / 'models' / 'folds' / 'Kfold5.json'
CACHE_DIR.mkdir(parents=True, exist_ok=True)

RANDOM_STATE = 42
torch.manual_seed(RANDOM_STATE); np.random.seed(RANDOM_STATE); random.seed(RANDOM_STATE)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🚀 DEVICE = {DEVICE}")

# Datos
FS = 160.0
EXPECTED_8 = ['C3','C4','Cz','CP3','CP4','FC3','FC4','FCz']
EXCLUDE_SUBJECTS = {38, 88, 89, 92, 100, 104}

# Escenario
CLASS_SCENARIO = '4c'
CLASS_NAMES_4C = ['Left','Right','Both Fists','Both Feet']
N_CLASSES = 4
N_FOLDS = 5

# Runs
MI_RUNS_LR = [4, 8, 12]   # T1/T2 → L/R
MI_RUNS_OF = [6, 10, 14]  # T1/T2 → BothFists/BothFeet
BASELINE_RUNS_EO = [1]

# Normalización por evento (z-score canal-a-canal)
NORM_EPOCH_ZSCORE = True

# Sliding windows
WIN_LEN_S = 1.5   # 1.5 s
STRIDE_S  = 0.5   # 0.5 s
# Rango donde EXTRAEMOS ventanas para entrenamiento (más largo que eval)
TRAIN_RANGE_S = (-0.5, 4.0)  # relativo al onset (robust MI latente)
# Rango donde EVALUAMOS por evento (ventanas solo aquí para evitar fuga)
EVAL_RANGE_S  = ( 0.0, 3.0)

# Entrenamiento
BATCH_SIZE = 64
EPOCHS_GLOBAL = 80
GLOBAL_VAL_SPLIT = 0.17
GLOBAL_PATIENCE = 12
LOG_EVERY = 5

# Opt & Scheduler
LR_BASE = 1e-2
WARMUP_EPOCHS = 4   # warmup lineal a LR_BASE
SGDR_T0 = 10
SGDR_Tmult = 1

# Pérdidas
LABEL_SMOOTH = 0.05
MIXUP_ALPHA  = 0.2
BOOST_BFISTS = 1.20  # clase 2 con ligero boost

# FT por sujeto (CV por evento)
DO_SUBJECT_FT = True
FT_EPOCHS = 20
FT_LR = 5e-4
FT_PATIENCE = 5
FT_CV_SPLITS = 4
FT_BATCH = 64

# =========================
# UTILS: canales y parsing
# =========================
def normalize_label(s: str) -> str:
    if s is None: return s
    s = s.strip()
    s = re.sub(r'[^A-Za-z0-9]', '', s)
    s = re.sub(r'([A-Za-z])0([0-9])', r'\1\2', s)
    s = re.sub(r'([A-Za-z])Z$', r'\1z', s)
    s = s.replace('fp','Fp').replace('FP','Fp')
    s = ''.join(ch.upper() if ch!='z' else 'z' for ch in s)
    return s

def rename_channels_1010(raw: mne.io.BaseRaw):
    mapping = {}
    for ch in raw.ch_names:
        lab = normalize_label(ch)
        lab = re.sub(r'([A-Z])Z$', r'\1z', lab)
        mapping[ch] = lab
    mne.rename_channels(raw.info, mapping)

def ensure_channels_order(raw: mne.io.BaseRaw, desired=EXPECTED_8):
    missing = [c for c in desired if c not in raw.ch_names]
    if missing:
        print(f"[WARN] Faltan canales {missing} en {getattr(raw,'filenames',[''])[0]}")
        return None
    raw.reorder_channels([c for c in raw.ch_names if c in desired] +
                         [c for c in raw.ch_names if c not in desired])
    raw.pick_channels(desired, ordered=True)
    return raw

_re_file = re.compile(r'[Ss](\d{3}).*?[Rr](\d{2})')

def parse_subject_run(path: Path):
    m = _re_file.search(str(path))
    if not m: return None, None
    return int(m.group(1)), int(m.group(2))

def run_kind(r):
    if r in MI_RUNS_LR: return 'LR'
    if r in MI_RUNS_OF: return 'OF'
    if r in BASELINE_RUNS_EO: return 'EO'
    return None

# (opcional) notch adaptativo vía CSV SNR
_SNR_TABLE = None
def _load_snr_table():
    global _SNR_TABLE
    if _SNR_TABLE is not None: return _SNR_TABLE
    csv_path = PROJ / 'reports' / 'psd_mains' / 'psd_mains_summary.csv'
    if csv_path.exists():
        try: _SNR_TABLE = pd.read_csv(csv_path)
        except: _SNR_TABLE = None
    return _SNR_TABLE

def _decide_notch(subject, run, th_db=10.0):
    df = _load_snr_table()
    if df is None: return 60.0
    row = df[(df['subject']==subject) & (df['run']==run)]
    if row.empty: return 60.0
    snr50 = float(row['snr50_db'].iloc[0]); snr60 = float(row['snr60_db'].iloc[0])
    if snr60 >= th_db and snr60 >= snr50: return 60.0
    if snr50 >= th_db and snr50 >  snr60: return 50.0
    return None

def read_raw_edf(path: Path):
    raw = mne.io.read_raw_edf(path, preload=True, verbose=False)
    raw.pick(mne.pick_types(raw.info, eeg=True))
    rename_channels_1010(raw)
    try:
        mont = mne.channels.make_standard_montage('standard_1020')
        raw.set_montage(mont, on_missing='ignore')
    except: pass
    if abs(raw.info['sfreq'] - FS) > 1e-6:
        raw.resample(FS, npad="auto")
    raw = ensure_channels_order(raw, EXPECTED_8)
    if raw is None: return None
    # notch (si hay SNR)
    sid, rid = parse_subject_run(path)
    notch = _decide_notch(sid, rid)
    if notch is not None:
        raw.notch_filter(freqs=[float(notch)], picks='eeg', method='spectrum_fit', phase='zero')
    return raw

def collect_events_T1T2(raw: mne.io.BaseRaw):
    if raw.annotations is None or len(raw.annotations)==0:
        return []
    def _norm(s): return str(s).strip().upper().replace(' ','')
    out = []
    for onset, desc in zip(raw.annotations.onset, raw.annotations.description):
        tag = _norm(desc)
        if tag in ('T1','T2'): out.append((float(onset), tag))
    out.sort()
    dedup = []
    last_t1 = last_t2 = -1e9
    for t, tag in out:
        if tag=='T1':
            if (t-last_t1) >= 0.5: dedup.append((t,tag)); last_t1=t
        else:
            if (t-last_t2) >= 0.5: dedup.append((t,tag)); last_t2=t
    return dedup

# =========================
# WINDOWING (train/eval)
# =========================
def make_windows_for_event(data_tc, fs, onset_sec, label_idx,
                           rel_range,   # (start_s, end_s) relativo al onset
                           win_len_s, stride_s,
                           zscore=True):
    """
    data_tc: (T, C)
    Devuelve lista de (win_tc, label_idx) y (samp_idx_ini, samp_idx_fin) absolutos
    """
    T, C = data_tc.shape
    s_rel, e_rel = rel_range
    s_abs = int(round((onset_sec + s_rel) * fs))
    e_abs = int(round((onset_sec + e_rel) * fs))
    s_abs = max(0, s_abs); e_abs = min(T, e_abs)
    if e_abs - s_abs <= 0: return []

    wl = int(round(win_len_s * fs))
    st = int(round(stride_s * fs))
    if wl <= 0 or st <= 0: return []

    out = []
    t = s_abs
    while t + wl <= e_abs:
        seg = data_tc[t:t+wl, :].astype(np.float32)
        if zscore:
            seg = (seg - seg.mean(axis=0, keepdims=True)) / (1e-6 + seg.std(axis=0, keepdims=True))
        out.append((seg, label_idx, t, t+wl))
        t += st
    return out

def label_from_tag(kind, tag):
    if kind=='LR':
        return 0 if tag=='T1' else 1
    else:
        return 2 if tag=='T1' else 3

def subjects_available():
    subs = []
    for sdir in sorted(DATA_RAW.glob('S*')):
        if not sdir.is_dir(): continue
        try: sid = int(sdir.name[1:])
        except: continue
        if sid in EXCLUDE_SUBJECTS: continue
        any_mi = any((sdir / f"S{sid:03d}R{r:02d}.edf").exists() for r in (MI_RUNS_LR+MI_RUNS_OF))
        if any_mi: subs.append(sid)
    return subs

def build_windows_dataset(subjects):
    X_wins, y_wins, g_wins, ev_ids, ev_meta = [], [], [], [], []  # ev_meta: (subj, run, onset_idx)
    ch_template = None
    event_counter = 0

    for s in tqdm(subjects, desc="Construyendo (windows)"):
        sdir = DATA_RAW / f"S{s:03d}"
        if not sdir.exists(): continue

        for r in MI_RUNS_LR + MI_RUNS_OF:
            p = sdir / f"S{s:03d}R{r:02d}.edf"
            if not p.exists(): continue
            kind = run_kind(r)
            raw = read_raw_edf(p)
            if raw is None: continue
            if ch_template is None: ch_template = raw.ch_names

            fs = raw.info['sfreq']
            data = raw.get_data().T.astype(np.float32)  # (T, C)
            events = collect_events_T1T2(raw)

            for onset, tag in events:
                lab = label_from_tag(kind, tag)
                # ventanas de TRAIN (amplio)
                wins_tr = make_windows_for_event(data, fs, raw.first_time+onset, lab,
                                                 rel_range=TRAIN_RANGE_S,
                                                 win_len_s=WIN_LEN_S, stride_s=STRIDE_S,
                                                 zscore=NORM_EPOCH_ZSCORE)
                # ventanas de EVAL (estricto)
                wins_ev = make_windows_for_event(data, fs, raw.first_time+onset, lab,
                                                 rel_range=EVAL_RANGE_S,
                                                 win_len_s=WIN_LEN_S, stride_s=STRIDE_S,
                                                 zscore=NORM_EPOCH_ZSCORE)

                # asignamos un id único por evento (para agrupar en eval/FT)
                ev_id = event_counter; event_counter += 1

                for seg, yk, si, ei in wins_tr:
                    X_wins.append(seg); y_wins.append(yk); g_wins.append(s); ev_ids.append(ev_id)
                # guardamos meta de eval (cantidad/índices por evento)
                ev_meta.append({
                    'event_id': ev_id, 'subject': s, 'run': r, 'label': lab,
                    'n_train_wins': len(wins_tr), 'n_eval_wins': len(wins_ev)
                })

    X = np.stack(X_wins, axis=0) if len(X_wins)>0 else np.zeros((0,int(WIN_LEN_S*FS),len(EXPECTED_8)), np.float32)
    y = np.asarray(y_wins, dtype=np.int64)
    g = np.asarray(g_wins, dtype=np.int64)
    ev_ids = np.asarray(ev_ids, dtype=np.int64)

    print("=== DEBUG WINDOWS SUMMARY ===")
    print(f"Ventanas totales: {len(X_wins)}  | sujetos: {len(set(g_wins))}  | eventos: {event_counter}")
    return X, y, g, ev_ids, ch_template, pd.DataFrame(ev_meta)

# =========================
# DATASET torch
# =========================
class EEGWindows(Dataset):
    def __init__(self, X, y, groups, ev_ids):
        self.X = X.astype(np.float32)
        self.y = y.astype(np.int64)
        self.g = groups.astype(np.int64)
        self.e = ev_ids.astype(np.int64)
    def __len__(self): return self.X.shape[0]
    def __getitem__(self, idx):
        x = self.X[idx]   # (T, C)
        x = np.expand_dims(x, 0)  # (1, T, C)
        return torch.from_numpy(x), torch.tensor(self.y[idx]), torch.tensor(self.g[idx]), torch.tensor(self.e[idx])

# =========================
# AUGMENTS
# =========================
def do_time_jitter(x, max_ms=25, fs=160.0):
    max_shift = int(round(max_ms/1000.0*fs))
    if max_shift<=0: return x
    B,_,T,C = x.shape
    shifts = torch.randint(-max_shift, max_shift+1, (B,), device=x.device)
    out = torch.zeros_like(x)
    for i,s in enumerate(shifts):
        if s>=0:
            out[i,:,s:,:] = x[i,:,:T-s,:]
        else:
            s=-s
            out[i,:,:T-s,:] = x[i,:,s:,:]
    return out

def do_gaussian_noise(x, sigma=0.01):
    if sigma<=0: return x
    return x + sigma*torch.randn_like(x)

def do_temporal_cutout_masked(x, y, classes_mask={2,3}, min_ms=30, max_ms=60, fs=160.0):
    B,_,T,C = x.shape
    Lmin = int(round(min_ms/1000.0*fs)); Lmax = int(round(max_ms/1000.0*fs))
    if Lmin<=0 or Lmax<=0 or Lmin>Lmax: return x
    out = x.clone()
    y_np = y.detach().cpu().numpy()
    for i in range(B):
        if int(y_np[i]) not in classes_mask: continue
        L = random.randint(Lmin, Lmax); 
        if L>=T: continue
        s = random.randint(0, T-L)
        out[i,:,s:s+L,:] = 0
    return out

def mixup_batch(x, y, n_classes, alpha=MIXUP_ALPHA):
    if alpha<=0:
        y_one = torch.nn.functional.one_hot(y, num_classes=n_classes).float()
        return x, y_one, 1.0
    lam = np.random.beta(alpha, alpha)
    perm = torch.randperm(x.size(0), device=x.device)
    x_mix = lam*x + (1-lam)*x[perm]
    ya = torch.nn.functional.one_hot(y, num_classes=n_classes).float()
    yb = ya[perm]
    y_mix = lam*ya + (1-lam)*yb
    return x_mix, y_mix, lam

class WeightedSoftCrossEntropy(nn.Module):
    def __init__(self, class_weights=None, label_smoothing=0.0):
        super().__init__()
        self.register_buffer('w', None if class_weights is None else class_weights.clone().float())
        self.ls = float(label_smoothing)
    def forward(self, logits, target_probs):
        if self.ls>0:
            K = logits.size(1)
            target_probs = (1-self.ls)*target_probs + self.ls*(1.0/K)
        logp = torch.log_softmax(logits, dim=1)
        loss_per_class = -(target_probs * logp)
        if self.w is not None:
            loss_per_class = loss_per_class * self.w.unsqueeze(0)
        return loss_per_class.sum(dim=1).mean()

def class_weight_tensor(y_indices, n_classes, boost_bfists=1.2):
    cnt = np.bincount(y_indices, minlength=n_classes).astype(np.float32)
    cnt[cnt==0]=1.0
    w = cnt.sum()/cnt
    w = w/w.mean()
    if n_classes>=3:
        w[2] *= boost_bfists
        w = w/w.mean()
    return torch.tensor(w, dtype=torch.float32, device=DEVICE)

# =========================
# MODELO: CNN + Transformer
# =========================
class ConvFrontEnd(nn.Module):
    """Conv temporal + depthwise espacial (tipo EEGNet lite)"""
    def __init__(self, n_ch, feat_dim=64, k_t=64, pool_t=4, drop_p=0.25):
        super().__init__()
        self.conv_t = nn.Conv2d(1, feat_dim, kernel_size=(k_t,1), padding=(k_t//2,0), bias=False)
        self.bn1 = nn.BatchNorm2d(feat_dim)
        self.act = nn.ELU()
        self.conv_dw = nn.Conv2d(feat_dim, feat_dim, kernel_size=(1,n_ch), groups=feat_dim, bias=False)
        self.bn2 = nn.BatchNorm2d(feat_dim)
        self.pool = nn.AvgPool2d(kernel_size=(pool_t,1), stride=(pool_t,1))
        self.drop = nn.Dropout(drop_p)
    def forward(self, x):  # x: (B,1,T,C)
        z = self.conv_t(x); z = self.bn1(z); z = self.act(z)
        z = self.conv_dw(z); z = self.bn2(z); z = self.act(z)
        z = self.pool(z); z = self.drop(z)      # (B, F, T', 1)
        z = z.squeeze(-1).permute(0,2,1)        # (B, T', F)
        return z

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=4000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len).unsqueeze(1)
        div = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0)/d_model))
        pe[:,0::2] = torch.sin(pos*div)
        pe[:,1::2] = torch.cos(pos*div)
        self.register_buffer('pe', pe.unsqueeze(0))  # (1, max_len, d_model)
    def forward(self, x):  # x: (B, T, D)
        T = x.size(1)
        return x + self.pe[:,:T,:]

class CNNTransformer(nn.Module):
    def __init__(self, n_ch, n_classes, cnn_feat=64, k_t=64, pool_t=4,
                 tf_d_model=96, nhead=4, tf_dim_ff=192, tf_layers=3, tf_dropout=0.1,
                 proj_dim=128, drop_p=0.3):
        super().__init__()
        self.front = ConvFrontEnd(n_ch, feat_dim=cnn_feat, k_t=k_t, pool_t=pool_t, drop_p=0.2)
        self.proj = nn.Linear(cnn_feat, tf_d_model)
        self.pe = PositionalEncoding(tf_d_model)

        enc_layer = nn.TransformerEncoderLayer(
            d_model=tf_d_model, nhead=nhead,
            dim_feedforward=tf_dim_ff, dropout=tf_dropout,
            activation='gelu', batch_first=True, norm_first=True
        )
        # 🔇 evitar warning nested_tensor
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=tf_layers, enable_nested_tensor=False)

        self.head = nn.Sequential(
            nn.LayerNorm(tf_d_model),
            nn.Linear(tf_d_model, proj_dim),
            nn.GELU(),
            nn.Dropout(drop_p),
            nn.Linear(proj_dim, n_classes)
        )

    def forward(self, x):  # (B,1,T,C)
        z = self.front(x)              # (B, T', F)
        z = self.proj(z)               # (B, T', D)
        z = self.pe(z)
        z = self.encoder(z)            # (B, T', D)
        z = z.mean(dim=1)              # pool temporal
        return self.head(z)

# =========================
# TRAIN / EVAL helpers
# =========================
def build_weighted_sampler(y, groups):
    y = np.asarray(y); groups = np.asarray(groups)
    cls_counts = np.bincount(y, minlength=len(np.unique(y))).astype(float)
    cls_w = 1.0 / cls_counts[y]
    subj_vals, subj_counts = np.unique(groups, return_counts=True)
    s_map = {s:c for s,c in zip(subj_vals, subj_counts)}
    subj_w = np.array([1.0/s_map[g] for g in groups], float)
    w = cls_w * subj_w; w = w / w.mean()
    return WeightedRandomSampler(torch.tensor(w, dtype=torch.float32), num_samples=len(w), replacement=True)

def train_one_epoch(model, loader, opt, criterion_soft, n_classes, fs=160.0, maxnorm=None):
    model.train()
    total = 0; correct = 0
    for xb, yb, _, _ in loader:
        xb, yb = xb.to(DEVICE), yb.to(DEVICE)
        # augments
        xb = do_time_jitter(xb, max_ms=25, fs=fs)
        xb = do_gaussian_noise(xb, sigma=0.01)
        xb = do_temporal_cutout_masked(xb, yb, classes_mask={2,3}, min_ms=30, max_ms=60, fs=fs)
        xb, ymix, _ = mixup_batch(xb, yb, n_classes=n_classes, alpha=MIXUP_ALPHA)

        logits = model(xb)
        loss = criterion_soft(logits, ymix)

        opt.zero_grad(set_to_none=True)
        loss.backward(); opt.step()

        with torch.no_grad():
            pred = logits.argmax(1)
            # para precisión aproximada con mixup, comparamos con argmax de ymix (solo para logging)
            tgt = ymix.argmax(1)
            correct += (pred==tgt).sum().item(); total += xb.size(0)
        if maxnorm is not None:
            # aplica max-norm a capas densas/pointwise si lo deseas (opcional)
            for m in model.modules():
                if isinstance(m, nn.Linear):
                    w = m.weight.data
                    norms = w.view(w.size(0), -1).norm(p=2, dim=1, keepdim=True)
                    desired = torch.clamp(norms, max=maxnorm)
                    w.view(w.size(0), -1).mul_(desired / (1e-8 + norms))
    return correct/max(1,total)

@torch.no_grad()
def eval_accuracy_windows(model, loader):
    model.eval()
    total = 0; correct = 0
    all_y, all_p = [], []
    for xb, yb, _, _ in loader:
        xb = xb.to(DEVICE)
        logits = model(xb)
        pred = logits.argmax(1).cpu()
        all_y.extend(yb.numpy().tolist())
        all_p.extend(pred.numpy().tolist())
        correct += (pred==yb).sum().item(); total += yb.size(0)
    return correct/max(1,total), np.asarray(all_y, int), np.asarray(all_p, int)

@torch.no_grad()
def evaluate_grouped_by_event(model, loader):
    """
    Devuelve y_true_event, y_pred_event, acc_event
    """
    model.eval()
    y_win, p_win, ev_win = [], [], []
    for xb, yb, _, ev in loader:
        xb = xb.to(DEVICE)
        logits = model(xb)
        pred = logits.argmax(1).cpu().numpy()
        y_win.extend(yb.numpy().tolist())
        p_win.extend(pred.tolist())
        ev_win.extend(ev.numpy().tolist())

    y_win = np.asarray(y_win, int)
    p_win = np.asarray(p_win, int)
    ev_win = np.asarray(ev_win, int)

    # voto mayoritario + desempate por frecuencia
    y_true_ev, y_pred_ev = majority_vote_per_event(y_win, p_win, ev_win, proba=None, n_classes=N_CLASSES)
    acc_ev = (y_true_ev == y_pred_ev).mean() if len(y_true_ev)>0 else 0.0
    return y_true_ev, y_pred_ev, acc_ev

def majority_vote_per_event(y_win, p_win, ev_win, proba=None, n_classes=4):
    y_true_ev, y_pred_ev = [], []
    for e in np.unique(ev_win):
        idx = np.where(ev_win==e)[0]
        yt = int(np.bincount(y_win[idx]).argmax())  # ground-truth por evento
        counts = np.bincount(p_win[idx], minlength=n_classes)
        top = np.where(counts == counts.max())[0]
        if len(top)==1:
            yhat = int(top[0])
        else:
            # desempate por suma de probas si las tuvieras; aquí por índice mínimo
            yhat = int(top.min())
        y_true_ev.append(yt); y_pred_ev.append(yhat)
    return np.asarray(y_true_ev, int), np.asarray(y_pred_ev, int)

def plot_training_curves(history, fname):
    plt.figure(figsize=(6,4))
    plt.plot(history['train_acc'], label='train_acc(win)')
    plt.plot(history['val_acc'], label='val_acc(win)')
    plt.xlabel('Epoch'); plt.ylabel('Accuracy')
    plt.title('Training curve (windows)')
    plt.legend(); plt.tight_layout(); plt.savefig(fname, dpi=150); plt.close()

def plot_confusion(y_true, y_pred, classes, title, fname):
    cm = confusion_matrix(y_true, y_pred, labels=list(range(len(classes))))
    with np.errstate(invalid='ignore'):
        cmn = cm.astype(float) / cm.sum(axis=1, keepdims=True)
    cmn = np.nan_to_num(cmn)
    plt.figure(figsize=(6,5))
    plt.imshow(cmn, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title(title); plt.colorbar(fraction=0.046, pad=0.04)
    ticks = np.arange(len(classes))
    plt.xticks(ticks, classes, rotation=45, ha='right'); plt.yticks(ticks, classes)
    fmt = '.2f'; thresh = cmn.max()/2
    for i,j in itertools.product(range(cmn.shape[0]), range(cmn.shape[1])):
        plt.text(j, i, format(cmn[i,j], fmt),
                 ha='center', va='center',
                 color='white' if cmn[i,j]>thresh else 'black')
    plt.ylabel('True'); plt.xlabel('Pred'); plt.tight_layout()
    plt.savefig(fname, dpi=150, bbox_inches='tight'); plt.close()

# =========================
# SCHEDULER (warmup + Cosine)
# =========================
class WarmupThenCosine:
    def __init__(self, opt, base_lr, warmup_epochs, T0, Tmult=1, min_lr=1e-5):
        self.opt = opt
        self.base_lr = base_lr
        self.warm = max(1, warmup_epochs)
        self.cos = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(opt, T_0=T0, T_mult=Tmult, eta_min=min_lr)
        self._epoch = 0
        # arranca en lr pequeño
        for pg in self.opt.param_groups: pg['lr'] = base_lr / (self.warm*10.0)
    def step(self):
        self._epoch += 1
        if self._epoch <= self.warm:
            lr = (self.base_lr * self._epoch) / float(self.warm)
            for pg in self.opt.param_groups: pg['lr'] = lr
        else:
            self.cos.step(self._epoch - self.warm)

# =========================
# FT por sujeto (CV por evento)
# =========================
def _events_from_indices(idx, ev_ids, y):
    idx = np.asarray(idx, int)
    ev = ev_ids[idx]; yw = y[idx]
    events = {}
    first = {}
    for k,i in enumerate(idx):
        e = ev[k]
        events.setdefault(e, []).append(i)
        if e not in first: first[e] = int(yw[k])
    ev_list = np.array(sorted(events.keys()), dtype=int)
    y_ev = np.array([first[e] for e in ev_list], dtype=int)
    win_lists = [events[e] for e in ev_list]
    return ev_list, y_ev, win_lists

def finetune_subject_eventcv(model_global, X, y, ev_ids, subj_ids, subj,
                             n_splits=FT_CV_SPLITS, epochs=FT_EPOCHS, lr=FT_LR, patience=FT_PATIENCE, batch=FT_BATCH):
    device = next(model_global.parameters()).device
    idx_subj = np.where(subj_ids==subj)[0]
    if idx_subj.size==0: return np.array([], int), np.array([], int)

    ev_list, y_events, win_lists = _events_from_indices(idx_subj, ev_ids, y)
    if len(np.unique(y_events))<2 or len(y_events)<n_splits:
        return np.array([], int), np.array([], int)

    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=RANDOM_STATE)
    y_true_ev_all, y_pred_ev_all = [], []

    for tr_e, te_e in skf.split(ev_list, y_events):
        tr_idx = np.concatenate([win_lists[i] for i in tr_e])
        te_idx = np.concatenate([win_lists[i] for i in te_e])

        ds_tr = EEGWindows(X[tr_idx], y[tr_idx], subj_ids[tr_idx], ev_ids[tr_idx])
        ds_te = EEGWindows(X[te_idx], y[te_idx], subj_ids[te_idx], ev_ids[te_idx])
        dl_tr = DataLoader(ds_tr, batch_size=batch, shuffle=True)
        dl_te = DataLoader(ds_te, batch_size=batch, shuffle=False)

        m = copy.deepcopy(model_global)
        for p in m.parameters(): p.requires_grad = True  # FT completo; si quieres, puedes hacer etapas
        opt = optim.Adam(m.parameters(), lr=lr)
        crit = nn.CrossEntropyLoss()

        best = copy.deepcopy(m.state_dict()); best_val = 1e9; bad=0
        # validación rápida: pérdida media en train (proxy) + early stop
        for ep in range(1, epochs+1):
            m.train()
            for xb, yb, _, _ in dl_tr:
                xb, yb = xb.to(device), yb.to(device)
                opt.zero_grad(set_to_none=True)
                loss = crit(m(xb), yb)
                loss.backward(); opt.step()
            # proxy val
            with torch.no_grad():
                m.eval()
                tot, cor = 0, 0
                for xb, yb, _, _ in dl_tr:
                    xb, yb = xb.to(device), yb.to(device)
                    pred = m(xb).argmax(1)
                    cor += (pred==yb).sum().item(); tot += yb.size(0)
                val = 1.0 - cor/max(1,tot)
                if val<best_val: best_val, bad, best = val, 0, copy.deepcopy(m.state_dict())
                else:
                    bad += 1
                    if bad>=patience: break
        m.load_state_dict(best)

        # pred por evento
        y_te_win, p_te_win, _ = eval_accuracy_windows(m, dl_te)
        # reconstruimos ev ids de te
        ev_te = []
        for _, _, _, e in dl_te:
            ev_te.extend(e.numpy().tolist())
        ev_te = np.asarray(ev_te, int)

        yt_ev, yp_ev = majority_vote_per_event(y_te_win, p_te_win, ev_te, proba=None, n_classes=N_CLASSES)
        y_true_ev_all.append(yt_ev); y_pred_ev_all.append(yp_ev)

    return np.concatenate(y_true_ev_all), np.concatenate(y_pred_ev_all)

# =========================
# Folds helpers
# =========================
def save_group_folds_json_with_indices(subject_ids_str, groups_array, n_splits, out_json_path,
                                       created_by="cnn_tf_event", description=None):
    out_json_path = Path(out_json_path)
    unique_subjects_int = sorted(np.unique(groups_array).tolist())
    subject_ids = [f"S{sid:03d}" for sid in unique_subjects_int]
    if len(subject_ids) < n_splits:
        raise ValueError(f"n_splits={n_splits} > num_subjects={len(subject_ids)}")
    groups = np.arange(len(subject_ids))
    gkf = GroupKFold(n_splits=n_splits)

    folds = []; fold_i = 0
    for tr_grp, te_grp in gkf.split(groups, groups, groups):
        fold_i += 1
        tr_sids = [subject_ids[int(i)] for i in tr_grp]
        te_sids = [subject_ids[int(i)] for i in te_grp]
        tr_int = [int(s[1:]) for s in tr_sids]; te_int = [int(s[1:]) for s in te_sids]
        tr_idx = np.where(np.isin(groups_array, tr_int))[0].tolist()
        te_idx = np.where(np.isin(groups_array, te_int))[0].tolist()
        folds.append({"fold": fold_i, "train": tr_sids, "test": te_sids, "tr_idx": tr_idx, "te_idx": te_idx})

    payload = {
        "created_at": datetime.now().isoformat(),
        "created_by": created_by,
        "description": description or "",
        "n_splits": int(n_splits),
        "n_subjects": len(subject_ids),
        "subject_ids": subject_ids,
        "folds": folds
    }
    out_json_path.parent.mkdir(parents=True, exist_ok=True)
    with open(out_json_path, "w", encoding="utf-8") as f:
        json.dump(payload, f, indent=2, ensure_ascii=False)
    print(f"Folds JSON guardado → {out_json_path}")
    return out_json_path

def load_group_folds_json(path_json, expected_subject_ids=None, strict_check=True):
    path_json = Path(path_json)
    if not path_json.exists():
        raise FileNotFoundError(str(path_json))
    with open(path_json, "r", encoding="utf-8") as f:
        payload = json.load(f)
    subj_json = payload.get("subject_ids", [])
    if expected_subject_ids is not None:
        exp = sorted(list(expected_subject_ids))
        if subj_json != exp:
            msg = ("Subject IDs del JSON no coinciden con los esperados.\n"
                   f"JSON={len(subj_json)} vs EXP={len(exp)}\n"
                   f"JSON first10={subj_json[:10]} | EXP first10={exp[:10]}")
            if strict_check: raise ValueError(msg)
            else: print("[WARN] " + msg)
    return payload

# =========================
# MAIN EXP
# =========================
def run_experiment():
    mne.set_log_level('WARNING')
    subs = subjects_available()
    print(f"Sujetos elegibles: {len(subs)}  →  {subs[:12]}{'...' if len(subs)>12 else ''}")

    # Construir ventanas (train-range), y ev_ids (por-evento consistente)
    X, y, groups, ev_ids, chs, ev_meta = build_windows_dataset(subs)
    N, T, C = X.shape
    print(f"Dataset windows: N={N} | T={T} | C={C} | sujetos={len(np.unique(groups))} | eventos={len(np.unique(ev_ids))}")

    ds = EEGWindows(X, y, groups, ev_ids)

    # Folds por sujetos
    unique_subs = sorted(np.unique(groups).tolist())
    subject_ids_str = [f"S{s:03d}" for s in unique_subs]
    folds_json = Path(FOLDS_JSON)
    folds_json.parent.mkdir(parents=True, exist_ok=True)
    if not folds_json.exists():
        save_group_folds_json_with_indices(subject_ids_str, groups, N_FOLDS, folds_json,
                                           created_by="CNNTF_windows",
                                           description="GroupKFold por sujeto (windows)")

    payload = load_group_folds_json(folds_json, expected_subject_ids=subject_ids_str, strict_check=False)
    folds = payload["folds"]

    global_folds = []
    all_true_ev = []
    all_pred_ev = []

    for f in folds:
        fold = f["fold"]
        tr_idx = np.asarray(f["tr_idx"], int)
        te_idx = np.asarray(f["te_idx"], int)

        # validación por sujeto (split de sujetos dentro de train)
        gss = GroupShuffleSplit(n_splits=1, test_size=GLOBAL_VAL_SPLIT, random_state=RANDOM_STATE)
        tr_subj_idx, va_subj_idx = next(gss.split(tr_idx, groups[tr_idx], groups[tr_idx]))
        tr_sub_idx = tr_idx[tr_subj_idx]
        va_idx     = tr_idx[va_subj_idx]

        # Log distribuciones por evento (aprox)
        print(f"[TRAIN-EVENT] eventos≈{len(np.unique(ev_ids[tr_sub_idx]))} dist={Counter(y[tr_sub_idx])}")
        print(f"[VAL-EVENT]   eventos≈{len(np.unique(ev_ids[va_idx]))} dist={Counter(y[va_idx])}")
        print(f"[TEST-EVENT]  eventos≈{len(np.unique(ev_ids[te_idx]))} dist={Counter(y[te_idx])}")

        sampler = build_weighted_sampler(y[tr_sub_idx], groups[tr_sub_idx])
        tr_loader = DataLoader(Subset(ds, tr_sub_idx), batch_size=BATCH_SIZE, sampler=sampler, drop_last=False)
        va_loader = DataLoader(Subset(ds, va_idx),     batch_size=BATCH_SIZE, shuffle=False, drop_last=False)
        te_loader = DataLoader(Subset(ds, te_idx),     batch_size=BATCH_SIZE, shuffle=False, drop_last=False)

        # Modelo
        model = CNNTransformer(n_ch=C, n_classes=N_CLASSES,
                               cnn_feat=64, k_t=64, pool_t=4,
                               tf_d_model=96, nhead=4, tf_dim_ff=192, tf_layers=3, tf_dropout=0.1,
                               proj_dim=128, drop_p=0.3).to(DEVICE)

        # Opt + scheduler
        opt = optim.Adam(model.parameters(), lr=LR_BASE, weight_decay=0.0)
        sched = WarmupThenCosine(opt, base_lr=LR_BASE, warmup_epochs=WARMUP_EPOCHS, T0=SGDR_T0, Tmult=SGDR_Tmult, min_lr=1e-4)

        # Criterio suave + pesos de clase
        w = class_weight_tensor(y[tr_sub_idx], N_CLASSES, boost_bfists=BOOST_BFISTS)
        criterion_soft = WeightedSoftCrossEntropy(class_weights=w, label_smoothing=LABEL_SMOOTH)

        history = {'train_acc': [], 'val_acc': []}
        best_model = copy.deepcopy(model.state_dict()); best_val = -1.0; bad=0

        print(f"\n[Fold {fold}/{N_FOLDS}] Entrenando (train_ev≈{len(np.unique(ev_ids[tr_sub_idx]))} | val_ev≈{len(np.unique(ev_ids[va_idx]))} | test_ev≈{len(np.unique(ev_ids[te_idx]))})")
        for ep in range(1, EPOCHS_GLOBAL+1):
            tr_acc = train_one_epoch(model, tr_loader, opt, criterion_soft, n_classes=N_CLASSES, fs=FS, maxnorm=2.0)
            va_acc_win, _, _ = eval_accuracy_windows(model, va_loader)
            history['train_acc'].append(tr_acc)
            history['val_acc'].append(va_acc_win)

            if ep % LOG_EVERY == 0 or ep in (1, 5, 10, 15, 20, 25, 30):
                cur_lr = opt.param_groups[0]['lr']
                print(f"  Época {ep:3d} | train_acc(win)={tr_acc:.4f} | val_acc(win)={va_acc_win:.4f} | LR={cur_lr:.5f}")

            # early stop por val (windows)
            if va_acc_win > best_val + 1e-4:
                best_val = va_acc_win; best_model = copy.deepcopy(model.state_dict()); bad=0
            else:
                bad += 1
                if bad >= GLOBAL_PATIENCE:
                    print(f"  Early stopping @ {ep} (mejor val_acc(win)={best_val:.4f})")
                    break

            sched.step()

        curve_path = f"training_curve_event_fold{fold}_cnntf.png"
        plot_training_curves(history, curve_path)
        print(f"↳ Curva train/val guardada: {curve_path}")

        model.load_state_dict(best_model)

        # ===== Evaluación global por EVENTO =====
        y_true_te, y_pred_te, acc_ev = evaluate_grouped_by_event(model, te_loader)
        print(f"[Fold {fold}/{N_FOLDS}] Global acc (por evento) = {acc_ev:.4f}")
        print(classification_report(y_true_te, y_pred_te, target_names=CLASS_NAMES_4C, digits=4))

        global_folds.append(acc_ev)
        all_true_ev.append(y_true_te); all_pred_ev.append(y_pred_te)

        # ===== Fine-tuning por sujeto (CV por evento) =====
        if DO_SUBJECT_FT:
            subj_te = groups[te_idx]
            y_true_ft_all, y_pred_ft_all = [], []
            used = 0
            for sid in np.unique(subj_te):
                yt, yp = finetune_subject_eventcv(model, X, y, ev_ids, groups, sid,
                                                  n_splits=FT_CV_SPLITS, epochs=FT_EPOCHS,
                                                  lr=FT_LR, patience=FT_PATIENCE, batch=FT_BATCH)
                if yt.size==0: continue
                y_true_ft_all.append(yt); y_pred_ft_all.append(yp); used += 1
            if y_true_ft_all:
                y_true_ft_all = np.concatenate(y_true_ft_all)
                y_pred_ft_all = np.concatenate(y_pred_ft_all)
                acc_ft = (y_true_ft_all==y_pred_ft_all).mean()
                print(f"  Fine-tuning (por sujeto, agrup. ev) acc={acc_ft:.4f} | Δ={acc_ft-acc_ev:+.4f} | sujetos={used}")
            else:
                print("  Fine-tuning no ejecutado (eventos/clases insuficientes).")

    # ===== Resultados finales =====
    if all_true_ev:
        all_true_ev = np.concatenate(all_true_ev)
        all_pred_ev = np.concatenate(all_pred_ev)
    else:
        all_true_ev = np.array([], int); all_pred_ev = np.array([], int)

    print("\n============================================================")
    print("RESULTADOS FINALES (por EVENTO) - CNN+Transformer")
    print("============================================================")
    print("Global folds:", [f"{a:.4f}" for a in global_folds])
    if global_folds:
        print(f"Global mean: {np.mean(global_folds):.4f}")

    if all_true_ev.size>0:
        plot_confusion(all_true_ev, all_pred_ev, CLASS_NAMES_4C,
                       title="Confusion Matrix - Global (event-level, all folds)",
                       fname="confusion_cnntf_sliding_event_eval_allfolds.png")
        print("↳ Matriz de confusión guardada: confusion_cnntf_sliding_event_eval_allfolds.png")

    return {
        "global_folds": global_folds,
        "all_true_ev": all_true_ev,
        "all_pred_ev": all_pred_ev,
        "folds_json": str(FOLDS_JSON)
    }

# =============== MAIN ===============
if __name__ == "__main__":
    print("🧪 CNN+Transformer (sliding windows) | 8 canales | 4 clases | eval por EVENTO")
    print(f"⚙️  Window: len={WIN_LEN_S}s, stride={STRIDE_S}s | TRAIN_RANGE={TRAIN_RANGE_S}s | EVAL_RANGE={EVAL_RANGE_S}s")
    print(f"🔧 Augs: jitter±25ms, noise=0.01, cutout 30–60ms (OF), mixup α={MIXUP_ALPHA}, LS={LABEL_SMOOTH}")
    print(f"🗂️  Folds={N_FOLDS} (GroupKFold por sujeto) | FT por sujeto={DO_SUBJECT_FT}")
    run_experiment()


🚀 DEVICE = cuda
🧪 CNN+Transformer (sliding windows) | 8 canales | 4 clases | eval por EVENTO
⚙️  Window: len=1.5s, stride=0.5s | TRAIN_RANGE=(-0.5, 4.0)s | EVAL_RANGE=(0.0, 3.0)s
🔧 Augs: jitter±25ms, noise=0.01, cutout 30–60ms (OF), mixup α=0.2, LS=0.05
🗂️  Folds=5 (GroupKFold por sujeto) | FT por sujeto=True
Sujetos elegibles: 103  →  [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]...


Construyendo (windows):   0%|          | 0/103 [00:00<?, ?it/s]

Construyendo (windows): 100%|██████████| 103/103 [00:38<00:00,  2.64it/s]


=== DEBUG WINDOWS SUMMARY ===
Ventanas totales: 64890  | sujetos: 103  | eventos: 9270
Dataset windows: N=64890 | T=240 | C=8 | sujetos=103 | eventos=9270
[TRAIN-EVENT] eventos≈768 dist=Counter({np.int64(0): 1386, np.int64(1): 1365, np.int64(3): 1323, np.int64(2): 1302})
[VAL-EVENT]   eventos≈216 dist=Counter({np.int64(1): 392, np.int64(0): 385, np.int64(3): 371, np.int64(2): 364})
[TEST-EVENT]  eventos≈252 dist=Counter({np.int64(0): 462, np.int64(3): 441, np.int64(2): 441, np.int64(1): 420})

[Fold 1/5] Entrenando (train_ev≈768 | val_ev≈216 | test_ev≈252)
  Época   1 | train_acc(win)=0.2504 | val_acc(win)=0.2500 | LR=0.00025
  Época   5 | train_acc(win)=0.2608 | val_acc(win)=0.2407 | LR=0.01000
  Época  10 | train_acc(win)=0.2476 | val_acc(win)=0.2407 | LR=0.00505
  Época  15 | train_acc(win)=0.2431 | val_acc(win)=0.2407 | LR=0.01000
  Early stopping @ 15 (mejor val_acc(win)=0.2652)
↳ Curva train/val guardada: training_curve_event_fold1_cnntf.png
[Fold 1/5] Global acc (por evento) = 0

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


TypeError: 'float' object is not subscriptable