# SHALLOWCNN + FINETUNING INTER

In [None]:
# -*- coding: utf-8 -*-
# Replicación fiel del paper "A Deep Learning MI-EEG Classification Model for BCIs"
# Dose et al., EUSIPCO 2018 — Shallow CNN sobre RAW EEG (PhysioNet BCI2000)
# Versión con FINE-TUNING progresivo por sujeto (CV 4-fold, LRs discriminativos, L2-SP, early stopping con validación)

import os, re, math, random, json, itertools, copy
from pathlib import Path
from datetime import datetime

import numpy as np
import pandas as pd
import mne

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset

from sklearn.model_selection import GroupKFold, StratifiedKFold, StratifiedShuffleSplit
from sklearn.utils import check_random_state
from sklearn.metrics import confusion_matrix, classification_report

import matplotlib.pyplot as plt
from tqdm import tqdm

# =========================
# CONFIGURACIÓN GENERAL
# =========================
PROJ = Path('..').resolve().parent
DATA_RAW = PROJ / 'data' / 'raw'
CACHE_DIR = PROJ / 'data' / 'cache'
FOLDS_DIR = PROJ / 'models' / 'folds' / 'Kfold5.json'
CACHE_DIR.mkdir(parents=True, exist_ok=True)

# Dispositivo y semilla
RANDOM_STATE = 42
torch.manual_seed(RANDOM_STATE)
np.random.seed(RANDOM_STATE)
random.seed(RANDOM_STATE)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🚀 Usando dispositivo: {DEVICE}")

# Escenario de clases
CLASS_SCENARIO = '4c'
WINDOW_MODE = '3s'
FS = 160.0
N_FOLDS = 5

# Entrenamiento global
BATCH_SIZE = 16
EPOCHS_GLOBAL = 100
LR = 1e-3

# Fine-tuning por sujeto (protocolo robusto)
CALIB_CV_FOLDS = 4            # 4-fold CV por sujeto ~ 75/25
FT_EPOCHS = 30                 # más alto, pero con ES por validación
FT_BASE_LR = 5e-5              # convs (temporal+spatial) — más bajo
FT_HEAD_LR = 1e-3              # fc+out — más alto
FT_L2SP = 1e-4                 # regularización más suave
FT_PATIENCE = 5                # early stopping con validación
FT_VAL_RATIO = 0.2             # validación dentro del set de calibración

# Sujetos excluidos
EXCLUDE_SUBJECTS = {38, 88, 89, 92, 100, 104}

# Runs
MI_RUNS_LR = [4, 8, 12]
MI_RUNS_OF = [6, 10, 14]
BASELINE_RUNS_EO = [1]

# Canales (8 en lugar de 64)
EXPECTED_8 = ['C3','C4','Cz','CP3','CP4','FC3','FC4','CPz']

# =========================
# UTILIDADES DE CANALES
# =========================
def normalize_label(s: str) -> str:
    if s is None: return s
    s = s.strip()
    s = re.sub(r'[^A-Za-z0-9]', '', s)
    s = re.sub(r'([A-Za-z])0([0-9])', r'\1\2', s)
    s = re.sub(r'([A-Za-z])Z$', r'\1z', s)
    s = s.replace('fp', 'Fp').replace('FP', 'Fp')
    s = ''.join(ch.upper() if ch != 'z' else 'z' for ch in s)
    return s

def rename_channels_1010(raw: mne.io.BaseRaw):
    mapping = {}
    for ch in raw.ch_names:
        lab = normalize_label(ch)
        lab = lab[:-1] + 'z' if lab.endswith('Z') else lab
        lab = re.sub(r'([A-Z])Z$', r'\1z', lab)
        mapping[ch] = lab
    mne.rename_channels(raw.info, mapping)

def ensure_channels_order(raw: mne.io.BaseRaw, desired_channels=EXPECTED_8):
    have = [ch for ch in desired_channels if ch in raw.ch_names]
    missing = [ch for ch in desired_channels if ch not in raw.ch_names]
    if missing:
        print(f"Warning: faltan canales {missing} en archivo {getattr(raw,'filenames', [''])[0]}")
        return None
    raw.reorder_channels([ch for ch in raw.ch_names if ch in desired_channels] +
                         [ch for ch in raw.ch_names if ch not in desired_channels])
    raw.pick_channels(desired_channels, ordered=True)
    return raw

# =========================
# LECTURA DE EDF y EVENTOS
# =========================
_re_file = re.compile(r'[Ss](\d{3}).*?[Rr](\d{2})')

def parse_subject_run(path: Path):
    m = _re_file.search(str(path))
    if not m: return None, None
    return int(m.group(1)), int(m.group(2))

def run_kind(run_id:int):
    if run_id in MI_RUNS_LR: return 'LR'
    if run_id in MI_RUNS_OF: return 'OF'
    if run_id in BASELINE_RUNS_EO: return 'EO'
    return None

def read_raw_edf(path: Path):
    raw = mne.io.read_raw_edf(path, preload=True, verbose=False)
    raw.pick(mne.pick_types(raw.info, eeg=True))
    rename_channels_1010(raw)
    try:
        mont = mne.channels.make_standard_montage('standard_1020')
        raw.set_montage(mont, on_missing='ignore')
    except Exception:
        pass
    if abs(raw.info['sfreq'] - FS) > 1e-6:
        raw.resample(FS, npad="auto")
    raw = ensure_channels_order(raw, EXPECTED_8)
    if raw is None:
        return None
    # Band-pass opcional (desactivado)
    # raw.filter(l_freq=8., h_freq=30., picks='eeg', method='iir', verbose=False)
    return raw

def collect_events_T1T2(raw: mne.io.BaseRaw):
    if raw.annotations is None or len(raw.annotations) == 0:
        return []
    def _norm(s): return str(s).strip().upper().replace(' ', '')
    res = []
    for onset, desc in zip(raw.annotations.onset, raw.annotations.description):
        tag = _norm(desc)
        if tag in ('T1','T2'):
            res.append((float(onset), tag))
    res.sort()
    dedup = []
    last_t1 = last_t2 = -1e9
    for t, tag in res:
        if tag == 'T1':
            if (t - last_t1) >= 0.5: dedup.append((t, tag)); last_t1 = t
        else:
            if (t - last_t2) >= 0.5: dedup.append((t, tag)); last_t2 = t
    return dedup

# =========================
# CONSTRUCCIÓN DE DATASETS
# =========================
def subjects_available():
    subs = []
    for sdir in sorted(DATA_RAW.glob('S*')):
        if not sdir.is_dir(): continue
        try:
            sid = int(sdir.name[1:])
        except: continue
        if sid in EXCLUDE_SUBJECTS: continue
        any_mi = any((sdir / f"S{sid:03d}R{r:02d}.edf").exists() for r in (MI_RUNS_LR + MI_RUNS_OF))
        if any_mi: subs.append(sid)
    return subs

def extract_trials_from_run(edf_path: Path, scenario: str, window_mode: str):
    subj, run = parse_subject_run(edf_path)
    kind = run_kind(run)
    if kind not in ('LR','OF','EO'):
        return ([], [])

    raw = read_raw_edf(edf_path)
    if raw is None:
        return ([], [])

    data = raw.get_data()
    fs = raw.info['sfreq']
    assert abs(fs - FS) < 1e-6

    out = []

    if kind in ('LR','OF'):
        events = collect_events_T1T2(raw)
        if window_mode == '3s':
            rel_start, rel_end = 0.0, 3.0
        else:
            rel_start, rel_end = -1.0, 5.0

        for onset_sec, tag in events:
            if kind == 'LR':
                if tag == 'T1': label = 'L'
                elif tag == 'T2': label = 'R'
                else: continue
            else:
                if tag == 'T1': label = 'BFISTS'
                elif tag == 'T2': label = 'BFEET'
                else: continue

            if scenario == '2c' and label not in ('L','R'):
                continue
            if scenario == '3c' and label not in ('L','R','BFISTS'):
                continue
            if scenario == '4c' and label not in ('L','R','BFISTS','BFEET'):
                continue

            s = int(round((raw.first_time + onset_sec + rel_start) * fs))
            e = int(round((raw.first_time + onset_sec + rel_end) * fs))
            if s < 0 or e > data.shape[1]:
                continue

            seg = data[:, s:e].T.astype(np.float32)
            # Normalización por época canal-a-canal (z-score)
            seg = (seg - seg.mean(axis=0, keepdims=True)) / (seg.std(axis=0, keepdims=True) + 1e-6)

            if label == 'L':       y = 0
            elif label == 'R':     y = 1
            elif label == 'BFISTS':y = 2
            elif label == 'BFEET': y = 3
            else: continue

            out.append((seg, y, subj))

    elif kind == 'EO':
        return ([], raw.ch_names)

    return out, raw.ch_names

def build_dataset_all(subjects, scenario='4c', window_mode='3s'):
    X, y, groups = [], [], []
    ch_template = None

    for s in tqdm(subjects, desc="Construyendo dataset (RAW)"):
        sdir = DATA_RAW / f"S{s:03d}"
        if not sdir.exists(): continue

        trials_L, trials_R, trials_FISTS, trials_FEET = [], [], [], []

        for r in MI_RUNS_LR:
            p = sdir / f"S{s:03d}R{r:02d}.edf"
            if not p.exists(): continue
            outs, chs = extract_trials_from_run(p, scenario, window_mode)
            if ch_template is None and chs: ch_template = chs
            for seg, lab, _ in outs:
                if lab == 0: trials_L.append(seg)
                elif lab == 1: trials_R.append(seg)

        for r in MI_RUNS_OF:
            p = sdir / f"S{s:03d}R{r:02d}.edf"
            if not p.exists(): continue
            outs, chs = extract_trials_from_run(p, scenario, window_mode)
            if ch_template is None and chs: ch_template = chs
            for seg, lab, _ in outs:
                if lab == 2: trials_FISTS.append(seg)
                elif lab == 3: trials_FEET.append(seg)

        need_per_class = 21
        def pick(trials, n, rng):
            if len(trials) < n:
                idx = rng.choice(len(trials), size=n, replace=True)
                return [trials[i] for i in idx]
            rng.shuffle(trials)
            return trials[:n]

        rng = check_random_state(RANDOM_STATE + s)
        if len(trials_L)==0 or len(trials_R)==0 or len(trials_FISTS)==0 or len(trials_FEET)==0:
            continue

        Lp  = pick(trials_L,     need_per_class, rng)
        Rp  = pick(trials_R,     need_per_class, rng)
        FIp = pick(trials_FISTS, need_per_class, rng)
        FEp = pick(trials_FEET,  need_per_class, rng)

        pack = [(Lp, 0), (Rp, 1), (FIp, 2), (FEp, 3)]
        for segs, lab in pack:
            for seg in segs:
                X.append(seg)
                y.append(lab)
                groups.append(s)

    X = np.stack(X, axis=0)
    y = np.asarray(y, dtype=np.int64)
    groups = np.asarray(groups, dtype=np.int64)

    n, T, C = X.shape
    n_classes = len(np.unique(y))
    print(f"Dataset construido: N={n} | T={T} | C={C} | clases={n_classes} | sujetos únicos={len(np.unique(groups))}")
    return X, y, groups, ch_template

# =========================
# SHALLOW CNN
# =========================
class ShallowDose2018(nn.Module):
    def __init__(self, n_ch: int, n_classes: int, kernel_t: int = 30, n_feat: int = 40, pool_t: int = 15):
        super().__init__()
        self.n_ch = n_ch
        self.n_classes = n_classes
        self.kernel_t = kernel_t
        self.n_feat = n_feat
        self.pool_t = pool_t

        self.temporal = nn.Conv2d(1, n_feat, kernel_size=(kernel_t, 1),
                                  padding=(kernel_t // 2, 0), bias=True)
        self.spatial  = nn.Conv2d(n_feat, n_feat, kernel_size=(1, n_ch),
                                  padding=(0, 0), bias=True)
        self.avgpool  = nn.AvgPool2d(kernel_size=(pool_t, 1), stride=(pool_t, 1))
        self.act      = nn.ELU()
        self.flatten  = nn.Flatten()

        self.fc  = None
        self.out = None
        self._T_in = None

    def _build_head(self, T_in: int, device: torch.device):
        T_pool = T_in // self.pool_t
        feat_dim = self.n_feat * T_pool
        self.fc  = nn.Linear(feat_dim, 80, bias=True).to(device)
        self.out = nn.Linear(80, self.n_classes, bias=True).to(device)
        self._T_in = T_in

    def ensure_head(self, T_in: int, device: torch.device):
        if (self.fc is None) or (self.out is None) or (self._T_in != T_in):
            self._build_head(T_in, device)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, _, T, C = x.shape
        self.ensure_head(T, x.device)

        z = self.temporal(x); z = self.act(z)
        z = self.spatial(z);  z = self.act(z)
        z = self.avgpool(z)
        z = self.flatten(z)
        z = self.fc(z); z = self.act(z)
        z = self.out(z)
        return z

# =========================
# TORCH DATASET
# =========================
class EEGTrials(Dataset):
    def __init__(self, X, y, groups):
        self.X = X.astype(np.float32)
        self.y = y.astype(np.int64)
        self.g = groups.astype(np.int64)
    def __len__(self): return self.X.shape[0]
    def __getitem__(self, idx):
        x = self.X[idx]
        x = np.expand_dims(x, 0)                 # (1, T, C)
        return torch.from_numpy(x), torch.tensor(self.y[idx]), torch.tensor(self.g[idx])

CLASS_NAMES_4C = ['Left', 'Right', 'Both Fists', 'Both Feet']

# =========================
# ENTRENAMIENTO / EVALUACIÓN
# =========================
def train_epoch(model, loader, opt, criterion):
    model.train()
    for xb, yb, _ in loader:
        xb, yb = xb.to(DEVICE), yb.to(DEVICE)
        opt.zero_grad(set_to_none=True)
        logits = model(xb)
        loss = criterion(logits, yb)
        loss.backward()
        opt.step()

@torch.no_grad()
def evaluate_with_preds(model, loader):
    model.eval()
    y_true, y_pred = [], []
    for xb, yb, _ in loader:
        xb = xb.to(DEVICE)
        logits = model(xb)
        pred = logits.argmax(dim=1).cpu().numpy().tolist()
        y_pred.extend(pred)
        y_true.extend(yb.numpy().tolist())
    y_true = np.asarray(y_true, dtype=int)
    y_pred = np.asarray(y_pred, dtype=int)
    acc = (y_true == y_pred).mean()
    return y_true, y_pred, float(acc)

def plot_confusion(y_true, y_pred, classes, title, fname):
    cm = confusion_matrix(y_true, y_pred, labels=list(range(len(classes))))
    with np.errstate(invalid='ignore'):
        cm_norm = cm.astype('float') / cm.sum(axis=1, keepdims=True)
    cm_norm = np.nan_to_num(cm_norm)

    plt.figure(figsize=(6, 5))
    plt.imshow(cm_norm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title(title)
    plt.colorbar(fraction=0.046, pad=0.04)
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45, ha='right')
    plt.yticks(tick_marks, classes)

    fmt = '.2f'
    thresh = cm_norm.max() / 2.
    for i, j in itertools.product(range(cm_norm.shape[0]), range(cm_norm.shape[1])):
        plt.text(j, i, format(cm_norm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm_norm[i, j] > thresh else "black")

    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.tight_layout()
    plt.savefig(fname, dpi=150, bbox_inches='tight')
    plt.close()

def print_report(y_true, y_pred, classes):
    print(classification_report(y_true, y_pred, target_names=classes, digits=4))

# =========================
# FINE-TUNING PROGRESIVO por sujeto (con validación interna)
# =========================
def _param_groups(model, mode):
    """
    Devuelve los parámetros a entrenar según 'mode':
      - 'out'            : solo capa final
      - 'head'           : fc + out
      - 'spatial+head'   : spatial + fc + out  (temporal queda congelada)
    """
    if mode == 'out':
        train = list(model.out.parameters())
    elif mode == 'head':
        train = list(model.fc.parameters()) + list(model.out.parameters())
    elif mode == 'spatial+head':
        train = (list(model.spatial.parameters())
                 + list(model.fc.parameters())
                 + list(model.out.parameters()))
    else:
        raise ValueError(mode)
    return train

def _freeze_for_mode(model, mode):
    # Primero congelamos todo
    for p in model.parameters(): p.requires_grad = False
    # Siempre dejamos la base temporal congelada en este protocolo
    # Descongelamos según 'mode'
    if mode == 'out':
        for p in model.out.parameters(): p.requires_grad = True
    elif mode == 'head':
        for p in model.fc.parameters():  p.requires_grad = True
        for p in model.out.parameters(): p.requires_grad = True
    elif mode == 'spatial+head':
        for p in model.spatial.parameters(): p.requires_grad = True
        for p in model.fc.parameters():      p.requires_grad = True
        for p in model.out.parameters():     p.requires_grad = True

def _class_weights(y_np, n_classes):
    # Pesos inversos a la frecuencia por clase en el set de calibración
    counts = np.bincount(y_np, minlength=n_classes).astype(np.float32)
    counts[counts == 0] = 1.0
    weights = counts.sum() / counts
    weights = weights / weights.mean()
    return torch.tensor(weights, dtype=torch.float32, device=DEVICE)

def _make_optimizer(train_params, mode):
    # LR más alto para la cabeza, bajo para spatial cuando aplique
    if mode == 'spatial+head':
        # separar spatial vs head para LRs distintos
        spatial, head = [], []
        for p in train_params:
            # heurística: parámetros que pertenecen a spatial tendrán .shape acorde
            # mejor: detectarlos por referencia al módulo
            pass
        # Como no tenemos tags aquí, armamos dos grupos manualmente en la llamada principal.
        # Devolvemos None y lo construimos fuera.
        return None
    else:
        return optim.Adam(train_params, lr=FT_HEAD_LR)

def _train_one_mode(model, X_cal, y_cal, n_classes, mode,
                    epochs=FT_EPOCHS, batch_size=16,
                    head_lr=FT_HEAD_LR, base_lr=FT_BASE_LR,
                    l2sp_lambda=FT_L2SP, patience=FT_PATIENCE, val_ratio=FT_VAL_RATIO):
    """
    Entrena en 'mode' con early stopping sobre un conjunto de validación interno.
    Devuelve el modelo con los mejores pesos (por val loss).
    """
    # Split Cal -> (train_cal, val_cal)
    sss = StratifiedShuffleSplit(n_splits=1, test_size=val_ratio, random_state=RANDOM_STATE)
    (tr_idx, va_idx), = sss.split(X_cal, y_cal)
    Xtr, ytr = X_cal[tr_idx], y_cal[tr_idx]
    Xva, yva = X_cal[va_idx], y_cal[va_idx]

    # Datasets
    ds_tr = torch.utils.data.TensorDataset(
        torch.from_numpy(Xtr).float().unsqueeze(1),
        torch.from_numpy(ytr).long()
    )
    ds_va = torch.utils.data.TensorDataset(
        torch.from_numpy(Xva).float().unsqueeze(1),
        torch.from_numpy(yva).long()
    )
    dl_tr = DataLoader(ds_tr, batch_size=batch_size, shuffle=True, drop_last=False)
    dl_va = DataLoader(ds_va, batch_size=batch_size, shuffle=False, drop_last=False)

    # Congelar / descongelar según modo
    _freeze_for_mode(model, mode)
    # Parámetros a entrenar y referencia L2-SP
    if mode == 'spatial+head':
        base_params = list(model.spatial.parameters())
        head_params = list(model.fc.parameters()) + list(model.out.parameters())
        train_params = base_params + head_params
        opt = optim.Adam([
            {"params": base_params, "lr": base_lr},
            {"params": head_params, "lr": head_lr},
        ])
    else:
        train_params = _param_groups(model, mode)
        opt = optim.Adam(train_params, lr=head_lr)

    ref = [p.detach().clone().to(p.device) for p in train_params]
    class_w = _class_weights(ytr, n_classes)
    crit = nn.CrossEntropyLoss(weight=class_w)

    best_state = copy.deepcopy(model.state_dict())
    best_val = float('inf')
    bad = 0

    for _ in range(epochs):
        # --- train ---
        model.train()
        for xb, yb in dl_tr:
            xb, yb = xb.to(DEVICE), yb.to(DEVICE)
            opt.zero_grad(set_to_none=True)
            logits = model(xb)
            loss = crit(logits, yb)
            # L2-SP hacia referencia de los parámetros que estamos entrenando
            reg = 0.0
            for p_cur, p_ref in zip(train_params, ref):
                reg = reg + torch.sum((p_cur - p_ref)**2)
            loss = loss + l2sp_lambda * reg
            loss.backward()
            opt.step()

        # --- val ---
        model.eval()
        with torch.no_grad():
            val_loss = 0.0
            nval = 0
            for xb, yb in dl_va:
                xb, yb = xb.to(DEVICE), yb.to(DEVICE)
                logits = model(xb)
                loss = crit(logits, yb)  # misma pérdida con pesos
                val_loss += loss.item() * xb.size(0)
                nval += xb.size(0)
            val_loss /= max(1, nval)

        if val_loss + 1e-7 < best_val:
            best_val = val_loss
            bad = 0
            best_state = copy.deepcopy(model.state_dict())
        else:
            bad += 1
            if bad >= patience:
                break

    model.load_state_dict(best_state)
    return model

@torch.no_grad()
def predict_numpy(model, X_np, device):
    model.eval()
    xb = torch.from_numpy(X_np).float().unsqueeze(1).to(device)  # (N,1,T,C)
    logits = model(xb)
    return logits.argmax(dim=1).cpu().numpy()

def subject_cv_finetune_predict_progressive(model_global, Xs, ys, device,
                                            n_splits=CALIB_CV_FOLDS, n_classes=4):
    """
    Para un sujeto: 4-fold StratifiedKFold.
      - En cada fold: se entrena con 3 modos progresivos:
          1) 'out'            (solo capa final)
          2) 'head'           (fc + out)
          3) 'spatial+head'   (spatial + fc + out, temporal congelada)
        Se evalúa en el test_fold y se elige el mejor.
    Devuelve y_true_subj, y_pred_subj (OOF por fold).
    """
    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=RANDOM_STATE)
    y_true_full = np.empty_like(ys)
    y_pred_full = np.empty_like(ys)

    for tr_idx, te_idx in skf.split(Xs, ys):
        Xcal, ycal = Xs[tr_idx], ys[tr_idx]
        Xho,  yho  = Xs[te_idx], ys[te_idx]

        # Stage A: 'out'
        m_out = copy.deepcopy(model_global)
        _train_one_mode(m_out, Xcal, ycal, n_classes, mode='out',
                        epochs=FT_EPOCHS, head_lr=FT_HEAD_LR, l2sp_lambda=FT_L2SP,
                        patience=FT_PATIENCE, val_ratio=FT_VAL_RATIO)
        yhat_out = predict_numpy(m_out, Xho, device)
        acc_out = (yhat_out == yho).mean()

        # Stage B: 'head'
        m_head = copy.deepcopy(model_global)
        _train_one_mode(m_head, Xcal, ycal, n_classes, mode='head',
                        epochs=FT_EPOCHS, head_lr=FT_HEAD_LR, l2sp_lambda=FT_L2SP,
                        patience=FT_PATIENCE, val_ratio=FT_VAL_RATIO)
        yhat_head = predict_numpy(m_head, Xho, device)
        acc_head = (yhat_head == yho).mean()

        # Stage C: 'spatial+head'
        m_sp = copy.deepcopy(model_global)
        _train_one_mode(m_sp, Xcal, ycal, n_classes, mode='spatial+head',
                        epochs=FT_EPOCHS, head_lr=FT_HEAD_LR, base_lr=FT_BASE_LR,
                        l2sp_lambda=FT_L2SP, patience=FT_PATIENCE, val_ratio=FT_VAL_RATIO)
        yhat_sp = predict_numpy(m_sp, Xho, device)
        acc_sp = (yhat_sp == yho).mean()

        # Elegir mejor
        best_idx = np.argmax([acc_out, acc_head, acc_sp])
        yhat_best = [yhat_out, yhat_head, yhat_sp][best_idx]

        y_true_full[te_idx] = yho
        y_pred_full[te_idx] = yhat_best

    return y_true_full, y_pred_full

# =========================
# FOLDS JSON helpers
# =========================
def save_group_folds_json_with_indices(subject_ids_str, groups_array, n_splits, out_json_path,
                                       created_by="dose_experiment", description=None):
    out_json_path = Path(out_json_path)
    unique_subjects_int = sorted(np.unique(groups_array).tolist())
    subject_ids = [f"S{sid:03d}" for sid in unique_subjects_int]

    if len(subject_ids) < n_splits:
        raise ValueError(f"n_splits={n_splits} mayor que número de sujetos={len(subject_ids)}")

    groups = np.arange(len(subject_ids))
    gkf = GroupKFold(n_splits=n_splits)

    folds = []
    fold_i = 0
    for train_idx_grp, test_idx_grp in gkf.split(groups, groups, groups):
        fold_i += 1
        train_sids = [subject_ids[int(i)] for i in train_idx_grp]
        test_sids  = [subject_ids[int(i)] for i in test_idx_grp]

        train_sids_int = [int(s[1:]) for s in train_sids]
        test_sids_int  = [int(s[1:]) for s in test_sids]

        tr_idx = np.where(np.isin(groups_array, train_sids_int))[0].tolist()
        te_idx = np.where(np.isin(groups_array, test_sids_int))[0].tolist()

        folds.append({
            "fold": int(fold_i),
            "train": train_sids,
            "test": test_sids,
            "tr_idx": tr_idx,
            "te_idx": te_idx
        })

    payload = {
        "created_at": datetime.now().isoformat(),
        "created_by": created_by,
        "description": description if description is not None else "",
        "n_splits": int(n_splits),
        "n_subjects": len(subject_ids),
        "subject_ids": subject_ids,
        "folds": folds
    }

    out_json_path.parent.mkdir(parents=True, exist_ok=True)
    with open(out_json_path, "w", encoding="utf-8") as f:
        json.dump(payload, f, indent=2, ensure_ascii=False)

    print(f"Folds JSON con índices guardado → {out_json_path}")
    return out_json_path

def load_group_folds_json(path_json, expected_subject_ids=None, strict_check=True):
    path_json = Path(path_json)
    if not path_json.exists():
        raise FileNotFoundError(f"No existe {path_json}")
    with open(path_json, "r", encoding="utf-8") as f:
        payload = json.load(f)

    subj_json = payload.get("subject_ids", [])
    if expected_subject_ids is not None:
        expected = sorted(list(expected_subject_ids))
        if subj_json != expected:
            msg = ("Los subject_ids del JSON no coinciden con expected_subject_ids.\n"
                   f"JSON has {len(subj_json)} subjects, expected {len(expected)}.\n"
                   f"First 10 JSON: {subj_json[:10]}\nFirst 10 expected: {expected[:10]}")
            if strict_check:
                raise ValueError(msg)
            else:
                print("WARNING: " + msg)
    return payload

# =========================
# EXPERIMENTO
# =========================
def run_experiment(save_folds_json=True, folds_json_path=FOLDS_DIR, folds_json_description="GroupKFold folds for comparison"):
    """
    - Crea/lee JSON con folds por sujeto (incluye tr_idx/te_idx)
    - Entrena modelo global por fold (inter-sujeto puro)
    - Evalúa Global acc en test
    - Realiza Fine-Tuning PROGRESIVO por sujeto (4-fold CV) y reporta acc
    """
    mne.set_log_level('WARNING')

    # sujetos y dataset
    subs = subjects_available()
    print(f"Sujetos elegibles: {len(subs)} → {subs[:10]}{'...' if len(subs)>10 else ''}")

    X, y, groups, chs = build_dataset_all(subs, scenario=CLASS_SCENARIO, window_mode=WINDOW_MODE)
    N, T, C = X.shape
    n_classes = len(np.unique(y))
    print(f"Listo para entrenar: N={N} | T={T} | C={C} | clases={n_classes} | sujetos={len(np.unique(groups))}")

    ds = EEGTrials(X, y, groups)
    criterion = nn.CrossEntropyLoss()

    # preparar JSON folds
    if folds_json_path is None:
        folds_json_path = Path("folds") / f"group_folds_{N_FOLDS}splits.json"
    else:
        folds_json_path = Path(folds_json_path)
    folds_json_path.parent.mkdir(parents=True, exist_ok=True)

    unique_subs = sorted(np.unique(groups).tolist())
    subject_ids_str = [f"S{s:03d}" for s in unique_subs]

    if not folds_json_path.exists():
        if not save_folds_json:
            raise FileNotFoundError(f"Folds JSON no encontrado en {folds_json_path} y save_folds_json=False.")
        save_group_folds_json_with_indices(subject_ids_str, groups, n_splits=N_FOLDS,
                                           out_json_path=folds_json_path,
                                           created_by="Joel_Clasificador",
                                           description=folds_json_description)

    payload = load_group_folds_json(folds_json_path, expected_subject_ids=subject_ids_str, strict_check=False)
    folds = payload["folds"]

    # bucle por folds
    global_folds = []
    ft_prog_folds = []
    all_true = []
    all_pred = []

    for f in folds:
        fold = f["fold"]
        tr_idx = np.asarray(f.get("tr_idx", []), dtype=int)
        te_idx = np.asarray(f.get("te_idx", []), dtype=int)

        if tr_idx.size == 0 or te_idx.size == 0:
            print(f"Advertencia: fold {fold} sin índices tr/te válidos. Saltando.")
            continue

        tr_loader = DataLoader(Subset(ds, tr_idx), batch_size=BATCH_SIZE, shuffle=True,  drop_last=False)
        te_loader = DataLoader(Subset(ds, te_idx), batch_size=BATCH_SIZE, shuffle=False, drop_last=False)

        model = ShallowDose2018(n_ch=C, n_classes=n_classes).to(DEVICE)
        opt = optim.Adam(model.parameters(), lr=LR)

        print(f"\n[Fold {fold}/{N_FOLDS}] Entrenando modelo global... (n_train={len(tr_idx)} | n_test={len(te_idx)})")
        for epoch in range(1, EPOCHS_GLOBAL + 1):
            train_epoch(model, tr_loader, opt, criterion)
            if epoch % 20 == 0:
                print(f"  Época {epoch}/{EPOCHS_GLOBAL}")

        # Evaluación global (inter-sujeto puro)
        y_true, y_pred, acc_global = evaluate_with_preds(model, te_loader)
        global_folds.append(acc_global)
        all_true.append(y_true); all_pred.append(y_pred)

        print(f"[Fold {fold}/{N_FOLDS}] Global acc={acc_global:.4f}")
        print_report(y_true, y_pred, CLASS_NAMES_4C)

        # ---------- Fine-tuning PROGRESIVO por sujeto con 4-fold CV ----------
        X_te, y_te, g_te = X[te_idx], y[te_idx], groups[te_idx]

        y_true_ft_all, y_pred_ft_all = [], []
        used_subjects = 0
        for sid in np.unique(g_te):
            idx = np.where(g_te == sid)[0]
            Xs, ys = X_te[idx], y_te[idx]

            # Seguridad: requiere al menos n_splits muestras (estratificado).
            if len(ys) < CALIB_CV_FOLDS or len(np.unique(ys)) < 2:
                continue

            y_true_subj, y_pred_subj = subject_cv_finetune_predict_progressive(
                model, Xs, ys, DEVICE, n_splits=CALIB_CV_FOLDS, n_classes=n_classes
            )
            y_true_ft_all.append(y_true_subj)
            y_pred_ft_all.append(y_pred_subj)
            used_subjects += 1

        if len(y_true_ft_all) > 0:
            y_true_ft_all = np.concatenate(y_true_ft_all)
            y_pred_ft_all = np.concatenate(y_pred_ft_all)
            acc_ft = (y_true_ft_all == y_pred_ft_all).mean()
            print(f"  Fine-tuning PROGRESIVO (por sujeto, {CALIB_CV_FOLDS}-fold CV) acc={acc_ft:.4f} | sujetos={used_subjects}")
            print(f"  Δ(FT-Global) = {acc_ft - acc_global:+.4f}")
        else:
            acc_ft = np.nan
            print("  Fine-tuning PROGRESIVO no ejecutado (sujeto(s) con muestras insuficientes).")

        ft_prog_folds.append(acc_ft)

    # ---------- resultados finales ----------
    if len(all_true) > 0:
        all_true = np.concatenate(all_true)
        all_pred = np.concatenate(all_pred)
    else:
        all_true = np.array([], dtype=int)
        all_pred = np.array([], dtype=int)

    print("\n" + "="*60)
    print("RESULTADOS FINALES")
    print("="*60)
    print("Global folds:", [f"{a:.4f}" for a in global_folds])
    if len(global_folds) > 0:
        print(f"Global mean: {np.mean(global_folds):.4f}")

    print("Fine-tune PROGRESIVO folds:", [("nan" if (a is None or np.isnan(a)) else f"{a:.4f}") for a in ft_prog_folds])
    if len(ft_prog_folds) > 0:
        print(f"Fine-tune PROGRESIVO mean: {np.nanmean(ft_prog_folds):.4f}")
        print(f"Δ(FT-Global) mean: {np.nanmean(ft_prog_folds) - np.mean(global_folds):+.4f}")

    # Matriz de confusión global (sobre todos los folds)
    if all_true.size > 0:
        plot_confusion(all_true, all_pred, CLASS_NAMES_4C,
                       title="Confusion Matrix - Global Model (All Folds)",
                       fname="confusion_global_allfolds.png")
        print("\n↳ Matriz de confusión guardada: confusion_global_allfolds.png")

    return {
        "global_folds": global_folds,
        "ft_prog_folds": ft_prog_folds,
        "all_true": all_true,
        "all_pred": all_pred,
        "folds_json_path": str(folds_json_path)
    }

# ---------- MAIN ----------
if __name__ == "__main__":
    print("🧠 INICIANDO EXPERIMENTO CON FINE-TUNING PROGRESIVO (por sujeto, 4-fold CV)")
    print(f"🔧 Configuración: {CLASS_SCENARIO}, {len(EXPECTED_8)} canales, {WINDOW_MODE}")
    print(f"⚙️  FT: epochs={FT_EPOCHS}, base_lr={FT_BASE_LR}, head_lr={FT_HEAD_LR}, L2SP={FT_L2SP}, patience={FT_PATIENCE}, CV={CALIB_CV_FOLDS}")
    run_experiment()


### Su intra

In [None]:
# -*- coding: utf-8 -*-
# Experimento INTRA-subject CV completo (5 folds) con Shallow CNN + Feature Extraction
# Adaptado del modelo “inter” que compartiste

import os, re, math, random, json
from pathlib import Path
import numpy as np
import pandas as pd
import mne
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import StratifiedKFold, StratifiedShuffleSplit
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
import matplotlib.pyplot as plt
import itertools
from tqdm import tqdm

# =========================
# CONFIGURACIÓN GENERAL
# =========================
PROJ = Path('..').resolve().parent
DATA_RAW = PROJ / 'data' / 'raw'
CACHE_DIR = PROJ / 'data' / 'cache'
RANDOM_STATE = 42
torch.manual_seed(RANDOM_STATE)
np.random.seed(RANDOM_STATE)
random.seed(RANDOM_STATE)
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

CLASS_SCENARIO = '4c'
WINDOW_MODE = '3s'
FS = 160.0
N_FOLDS = 5
BATCH_SIZE = 16
EPOCHS_GLOBAL = 50
LR = 1e-3

EXCLUDE_SUBJECTS = {38, 88, 89, 92, 100, 104}
MI_RUNS_LR = [4, 8, 12]
MI_RUNS_OF = [6, 10, 14]
EXPECTED_64 = ['C3','C4','Cz','CP3','CP4','FC3','FC4','CPz']

CLASSIFIER_TYPE = 'svm'  # 'svm' o 'logistic'
CLASS_NAMES_4C = ['Left','Right','Both Fists','Both Feet']

# =========================
# UTILIDADES DE CANALES
# =========================
def normalize_label(s: str) -> str:
    if s is None: return s
    s = s.strip()
    s = re.sub(r'[^A-Za-z0-9]', '', s)
    s = re.sub(r'([A-Za-z])0([0-9])', r'\1\2', s)
    s = re.sub(r'([A-Za-z])Z$', r'\1z', s)
    s = s.replace('fp', 'Fp').replace('FP', 'Fp')
    s = ''.join(ch.upper() if ch != 'z' else 'z' for ch in s)
    return s

def rename_channels_1010(raw: mne.io.BaseRaw):
    mapping = {ch: normalize_label(ch) for ch in raw.ch_names}
    mne.rename_channels(raw.info, mapping)

def ensure_channels_order(raw: mne.io.BaseRaw, desired_channels=EXPECTED_64) -> mne.io.BaseRaw:
    have = [ch for ch in desired_channels if ch in raw.ch_names]
    missing = [ch for ch in desired_channels if ch not in raw.ch_names]
    if missing:
        print(f"Warning: faltan canales {missing} en archivo {getattr(raw,'filenames',[''])[0]}")
        return None
    return raw.reorder_channels(desired_channels)

# =========================
# LECTURA DE EDF y EVENTOS
# =========================
_re_file = re.compile(r'[Ss](\d{3}).*?[Rr](\d{2})')

def parse_subject_run(path: Path):
    m = _re_file.search(str(path))
    if not m: return None, None
    return int(m.group(1)), int(m.group(2))

def run_kind(run_id:int):
    if run_id in MI_RUNS_LR: return 'LR'
    if run_id in MI_RUNS_OF: return 'OF'
    return None

def read_raw_edf(path: Path) -> mne.io.BaseRaw:
    raw = mne.io.read_raw_edf(path, preload=True, verbose=False)
    raw.pick(mne.pick_types(raw.info, eeg=True))
    rename_channels_1010(raw)
    try:
        raw.set_montage(mne.channels.make_standard_montage('standard_1020'), on_missing='ignore')
    except: pass
    if abs(raw.info['sfreq'] - FS) > 1e-6:
        raw.resample(FS, npad='auto')
    return ensure_channels_order(raw)

def collect_events_T1T2(raw: mne.io.BaseRaw):
    if raw.annotations is None or len(raw.annotations) == 0: return []
    res = []
    for onset, desc in zip(raw.annotations.onset, raw.annotations.description):
        tag = str(desc).strip().upper().replace(' ','')
        if tag in ('T1','T2'): res.append((float(onset), tag))
    res.sort()
    dedup = []
    last_t1 = last_t2 = -1e9
    for t, tag in res:
        if tag=='T1' and (t-last_t1)>=0.5: dedup.append((t,tag)); last_t1=t
        if tag=='T2' and (t-last_t2)>=0.5: dedup.append((t,tag)); last_t2=t
    return dedup

# =========================
# CONSTRUCCIÓN DE DATASET
# =========================
def subjects_available():
    subs=[]
    for sdir in sorted(DATA_RAW.glob('S*')):
        if not sdir.is_dir(): continue
        try: sid=int(sdir.name[1:])
        except: continue
        if sid in EXCLUDE_SUBJECTS: continue
        any_mi = any((sdir/f"S{sid:03d}R{r:02d}.edf").exists() for r in MI_RUNS_LR+MI_RUNS_OF)
        if any_mi: subs.append(sid)
    return subs

def extract_trials_from_run(edf_path: Path):
    subj, run = parse_subject_run(edf_path)
    kind = run_kind(run)
    if kind not in ('LR','OF'): return []

    raw = read_raw_edf(edf_path)
    data = raw.get_data()
    fs = raw.info['sfreq']
    events = collect_events_T1T2(raw)

    out=[]
    for onset, tag in events:
        if kind=='LR':
            label=0 if tag=='T1' else 1
        else:
            label=2 if tag=='T1' else 3
        s = int(round(onset*fs))
        e = int(round((onset+3.0)*fs))
        if s<0 or e>data.shape[1]: continue
        seg=data[:,s:e].T.astype(np.float32)
        out.append((seg,label,subj))
    return out

def build_dataset_all(subjects):
    X,y,groups=[],[],[]
    for s in tqdm(subjects,"Construyendo dataset"):
        sdir = DATA_RAW/f"S{s:03d}"
        if not sdir.exists(): continue
        for r in MI_RUNS_LR+MI_RUNS_OF:
            path = sdir/f"S{s:03d}R{r:02d}.edf"
            if not path.exists(): continue
            trials=extract_trials_from_run(path)
            for seg,label,subj in trials:
                X.append(seg); y.append(label); groups.append(subj)
    X=np.stack(X,0); y=np.array(y,dtype=int); groups=np.array(groups,dtype=int)
    print(f"Dataset construido: N={X.shape[0]} | T={X.shape[1]} | C={X.shape[2]} | sujetos={len(np.unique(groups))}")
    return X,y,groups

# =========================
# TORCH DATASET
# =========================
class EEGTrials(Dataset):
    def __init__(self,X,y,groups):
        self.X=X.astype(np.float32)
        self.y=y.astype(np.int64)
        self.g=groups.astype(np.int64)
    def __len__(self): return self.X.shape[0]
    def __getitem__(self,idx):
        x=self.X[idx]; x=np.expand_dims(x,0)
        return torch.from_numpy(x), torch.tensor(self.y[idx]), torch.tensor(self.g[idx])

# =========================
# SHALLOW CNN
# =========================
class ShallowDose2018(nn.Module):
    def __init__(self,n_ch,n_classes,kernel_t=30,n_feat=40,pool_t=15):
        super().__init__()
        self.temporal=nn.Conv2d(1,n_feat,(kernel_t,1),padding=(kernel_t//2,0))
        self.spatial=nn.Conv2d(n_feat,n_feat,(1,n_ch))
        self.avgpool=nn.AvgPool2d((pool_t,1),stride=(pool_t,1))
        self.act=nn.ELU()
        self.flatten=nn.Flatten()
        self.fc=None; self.out=None; self._T_in=None
        self.n_classes=n_classes
    def _build_head(self,T_in,device):
        T_pool=T_in//15
        feat_dim=40*T_pool
        self.fc=nn.Linear(feat_dim,80).to(device)
        self.out=nn.Linear(80,self.n_classes).to(device)
        self._T_in=T_in
    def ensure_head(self,T_in,device):
        if self.fc is None or self.out is None or self._T_in!=T_in:
            self._build_head(T_in,device)
    def forward(self,x):
        B,_,T,C=x.shape
        self.ensure_head(T,x.device)
        z=self.temporal(x); z=self.act(z)
        z=self.spatial(z); z=self.act(z)
        z=self.avgpool(z); z=self.flatten(z)
        z=self.fc(z); z=self.act(z)
        return self.out(z)
    def extract_features(self,x):
        B,_,T,C=x.shape
        self.ensure_head(T,x.device)
        z=self.temporal(x); z=self.act(z)
        z=self.spatial(z); z=self.act(z)
        z=self.avgpool(z); z=self.flatten(z)
        z=self.fc(z); z=self.act(z)
        return z

# =========================
# TRAIN + FEATURE EXTRACTION
# =========================
def train_epoch(model,loader,opt,criterion):
    model.train()
    for xb,yb,_ in loader:
        xb,yb=xb.to(DEVICE),yb.to(DEVICE)
        opt.zero_grad(set_to_none=True)
        loss=criterion(model(xb),yb)
        loss.backward(); opt.step()

def extract_features_from_model(model,X,batch_size=32):
    model.eval()
    feats=[]
    with torch.no_grad():
        for i in range(0,len(X),batch_size):
            x_batch=torch.from_numpy(X[i:i+batch_size]).float().unsqueeze(1).to(DEVICE)
            feats.append(model.extract_features(x_batch).cpu().numpy())
    return np.concatenate(feats,0)

# =========================
# GLOBAL MODEL + INTRA-SUBJECT SVM
# =========================
def train_global_model(X, y, model_class=ShallowDose2018, epochs=EPOCHS_GLOBAL):
    model = model_class(n_ch=X.shape[2], n_classes=len(np.unique(y))).to(DEVICE)
    ds = EEGTrials(X, y, np.zeros(len(y)))
    loader = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True)
    opt = optim.Adam(model.parameters(), lr=LR)
    criterion = nn.CrossEntropyLoss()
    for epoch in range(epochs):
        train_epoch(model, loader, opt, criterion)
    return model

def intra_subject_svm(features, y, groups, n_splits=5, classifier_type='svm', class_names=None):
    subjects = np.unique(groups)
    all_accs = []
    for s in subjects:
        idx = np.where(groups == s)[0]
        X_s, y_s = features[idx], y[idx]
        skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=RANDOM_STATE)
        subj_accs = []
        y_true_all, y_pred_all = [], []
        for tr_idx, te_idx in skf.split(X_s, y_s):
            X_tr, X_te = X_s[tr_idx], X_s[te_idx]
            y_tr, y_te = y_s[tr_idx], y_s[te_idx]

            if classifier_type == 'svm':
                clf = SVC(kernel='linear', random_state=RANDOM_STATE)
            else:
                clf = LogisticRegression(max_iter=500, random_state=RANDOM_STATE)
            clf.fit(X_tr, y_tr)
            y_pred = clf.predict(X_te)
            subj_accs.append((y_pred == y_te).mean())
            y_true_all.extend(y_te)
            y_pred_all.extend(y_pred)

        mean_acc = np.mean(subj_accs)
        all_accs.append(mean_acc)
        print(f"Sujeto {s} intra-subject acc={mean_acc:.4f}")

        # Matriz de confusión y classification report
        if class_names is not None:
            cm = confusion_matrix(y_true_all, y_pred_all)
            print(f"Confusion matrix sujeto {s}:\n{cm}")
            print(f"Classification report sujeto {s}:\n{classification_report(y_true_all, y_pred_all, target_names=class_names)}")

    print(f"\n✅ Promedio intra-subject accuracy: {np.mean(all_accs):.4f}")
    return all_accs

# =========================
# MAIN
# =========================
if __name__=="__main__":
    subs = subjects_available()
    X, y, groups = build_dataset_all(subs)

    print("🔹 Entrenando modelo CNN global...")
    global_model = train_global_model(X, y)

    print("🔹 Extrayendo features con modelo global...")
    features = extract_features_from_model(global_model, X)

    print("🔹 Evaluando intra-subject CV usando SVM sobre features...")
    intra_subject_svm(features, y, groups, n_splits=5, classifier_type=CLASSIFIER_TYPE, class_names=CLASS_NAMES_4C)


# EGGNET + FINETUNING INTER MEJOR

In [10]:
# -*- coding: utf-8 -*-
# Replicación fiel del paper "A Deep Learning MI-EEG Classification Model for BCIs"
# Dose et al., EUSIPCO 2018 — ahora con EEGNet (Lawhern et al., 2018) en vez de ShallowConvNet
# Protocolo: entrenamiento global inter-sujeto + FINE-TUNING PROGRESIVO por sujeto
# (CV 4-fold, LRs discriminativos, L2-SP, early stopping con validación)

import os, re, math, random, json, itertools, copy
from pathlib import Path
from datetime import datetime

import numpy as np
import pandas as pd
import mne

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset

from sklearn.model_selection import GroupKFold, StratifiedKFold, StratifiedShuffleSplit, GroupShuffleSplit
from sklearn.utils import check_random_state
from sklearn.metrics import confusion_matrix, classification_report

import matplotlib.pyplot as plt
from tqdm import tqdm

# =========================
# CONFIGURACIÓN GENERAL
# =========================
PROJ = Path('..').resolve().parent
DATA_RAW = PROJ / 'data' / 'raw'
CACHE_DIR = PROJ / 'data' / 'cache'
FOLDS_DIR = PROJ / 'models' / 'folds' / 'Kfold5.json'
CACHE_DIR.mkdir(parents=True, exist_ok=True)

# Dispositivo y semilla
RANDOM_STATE = 42
torch.manual_seed(RANDOM_STATE)
np.random.seed(RANDOM_STATE)
random.seed(RANDOM_STATE)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🚀 Usando dispositivo: {DEVICE}")

# Escenario de clases
CLASS_SCENARIO = '4c'
WINDOW_MODE = '6s'
FS = 160.0
N_FOLDS = 5

# Entrenamiento global
BATCH_SIZE = 16
EPOCHS_GLOBAL = 100
LR = 1e-3
# >>> Añadidos para diagnóstico/ES global <<<
GLOBAL_VAL_SPLIT = 0.15   # fracción de sujetos (dentro del train) para validación
GLOBAL_PATIENCE  = 10     # épocas sin mejora en val_acc
LOG_EVERY        = 5      # log cada N épocas

# Fine-tuning por sujeto (protocolo robusto)
CALIB_CV_FOLDS = 4            # 4-fold CV por sujeto ~ 75/25
FT_EPOCHS = 30                 # ES con validación
FT_BASE_LR = 5e-5              # convs "base" (temporal/depthwise) — más bajo
FT_HEAD_LR = 1e-3              # fc+out — más alto
FT_L2SP = 1e-4                 # regularización suave
FT_PATIENCE = 5                # early stopping con validación
FT_VAL_RATIO = 0.2             # validación dentro del set de calibración

# Sujetos excluidos
EXCLUDE_SUBJECTS = {38, 88, 89, 92, 100, 104}

# Runs
MI_RUNS_LR = [4, 8, 12]
MI_RUNS_OF = [6, 10, 14]
BASELINE_RUNS_EO = [1]

# Canales (8 en lugar de 64)
EXPECTED_8 = ['C3','C4','Cz','CP3','CP4','FC3','FC4','FCz']

# =========================
# UTILIDADES DE CANALES
# =========================
def normalize_label(s: str) -> str:
    if s is None: return s
    s = s.strip()
    s = re.sub(r'[^A-Za-z0-9]', '', s)
    s = re.sub(r'([A-Za-z])0([0-9])', r'\1\2', s)
    s = re.sub(r'([A-Za-z])Z$', r'\1z', s)
    s = s.replace('fp', 'Fp').replace('FP', 'Fp')
    s = ''.join(ch.upper() if ch != 'z' else 'z' for ch in s)
    return s

def rename_channels_1010(raw: mne.io.BaseRaw):
    mapping = {}
    for ch in raw.ch_names:
        lab = normalize_label(ch)
        lab = lab[:-1] + 'z' if lab.endswith('Z') else lab
        lab = re.sub(r'([A-Z])Z$', r'\1z', lab)
        mapping[ch] = lab
    mne.rename_channels(raw.info, mapping)

def ensure_channels_order(raw: mne.io.BaseRaw, desired_channels=EXPECTED_8):
    have = [ch for ch in desired_channels if ch in raw.ch_names]
    missing = [ch for ch in desired_channels if ch not in raw.ch_names]
    if missing:
        print(f"Warning: faltan canales {missing} en archivo {getattr(raw,'filenames', [''])[0]}")
        return None
    # Reordenar y quedarse SOLO con los deseados (8)
    raw.reorder_channels([ch for ch in raw.ch_names if ch in desired_channels] +
                         [ch for ch in raw.ch_names if ch not in desired_channels])
    raw.pick_channels(desired_channels, ordered=True)
    return raw

# =========================
# LECTURA DE EDF y EVENTOS
# =========================
_re_file = re.compile(r'[Ss](\d{3}).*?[Rr](\d{2})')

def parse_subject_run(path: Path):
    m = _re_file.search(str(path))
    if not m: return None, None
    return int(m.group(1)), int(m.group(2))

def run_kind(run_id:int):
    if run_id in MI_RUNS_LR: return 'LR'
    if run_id in MI_RUNS_OF: return 'OF'
    if run_id in BASELINE_RUNS_EO: return 'EO'
    return None

def read_raw_edf(path: Path):
    raw = mne.io.read_raw_edf(path, preload=True, verbose=False)
    raw.pick(mne.pick_types(raw.info, eeg=True))
    rename_channels_1010(raw)
    try:
        mont = mne.channels.make_standard_montage('standard_1020')
        raw.set_montage(mont, on_missing='ignore')
    except Exception:
        pass
    if abs(raw.info['sfreq'] - FS) > 1e-6:
        raw.resample(FS, npad="auto")
    raw = ensure_channels_order(raw, EXPECTED_8)
    if raw is None:
        return None
    # Band-pass opcional (desactivado)
    #raw.filter(l_freq=8., h_freq=30., picks='eeg', method='iir', verbose=False)
    return raw

def collect_events_T1T2(raw: mne.io.BaseRaw):
    if raw.annotations is None or len(raw.annotations) == 0:
        return []
    def _norm(s): return str(s).strip().upper().replace(' ', '')
    res = []
    for onset, desc in zip(raw.annotations.onset, raw.annotations.description):
        tag = _norm(desc)
        if tag in ('T1','T2'):
            res.append((float(onset), tag))
    res.sort()
    dedup = []
    last_t1 = last_t2 = -1e9
    for t, tag in res:
        if tag == 'T1':
            if (t - last_t1) >= 0.5: dedup.append((t, tag)); last_t1 = t
        else:
            if (t - last_t2) >= 0.5: dedup.append((t, tag)); last_t2 = t
    return dedup

# =========================
# CONSTRUCCIÓN DE DATASETS
# =========================
def subjects_available():
    subs = []
    for sdir in sorted(DATA_RAW.glob('S*')):
        if not sdir.is_dir(): continue
        try:
            sid = int(sdir.name[1:])
        except: continue
        if sid in EXCLUDE_SUBJECTS: continue
        any_mi = any((sdir / f"S{sid:03d}R{r:02d}.edf").exists() for r in (MI_RUNS_LR + MI_RUNS_OF))
        if any_mi: subs.append(sid)
    return subs

def extract_trials_from_run(edf_path: Path, scenario: str, window_mode: str):
    subj, run = parse_subject_run(edf_path)
    kind = run_kind(run)
    if kind not in ('LR','OF','EO'):
        return ([], [])

    raw = read_raw_edf(edf_path)
    if raw is None:
        return ([], [])

    data = raw.get_data()
    fs = raw.info['sfreq']
    assert abs(fs - FS) < 1e-6

    out = []

    if kind in ('LR','OF'):
        events = collect_events_T1T2(raw)
        if window_mode == '3s':
            rel_start, rel_end = 0.0, 3.0
        else:
            rel_start, rel_end = -1.0, 5.0

        for onset_sec, tag in events:
            if kind == 'LR':
                if tag == 'T1': label = 'L'
                elif tag == 'T2': label = 'R'
                else: continue
            else:
                if tag == 'T1': label = 'BFISTS'
                elif tag == 'T2': label = 'BFEET'
                else: continue

            if scenario == '2c' and label not in ('L','R'):
                continue
            if scenario == '3c' and label not in ('L','R','BFISTS'):
                continue
            if scenario == '4c' and label not in ('L','R','BFISTS','BFEET'):
                continue

            s = int(round((raw.first_time + onset_sec + rel_start) * fs))
            e = int(round((raw.first_time + onset_sec + rel_end) * fs))
            if s < 0 or e > data.shape[1]:
                continue

            seg = data[:, s:e].T.astype(np.float32)
            # Normalización por época canal-a-canal (z-score)
            seg = (seg - seg.mean(axis=0, keepdims=True)) / (seg.std(axis=0, keepdims=True) + 1e-6)

            if label == 'L':       y = 0
            elif label == 'R':     y = 1
            elif label == 'BFISTS':y = 2
            elif label == 'BFEET': y = 3
            else: continue

            out.append((seg, y, subj))

    elif kind == 'EO':
        return ([], raw.ch_names)

    return out, raw.ch_names

def build_dataset_all(subjects, scenario='4c', window_mode='3s'):
    X, y, groups = [], [], []
    ch_template = None

    for s in tqdm(subjects, desc="Construyendo dataset (RAW)"):
        sdir = DATA_RAW / f"S{s:03d}"
        if not sdir.exists(): continue

        trials_L, trials_R, trials_FISTS, trials_FEET = [], [], [], []

        for r in MI_RUNS_LR:
            p = sdir / f"S{s:03d}R{r:02d}.edf"
            if not p.exists(): continue
            outs, chs = extract_trials_from_run(p, scenario, window_mode)
            if ch_template is None and chs: ch_template = chs
            for seg, lab, _ in outs:
                if lab == 0: trials_L.append(seg)
                elif lab == 1: trials_R.append(seg)

        for r in MI_RUNS_OF:
            p = sdir / f"S{s:03d}R{r:02d}.edf"
            if not p.exists(): continue
            outs, chs = extract_trials_from_run(p, scenario, window_mode)
            if ch_template is None and chs: ch_template = chs
            for seg, lab, _ in outs:
                if lab == 2: trials_FISTS.append(seg)
                elif lab == 3: trials_FEET.append(seg)

        need_per_class = 21
        def pick(trials, n, rng):
            if len(trials) < n:
                idx = rng.choice(len(trials), size=n, replace=True)
                return [trials[i] for i in idx]
            rng.shuffle(trials)
            return trials[:n]

        rng = check_random_state(RANDOM_STATE + s)
        if len(trials_L)==0 or len(trials_R)==0 or len(trials_FISTS)==0 or len(trials_FEET)==0:
            continue

        Lp  = pick(trials_L,     need_per_class, rng)
        Rp  = pick(trials_R,     need_per_class, rng)
        FIp = pick(trials_FISTS, need_per_class, rng)
        FEp = pick(trials_FEET,  need_per_class, rng)

        pack = [(Lp, 0), (Rp, 1), (FIp, 2), (FEp, 3)]
        for segs, lab in pack:
            for seg in segs:
                X.append(seg)
                y.append(lab)
                groups.append(s)

    X = np.stack(X, axis=0)
    y = np.asarray(y, dtype=np.int64)
    groups = np.asarray(groups, dtype=np.int64)

    n, T, C = X.shape
    n_classes = len(np.unique(y))
    print(f"Dataset construido: N={n} | T={T} | C={C} | clases={n_classes} | sujetos únicos={len(np.unique(groups))}")
    return X, y, groups, ch_template

# =========================
# EEGNet (Lawhern et al., 2018) adaptado a (B,1,T,C)
# =========================
class EEGNet(nn.Module):
    """
    Entrada: x de forma (B, 1, T, C)  [T=tiempo, C=canales]
    Bloques:
      1) Temporal conv     : Conv2d(1 -> F1, (kernel_t,1), padding 'same'), BN, ELU
      2) Depthwise (espacial): Conv2d(F1 -> F1*D, (1,C), groups=F1, BN, ELU, AvgPool(4,1), Dropout
      3) Separable temporal: Depthwise temporal (k_sep,1) groups=F1*D + Pointwise 1x1 a F2, BN, ELU, AvgPool(8,1), Dropout
      4) FC -> OUT
    """
    def __init__(self, n_ch: int, n_classes: int,
                 F1: int = 8, D: int = 2, kernel_t: int = 64, k_sep: int = 16,
                 pool1_t: int = 4, pool2_t: int = 8, dropout: float = 0.5):
        super().__init__()
        self.n_ch = n_ch
        self.n_classes = n_classes
        self.F1 = F1
        self.D = D
        self.F2 = F1 * D
        self.kernel_t = kernel_t
        self.k_sep = k_sep
        self.pool1_t = pool1_t
        self.pool2_t = pool2_t

        # Bloque 1: temporal
        self.conv_temporal = nn.Conv2d(1, F1, kernel_size=(kernel_t, 1),
                                       padding=(kernel_t // 2, 0), bias=False)
        self.bn1 = nn.BatchNorm2d(F1)
        self.act = nn.ELU()

        # Bloque 2: depthwise (espacial)
        self.conv_depthwise = nn.Conv2d(F1, self.F2, kernel_size=(1, n_ch),
                                        groups=F1, bias=False)
        self.bn2 = nn.BatchNorm2d(self.F2)
        self.pool1 = nn.AvgPool2d(kernel_size=(pool1_t, 1), stride=(pool1_t, 1))
        self.drop1 = nn.Dropout(dropout)

        # Bloque 3: separable temporal (depthwise temporal + pointwise)
        self.conv_sep_depth = nn.Conv2d(self.F2, self.F2, kernel_size=(k_sep, 1),
                                        groups=self.F2, padding=(k_sep // 2, 0), bias=False)
        self.conv_sep_point = nn.Conv2d(self.F2, self.F2, kernel_size=(1, 1), bias=False)
        self.bn3 = nn.BatchNorm2d(self.F2)
        self.pool2 = nn.AvgPool2d(kernel_size=(pool2_t, 1), stride=(pool2_t, 1))
        self.drop2 = nn.Dropout(dropout)

        self.flatten = nn.Flatten()

        # Cabeza dinámica
        self.fc = None
        self.out = None
        self._T_in = None

    def _build_head(self, T_in: int, device: torch.device):
        # Con padding 'same' en temporal y separable temporal,
        # el tamaño temporal se reduce por los pools:
        T1 = T_in // self.pool1_t
        T2 = T1 // self.pool2_t
        feat_dim = self.F2 * T2 * 1  # ancho=1 tras conv_depthwise (kernel (1,C))
        self.fc = nn.Linear(feat_dim, 80, bias=True).to(device)
        self.out = nn.Linear(80, self.n_classes, bias=True).to(device)
        self._T_in = T_in

    def ensure_head(self, T_in: int, device: torch.device):
        if (self.fc is None) or (self.out is None) or (self._T_in != T_in):
            self._build_head(T_in, device)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B,1,T,C)
        B, _, T, C = x.shape
        self.ensure_head(T, x.device)

        z = self.conv_temporal(x)
        z = self.bn1(z); z = self.act(z)

        z = self.conv_depthwise(z)   # (B, F2, T, 1)
        z = self.bn2(z); z = self.act(z)
        z = self.pool1(z)
        z = self.drop1(z)

        z = self.conv_sep_depth(z)
        z = self.conv_sep_point(z)
        z = self.bn3(z); z = self.act(z)
        z = self.pool2(z)
        z = self.drop2(z)

        z = self.flatten(z)
        z = self.fc(z); z = self.act(z)
        z = self.out(z)
        return z

# =========================
# TORCH DATASET
# =========================
class EEGTrials(Dataset):
    def __init__(self, X, y, groups):
        self.X = X.astype(np.float32)
        self.y = y.astype(np.int64)
        self.g = groups.astype(np.int64)
    def __len__(self): return self.X.shape[0]
    def __getitem__(self, idx):
        x = self.X[idx]
        x = np.expand_dims(x, 0)                 # (1, T, C)
        return torch.from_numpy(x), torch.tensor(self.y[idx]), torch.tensor(self.g[idx])

CLASS_NAMES_4C = ['Left', 'Right', 'Both Fists', 'Both Feet']

# =========================
# ENTRENAMIENTO / EVALUACIÓN
# =========================
def train_epoch(model, loader, opt, criterion):
    model.train()
    for xb, yb, _ in loader:
        xb, yb = xb.to(DEVICE), yb.to(DEVICE)
        opt.zero_grad(set_to_none=True)
        logits = model(xb)
        loss = criterion(logits, yb)
        loss.backward()
        opt.step()

@torch.no_grad()
def evaluate_with_preds(model, loader):
    model.eval()
    y_true, y_pred = [], []
    for xb, yb, _ in loader:
        xb = xb.to(DEVICE)
        logits = model(xb)
        pred = logits.argmax(dim=1).cpu().numpy().tolist()
        y_pred.extend(pred)
        y_true.extend(yb.numpy().tolist())
    y_true = np.asarray(y_true, dtype=int)
    y_pred = np.asarray(y_pred, dtype=int)
    acc = (y_true == y_pred).mean()
    return y_true, y_pred, float(acc)

def plot_confusion(y_true, y_pred, classes, title, fname):
    cm = confusion_matrix(y_true, y_pred, labels=list(range(len(classes))))
    with np.errstate(invalid='ignore'):
        cm_norm = cm.astype('float') / cm.sum(axis=1, keepdims=True)
    cm_norm = np.nan_to_num(cm_norm)

    plt.figure(figsize=(6, 5))
    plt.imshow(cm_norm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title(title)
    plt.colorbar(fraction=0.046, pad=0.04)
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45, ha='right')
    plt.yticks(tick_marks, classes)

    fmt = '.2f'
    thresh = cm_norm.max() / 2.
    for i, j in itertools.product(range(cm_norm.shape[0]), range(cm_norm.shape[1])):
        plt.text(j, i, format(cm_norm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm_norm[i, j] > thresh else "black")

    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.tight_layout()
    plt.savefig(fname, dpi=150, bbox_inches='tight')
    plt.close()

def print_report(y_true, y_pred, classes):
    print(classification_report(y_true, y_pred, target_names=classes, digits=4))

# =========================
# FINE-TUNING PROGRESIVO por sujeto (con validación interna)
# =========================
def _param_groups(model, mode):
    """
    En EEGNet:
      - 'out'            : solo capa final (model.out)
      - 'head'           : fc + out
      - 'spatial+head'   : depthwise + separable + fc + out  (temporal queda congelada)
    """
    if mode == 'out':
        train = list(model.out.parameters())
    elif mode == 'head':
        train = list(model.fc.parameters()) + list(model.out.parameters())
    elif mode == 'spatial+head':
        train = (list(model.conv_depthwise.parameters()) +
                 list(model.bn2.parameters()) +
                 list(model.conv_sep_depth.parameters()) +
                 list(model.conv_sep_point.parameters()) +
                 list(model.bn3.parameters()) +
                 list(model.fc.parameters()) +
                 list(model.out.parameters()))
    else:
        raise ValueError(mode)
    return train

def _freeze_for_mode(model, mode):
    # Congelamos todo
    for p in model.parameters(): p.requires_grad = False
    # Siempre mantenemos CONGELADO el bloque temporal en este protocolo
    # (conv_temporal + bn1)
    if mode == 'out':
        for p in model.out.parameters(): p.requires_grad = True
    elif mode == 'head':
        for p in model.fc.parameters():  p.requires_grad = True
        for p in model.out.parameters(): p.requires_grad = True
    elif mode == 'spatial+head':
        for p in model.conv_depthwise.parameters(): p.requires_grad = True
        for p in model.bn2.parameters():           p.requires_grad = True
        for p in model.conv_sep_depth.parameters():p.requires_grad = True
        for p in model.conv_sep_point.parameters():p.requires_grad = True
        for p in model.bn3.parameters():           p.requires_grad = True
        for p in model.fc.parameters():            p.requires_grad = True
        for p in model.out.parameters():           p.requires_grad = True

def _class_weights(y_np, n_classes):
    counts = np.bincount(y_np, minlength=n_classes).astype(np.float32)
    counts[counts == 0] = 1.0
    weights = counts.sum() / counts
    weights = weights / weights.mean()
    return torch.tensor(weights, dtype=torch.float32, device=DEVICE)

def _train_one_mode(model, X_cal, y_cal, n_classes, mode,
                    epochs=FT_EPOCHS, batch_size=16,
                    head_lr=FT_HEAD_LR, base_lr=FT_BASE_LR,
                    l2sp_lambda=FT_L2SP, patience=FT_PATIENCE, val_ratio=FT_VAL_RATIO):
    """
    Entrena en 'mode' con early stopping sobre un conjunto de validación interno.
    Devuelve el modelo con los mejores pesos (por val loss).
    """
    # Split Cal -> (train_cal, val_cal)
    sss = StratifiedShuffleSplit(n_splits=1, test_size=val_ratio, random_state=RANDOM_STATE)
    (tr_idx, va_idx), = sss.split(X_cal, y_cal)
    Xtr, ytr = X_cal[tr_idx], y_cal[tr_idx]
    Xva, yva = X_cal[va_idx], y_cal[va_idx]

    ds_tr = torch.utils.data.TensorDataset(
        torch.from_numpy(Xtr).float().unsqueeze(1),
        torch.from_numpy(ytr).long()
    )
    ds_va = torch.utils.data.TensorDataset(
        torch.from_numpy(Xva).float().unsqueeze(1),
        torch.from_numpy(yva).long()
    )
    dl_tr = DataLoader(ds_tr, batch_size=batch_size, shuffle=True, drop_last=False)
    dl_va = DataLoader(ds_va, batch_size=batch_size, shuffle=False, drop_last=False)

    _freeze_for_mode(model, mode)

    if mode == 'spatial+head':
        # grupos con LR discriminativos
        base_params = (list(model.conv_depthwise.parameters()) +
                       list(model.bn2.parameters()) +
                       list(model.conv_sep_depth.parameters()) +
                       list(model.conv_sep_point.parameters()) +
                       list(model.bn3.parameters()))
        head_params = list(model.fc.parameters()) + list(model.out.parameters())
        train_params = base_params + head_params
        opt = optim.Adam([
            {"params": base_params, "lr": base_lr},
            {"params": head_params, "lr": head_lr},
        ])
    else:
        train_params = _param_groups(model, mode)
        opt = optim.Adam(train_params, lr=head_lr)

    ref = [p.detach().clone().to(p.device) for p in train_params]
    class_w = _class_weights(ytr, n_classes)
    crit = nn.CrossEntropyLoss(weight=class_w)

    best_state = copy.deepcopy(model.state_dict())
    best_val = float('inf')
    bad = 0

    for _ in range(epochs):
        # --- train ---
        model.train()
        for xb, yb in dl_tr:
            xb, yb = xb.to(DEVICE), yb.to(DEVICE)
            opt.zero_grad(set_to_none=True)
            logits = model(xb)
            loss = crit(logits, yb)
            # L2-SP hacia referencia de los parámetros que estamos entrenando
            reg = 0.0
            for p_cur, p_ref in zip(train_params, ref):
                reg = reg + torch.sum((p_cur - p_ref)**2)
            loss = loss + l2sp_lambda * reg
            loss.backward()
            opt.step()

        # --- val ---
        model.eval()
        with torch.no_grad():
            val_loss = 0.0
            nval = 0
            for xb, yb in dl_va:
                xb, yb = xb.to(DEVICE), yb.to(DEVICE)
                logits = model(xb)
                loss = crit(logits, yb)
                val_loss += loss.item() * xb.size(0)
                nval += xb.size(0)
            val_loss /= max(1, nval)

        if val_loss + 1e-7 < best_val:
            best_val = val_loss
            bad = 0
            best_state = copy.deepcopy(model.state_dict())
        else:
            bad += 1
            if bad >= patience:
                break

    model.load_state_dict(best_state)
    return model

@torch.no_grad()
def predict_numpy(model, X_np, device):
    model.eval()
    xb = torch.from_numpy(X_np).float().unsqueeze(1).to(device)  # (N,1,T,C)
    logits = model(xb)
    return logits.argmax(dim=1).cpu().numpy()

def subject_cv_finetune_predict_progressive(model_global, Xs, ys, device,
                                            n_splits=CALIB_CV_FOLDS, n_classes=4):
    """
    Para un sujeto: 4-fold StratifiedKFold.
      - En cada fold: 'out' → 'head' → 'spatial+head' (temporal congelado)
      - Se elige el mejor en el split de holdout del sujeto.
    """
    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=RANDOM_STATE)
    y_true_full = np.empty_like(ys)
    y_pred_full = np.empty_like(ys)

    for tr_idx, te_idx in skf.split(Xs, ys):
        Xcal, ycal = Xs[tr_idx], ys[tr_idx]
        Xho,  yho  = Xs[te_idx], ys[te_idx]

        # Stage A: 'out'
        m_out = copy.deepcopy(model_global)
        _train_one_mode(m_out, Xcal, ycal, n_classes, mode='out',
                        epochs=FT_EPOCHS, head_lr=FT_HEAD_LR, l2sp_lambda=FT_L2SP,
                        patience=FT_PATIENCE, val_ratio=FT_VAL_RATIO)
        yhat_out = predict_numpy(m_out, Xho, device)
        acc_out = (yhat_out == yho).mean()

        # Stage B: 'head'
        m_head = copy.deepcopy(model_global)
        _train_one_mode(m_head, Xcal, ycal, n_classes, mode='head',
                        epochs=FT_EPOCHS, head_lr=FT_HEAD_LR, l2sp_lambda=FT_L2SP,
                        patience=FT_PATIENCE, val_ratio=FT_VAL_RATIO)
        yhat_head = predict_numpy(m_head, Xho, device)
        acc_head = (yhat_head == yho).mean()

        # Stage C: 'spatial+head'
        m_sp = copy.deepcopy(model_global)
        _train_one_mode(m_sp, Xcal, ycal, n_classes, mode='spatial+head',
                        epochs=FT_EPOCHS, head_lr=FT_HEAD_LR, base_lr=FT_BASE_LR,
                        l2sp_lambda=FT_L2SP, patience=FT_PATIENCE, val_ratio=FT_VAL_RATIO)
        yhat_sp = predict_numpy(m_sp, Xho, device)
        acc_sp = (yhat_sp == yho).mean()

        best_idx = np.argmax([acc_out, acc_head, acc_sp])
        yhat_best = [yhat_out, yhat_head, yhat_sp][best_idx]

        y_true_full[te_idx] = yho
        y_pred_full[te_idx] = yhat_best

    return y_true_full, y_pred_full

# =========================
# FOLDS JSON helpers
# =========================
def save_group_folds_json_with_indices(subject_ids_str, groups_array, n_splits, out_json_path,
                                       created_by="dose_experiment", description=None):
    out_json_path = Path(out_json_path)
    unique_subjects_int = sorted(np.unique(groups_array).tolist())
    subject_ids = [f"S{sid:03d}" for sid in unique_subjects_int]

    if len(subject_ids) < n_splits:
        raise ValueError(f"n_splits={n_splits} mayor que número de sujetos={len(subject_ids)}")

    groups = np.arange(len(subject_ids))
    gkf = GroupKFold(n_splits=n_splits)

    folds = []
    fold_i = 0
    for train_idx_grp, test_idx_grp in gkf.split(groups, groups, groups):
        fold_i += 1
        train_sids = [subject_ids[int(i)] for i in train_idx_grp]
        test_sids  = [subject_ids[int(i)] for i in test_idx_grp]

        train_sids_int = [int(s[1:]) for s in train_sids]
        test_sids_int  = [int(s[1:]) for s in test_sids]

        tr_idx = np.where(np.isin(groups_array, train_sids_int))[0].tolist()
        te_idx = np.where(np.isin(groups_array, test_sids_int))[0].tolist()

        folds.append({
            "fold": int(fold_i),
            "train": train_sids,
            "test": test_sids,
            "tr_idx": tr_idx,
            "te_idx": te_idx
        })

    payload = {
        "created_at": datetime.now().isoformat(),
        "created_by": created_by,
        "description": description if description is not None else "",
        "n_splits": int(n_splits),
        "n_subjects": len(subject_ids),
        "subject_ids": subject_ids,
        "folds": folds
    }

    out_json_path.parent.mkdir(parents=True, exist_ok=True)
    with open(out_json_path, "w", encoding="utf-8") as f:
        json.dump(payload, f, indent=2, ensure_ascii=False)

    print(f"Folds JSON con índices guardado → {out_json_path}")
    return out_json_path

def load_group_folds_json(path_json, expected_subject_ids=None, strict_check=True):
    path_json = Path(path_json)
    if not path_json.exists():
        raise FileNotFoundError(f"No existe {path_json}")
    with open(path_json, "r", encoding="utf-8") as f:
        payload = json.load(f)

    subj_json = payload.get("subject_ids", [])
    if expected_subject_ids is not None:
        expected = sorted(list(expected_subject_ids))
        if subj_json != expected:
            msg = ("Los subject_ids del JSON no coinciden con expected_subject_ids.\n"
                   f"JSON has {len(subj_json)} subjects, expected {len(expected)}.\n"
                   f"First 10 JSON: {subj_json[:10]}\nFirst 10 expected: {expected[:10]}")
            if strict_check:
                raise ValueError(msg)
            else:
                print("WARNING: " + msg)
    return payload

# =========================
# EXPERIMENTO
# =========================
def run_experiment(save_folds_json=True, folds_json_path=FOLDS_DIR, folds_json_description="GroupKFold folds for comparison"):
    """
    - Crea/lee JSON con folds por sujeto (incluye tr_idx/te_idx)
    - Entrena modelo global por fold (inter-sujeto puro) con validación interna por sujetos + ES
    - Evalúa Global acc en test
    - Realiza Fine-Tuning PROGRESIVO por sujeto (4-fold CV) y reporta acc
    """
    mne.set_log_level('WARNING')

    # sujetos y dataset
    subs = subjects_available()
    print(f"Sujetos elegibles: {len(subs)} → {subs[:10]}{'...' if len(subs)>10 else ''}")

    X, y, groups, chs = build_dataset_all(subs, scenario=CLASS_SCENARIO, window_mode=WINDOW_MODE)
    N, T, C = X.shape
    n_classes = len(np.unique(y))
    print(f"Listo para entrenar: N={N} | T={T} | C={C} | clases={n_classes} | sujetos={len(np.unique(groups))}")

    ds = EEGTrials(X, y, groups)
    criterion = nn.CrossEntropyLoss()

    # preparar JSON folds
    if folds_json_path is None:
        folds_json_path = Path("folds") / f"group_folds_{N_FOLDS}splits.json"
    else:
        folds_json_path = Path(folds_json_path)
    folds_json_path.parent.mkdir(parents=True, exist_ok=True)

    unique_subs = sorted(np.unique(groups).tolist())
    subject_ids_str = [f"S{s:03d}" for s in unique_subs]

    if not folds_json_path.exists():
        if not save_folds_json:
            raise FileNotFoundError(f"Folds JSON no encontrado en {folds_json_path} y save_folds_json=False.")
        save_group_folds_json_with_indices(subject_ids_str, groups, n_splits=N_FOLDS,
                                           out_json_path=folds_json_path,
                                           created_by="Joel_Clasificador",
                                           description=folds_json_description)

    payload = load_group_folds_json(folds_json_path, expected_subject_ids=subject_ids_str, strict_check=False)
    folds = payload["folds"]

    # bucle por folds
    global_folds = []
    ft_prog_folds = []
    all_true = []
    all_pred = []

    for f in folds:
        fold = f["fold"]
        tr_idx = np.asarray(f.get("tr_idx", []), dtype=int)
        te_idx = np.asarray(f.get("te_idx", []), dtype=int)

        if tr_idx.size == 0 or te_idx.size == 0:
            print(f"Advertencia: fold {fold} sin índices tr/te válidos. Saltando.")
            continue

        # ===== Split de validación por SUJETOS dentro del set de entrenamiento =====
        gss = GroupShuffleSplit(n_splits=1, test_size=GLOBAL_VAL_SPLIT, random_state=RANDOM_STATE)
        tr_subj_idx, va_subj_idx = next(gss.split(tr_idx, groups[tr_idx], groups[tr_idx]))
        tr_sub_idx = tr_idx[tr_subj_idx]
        va_idx     = tr_idx[va_subj_idx]

        tr_loader = DataLoader(Subset(ds, tr_sub_idx), batch_size=BATCH_SIZE, shuffle=True,  drop_last=False)
        va_loader = DataLoader(Subset(ds, va_idx),     batch_size=BATCH_SIZE, shuffle=False, drop_last=False)
        te_loader = DataLoader(Subset(ds, te_idx),     batch_size=BATCH_SIZE, shuffle=False, drop_last=False)

        # ===== EEGNet =====
        model = EEGNet(n_ch=C, n_classes=n_classes, F1=8, D=2, kernel_t=64, k_sep=16,
                       pool1_t=4, pool2_t=8, dropout=0.5).to(DEVICE)
        opt = optim.Adam(model.parameters(), lr=LR)

        def _acc(loader):
            return evaluate_with_preds(model, loader)[2]

        # ===== Entrenamiento con logging + EARLY STOPPING por val_acc =====
        print(f"\n[Fold {fold}/{N_FOLDS}] Entrenando modelo global con validación interna por sujetos..."
              f" (n_train={len(tr_sub_idx)} | n_val={len(va_idx)} | n_test={len(te_idx)})")

        best_state = copy.deepcopy(model.state_dict())
        best_val = -1.0
        bad = 0

        for epoch in range(1, EPOCHS_GLOBAL + 1):
            train_epoch(model, tr_loader, opt, criterion)

            if epoch % LOG_EVERY == 0:
                tr_acc = _acc(tr_loader)
                va_acc = _acc(va_loader)
                print(f"  Época {epoch:3d} | train_acc={tr_acc:.4f} | val_acc={va_acc:.4f}")

                if va_acc > best_val + 1e-4:
                    best_val = va_acc
                    best_state = copy.deepcopy(model.state_dict())
                    bad = 0
                else:
                    bad += 1
                    if bad >= GLOBAL_PATIENCE:
                        print(f"  Early stopping en época {epoch} (mejor val_acc={best_val:.4f})")
                        break

        # cargar mejor estado antes de evaluar en test
        model.load_state_dict(best_state)

        # ===== Evaluación global (inter-sujeto puro) =====
        y_true, y_pred, acc_global = evaluate_with_preds(model, te_loader)
        global_folds.append(acc_global)
        all_true.append(y_true); all_pred.append(y_pred)

        print(f"[Fold {fold}/{N_FOLDS}] Global acc={acc_global:.4f}")
        print_report(y_true, y_pred, CLASS_NAMES_4C)

        # ---------- Fine-tuning PROGRESIVO por sujeto con 4-fold CV ----------
        X_te, y_te, g_te = X[te_idx], y[te_idx], groups[te_idx]

        y_true_ft_all, y_pred_ft_all = [], []
        used_subjects = 0
        for sid in np.unique(g_te):
            idx = np.where(g_te == sid)[0]
            Xs, ys = X_te[idx], y_te[idx]

            # Seguridad
            if len(ys) < CALIB_CV_FOLDS or len(np.unique(ys)) < 2:
                continue

            y_true_subj, y_pred_subj = subject_cv_finetune_predict_progressive(
                model, Xs, ys, DEVICE, n_splits=CALIB_CV_FOLDS, n_classes=n_classes
            )
            y_true_ft_all.append(y_true_subj)
            y_pred_ft_all.append(y_pred_subj)
            used_subjects += 1

        if len(y_true_ft_all) > 0:
            y_true_ft_all = np.concatenate(y_true_ft_all)
            y_pred_ft_all = np.concatenate(y_pred_ft_all)
            acc_ft = (y_true_ft_all == y_pred_ft_all).mean()
            print(f"  Fine-tuning PROGRESIVO (por sujeto, {CALIB_CV_FOLDS}-fold CV) acc={acc_ft:.4f} | sujetos={used_subjects}")
            print(f"  Δ(FT-Global) = {acc_ft - acc_global:+.4f}")
        else:
            acc_ft = np.nan
            print("  Fine-tuning PROGRESIVO no ejecutado (sujeto(s) con muestras insuficientes).")

        ft_prog_folds.append(acc_ft)

    # ---------- resultados finales ----------
    if len(all_true) > 0:
        all_true = np.concatenate(all_true)
        all_pred = np.concatenate(all_pred)
    else:
        all_true = np.array([], dtype=int)
        all_pred = np.array([], dtype=int)

    print("\n" + "="*60)
    print("RESULTADOS FINALES")
    print("="*60)
    print("Global folds:", [f"{a:.4f}" for a in global_folds])
    if len(global_folds) > 0:
        print(f"Global mean: {np.mean(global_folds):.4f}")

    print("Fine-tune PROGRESIVO folds:", [("nan" if (a is None or np.isnan(a)) else f"{a:.4f}") for a in ft_prog_folds])
    if len(ft_prog_folds) > 0:
        print(f"Fine-tune PROGRESIVO mean: {np.nanmean(ft_prog_folds):.4f}")
        print(f"Δ(FT-Global) mean: {np.nanmean(ft_prog_folds) - np.mean(global_folds):+.4f}")

    # Matriz de confusión global (sobre todos los folds)
    if all_true.size > 0:
        plot_confusion(all_true, all_pred, CLASS_NAMES_4C,
                       title="Confusion Matrix - Global Model (All Folds)",
                       fname="confusion_global_allfolds.png")
        print("\n↳ Matriz de confusión guardada: confusion_global_allfolds.png")

    return {
        "global_folds": global_folds,
        "ft_prog_folds": ft_prog_folds,
        "all_true": all_true,
        "all_pred": all_pred,
        "folds_json_path": str(folds_json_path)
    }

# ---------- MAIN ----------
if __name__ == "__main__":
    print("🧠 INICIANDO EXPERIMENTO CON EEGNet + FINE-TUNING PROGRESIVO (por sujeto, 4-fold CV)")
    print(f"🔧 Configuración: {CLASS_SCENARIO}, {len(EXPECTED_8)} canales, {WINDOW_MODE}")
    print(f"⚙️  FT: epochs={FT_EPOCHS}, base_lr={FT_BASE_LR}, head_lr={FT_HEAD_LR}, L2SP={FT_L2SP}, patience={FT_PATIENCE}, CV={CALIB_CV_FOLDS}")
    run_experiment()


🚀 Usando dispositivo: cuda
🧠 INICIANDO EXPERIMENTO CON EEGNet + FINE-TUNING PROGRESIVO (por sujeto, 4-fold CV)
🔧 Configuración: 4c, 8 canales, 6s
⚙️  FT: epochs=30, base_lr=5e-05, head_lr=0.001, L2SP=0.0001, patience=5, CV=4
Sujetos elegibles: 103 → [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]...


Construyendo dataset (RAW):   0%|          | 0/103 [00:00<?, ?it/s]

Construyendo dataset (RAW): 100%|██████████| 103/103 [00:12<00:00,  8.24it/s]


Dataset construido: N=8652 | T=960 | C=8 | clases=4 | sujetos únicos=103
Listo para entrenar: N=8652 | T=960 | C=8 | clases=4 | sujetos=103

[Fold 1/5] Entrenando modelo global con validación interna por sujetos... (n_train=5796 | n_val=1092 | n_test=1764)
  Época   5 | train_acc=0.4476 | val_acc=0.4267
  Época  10 | train_acc=0.4777 | val_acc=0.4304
  Época  15 | train_acc=0.4941 | val_acc=0.4359
  Época  20 | train_acc=0.4976 | val_acc=0.4551
  Época  25 | train_acc=0.5074 | val_acc=0.4469
  Época  30 | train_acc=0.5055 | val_acc=0.4652
  Época  35 | train_acc=0.5097 | val_acc=0.4689
  Época  40 | train_acc=0.5043 | val_acc=0.4698
  Época  45 | train_acc=0.5192 | val_acc=0.4734
  Época  50 | train_acc=0.5164 | val_acc=0.4771
  Época  55 | train_acc=0.5110 | val_acc=0.4625
  Época  60 | train_acc=0.5212 | val_acc=0.4734
  Época  65 | train_acc=0.5126 | val_acc=0.4634
  Época  70 | train_acc=0.5228 | val_acc=0.4780
  Época  75 | train_acc=0.5181 | val_acc=0.4670
  Época  80 | train_acc

In [5]:
# -*- coding: utf-8 -*-
# EEGNet unificado: INTER (cross-subject) + INTRA (per-subject CV) con FT progresivo
# Mantiene comportamientos y resultados de tus dos scripts originales.

import os, re, math, random, json, itertools, copy, argparse
from pathlib import Path
from datetime import datetime

import numpy as np
import pandas as pd
import mne

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset, TensorDataset

from sklearn.model_selection import (GroupKFold, StratifiedKFold, StratifiedShuffleSplit,
                                     GroupShuffleSplit)
from sklearn.utils import check_random_state
from sklearn.metrics import confusion_matrix, classification_report, f1_score

import matplotlib.pyplot as plt
from tqdm import tqdm

# =========================
# CONFIGURACIÓN GLOBAL
# =========================
PROJ = Path('..').resolve().parent
DATA_RAW = PROJ / 'data' / 'raw'
CACHE_DIR = PROJ / 'data' / 'cache'
FOLDS_JSON_DEFAULT = PROJ / 'models' / 'folds' / 'Kfold5.json'
CACHE_DIR.mkdir(parents=True, exist_ok=True)

# determinismo
RANDOM_STATE = 42
np.random.seed(RANDOM_STATE)
random.seed(RANDOM_STATE)
torch.manual_seed(RANDOM_STATE)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🚀 Usando dispositivo: {DEVICE}")

# canales / runs / escenario
EXPECTED_8 = ['C3','C4','Cz','CP3','CP4','FC3','FC4','CPz']
MI_RUNS_LR = [4, 8, 12]
MI_RUNS_OF = [6, 10, 14]
BASELINE_RUNS_EO = [1]   # solo para lectura de nombres de canales si hiciera falta
EXCLUDE_SUBJECTS = {38, 88, 89, 92, 100, 104}

CLASS_SCENARIO = '4c'
WINDOW_MODE = '3s'  # 0–3s
FS = 160.0
CLASS_NAMES_4C = ['Left', 'Right', 'Both Fists', 'Both Feet']

# Hiperparámetros compartidos
N_FOLDS_INTER = 5
BATCH_SIZE_INTER = 32
EPOCHS_GLOBAL_INTER = 100
LR_GLOBAL_INTER = 1e-3
GLOBAL_VAL_SPLIT = 0.15
GLOBAL_PATIENCE = 10
LOG_EVERY = 5

# Fine-tuning (compartidos)
CALIB_CV_FOLDS = 4
FT_EPOCHS = 30
FT_BASE_LR = 5e-5
FT_HEAD_LR = 1e-3
FT_L2SP = 1e-4
FT_PATIENCE = 5
FT_VAL_RATIO = 0.2

# INTRA específicos (mismos valores que tu script)
N_FOLDS_INTRA = 5
BATCH_SIZE_GLOBAL_INTRA = 16
EPOCHS_GLOBAL_INTRA = 100
LR_GLOBAL_INTRA = 1e-3

# =========================
# UTILIDADES EEG / EDF
# =========================
_re_file = re.compile(r'[Ss](\d{3}).*?[Rr](\d{2})')

def normalize_label(s: str) -> str:
    if s is None: return s
    s = s.strip()
    s = re.sub(r'[^A-Za-z0-9]', '', s)
    s = re.sub(r'([A-Za-z])0([0-9])', r'\1\2', s)
    s = re.sub(r'([A-Za-z])Z$', r'\1z', s)
    s = s.replace('fp', 'Fp').replace('FP', 'Fp')
    s = ''.join(ch.upper() if ch != 'z' else 'z' for ch in s)
    return s

def rename_channels_1010(raw: mne.io.BaseRaw):
    mapping = {}
    for ch in raw.ch_names:
        lab = normalize_label(ch)
        lab = lab[:-1] + 'z' if lab.endswith('Z') else lab
        lab = re.sub(r'([A-Z])Z$', r'\1z', lab)
        mapping[ch] = lab
    mne.rename_channels(raw.info, mapping)

def ensure_channels_order(raw: mne.io.BaseRaw, desired_channels=EXPECTED_8):
    have = [ch for ch in desired_channels if ch in raw.ch_names]
    missing = [ch for ch in desired_channels if ch not in raw.ch_names]
    if missing:
        print(f"Warning: faltan canales {missing} en archivo {getattr(raw,'filenames', [''])[0]}")
        return None
    raw.reorder_channels([ch for ch in raw.ch_names if ch in desired_channels] +
                         [ch for ch in raw.ch_names if ch not in desired_channels])
    raw.pick_channels(desired_channels, ordered=True)
    return raw

def parse_subject_run(path: Path):
    m = _re_file.search(str(path))
    if not m: return None, None
    return int(m.group(1)), int(m.group(2))

def run_kind(run_id:int):
    if run_id in MI_RUNS_LR: return 'LR'
    if run_id in MI_RUNS_OF: return 'OF'
    if run_id in BASELINE_RUNS_EO: return 'EO'
    return None

def read_raw_edf(path: Path):
    raw = mne.io.read_raw_edf(path, preload=True, verbose=False)
    raw.pick(mne.pick_types(raw.info, eeg=True))
    rename_channels_1010(raw)
    try:
        mont = mne.channels.make_standard_montage('standard_1020')
        raw.set_montage(mont, on_missing='ignore')
    except Exception:
        pass
    if abs(raw.info['sfreq'] - FS) > 1e-6:
        raw.resample(FS, npad="auto")
    raw = ensure_channels_order(raw, EXPECTED_8)
    if raw is None: return None
    return raw

def collect_events_T1T2(raw: mne.io.BaseRaw):
    if raw.annotations is None or len(raw.annotations) == 0:
        return []
    def _norm(s): return str(s).strip().upper().replace(' ', '')
    res = []
    for onset, desc in zip(raw.annotations.onset, raw.annotations.description):
        tag = _norm(desc)
        if tag in ('T1','T2'):
            res.append((float(onset), tag))
    res.sort()
    # deduplicación mínima
    dedup, last_t1, last_t2 = [], -1e9, -1e9
    for t, tag in res:
        if tag == 'T1':
            if (t - last_t1) >= 0.5: dedup.append((t, tag)); last_t1 = t
        else:
            if (t - last_t2) >= 0.5: dedup.append((t, tag)); last_t2 = t
    return dedup

# =========================
# DATASETS (dos variantes)
# =========================
def extract_trials_from_run_common(edf_path: Path, scenario: str, window_mode: str):
    subj, run = parse_subject_run(edf_path)
    kind = run_kind(run)
    if kind not in ('LR','OF','EO'):
        return ([], [])
    raw = read_raw_edf(edf_path)
    if raw is None: return ([], [])
    data = raw.get_data()
    fs = raw.info['sfreq']; assert abs(fs - FS) < 1e-6
    out = []

    if kind in ('LR','OF'):
        events = collect_events_T1T2(raw)
        rel_start, rel_end = (0.0, 3.0) if window_mode == '3s' else (-1.0, 5.0)
        for onset_sec, tag in events:
            if kind == 'LR':
                label = 0 if tag == 'T1' else 1
            else:
                label = 2 if tag == 'T1' else 3
            if scenario == '2c' and label not in (0,1): continue
            if scenario == '3c' and label not in (0,1,2): continue
            s = int(round((raw.first_time + onset_sec + rel_start) * fs))
            e = int(round((raw.first_time + onset_sec + rel_end) * fs))
            if s < 0 or e > data.shape[1]: continue
            seg = data[:, s:e].T.astype(np.float32)
            seg = (seg - seg.mean(axis=0, keepdims=True)) / (seg.std(axis=0, keepdims=True) + 1e-6)
            out.append((seg, label, subj))
    elif kind == 'EO':
        return ([], raw.ch_names)

    return out, raw.ch_names

# INTER: balancea a 21 ensayos por clase/sujeto (como tu script INTER)
def build_dataset_all_inter(subjects, scenario='4c', window_mode='3s'):
    X, y, groups, ch_template = [], [], [], None
    for s in tqdm(subjects, desc="Construyendo dataset INTER (balanceado)"):
        sdir = DATA_RAW / f"S{s:03d}"
        if not sdir.exists(): continue
        trials = {0:[],1:[],2:[],3:[]}
        for r in MI_RUNS_LR + MI_RUNS_OF:
            p = sdir / f"S{s:03d}R{r:02d}.edf"
            if not p.exists(): continue
            outs, chs = extract_trials_from_run_common(p, scenario, window_mode)
            if ch_template is None and chs: ch_template = chs
            for seg, lab, _ in outs:
                trials[lab].append(seg)
        # skip if falta alguna clase
        if any(len(trials[k])==0 for k in (0,1,2,3)): continue
        need = 21
        rng = check_random_state(RANDOM_STATE + s)
        def pick(arr):
            if len(arr) < need:
                idx = rng.choice(len(arr), size=need, replace=True); return [arr[i] for i in idx]
            rng.shuffle(arr); return arr[:need]
        for lab in (0,1,2,3):
            for seg in pick(trials[lab]):
                X.append(seg); y.append(lab); groups.append(s)
    X = np.stack(X, axis=0); y = np.asarray(y, np.int64); groups = np.asarray(groups, np.int64)
    n, T, C = X.shape
    print(f"[INTER] Dataset: N={n} | T={T} | C={C} | clases={len(np.unique(y))} | sujetos={len(np.unique(groups))}")
    return X, y, groups, ch_template

# INTRA: sin balanceo artificial (como tu script INTRA)
def build_dataset_all_intra(subjects, scenario='4c', window_mode='3s'):
    X, y, groups = [], [], []
    for s in subjects:
        sdir = DATA_RAW / f"S{s:03d}"
        if not sdir.exists(): continue
        for r in MI_RUNS_LR + MI_RUNS_OF:
            p = sdir / f"S{s:03d}R{r:02d}.edf"
            if not p.exists(): continue
            outs, _ = extract_trials_from_run_common(p, scenario, window_mode)
            for seg, lab, subj in outs:
                X.append(seg); y.append(lab); groups.append(subj)
    X = np.stack(X, axis=0); y = np.asarray(y, np.int64); groups = np.asarray(groups, np.int64)
    n, T, C = X.shape
    print(f"[INTRA] Dataset: N={n} | T={T} | C={C} | clases={len(np.unique(y))} | sujetos={len(np.unique(groups))}")
    return X, y, groups

def subjects_available():
    subs = []
    for sdir in sorted(DATA_RAW.glob('S*')):
        if not sdir.is_dir(): continue
        try: sid = int(sdir.name[1:])
        except: continue
        if sid in EXCLUDE_SUBJECTS: continue
        any_mi = any((sdir / f"S{sid:03d}R{r:02d}.edf").exists() for r in (MI_RUNS_LR + MI_RUNS_OF))
        if any_mi: subs.append(sid)
    return subs

# =========================
# MODELO
# =========================
class EEGNet(nn.Module):
    def __init__(self, n_ch: int, n_classes: int,
                 F1: int = 8, D: int = 2, kernel_t: int = 64, k_sep: int = 16,
                 pool1_t: int = 4, pool2_t: int = 8, dropout: float = 0.5):
        super().__init__()
        self.n_ch = n_ch; self.n_classes = n_classes
        self.F1 = F1; self.D = D; self.F2 = F1 * D
        self.kernel_t = kernel_t; self.k_sep = k_sep
        self.pool1_t = pool1_t; self.pool2_t = pool2_t
        self.conv_temporal = nn.Conv2d(1, F1, kernel_size=(kernel_t, 1),
                                       padding=(kernel_t // 2, 0), bias=False)
        self.bn1 = nn.BatchNorm2d(F1); self.act = nn.ELU()
        self.conv_depthwise = nn.Conv2d(F1, self.F2, kernel_size=(1, n_ch),
                                        groups=F1, bias=False)
        self.bn2 = nn.BatchNorm2d(self.F2)
        self.pool1 = nn.AvgPool2d(kernel_size=(pool1_t, 1), stride=(pool1_t, 1))
        self.drop1 = nn.Dropout(dropout)
        self.conv_sep_depth = nn.Conv2d(self.F2, self.F2, kernel_size=(k_sep, 1),
                                        groups=self.F2, padding=(k_sep // 2, 0), bias=False)
        self.conv_sep_point = nn.Conv2d(self.F2, self.F2, kernel_size=(1, 1), bias=False)
        self.bn3 = nn.BatchNorm2d(self.F2)
        self.pool2 = nn.AvgPool2d(kernel_size=(pool2_t, 1), stride=(pool2_t, 1))
        self.drop2 = nn.Dropout(dropout)
        self.flatten = nn.Flatten()
        self.fc = None; self.out = None; self._T_in = None

    def _build_head(self, T_in: int, device: torch.device):
        T1 = T_in // self.pool1_t; T2 = T1 // self.pool2_t
        feat_dim = self.F2 * T2 * 1
        self.fc  = nn.Linear(feat_dim, 80, bias=True).to(device)
        self.out = nn.Linear(80, self.n_classes, bias=True).to(device)
        self._T_in = T_in

    def ensure_head(self, T_in: int, device: torch.device):
        if (self.fc is None) or (self.out is None) or (self._T_in != T_in):
            self._build_head(T_in, device)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, _, T, C = x.shape
        self.ensure_head(T, x.device)
        z = self.conv_temporal(x); z = self.bn1(z); z = self.act(z)
        z = self.conv_depthwise(z); z = self.bn2(z); z = self.act(z)
        z = self.pool1(z); z = self.drop1(z)
        z = self.conv_sep_depth(z); z = self.conv_sep_point(z)
        z = self.bn3(z); z = self.act(z)
        z = self.pool2(z); z = self.drop2(z)
        z = self.flatten(z)
        z = self.fc(z); z = self.act(z)
        return self.out(z)  # logits; CrossEntropyLoss aplica softmax internamente

# =========================
# DATASETS TORCH
# =========================
class EEGTrials(Dataset):
    def __init__(self, X, y, groups):
        self.X = X.astype(np.float32)
        self.y = y.astype(np.int64)
        self.g = groups.astype(np.int64)
    def __len__(self): return self.X.shape[0]
    def __getitem__(self, idx):
        x = self.X[idx]
        x = np.expand_dims(x, 0)  # (1,T,C)
        return torch.from_numpy(x), torch.tensor(self.y[idx]), torch.tensor(self.g[idx])

# =========================
# ENTRENAR / EVALUAR
# =========================
def train_epoch(model, loader, opt, criterion):
    model.train()
    for xb, yb, _ in loader:
        xb, yb = xb.to(DEVICE), yb.to(DEVICE)
        opt.zero_grad(set_to_none=True)
        logits = model(xb)
        loss = criterion(logits, yb)
        loss.backward(); opt.step()

@torch.no_grad()
def evaluate_with_preds(model, loader):
    model.eval()
    y_true, y_pred = [], []
    for xb, yb, _ in loader:
        xb = xb.to(DEVICE)
        logits = model(xb)
        pred = logits.argmax(dim=1).cpu().numpy().tolist()
        y_pred.extend(pred); y_true.extend(yb.numpy().tolist())
    y_true = np.asarray(y_true, dtype=int)
    y_pred = np.asarray(y_pred, dtype=int)
    acc = (y_true == y_pred).mean()
    return y_true, y_pred, float(acc)

def plot_confusion(y_true, y_pred, classes, title, fname):
    cm = confusion_matrix(y_true, y_pred, labels=list(range(len(classes))))
    with np.errstate(invalid='ignore'):
        cm_norm = cm.astype('float') / cm.sum(axis=1, keepdims=True)
    cm_norm = np.nan_to_num(cm_norm)
    plt.figure(figsize=(6,5))
    plt.imshow(cm_norm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title(title); plt.colorbar(fraction=0.046, pad=0.04)
    ticks = np.arange(len(classes))
    plt.xticks(ticks, classes, rotation=45, ha='right'); plt.yticks(ticks, classes)
    fmt = '.2f'; thresh = cm_norm.max() / 2.
    for i, j in itertools.product(range(cm_norm.shape[0]), range(cm_norm.shape[1])):
        plt.text(j, i, format(cm_norm[i, j], fmt),
                 ha="center", color="white" if cm_norm[i, j] > thresh else "black")
    plt.ylabel('True label'); plt.xlabel('Predicted label')
    plt.tight_layout(); plt.savefig(fname, dpi=150, bbox_inches='tight'); plt.close()

def print_report(y_true, y_pred, classes):
    print(classification_report(y_true, y_pred, target_names=classes, digits=4))

# =========================
# FINE-TUNING PROGRESIVO (compartido)
# =========================
def _param_groups(model, mode):
    if mode == 'out':
        train = list(model.out.parameters())
    elif mode == 'head':
        train = list(model.fc.parameters()) + list(model.out.parameters())
    elif mode == 'spatial+head':
        train = (list(model.conv_depthwise.parameters()) +
                 list(model.bn2.parameters()) +
                 list(model.conv_sep_depth.parameters()) +
                 list(model.conv_sep_point.parameters()) +
                 list(model.bn3.parameters()) +
                 list(model.fc.parameters()) +
                 list(model.out.parameters()))
    else:
        raise ValueError(mode)
    return train

def _freeze_for_mode(model, mode):
    for p in model.parameters(): p.requires_grad = False
    if mode == 'out':
        for p in model.out.parameters(): p.requires_grad = True
    elif mode == 'head':
        for p in model.fc.parameters():  p.requires_grad = True
        for p in model.out.parameters(): p.requires_grad = True
    elif mode == 'spatial+head':
        for p in model.conv_depthwise.parameters(): p.requires_grad = True
        for p in model.bn2.parameters():           p.requires_grad = True
        for p in model.conv_sep_depth.parameters():p.requires_grad = True
        for p in model.conv_sep_point.parameters():p.requires_grad = True
        for p in model.bn3.parameters():           p.requires_grad = True
        for p in model.fc.parameters():            p.requires_grad = True
        for p in model.out.parameters():           p.requires_grad = True

def _class_weights(y_np, n_classes):
    counts = np.bincount(y_np, minlength=n_classes).astype(np.float32)
    counts[counts == 0] = 1.0
    w = counts.sum() / counts
    w = w / w.mean()
    return torch.tensor(w, dtype=torch.float32, device=DEVICE)

def _train_one_mode_inter(model, X_cal, y_cal, n_classes, mode,
                          epochs=FT_EPOCHS, batch_size=16,
                          head_lr=FT_HEAD_LR, base_lr=FT_BASE_LR,
                          l2sp_lambda=FT_L2SP, patience=FT_PATIENCE, val_ratio=FT_VAL_RATIO):
    """Versión INTER: early stopping por val_loss (como tu INTER)."""
    sss = StratifiedShuffleSplit(n_splits=1, test_size=val_ratio, random_state=RANDOM_STATE)
    (tr_idx, va_idx), = sss.split(X_cal, y_cal)
    Xtr, ytr = X_cal[tr_idx], y_cal[tr_idx]
    Xva, yva = X_cal[va_idx], y_cal[va_idx]

    ds_tr = TensorDataset(torch.from_numpy(Xtr).float().unsqueeze(1),
                          torch.from_numpy(ytr).long())
    ds_va = TensorDataset(torch.from_numpy(Xva).float().unsqueeze(1),
                          torch.from_numpy(yva).long())
    dl_tr = DataLoader(ds_tr, batch_size=batch_size, shuffle=True, drop_last=False)
    dl_va = DataLoader(ds_va, batch_size=batch_size, shuffle=False, drop_last=False)

    _freeze_for_mode(model, mode)

    if mode == 'spatial+head':
        base_params = (list(model.conv_depthwise.parameters()) +
                       list(model.bn2.parameters()) +
                       list(model.conv_sep_depth.parameters()) +
                       list(model.conv_sep_point.parameters()) +
                       list(model.bn3.parameters()))
        head_params = list(model.fc.parameters()) + list(model.out.parameters())
        train_params = base_params + head_params
        opt = optim.Adam([{"params": base_params, "lr": base_lr},
                          {"params": head_params, "lr": head_lr}])
    else:
        train_params = _param_groups(model, mode)
        opt = optim.Adam(train_params, lr=head_lr)

    ref = [p.detach().clone().to(p.device) for p in train_params]
    class_w = _class_weights(ytr, n_classes)
    crit = nn.CrossEntropyLoss(weight=class_w)

    best_state = copy.deepcopy(model.state_dict())
    best_val = float('inf')
    bad = 0

    for _ in range(epochs):
        # train
        model.train()
        for xb, yb in dl_tr:
            xb, yb = xb.to(DEVICE), yb.to(DEVICE)
            opt.zero_grad(set_to_none=True)
            logits = model(xb)
            loss = crit(logits, yb)
            reg = 0.0
            for p_cur, p_ref in zip(train_params, ref):
                reg = reg + torch.sum((p_cur - p_ref)**2)
            loss = loss + l2sp_lambda * reg
            loss.backward(); opt.step()
        # val -> val_loss
        model.eval()
        with torch.no_grad():
            val_loss, nval = 0.0, 0
            for xb, yb in dl_va:
                xb, yb = xb.to(DEVICE), yb.to(DEVICE)
                logits = model(xb)
                loss = crit(logits, yb)
                val_loss += loss.item() * xb.size(0); nval += xb.size(0)
            val_loss /= max(1, nval)

        if val_loss + 1e-7 < best_val:
            best_val = val_loss; bad = 0
            best_state = copy.deepcopy(model.state_dict())
        else:
            bad += 1
            if bad >= patience: break

    model.load_state_dict(best_state)
    return model

def _train_one_mode_intra(model, X_tr, y_tr, X_va, y_va, mode, n_classes,
                          epochs=FT_EPOCHS, head_lr=FT_HEAD_LR, base_lr=FT_BASE_LR,
                          l2sp_lambda=FT_L2SP, patience=FT_PATIENCE, batch_size=16):
    """Versión INTRA: selección por val_acc (como tu INTRA)."""
    _freeze_for_mode(model, mode)
    ds_tr = TensorDataset(torch.from_numpy(X_tr).float().unsqueeze(1),
                          torch.from_numpy(y_tr).long())
    ds_va = TensorDataset(torch.from_numpy(X_va).float().unsqueeze(1),
                          torch.from_numpy(y_va).long())
    dl_tr = DataLoader(ds_tr, batch_size=batch_size, shuffle=True, drop_last=False)
    dl_va = DataLoader(ds_va, batch_size=batch_size, shuffle=False, drop_last=False)

    if mode == 'spatial+head':
        base_params = (list(model.conv_depthwise.parameters()) +
                       list(model.bn2.parameters()) +
                       list(model.conv_sep_depth.parameters()) +
                       list(model.conv_sep_point.parameters()) +
                       list(model.bn3.parameters()))
        head_params = list(model.fc.parameters()) + list(model.out.parameters())
        train_params = base_params + head_params
        opt = optim.Adam([{"params": base_params, "lr": base_lr},
                          {"params": head_params, "lr": head_lr}])
    elif mode == 'head':
        train_params = list(model.fc.parameters()) + list(model.out.parameters())
        opt = optim.Adam(train_params, lr=head_lr)
    else:  # out
        train_params = list(model.out.parameters())
        opt = optim.Adam(train_params, lr=head_lr)

    ref = [p.detach().clone().to(p.device) for p in train_params]
    class_w = _class_weights(y_tr, n_classes)
    crit = nn.CrossEntropyLoss(weight=class_w)

    best_state = copy.deepcopy(model.state_dict())
    best_val_acc = -1.0
    bad = 0

    for _ in range(epochs):
        # train
        model.train()
        for xb, yb in dl_tr:
            xb, yb = xb.to(DEVICE), yb.to(DEVICE)
            opt.zero_grad(set_to_none=True)
            logits = model(xb)
            loss = crit(logits, yb)
            reg = 0.0
            for p_cur, p_ref in zip(train_params, ref):
                reg = reg + torch.sum((p_cur - p_ref)**2)
            loss = loss + l2sp_lambda * reg
            loss.backward(); opt.step()
        # val -> val_acc
        y_true_va, y_pred_va, acc_va = _eval_dl(model, dl_va)
        if acc_va > best_val_acc + 1e-6:
            best_val_acc = acc_va; best_state = copy.deepcopy(model.state_dict()); bad = 0
        else:
            bad += 1
            if bad >= patience: break

    model.load_state_dict(best_state)
    return copy.deepcopy(model), float(best_val_acc)

@torch.no_grad()
def _eval_dl(model, loader):
    model.eval()
    y_true, y_pred = [], []
    for xb, yb in loader:
        xb = xb.to(DEVICE)
        logits = model(xb)
        y_pred.extend(logits.argmax(1).cpu().numpy().tolist())
        y_true.extend(yb.numpy().tolist())
    y_true = np.asarray(y_true, int); y_pred = np.asarray(y_pred, int)
    acc = (y_true == y_pred).mean() if y_true.size else 0.0
    return y_true, y_pred, float(acc)

@torch.no_grad()
def predict_numpy(model, X_np):
    model.eval()
    xb = torch.from_numpy(X_np).float().unsqueeze(1).to(DEVICE)
    logits = model(xb)
    return logits.argmax(dim=1).cpu().numpy()

def subject_cv_finetune_predict_progressive_inter(model_global, Xs, ys, n_classes):
    """INTER: 4-fold en el sujeto, entrena out/head/spatial+head (val_loss) y elige por accuracy en holdout."""
    skf = StratifiedKFold(n_splits=CALIB_CV_FOLDS, shuffle=True, random_state=RANDOM_STATE)
    y_true_full = np.empty_like(ys); y_pred_full = np.empty_like(ys)
    for tr_idx, te_idx in skf.split(Xs, ys):
        Xcal, ycal = Xs[tr_idx], ys[tr_idx]
        Xho,  yho  = Xs[te_idx], ys[te_idx]
        # out
        m_out = copy.deepcopy(model_global)
        _train_one_mode_inter(m_out, Xcal, ycal, n_classes, mode='out',
                              epochs=FT_EPOCHS, head_lr=FT_HEAD_LR, l2sp_lambda=FT_L2SP,
                              patience=FT_PATIENCE, val_ratio=FT_VAL_RATIO)
        yhat_out = predict_numpy(m_out, Xho); acc_out = (yhat_out == yho).mean()
        # head
        m_head = copy.deepcopy(model_global)
        _train_one_mode_inter(m_head, Xcal, ycal, n_classes, mode='head',
                              epochs=FT_EPOCHS, head_lr=FT_HEAD_LR, l2sp_lambda=FT_L2SP,
                              patience=FT_PATIENCE, val_ratio=FT_VAL_RATIO)
        yhat_head = predict_numpy(m_head, Xho); acc_head = (yhat_head == yho).mean()
        # spatial+head
        m_sp = copy.deepcopy(model_global)
        _train_one_mode_inter(m_sp, Xcal, ycal, n_classes, mode='spatial+head',
                              epochs=FT_EPOCHS, head_lr=FT_HEAD_LR, base_lr=FT_BASE_LR,
                              l2sp_lambda=FT_L2SP, patience=FT_PATIENCE, val_ratio=FT_VAL_RATIO)
        yhat_sp = predict_numpy(m_sp, Xho); acc_sp = (yhat_sp == yho).mean()
        best_idx = np.argmax([acc_out, acc_head, acc_sp])
        yhat_best = [yhat_out, yhat_head, yhat_sp][best_idx]
        y_true_full[te_idx] = yho; y_pred_full[te_idx] = yhat_best
    return y_true_full, y_pred_full

# =========================
# PIPELINE INTER (cross-subject)
# =========================
def run_inter(save_folds_json=True, folds_json_path=FOLDS_JSON_DEFAULT,
              folds_json_description="GroupKFold folds for comparison"):
    mne.set_log_level('WARNING')
    subs = subjects_available()
    print(f"Sujetos elegibles: {len(subs)} → {subs[:10]}{'...' if len(subs)>10 else ''}")

    X, y, groups, chs = build_dataset_all_inter(subs, scenario=CLASS_SCENARIO, window_mode=WINDOW_MODE)
    N, T, C = X.shape; n_classes = len(np.unique(y))
    ds = EEGTrials(X, y, groups)
    criterion = nn.CrossEntropyLoss()

    folds_json_path = Path(folds_json_path) if folds_json_path else Path("folds/group_folds_5.json")
    folds_json_path.parent.mkdir(parents=True, exist_ok=True)
    subject_ids_str = [f"S{s:03d}" for s in sorted(np.unique(groups).tolist())]

    if not folds_json_path.exists():
        if not save_folds_json:
            raise FileNotFoundError(f"No existe {folds_json_path} y save_folds_json=False.")
        save_group_folds_json_with_indices(subject_ids_str, groups, n_splits=N_FOLDS_INTER,
                                           out_json_path=folds_json_path,
                                           created_by="EEGNet_INTER",
                                           description=folds_json_description)

    payload = load_group_folds_json(folds_json_path, expected_subject_ids=subject_ids_str, strict_check=False)
    folds = payload["folds"]

    global_folds, ft_prog_folds, all_true, all_pred = [], [], [], []

    for f in folds:
        fold = f["fold"]
        tr_idx = np.asarray(f.get("tr_idx", []), dtype=int)
        te_idx = np.asarray(f.get("te_idx", []), dtype=int)
        if tr_idx.size == 0 or te_idx.size == 0:
            print(f"Advertencia: fold {fold} vacío. Saltando."); continue

        # split de validación por sujetos dentro de TRAIN
        gss = GroupShuffleSplit(n_splits=1, test_size=GLOBAL_VAL_SPLIT, random_state=RANDOM_STATE)
        tr_subj_idx, va_subj_idx = next(gss.split(tr_idx, groups[tr_idx], groups[tr_idx]))
        tr_sub_idx = tr_idx[tr_subj_idx]; va_idx = tr_idx[va_subj_idx]

        tr_loader = DataLoader(Subset(ds, tr_sub_idx), batch_size=BATCH_SIZE_INTER, shuffle=True,  drop_last=False)
        va_loader = DataLoader(Subset(ds, va_idx),     batch_size=BATCH_SIZE_INTER, shuffle=False, drop_last=False)
        te_loader = DataLoader(Subset(ds, te_idx),     batch_size=BATCH_SIZE_INTER, shuffle=False, drop_last=False)

        model = EEGNet(n_ch=C, n_classes=n_classes, F1=8, D=2, kernel_t=64, k_sep=16,
                       pool1_t=4, pool2_t=8, dropout=0.5).to(DEVICE)
        opt = optim.Adam(model.parameters(), lr=LR_GLOBAL_INTER)

        def _acc(loader): return evaluate_with_preds(model, loader)[2]

        print(f"\n[Fold {fold}/{N_FOLDS_INTER}] Entrenamiento global (val por sujetos)")
        best_state = copy.deepcopy(model.state_dict()); best_val = -1.0; bad = 0
        for epoch in range(1, EPOCHS_GLOBAL_INTER + 1):
            train_epoch(model, tr_loader, opt, criterion)
            if epoch % LOG_EVERY == 0:
                tr_acc = _acc(tr_loader); va_acc = _acc(va_loader)
                print(f"  Época {epoch:3d} | train_acc={tr_acc:.4f} | val_acc={va_acc:.4f}")
                if va_acc > best_val + 1e-4:
                    best_val = va_acc; best_state = copy.deepcopy(model.state_dict()); bad = 0
                else:
                    bad += 1
                    if bad >= GLOBAL_PATIENCE:
                        print(f"  Early stopping en época {epoch} (mejor val_acc={best_val:.4f})")
                        break
        model.load_state_dict(best_state)

        # test inter-sujeto
        y_true, y_pred, acc_global = evaluate_with_preds(model, te_loader)
        global_folds.append(acc_global); all_true.append(y_true); all_pred.append(y_pred)
        print(f"[Fold {fold}] Global acc={acc_global:.4f}")
        print_report(y_true, y_pred, CLASS_NAMES_4C)

        # FT progresivo por sujeto en TEST
        X_te, y_te, g_te = X[te_idx], y[te_idx], groups[te_idx]
        y_true_ft_all, y_pred_ft_all, used_subjects = [], [], 0
        for sid in np.unique(g_te):
            idx = np.where(g_te == sid)[0]
            Xs, ys = X_te[idx], y_te[idx]
            if len(ys) < CALIB_CV_FOLDS or len(np.unique(ys)) < 2:
                continue
            y_true_subj, y_pred_subj = subject_cv_finetune_predict_progressive_inter(
                model, Xs, ys, n_classes)
            y_true_ft_all.append(y_true_subj); y_pred_ft_all.append(y_pred_subj); used_subjects += 1
        if len(y_true_ft_all) > 0:
            y_true_ft_all = np.concatenate(y_true_ft_all)
            y_pred_ft_all = np.concatenate(y_pred_ft_all)
            acc_ft = (y_true_ft_all == y_pred_ft_all).mean()
            print(f"  FT PROGRESIVO (por sujeto, {CALIB_CV_FOLDS}-fold) acc={acc_ft:.4f} | sujetos={used_subjects}")
            print(f"  Δ(FT-Global) = {acc_ft - acc_global:+.4f}")
        else:
            acc_ft = np.nan
            print("  FT PROGRESIVO no ejecutado (sujeto(s) insuficientes).")
        ft_prog_folds.append(acc_ft)

    all_true = np.concatenate(all_true) if len(all_true)>0 else np.array([], int)
    all_pred = np.concatenate(all_pred) if len(all_pred)>0 else np.array([], int)

    print("\n" + "="*60)
    print("RESULTADOS INTER")
    print("="*60)
    print("Global folds:", [f"{a:.4f}" for a in global_folds])
    if len(global_folds)>0: print(f"Global mean: {np.mean(global_folds):.4f}")
    print("Fine-tune PROGRESIVO folds:", [("nan" if (a is None or np.isnan(a)) else f"{a:.4f}") for a in ft_prog_folds])
    if len(ft_prog_folds)>0:
        print(f"Fine-tune PROGRESIVO mean: {np.nanmean(ft_prog_folds):.4f}")
        print(f"Δ(FT-Global) mean: {np.nanmean(ft_prog_folds) - np.mean(global_folds):+.4f}")

    if all_true.size > 0:
        plot_confusion(all_true, all_pred, CLASS_NAMES_4C,
                       title="INTER — Confusion Matrix (All Folds)",
                       fname="confusion_inter_allfolds.png")
        print("↳ Matriz de confusión guardada: confusion_inter_allfolds.png")

# =========================
# PIPELINE INTRA (per-subject CV)
# =========================
def train_global_model_intra(X, y, epochs=EPOCHS_GLOBAL_INTRA):
    n_ch = X.shape[2]; n_classes = len(np.unique(y))
    model = EEGNet(n_ch=n_ch, n_classes=n_classes, F1=8, D=2, kernel_t=64, k_sep=16,
                   pool1_t=4, pool2_t=8, dropout=0.5).to(DEVICE)
    ds = EEGTrials(X, y, np.zeros(len(y)))
    dl = DataLoader(ds, batch_size=BATCH_SIZE_GLOBAL_INTRA, shuffle=True, drop_last=False)
    crit = nn.CrossEntropyLoss(); opt = optim.Adam(model.parameters(), lr=LR_GLOBAL_INTRA)
    for ep in range(1, epochs+1):
        train_epoch(model, dl, opt, crit)
        if ep % LOG_EVERY == 0:
            y_t, y_p, acc = evaluate_with_preds(model, dl)
            print(f"[GLOBAL PRETRAIN] Época {ep:3d}/{epochs} | train_acc={acc:.4f}")
    return model

def run_intra(n_folds=N_FOLDS_INTRA):
    ts = datetime.now().strftime("%Y%m%d-%H%M%S")
    OUT_BASE = PROJ / 'models' / 'eegnet_intra_ft'
    TAB_DIR = OUT_BASE / 'tables'; LOG_DIR = OUT_BASE / 'logs'; FIG_DIR = OUT_BASE / 'figures'
    for d in (TAB_DIR, LOG_DIR, FIG_DIR): d.mkdir(parents=True, exist_ok=True)
    mne.set_log_level('WARNING')

    subs = subjects_available()
    X, y, groups = build_dataset_all_intra(subs, scenario=CLASS_SCENARIO, window_mode=WINDOW_MODE)
    print("🔹 Entrenando modelo GLOBAL (pretrain sobre todos los sujetos mezclados)...")
    model_global = train_global_model_intra(X, y, epochs=EPOCHS_GLOBAL_INTRA)

    subjects = np.unique(groups)
    rows, cm_items = [], []
    all_true_global, all_pred_global = [], []

    for sid in subjects:
        idx = np.where(groups == sid)[0]
        Xs, ys = X[idx], y[idx]
        print(f"\n== Sujeto S{sid:03d} ==  N={len(ys)} | T={X.shape[1]} | C={X.shape[2]}")
        skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=RANDOM_STATE)
        fold_accs, fold_f1s = [], []
        y_true_full, y_pred_full = [], []

        for fold_i, (tr_idx, te_idx) in enumerate(skf.split(Xs, ys), start=1):
            sss = StratifiedShuffleSplit(n_splits=1, test_size=FT_VAL_RATIO, random_state=RANDOM_STATE+fold_i)
            (cal_idx, va_idx), = sss.split(Xs[tr_idx], ys[tr_idx])
            cal_idx = tr_idx[cal_idx]; va_idx = tr_idx[va_idx]; te_idx_ = te_idx
            Xcal, ycal = Xs[cal_idx], ys[cal_idx]
            Xval, yval = Xs[va_idx], ys[va_idx]
            Xtest, ytest = Xs[te_idx_], ys[te_idx_]
            print(f"  [Fold {fold_i}/{n_folds}] train_cal={len(ycal)} | val={len(yval)} | test={len(ytest)}")

            # clones del global
            m_out  = copy.deepcopy(model_global).to(DEVICE)
            m_head = copy.deepcopy(model_global).to(DEVICE)
            m_sp   = copy.deepcopy(model_global).to(DEVICE)
            n_classes = len(np.unique(y))

            # out/head/spatial+head con selección por val_acc
            m_out,  accA = _train_one_mode_intra(m_out,  Xcal, ycal, Xval, yval, mode='out',
                                                 n_classes=n_classes, epochs=FT_EPOCHS,
                                                 head_lr=FT_HEAD_LR, l2sp_lambda=FT_L2SP,
                                                 patience=FT_PATIENCE, batch_size=16)
            m_head, accB = _train_one_mode_intra(m_head, Xcal, ycal, Xval, yval, mode='head',
                                                 n_classes=n_classes, epochs=FT_EPOCHS,
                                                 head_lr=FT_HEAD_LR, l2sp_lambda=FT_L2SP,
                                                 patience=FT_PATIENCE, batch_size=16)
            m_sp,   accC = _train_one_mode_intra(m_sp,   Xcal, ycal, Xval, yval, mode='spatial+head',
                                                 n_classes=n_classes, epochs=FT_EPOCHS,
                                                 head_lr=FT_HEAD_LR, base_lr=FT_BASE_LR,
                                                 l2sp_lambda=FT_L2SP, patience=FT_PATIENCE, batch_size=16)

            stages = [('out',m_out,accA), ('head',m_head,accB), ('spatial+head',m_sp,accC)]
            best_name, best_model, best_val = max(stages, key=lambda t: t[2])

            yhat = predict_numpy(best_model, Xtest)
            acc = (yhat == ytest).mean()
            f1m = f1_score(ytest, yhat, average='macro')
            print(f"    → Best stage={best_name} (val_acc={best_val:.4f}) | TEST acc={acc:.4f} | f1m={f1m:.4f}")

            fold_accs.append(float(acc)); fold_f1s.append(float(f1m))
            y_true_full.extend(ytest.tolist()); y_pred_full.extend(yhat.tolist())

        acc_mu = float(np.mean(fold_accs)) if fold_accs else 0.0
        f1_mu  = float(np.mean(fold_f1s)) if fold_f1s else 0.0
        rows.append(dict(subject=f"S{sid:03d}", acc_mean=acc_mu, f1_macro_mean=f1_mu, k=n_folds))
        print(f"  ⇒ Sujeto S{sid:03d} | ACC={acc_mu:.3f} | F1m={f1_mu:.3f}")

        cm = confusion_matrix(y_true_full, y_pred_full, labels=list(range(len(CLASS_NAMES_4C))))
        cm_items.append((f"S{sid:03d}", cm, CLASS_NAMES_4C))
        all_true_global.extend(y_true_full); all_pred_global.extend(y_pred_full)

    # resumen global y guardados (idéntico a tu INTRA)
    df = pd.DataFrame(rows).sort_values("subject")
    acc_mu_glob = float(df['acc_mean'].mean()) if not df.empty else 0.0
    acc_sd_glob = float(df['acc_mean'].std(ddof=0)) if not df.empty else 0.0
    f1_mu_glob  = float(df['f1_macro_mean'].mean()) if not df.empty else 0.0
    f1_sd_glob  = float(df['f1_macro_mean'].std(ddof=0)) if not df.empty else 0.0

    df_glob = pd.DataFrame([{'subject':'GLOBAL','acc_mean':acc_mu_glob,'f1_macro_mean':f1_mu_glob,'k':n_folds}])
    df_out = pd.concat([df, df_glob], ignore_index=True)

    run_tag = f"{ts}_eegnet_intra_ft_{'0to3000ms' if WINDOW_MODE=='3s' else 'custom'}"
    out_csv = TAB_DIR / f"{run_tag}_metrics.csv"
    df_out.to_csv(out_csv, index=False); print(f"\n📄 CSV → {out_csv}")

    out_txt = LOG_DIR / f"{run_tag}_metrics.txt"
    with open(out_txt, "w", encoding="utf-8") as f:
        f.write(f"EEGNet INTRA FT (progresivo) — {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
        f.write(f"k-fold por sujeto: {n_folds}\n\n")
        f.write("subject | acc_mean | f1_macro_mean | k\n")
        f.write("-"*60 + "\n")
        for _, r in df_out.iterrows():
            f.write(f"{r['subject']} | {r['acc_mean']:.4f} | {r['f1_macro_mean']:.4f} | {int(r['k'])}\n")
        f.write("\nGLOBAL:\n")
        f.write(f"ACC={acc_mu_glob:.3f}±{acc_sd_glob:.3f} | F1m={f1_mu_glob:.3f}±{f1_sd_glob:.3f}\n")
    print(f"📝 TXT → {out_txt}")

    # mosaicos por sujeto
    def _plot_mosaic(cm_items, per_fig=12, n_cols=4):
        n = len(cm_items)
        if n == 0: return
        n_figs = int(np.ceil(n / per_fig))
        for fig_idx in range(n_figs):
            start, end = fig_idx*per_fig, min((fig_idx+1)*per_fig, n)
            chunk = cm_items[start:end]
            count = len(chunk); n_rows = int(np.ceil(count / n_cols))
            fig, axes = plt.subplots(n_rows, n_cols, figsize=(4.5*n_cols, 3.8*n_rows), dpi=140)
            axes = np.atleast_2d(axes).flatten()
            for ax_i, (label, cm_sum, classes) in enumerate(chunk):
                ax = axes[ax_i]
                with np.errstate(invalid='ignore'):
                    cm_norm = cm_sum.astype('float') / cm_sum.sum(axis=1, keepdims=True)
                cm_norm = np.nan_to_num(cm_norm)
                ax.imshow(cm_norm, interpolation='nearest', cmap=plt.cm.Blues, vmin=0.0, vmax=1.0)
                ax.set_title(label); ax.set_xticks(range(len(classes))); ax.set_yticks(range(len(classes)))
                ax.set_xticklabels(classes, rotation=45, ha='right', fontsize=9)
                ax.set_yticklabels(classes, fontsize=9)
                for i, j in itertools.product(range(cm_norm.shape[0]), range(cm_norm.shape[1])):
                    ax.text(j, i, f"{cm_norm[i, j]:.2f}", ha="center",
                            color="white" if cm_norm[i, j] > 0.5 else "black")
                ax.set_xlabel(""); ax.set_ylabel("")
            for j in range(ax_i + 1, len(axes)): axes[j].axis("off")
            fig.suptitle(f"INTRA — Matrices de confusión por sujeto (página {fig_idx+1}/{n_figs})", y=0.995, fontsize=14)
            fig.tight_layout(rect=[0, 0, 1, 0.97])
            out_png = (PROJ / 'models' / 'eegnet_intra_ft' / 'figures' / f"{run_tag}_confusions_p{fig_idx+1}.png")
            fig.savefig(out_png); plt.close(fig); print(f"🖼️  Fig → {out_png}")

    _plot_mosaic(cm_items, per_fig=12, n_cols=4)

    # confusión global
    if len(all_true_global) > 0:
        all_true = np.array(all_true_global, int); all_pred = np.array(all_pred_global, int)
        cmG = confusion_matrix(all_true, all_pred, labels=list(range(len(CLASS_NAMES_4C))))
        plt.figure(figsize=(6.5,5.2), dpi=140)
        with np.errstate(invalid='ignore'):
            cm_norm = cmG.astype('float') / cmG.sum(axis=1, keepdims=True)
        cm_norm = np.nan_to_num(cm_norm)
        plt.imshow(cm_norm, interpolation='nearest', cmap=plt.cm.Blues, vmin=0, vmax=1)
        plt.title("INTRA — Matriz de confusión GLOBAL"); plt.colorbar(fraction=0.046, pad=0.04)
        ticks = np.arange(len(CLASS_NAMES_4C))
        plt.xticks(ticks, CLASS_NAMES_4C, rotation=45, ha='right'); plt.yticks(ticks, CLASS_NAMES_4C)
        for i, j in itertools.product(range(cm_norm.shape[0]), range(cm_norm.shape[1])):
            plt.text(j, i, f"{cm_norm[i, j]:.2f}",
                     ha="center", va="center",
                     color="white" if cm_norm[i, j] > 0.5 else "black")
        plt.tight_layout()
        out_png = PROJ / 'models' / 'eegnet_intra_ft' / 'figures' / f"{ts}_global_confusion.png"
        plt.savefig(out_png); plt.close(); print(f"🖼️  Global Fig → {out_png}")

    print(f"\n[GLOBAL INTRA FT] ACC={acc_mu_glob:.3f}±{acc_sd_glob:.3f} | F1m={f1_mu_glob:.3f}±{f1_sd_glob:.3f}")

# =========================
# FOLDS JSON helpers (INTER)
# =========================
def save_group_folds_json_with_indices(subject_ids_str, groups_array, n_splits, out_json_path,
                                       created_by="eegnet_unificado", description=None):
    out_json_path = Path(out_json_path)
    unique_subjects_int = sorted(np.unique(groups_array).tolist())
    subject_ids = [f"S{sid:03d}" for sid in unique_subjects_int]
    if len(subject_ids) < n_splits:
        raise ValueError(f"n_splits={n_splits} > n_subjects={len(subject_ids)}")
    groups = np.arange(len(subject_ids))
    gkf = GroupKFold(n_splits=n_splits)
    folds = []; fold_i = 0
    for train_idx_grp, test_idx_grp in gkf.split(groups, groups, groups):
        fold_i += 1
        train_sids = [subject_ids[int(i)] for i in train_idx_grp]
        test_sids  = [subject_ids[int(i)] for i in test_idx_grp]
        train_sids_int = [int(s[1:]) for s in train_sids]
        test_sids_int  = [int(s[1:]) for s in test_sids]
        tr_idx = np.where(np.isin(groups_array, train_sids_int))[0].tolist()
        te_idx = np.where(np.isin(groups_array, test_sids_int))[0].tolist()
        folds.append({"fold": int(fold_i), "train": train_sids, "test": test_sids,
                      "tr_idx": tr_idx, "te_idx": te_idx})
    payload = {
        "created_at": datetime.now().isoformat(),
        "created_by": created_by,
        "description": description if description is not None else "",
        "n_splits": int(n_splits),
        "n_subjects": len(subject_ids),
        "subject_ids": subject_ids,
        "folds": folds
    }
    out_json_path.parent.mkdir(parents=True, exist_ok=True)
    with open(out_json_path, "w", encoding="utf-8") as f:
        json.dump(payload, f, indent=2, ensure_ascii=False)
    print(f"Folds JSON con índices guardado → {out_json_path}")
    return out_json_path

def load_group_folds_json(path_json, expected_subject_ids=None, strict_check=True):
    path_json = Path(path_json)
    if not path_json.exists(): raise FileNotFoundError(f"No existe {path_json}")
    with open(path_json, "r", encoding="utf-8") as f: payload = json.load(f)
    subj_json = payload.get("subject_ids", [])
    if expected_subject_ids is not None:
        expected = sorted(list(expected_subject_ids))
        if subj_json != expected:
            msg = ("subject_ids del JSON no coinciden.\n"
                   f"JSON has {len(subj_json)} vs expected {len(expected)}.")
            if strict_check: raise ValueError(msg)
            else: print("WARNING:", msg)
    return payload

# =========================
# MAIN
# =========================
if __name__ == "__main__":
    # Lista de sujetos elegibles (solo para feedback rápido)
    subs = subjects_available()
    if len(subs) == 0:
        print("No hay sujetos elegibles en data/raw (o todos excluidos).")
    else:
        print(f"Sujetos elegibles: {len(subs)} → {subs[:10]}{'...' if len(subs)>10 else ''}")

    print(f"🔧 Configuración común: escenario={CLASS_SCENARIO}, ventana={WINDOW_MODE}, fs={FS}")

    # === Elige UNO: descomenta la línea del modo que quieras correr ===

    # --- MODO INTER-SUJETO (cross-subject) ---
    run_inter(
        save_folds_json=True,
        folds_json_path=FOLDS_JSON_DEFAULT,
        folds_json_description="GroupKFold folds for comparison"
    )

    # --- MODO INTRA-SUJETO (k-fold por sujeto) ---
    # run_intra(n_folds=N_FOLDS_INTRA)



🚀 Usando dispositivo: cuda
Sujetos elegibles: 103 → [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]...
🔧 Configuración común: escenario=4c, ventana=3s, fs=160.0
Sujetos elegibles: 103 → [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]...


Construyendo dataset INTER (balanceado):   1%|          | 1/103 [00:00<00:13,  7.47it/s]

Construyendo dataset INTER (balanceado):   8%|▊         | 8/103 [00:01<00:13,  7.14it/s]


KeyboardInterrupt: 

In [3]:
# -*- coding: utf-8 -*-
# Replicación fiel del paper "A Deep Learning MI-EEG Classification Model for BCIs"
# Dose et al., EUSIPCO 2018 — ahora con EEGNet (Lawhern et al., 2018) en vez de ShallowConvNet
# Protocolo: entrenamiento global inter-sujeto + FINE-TUNING PROGRESIVO por sujeto
# (CV 4-fold, LRs discriminativos, L2-SP, early stopping con validación)
# + SupCon preentrenamiento contrastivo supervisado (opcional)

import os, re, math, random, json, itertools, copy
from pathlib import Path
from datetime import datetime

import numpy as np
import pandas as pd
import mne

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset

from sklearn.model_selection import GroupKFold, StratifiedKFold, StratifiedShuffleSplit, GroupShuffleSplit
from sklearn.utils import check_random_state
from sklearn.metrics import confusion_matrix, classification_report

import matplotlib.pyplot as plt
from tqdm import tqdm

# =========================
# CONFIGURACIÓN GENERAL
# =========================
PROJ = Path('..').resolve().parent
DATA_RAW = PROJ / 'data' / 'raw'
CACHE_DIR = PROJ / 'data' / 'cache'
FOLDS_DIR = PROJ / 'models' / 'folds' / 'Kfold5.json'
CACHE_DIR.mkdir(parents=True, exist_ok=True)

# Dispositivo y semilla
RANDOM_STATE = 42
torch.manual_seed(RANDOM_STATE)
np.random.seed(RANDOM_STATE)
random.seed(RANDOM_STATE)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🚀 Usando dispositivo: {DEVICE}")

# Escenario de clases
CLASS_SCENARIO = '4c'
WINDOW_MODE = '6s'
FS = 160.0
N_FOLDS = 5

# Entrenamiento global
BATCH_SIZE = 16
EPOCHS_GLOBAL = 100
LR = 1e-3
# >>> Añadidos para diagnóstico/ES global <<<
GLOBAL_VAL_SPLIT = 0.15   # fracción de sujetos (dentro del train) para validación
GLOBAL_PATIENCE  = 10     # épocas sin mejora en val_acc
LOG_EVERY        = 5      # log cada N épocas

# Fine-tuning por sujeto (protocolo robusto)
CALIB_CV_FOLDS = 4            # 4-fold CV por sujeto ~ 75/25
FT_EPOCHS = 30                 # ES con validación
FT_BASE_LR = 5e-5              # convs "base" (temporal/depthwise) — más bajo
FT_HEAD_LR = 1e-3              # fc+out — más alto
FT_L2SP = 1e-4                 # regularización suave
FT_PATIENCE = 5                # early stopping con validación
FT_VAL_RATIO = 0.2             # validación dentro del set de calibración

# Sujetos excluidos
EXCLUDE_SUBJECTS = {38, 88, 89, 92, 100, 104}

# Runs
MI_RUNS_LR = [4, 8, 12]
MI_RUNS_OF = [6, 10, 14]
BASELINE_RUNS_EO = [1]

# Canales (8 en lugar de 64)
EXPECTED_8 = ['C3','C4','Cz','CP3','CP4','FC3','FC4','CPz']

# ====== SupCon (preentrenamiento contrastivo supervisado) ======
USE_SUPCON_PRETRAIN = False      # <- pon False para desactivarlo
SUPCON_EPOCHS = 40
SUPCON_BATCH = 64               # intenta que sea >=64 si cabe en GPU
SUPCON_LR = 1e-3
SUPCON_TEMP = 0.07
SUPCON_PROJ_DIM = 128
SUPCON_LOG_EVERY = 5


# =========================
# UTILIDADES DE CANALES
# =========================
def normalize_label(s: str) -> str:
    if s is None: return s
    s = s.strip()
    s = re.sub(r'[^A-Za-z0-9]', '', s)
    s = re.sub(r'([A-Za-z])0([0-9])', r'\1\2', s)
    s = re.sub(r'([A-Za-z])Z$', r'\1z', s)
    s = s.replace('fp', 'Fp').replace('FP', 'Fp')
    s = ''.join(ch.upper() if ch != 'z' else 'z' for ch in s)
    return s

def rename_channels_1010(raw: mne.io.BaseRaw):
    mapping = {}
    for ch in raw.ch_names:
        lab = normalize_label(ch)
        lab = lab[:-1] + 'z' if lab.endswith('Z') else lab
        lab = re.sub(r'([A-Z])Z$', r'\1z', lab)
        mapping[ch] = lab
    mne.rename_channels(raw.info, mapping)

def ensure_channels_order(raw: mne.io.BaseRaw, desired_channels=EXPECTED_8):
    have = [ch for ch in desired_channels if ch in raw.ch_names]
    missing = [ch for ch in desired_channels if ch not in raw.ch_names]
    if missing:
        print(f"Warning: faltan canales {missing} en archivo {getattr(raw,'filenames', [''])[0]}")
        return None
    # Reordenar y quedarse SOLO con los deseados (8)
    raw.reorder_channels([ch for ch in raw.ch_names if ch in desired_channels] +
                         [ch for ch in raw.ch_names if ch not in desired_channels])
    raw.pick_channels(desired_channels, ordered=True)
    return raw

# =========================
# LECTURA DE EDF y EVENTOS
# =========================
_re_file = re.compile(r'[Ss](\d{3}).*?[Rr](\d{2})')

def parse_subject_run(path: Path):
    m = _re_file.search(str(path))
    if not m: return None, None
    return int(m.group(1)), int(m.group(2))

def run_kind(run_id:int):
    if run_id in MI_RUNS_LR: return 'LR'
    if run_id in MI_RUNS_OF: return 'OF'
    if run_id in BASELINE_RUNS_EO: return 'EO'
    return None

def read_raw_edf(path: Path):
    raw = mne.io.read_raw_edf(path, preload=True, verbose=False)
    raw.pick(mne.pick_types(raw.info, eeg=True))
    rename_channels_1010(raw)
    try:
        mont = mne.channels.make_standard_montage('standard_1020')
        raw.set_montage(mont, on_missing='ignore')
    except Exception:
        pass
    if abs(raw.info['sfreq'] - FS) > 1e-6:
        raw.resample(FS, npad="auto")
    raw = ensure_channels_order(raw, EXPECTED_8)
    if raw is None:
        return None
    # Band-pass opcional (desactivado)
    #raw.filter(l_freq=8., h_freq=30., picks='eeg', method='iir', verbose=False)
    return raw

def collect_events_T1T2(raw: mne.io.BaseRaw):
    if raw.annotations is None or len(raw.annotations) == 0:
        return []
    def _norm(s): return str(s).strip().upper().replace(' ', '')
    res = []
    for onset, desc in zip(raw.annotations.onset, raw.annotations.description):
        tag = _norm(desc)
        if tag in ('T1','T2'):
            res.append((float(onset), tag))
    res.sort()
    dedup = []
    last_t1 = last_t2 = -1e9
    for t, tag in res:
        if tag == 'T1':
            if (t - last_t1) >= 0.5: dedup.append((t, tag)); last_t1 = t
        else:
            if (t - last_t2) >= 0.5: dedup.append((t, tag)); last_t2 = t
    return dedup

# =========================
# CONSTRUCCIÓN DE DATASETS
# =========================
def subjects_available():
    subs = []
    for sdir in sorted(DATA_RAW.glob('S*')):
        if not sdir.is_dir(): continue
        try:
            sid = int(sdir.name[1:])
        except: continue
        if sid in EXCLUDE_SUBJECTS: continue
        any_mi = any((sdir / f"S{sid:03d}R{r:02d}.edf").exists() for r in (MI_RUNS_LR + MI_RUNS_OF))
        if any_mi: subs.append(sid)
    return subs

def extract_trials_from_run(edf_path: Path, scenario: str, window_mode: str):
    subj, run = parse_subject_run(edf_path)
    kind = run_kind(run)
    if kind not in ('LR','OF','EO'):
        return ([], [])

    raw = read_raw_edf(edf_path)
    if raw is None:
        return ([], [])

    data = raw.get_data()
    fs = raw.info['sfreq']
    assert abs(fs - FS) < 1e-6

    out = []

    if kind in ('LR','OF'):
        events = collect_events_T1T2(raw)
        if window_mode == '3s':
            rel_start, rel_end = 0.0, 3.0
        else:
            rel_start, rel_end = -1.0, 5.0

        for onset_sec, tag in events:
            if kind == 'LR':
                if tag == 'T1': label = 'L'
                elif tag == 'T2': label = 'R'
                else: continue
            else:
                if tag == 'T1': label = 'BFISTS'
                elif tag == 'T2': label = 'BFEET'
                else: continue

            if scenario == '2c' and label not in ('L','R'):
                continue
            if scenario == '3c' and label not in ('L','R','BFISTS'):
                continue
            if scenario == '4c' and label not in ('L','R','BFISTS','BFEET'):
                continue

            s = int(round((raw.first_time + onset_sec + rel_start) * fs))
            e = int(round((raw.first_time + onset_sec + rel_end) * fs))
            if s < 0 or e > data.shape[1]:
                continue

            seg = data[:, s:e].T.astype(np.float32)
            # Normalización por época canal-a-canal (z-score)
            seg = (seg - seg.mean(axis=0, keepdims=True)) / (seg.std(axis=0, keepdims=True) + 1e-6)

            if label == 'L':       y = 0
            elif label == 'R':     y = 1
            elif label == 'BFISTS':y = 2
            elif label == 'BFEET': y = 3
            else: continue

            out.append((seg, y, subj))

    elif kind == 'EO':
        return ([], raw.ch_names)

    return out, raw.ch_names

def build_dataset_all(subjects, scenario='4c', window_mode='3s'):
    X, y, groups = [], [], []
    ch_template = None

    for s in tqdm(subjects, desc="Construyendo dataset (RAW)"):
        sdir = DATA_RAW / f"S{s:03d}"
        if not sdir.exists(): continue

        trials_L, trials_R, trials_FISTS, trials_FEET = [], [], [], []

        for r in MI_RUNS_LR:
            p = sdir / f"S{s:03d}R{r:02d}.edf"
            if not p.exists(): continue
            outs, chs = extract_trials_from_run(p, scenario, window_mode)
            if ch_template is None and chs: ch_template = chs
            for seg, lab, _ in outs:
                if lab == 0: trials_L.append(seg)
                elif lab == 1: trials_R.append(seg)

        for r in MI_RUNS_OF:
            p = sdir / f"S{s:03d}R{r:02d}.edf"
            if not p.exists(): continue
            outs, chs = extract_trials_from_run(p, scenario, window_mode)
            if ch_template is None and chs: ch_template = chs
            for seg, lab, _ in outs:
                if lab == 2: trials_FISTS.append(seg)
                elif lab == 3: trials_FEET.append(seg)

        need_per_class = 21
        def pick(trials, n, rng):
            if len(trials) < n:
                idx = rng.choice(len(trials), size=n, replace=True)
                return [trials[i] for i in idx]
            rng.shuffle(trials)
            return trials[:n]

        rng = check_random_state(RANDOM_STATE + s)
        if len(trials_L)==0 or len(trials_R)==0 or len(trials_FISTS)==0 or len(trials_FEET)==0:
            continue

        Lp  = pick(trials_L,     need_per_class, rng)
        Rp  = pick(trials_R,     need_per_class, rng)
        FIp = pick(trials_FISTS, need_per_class, rng)
        FEp = pick(trials_FEET,  need_per_class, rng)

        pack = [(Lp, 0), (Rp, 1), (FIp, 2), (FEp, 3)]
        for segs, lab in pack:
            for seg in segs:
                X.append(seg)
                y.append(lab)
                groups.append(s)

    X = np.stack(X, axis=0)
    y = np.asarray(y, dtype=np.int64)
    groups = np.asarray(groups, dtype=np.int64)

    n, T, C = X.shape
    n_classes = len(np.unique(y))
    print(f"Dataset construido: N={n} | T={T} | C={C} | clases={n_classes} | sujetos únicos={len(np.unique(groups))}")
    return X, y, groups, ch_template

# =========================
# EEGNet (Lawhern et al., 2018) adaptado a (B,1,T,C)
# =========================
class EEGNet(nn.Module):
    """
    Entrada: x de forma (B, 1, T, C)  [T=tiempo, C=canales]
    Bloques:
      1) Temporal conv     : Conv2d(1 -> F1, (kernel_t,1), padding 'same'), BN, ELU
      2) Depthwise (espacial): Conv2d(F1 -> F1*D, (1,C), groups=F1, BN, ELU, AvgPool(4,1), Dropout
      3) Separable temporal: Depthwise temporal (k_sep,1) groups=F1*D + Pointwise 1x1 a F2, BN, ELU, AvgPool(8,1), Dropout
      4) FC -> OUT
    """
    def __init__(self, n_ch: int, n_classes: int,
                 F1: int = 8, D: int = 2, kernel_t: int = 64, k_sep: int = 16,
                 pool1_t: int = 4, pool2_t: int = 8, dropout: float = 0.5):
        super().__init__()
        self.n_ch = n_ch
        self.n_classes = n_classes
        self.F1 = F1
        self.D = D
        self.F2 = F1 * D
        self.kernel_t = kernel_t
        self.k_sep = k_sep
        self.pool1_t = pool1_t
        self.pool2_t = pool2_t

        # Bloque 1: temporal
        self.conv_temporal = nn.Conv2d(1, F1, kernel_size=(kernel_t, 1),
                                       padding=(kernel_t // 2, 0), bias=False)
        self.bn1 = nn.BatchNorm2d(F1)
        self.act = nn.ELU()

        # Bloque 2: depthwise (espacial)
        self.conv_depthwise = nn.Conv2d(F1, self.F2, kernel_size=(1, n_ch),
                                        groups=F1, bias=False)
        self.bn2 = nn.BatchNorm2d(self.F2)
        self.pool1 = nn.AvgPool2d(kernel_size=(pool1_t, 1), stride=(pool1_t, 1))
        self.drop1 = nn.Dropout(dropout)

        # Bloque 3: separable temporal (depthwise temporal + pointwise)
        self.conv_sep_depth = nn.Conv2d(self.F2, self.F2, kernel_size=(k_sep, 1),
                                        groups=self.F2, padding=(k_sep // 2, 0), bias=False)
        self.conv_sep_point = nn.Conv2d(self.F2, self.F2, kernel_size=(1, 1), bias=False)
        self.bn3 = nn.BatchNorm2d(self.F2)
        self.pool2 = nn.AvgPool2d(kernel_size=(pool2_t, 1), stride=(pool2_t, 1))
        self.drop2 = nn.Dropout(dropout)

        self.flatten = nn.Flatten()

        # Cabeza dinámica
        self.fc = None
        self.out = None
        self._T_in = None

    def _build_head(self, T_in: int, device: torch.device):
        # Con padding 'same' en temporal y separable temporal,
        # el tamaño temporal se reduce por los pools:
        T1 = T_in // self.pool1_t
        T2 = T1 // self.pool2_t
        feat_dim = self.F2 * T2 * 1  # ancho=1 tras conv_depthwise (kernel (1,C))
        self.fc = nn.Linear(feat_dim, 80, bias=True).to(device)
        self.out = nn.Linear(80, self.n_classes, bias=True).to(device)
        self._T_in = T_in

    def ensure_head(self, T_in: int, device: torch.device):
        if (self.fc is None) or (self.out is None) or (self._T_in != T_in):
            self._build_head(T_in, device)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B,1,T,C)
        B, _, T, C = x.shape
        self.ensure_head(T, x.device)

        z = self.conv_temporal(x)
        z = self.bn1(z); z = self.act(z)

        z = self.conv_depthwise(z)   # (B, F2, T, 1)
        z = self.bn2(z); z = self.act(z)
        z = self.pool1(z)
        z = self.drop1(z)

        z = self.conv_sep_depth(z)
        z = self.conv_sep_point(z)
        z = self.bn3(z); z = self.act(z)
        z = self.pool2(z)
        z = self.drop2(z)

        z = self.flatten(z)
        z = self.fc(z); z = self.act(z)
        z = self.out(z)
        return z

    def forward_features(self, x: torch.Tensor) -> torch.Tensor:
        """
        Devuelve el embedding pre-logits tras fc+ELU (dim=80).
        x: (B,1,T,C)
        """
        B, _, T, C = x.shape
        self.ensure_head(T, x.device)

        z = self.conv_temporal(x); z = self.bn1(z); z = self.act(z)
        z = self.conv_depthwise(z); z = self.bn2(z); z = self.act(z)
        z = self.pool1(z); z = self.drop1(z)
        z = self.conv_sep_depth(z); z = self.conv_sep_point(z)
        z = self.bn3(z); z = self.act(z)
        z = self.pool2(z); z = self.drop2(z)
        z = self.flatten(z)
        z = self.fc(z); z = self.act(z)   # embedding 80-D
        return z

# =========================
# TORCH DATASET
# =========================
class EEGTrials(Dataset):
    def __init__(self, X, y, groups):
        self.X = X.astype(np.float32)
        self.y = y.astype(np.int64)
        self.g = groups.astype(np.int64)
    def __len__(self): return self.X.shape[0]
    def __getitem__(self, idx):
        x = self.X[idx]
        x = np.expand_dims(x, 0)                 # (1, T, C)
        return torch.from_numpy(x), torch.tensor(self.y[idx]), torch.tensor(self.g[idx])

CLASS_NAMES_4C = ['Left', 'Right', 'Both Fists', 'Both Feet']

# ====== Augmentations EEG para SupCon ======
class EEGAugment(torch.nn.Module):
    def __init__(self, p_jitter=0.5, max_jitter=16,  # ~100 ms a 160 Hz
                 p_noise=0.8, noise_std=0.02,
                 p_tmask=0.4, tmask_max=32,          # ~200 ms
                 p_cdrop=0.2, cdrop_max=1):          # drop 0-1 canales
        super().__init__()
        self.p_jitter = p_jitter
        self.max_jitter = max_jitter
        self.p_noise = p_noise
        self.noise_std = noise_std
        self.p_tmask = p_tmask
        self.tmask_max = tmask_max
        self.p_cdrop = p_cdrop
        self.cdrop_max = cdrop_max

    def forward(self, x):
        """
        Acepta:
          - (1, T, C)  -> ejemplo único (3D)
          - (B, 1, T, C) -> batch (4D)
        Devuelve con la misma dimensionalidad de entrada.
        """
        original_3d = False
        if x.dim() == 3:           # (1, T, C)
            original_3d = True
            x = x.unsqueeze(0)     # -> (B=1, 1, T, C)
        elif x.dim() != 4:
            raise ValueError(f"EEGAugment espera 3D o 4D, recibido {x.dim()}D")

        B, _, T, C = x.shape
        out = x.clone()

        # Jitter temporal
        if torch.rand(1).item() < self.p_jitter and T > 2:
            shift = int(torch.randint(-self.max_jitter, self.max_jitter + 1, (1,)).item())
            if shift > 0:
                out[:, :, shift:, :] = out[:, :, :-shift, :].clone()
            elif shift < 0:
                out[:, :, :shift, :] = out[:, :, -shift:, :].clone()

        # Ruido gaussiano leve
        if torch.rand(1).item() < self.p_noise:
            out = out + torch.randn_like(out) * self.noise_std

        # Time masking corto
        if torch.rand(1).item() < self.p_tmask and T > 4:
            w = int(torch.randint(1, self.tmask_max + 1, (1,)).item())
            s = int(torch.randint(0, max(1, T - w), (1,)).item())
            out[:, :, s:s + w, :] = 0.0

        # Channel dropout
        if torch.rand(1).item() < self.p_cdrop and C > 1:
            k = int(torch.randint(1, self.cdrop_max + 1, (1,)).item())
            ch = torch.randperm(C)[:k]
            out[:, :, :, ch] = 0.0

        if original_3d:
            out = out.squeeze(0)   # vuelve a (1, T, C)
        return out

class ContrastiveTrials(Dataset):
    """ Devuelve dos vistas aumentadas + etiqueta """
    def __init__(self, X, y):
        self.X = X.astype(np.float32)
        self.y = y.astype(np.int64)
        self.aug = EEGAugment()

    def __len__(self): return self.X.shape[0]

    def __getitem__(self, idx):
        x = self.X[idx]               # (T,C)
        y = self.y[idx]
        x = np.expand_dims(x, 0)      # (1,T,C)
        x = torch.from_numpy(x)
        v1 = self.aug(x.clone())
        v2 = self.aug(x.clone())
        return v1, v2, torch.tensor(y, dtype=torch.long)

# ====== Proyección + pérdida SupCon ======
class ProjectionHead(nn.Module):
    def __init__(self, in_dim=80, proj_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, in_dim, bias=True),
            nn.ReLU(inplace=True),
            nn.Linear(in_dim, proj_dim, bias=True)
        )
    def forward(self, z):
        z = self.net(z)
        z = torch.nn.functional.normalize(z, dim=-1)
        return z

class SupConLoss(nn.Module):
    """
    Supervised Contrastive Loss (Khosla et al.)
    features: (B, n_views, D) normalizadas
    labels  : (B,)
    """
    def __init__(self, temperature=0.07):
        super().__init__()
        self.tau = temperature

    def forward(self, features, labels):
        device = features.device
        B, V, D = features.shape
        feat = features.view(B*V, D)
        labels = labels.view(B)
        labels = labels.contiguous().view(-1, 1)                    # (B,1)
        mask = torch.eq(labels, labels.T).float().to(device)        # (B,B)

        sim = torch.div(torch.matmul(feat, feat.T), self.tau)       # (BV,BV)
        logits_mask = torch.ones_like(sim) - torch.eye(B*V, device=device)
        sim = sim * logits_mask

        mask = mask.repeat(V, V)                                     # (BV,BV)

        exp_sim = torch.exp(sim) * logits_mask
        log_prob = sim - torch.log(exp_sim.sum(dim=1, keepdim=True) + 1e-12)

        mean_log_prob_pos = (mask * log_prob).sum(dim=1) / (mask.sum(dim=1) + 1e-12)
        loss = - mean_log_prob_pos.mean()
        return loss

def supcon_pretrain(model: EEGNet, X_tr: np.ndarray, y_tr: np.ndarray,
                    epochs: int = SUPCON_EPOCHS, batch_size: int = SUPCON_BATCH,
                    lr: float = SUPCON_LR, temperature: float = SUPCON_TEMP,
                    proj_dim: int = SUPCON_PROJ_DIM, log_every: int = SUPCON_LOG_EVERY):
    """
    Preentrena el backbone de EEGNet con SupCon usando etiquetas (positivos = misma clase).
    Usa dos vistas augmentadas por ensayo.
    """
    model.train()
    # asegurar construcción de cabeza (necesitamos conocer T)
    with torch.no_grad():
        dummy = torch.from_numpy(X_tr[:2]).float().unsqueeze(1).to(DEVICE)
        _ = model(dummy)

    proj = ProjectionHead(in_dim=80, proj_dim=proj_dim).to(DEVICE)
    crit = SupConLoss(temperature=temperature)
    ds = ContrastiveTrials(X_tr, y_tr)
    dl = DataLoader(ds, batch_size=batch_size, shuffle=True, drop_last=True)

    # Entrenamos backbone + proyector
    opt = optim.Adam(list(model.parameters()) + list(proj.parameters()), lr=lr)

    for ep in range(1, epochs+1):
        running = 0.0
        n = 0
        for v1, v2, yb in dl:
            v1 = v1.to(DEVICE)  # (B,1,T,C)
            v2 = v2.to(DEVICE)
            yb = yb.to(DEVICE)
            opt.zero_grad(set_to_none=True)

            z1 = model.forward_features(v1)         # (B,80)
            z2 = model.forward_features(v2)         # (B,80)
            p1 = proj(z1)                           # (B,D)
            p2 = proj(z2)                           # (B,D)
            feats = torch.stack([p1, p2], dim=1)    # (B,2,D)

            loss = crit(feats, yb)
            loss.backward()
            opt.step()

            bs = yb.size(0)
            running += loss.item() * bs
            n += bs

        if ep % log_every == 0:
            print(f"[SupCon] Epoch {ep:03d}/{epochs} | loss={running/max(1,n):.4f}")

    print("[SupCon] Preentrenamiento contrastivo completado (backbone inicializado).")

# =========================
# ENTRENAMIENTO / EVALUACIÓN
# =========================
def train_epoch(model, loader, opt, criterion):
    model.train()
    for xb, yb, _ in loader:
        xb, yb = xb.to(DEVICE), yb.to(DEVICE)
        opt.zero_grad(set_to_none=True)
        logits = model(xb)
        loss = criterion(logits, yb)
        loss.backward()
        opt.step()

@torch.no_grad()
def evaluate_with_preds(model, loader):
    model.eval()
    y_true, y_pred = [], []
    for xb, yb, _ in loader:
        xb = xb.to(DEVICE)
        logits = model(xb)
        pred = logits.argmax(dim=1).cpu().numpy().tolist()
        y_pred.extend(pred)
        y_true.extend(yb.numpy().tolist())
    y_true = np.asarray(y_true, dtype=int)
    y_pred = np.asarray(y_pred, dtype=int)
    acc = (y_true == y_pred).mean()
    return y_true, y_pred, float(acc)

def plot_confusion(y_true, y_pred, classes, title, fname):
    cm = confusion_matrix(y_true, y_pred, labels=list(range(len(classes))))
    with np.errstate(invalid='ignore'):
        cm_norm = cm.astype('float') / cm.sum(axis=1, keepdims=True)
    cm_norm = np.nan_to_num(cm_norm)

    plt.figure(figsize=(6, 5))
    plt.imshow(cm_norm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title(title)
    plt.colorbar(fraction=0.046, pad=0.04)
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45, ha='right')
    plt.yticks(tick_marks, classes)

    fmt = '.2f'
    thresh = cm_norm.max() / 2.
    for i, j in itertools.product(range(cm_norm.shape[0]), range(cm_norm.shape[1])):
        plt.text(j, i, format(cm_norm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm_norm[i, j] > thresh else "black")

    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.tight_layout()
    plt.savefig(fname, dpi=150, bbox_inches='tight')
    plt.close()

def print_report(y_true, y_pred, classes):
    print(classification_report(y_true, y_pred, target_names=classes, digits=4))

# =========================
# FINE-TUNING PROGRESIVO por sujeto (con validación interna)
# =========================
def _param_groups(model, mode):
    """
    En EEGNet:
      - 'out'            : solo capa final (model.out)
      - 'head'           : fc + out
      - 'spatial+head'   : depthwise + separable + fc + out  (temporal queda congelada)
    """
    if mode == 'out':
        train = list(model.out.parameters())
    elif mode == 'head':
        train = list(model.fc.parameters()) + list(model.out.parameters())
    elif mode == 'spatial+head':
        train = (list(model.conv_depthwise.parameters()) +
                 list(model.bn2.parameters()) +
                 list(model.conv_sep_depth.parameters()) +
                 list(model.conv_sep_point.parameters()) +
                 list(model.bn3.parameters()) +
                 list(model.fc.parameters()) +
                 list(model.out.parameters()))
    else:
        raise ValueError(mode)
    return train

def _freeze_for_mode(model, mode):
    # Congelamos todo
    for p in model.parameters(): p.requires_grad = False
    # Siempre mantenemos CONGELADO el bloque temporal en este protocolo
    # (conv_temporal + bn1)
    if mode == 'out':
        for p in model.out.parameters(): p.requires_grad = True
    elif mode == 'head':
        for p in model.fc.parameters():  p.requires_grad = True
        for p in model.out.parameters(): p.requires_grad = True
    elif mode == 'spatial+head':
        for p in model.conv_depthwise.parameters(): p.requires_grad = True
        for p in model.bn2.parameters():           p.requires_grad = True
        for p in model.conv_sep_depth.parameters():p.requires_grad = True
        for p in model.conv_sep_point.parameters():p.requires_grad = True
        for p in model.bn3.parameters():           p.requires_grad = True
        for p in model.fc.parameters():            p.requires_grad = True
        for p in model.out.parameters():           p.requires_grad = True

def _class_weights(y_np, n_classes):
    counts = np.bincount(y_np, minlength=n_classes).astype(np.float32)
    counts[counts == 0] = 1.0
    weights = counts.sum() / counts
    weights = weights / weights.mean()
    return torch.tensor(weights, dtype=torch.float32, device=DEVICE)

def _train_one_mode(model, X_cal, y_cal, n_classes, mode,
                    epochs=FT_EPOCHS, batch_size=16,
                    head_lr=FT_HEAD_LR, base_lr=FT_BASE_LR,
                    l2sp_lambda=FT_L2SP, patience=FT_PATIENCE, val_ratio=FT_VAL_RATIO):
    """
    Entrena en 'mode' con early stopping sobre un conjunto de validación interno.
    Devuelve el modelo con los mejores pesos (por val loss).
    """
    # Split Cal -> (train_cal, val_cal)
    sss = StratifiedShuffleSplit(n_splits=1, test_size=val_ratio, random_state=RANDOM_STATE)
    (tr_idx, va_idx), = sss.split(X_cal, y_cal)
    Xtr, ytr = X_cal[tr_idx], y_cal[tr_idx]
    Xva, yva = X_cal[va_idx], y_cal[va_idx]

    ds_tr = torch.utils.data.TensorDataset(
        torch.from_numpy(Xtr).float().unsqueeze(1),
        torch.from_numpy(ytr).long()
    )
    ds_va = torch.utils.data.TensorDataset(
        torch.from_numpy(Xva).float().unsqueeze(1),
        torch.from_numpy(yva).long()
    )
    dl_tr = DataLoader(ds_tr, batch_size=batch_size, shuffle=True, drop_last=False)
    dl_va = DataLoader(ds_va, batch_size=batch_size, shuffle=False, drop_last=False)

    _freeze_for_mode(model, mode)

    if mode == 'spatial+head':
        # grupos con LR discriminativos
        base_params = (list(model.conv_depthwise.parameters()) +
                       list(model.bn2.parameters()) +
                       list(model.conv_sep_depth.parameters()) +
                       list(model.conv_sep_point.parameters()) +
                       list(model.bn3.parameters()))
        head_params = list(model.fc.parameters()) + list(model.out.parameters())
        train_params = base_params + head_params
        opt = optim.Adam([
            {"params": base_params, "lr": base_lr},
            {"params": head_params, "lr": head_lr},
        ])
    else:
        train_params = _param_groups(model, mode)
        opt = optim.Adam(train_params, lr=head_lr)

    ref = [p.detach().clone().to(p.device) for p in train_params]
    class_w = _class_weights(ytr, n_classes)
    crit = nn.CrossEntropyLoss(weight=class_w)

    best_state = copy.deepcopy(model.state_dict())
    best_val = float('inf')
    bad = 0

    for _ in range(epochs):
        # --- train ---
        model.train()
        for xb, yb in dl_tr:
            xb, yb = xb.to(DEVICE), yb.to(DEVICE)
            opt.zero_grad(set_to_none=True)
            logits = model(xb)
            loss = crit(logits, yb)
            # L2-SP hacia referencia de los parámetros que estamos entrenando
            reg = 0.0
            for p_cur, p_ref in zip(train_params, ref):
                reg = reg + torch.sum((p_cur - p_ref)**2)
            loss = loss + l2sp_lambda * reg
            loss.backward()
            opt.step()

        # --- val ---
        model.eval()
        with torch.no_grad():
            val_loss = 0.0
            nval = 0
            for xb, yb in dl_va:
                xb, yb = xb.to(DEVICE), yb.to(DEVICE)
                logits = model(xb)
                loss = crit(logits, yb)
                val_loss += loss.item() * xb.size(0)
                nval += xb.size(0)
            val_loss /= max(1, nval)

        if val_loss + 1e-7 < best_val:
            best_val = val_loss
            bad = 0
            best_state = copy.deepcopy(model.state_dict())
        else:
            bad += 1
            if bad >= patience:
                break

    model.load_state_dict(best_state)
    return model

@torch.no_grad()
def predict_numpy(model, X_np, device):
    model.eval()
    xb = torch.from_numpy(X_np).float().unsqueeze(1).to(device)  # (N,1,T,C)
    logits = model(xb)
    return logits.argmax(dim=1).cpu().numpy()

def subject_cv_finetune_predict_progressive(model_global, Xs, ys, device,
                                            n_splits=CALIB_CV_FOLDS, n_classes=4):
    """
    Para un sujeto: 4-fold StratifiedKFold.
      - En cada fold: 'out' → 'head' → 'spatial+head' (temporal congelado)
      - Se elige el mejor en el split de holdout del sujeto.
    """
    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=RANDOM_STATE)
    y_true_full = np.empty_like(ys)
    y_pred_full = np.empty_like(ys)

    for tr_idx, te_idx in skf.split(Xs, ys):
        Xcal, ycal = Xs[tr_idx], ys[tr_idx]
        Xho,  yho  = Xs[te_idx], ys[te_idx]

        # Stage A: 'out'
        m_out = copy.deepcopy(model_global)
        _train_one_mode(m_out, Xcal, ycal, n_classes, mode='out',
                        epochs=FT_EPOCHS, head_lr=FT_HEAD_LR, l2sp_lambda=FT_L2SP,
                        patience=FT_PATIENCE, val_ratio=FT_VAL_RATIO)
        yhat_out = predict_numpy(m_out, Xho, device)

        # Stage B: 'head'
        m_head = copy.deepcopy(model_global)
        _train_one_mode(m_head, Xcal, ycal, n_classes, mode='head',
                        epochs=FT_EPOCHS, head_lr=FT_HEAD_LR, l2sp_lambda=FT_L2SP,
                        patience=FT_PATIENCE, val_ratio=FT_VAL_RATIO)
        yhat_head = predict_numpy(m_head, Xho, device)

        # Stage C: 'spatial+head'
        m_sp = copy.deepcopy(model_global)
        _train_one_mode(m_sp, Xcal, ycal, n_classes, mode='spatial+head',
                        epochs=FT_EPOCHS, head_lr=FT_HEAD_LR, base_lr=FT_BASE_LR,
                        l2sp_lambda=FT_L2SP, patience=FT_PATIENCE, val_ratio=FT_VAL_RATIO)
        yhat_sp = predict_numpy(m_sp, Xho, device)

        # mejor de las tres
        accs = [ (yhat_out == yho).mean(), (yhat_head == yho).mean(), (yhat_sp == yho).mean() ]
        best_idx = int(np.argmax(accs))
        yhat_best = [yhat_out, yhat_head, yhat_sp][best_idx]

        y_true_full[te_idx] = yho
        y_pred_full[te_idx] = yhat_best

    return y_true_full, y_pred_full

# =========================
# FOLDS JSON helpers
# =========================
def save_group_folds_json_with_indices(subject_ids_str, groups_array, n_splits, out_json_path,
                                       created_by="dose_experiment", description=None):
    out_json_path = Path(out_json_path)
    unique_subjects_int = sorted(np.unique(groups_array).tolist())
    subject_ids = [f"S{sid:03d}" for sid in unique_subjects_int]

    if len(subject_ids) < n_splits:
        raise ValueError(f"n_splits={n_splits} mayor que número de sujetos={len(subject_ids)}")

    groups = np.arange(len(subject_ids))
    gkf = GroupKFold(n_splits=n_splits)

    folds = []
    fold_i = 0
    for train_idx_grp, test_idx_grp in gkf.split(groups, groups, groups):
        fold_i += 1
        train_sids = [subject_ids[int(i)] for i in train_idx_grp]
        test_sids  = [subject_ids[int(i)] for i in test_idx_grp]

        train_sids_int = [int(s[1:]) for s in train_sids]
        test_sids_int  = [int(s[1:]) for s in test_sids]

        tr_idx = np.where(np.isin(groups_array, train_sids_int))[0].tolist()
        te_idx = np.where(np.isin(groups_array, test_sids_int))[0].tolist()

        folds.append({
            "fold": int(fold_i),
            "train": train_sids,
            "test": test_sids,
            "tr_idx": tr_idx,
            "te_idx": te_idx
        })

    payload = {
        "created_at": datetime.now().isoformat(),
        "created_by": created_by,
        "description": description if description is not None else "",
        "n_splits": int(n_splits),
        "n_subjects": len(subject_ids),
        "subject_ids": subject_ids,
        "folds": folds
    }

    out_json_path.parent.mkdir(parents=True, exist_ok=True)
    with open(out_json_path, "w", encoding="utf-8") as f:
        json.dump(payload, f, indent=2, ensure_ascii=False)

    print(f"Folds JSON con índices guardado → {out_json_path}")
    return out_json_path

def load_group_folds_json(path_json, expected_subject_ids=None, strict_check=True):
    path_json = Path(path_json)
    if not path_json.exists():
        raise FileNotFoundError(f"No existe {path_json}")
    with open(path_json, "r", encoding="utf-8") as f:
        payload = json.load(f)

    subj_json = payload.get("subject_ids", [])
    if expected_subject_ids is not None:
        expected = sorted(list(expected_subject_ids))
        if subj_json != expected:
            msg = ("Los subject_ids del JSON no coinciden con expected_subject_ids.\n"
                   f"JSON has {len(subj_json)} subjects, expected {len(expected)}.\n"
                   f"First 10 JSON: {subj_json[:10]}\nFirst 10 expected: {expected[:10]}")
            if strict_check:
                raise ValueError(msg)
            else:
                print("WARNING: " + msg)
    return payload

# =========================
# EXPERIMENTO
# =========================
def run_experiment(save_folds_json=True, folds_json_path=FOLDS_DIR, folds_json_description="GroupKFold folds for comparison"):
    """
    - Crea/lee JSON con folds por sujeto (incluye tr_idx/te_idx)
    - Entrena modelo global por fold (inter-sujeto puro) con validación interna por sujetos + ES
    - Evalúa Global acc en test
    - Realiza Fine-Tuning PROGRESIVO por sujeto (4-fold CV) y reporta acc
    """
    mne.set_log_level('WARNING')

    # sujetos y dataset
    subs = subjects_available()
    print(f"Sujetos elegibles: {len(subs)} → {subs[:10]}{'...' if len(subs)>10 else ''}")

    X, y, groups, chs = build_dataset_all(subs, scenario=CLASS_SCENARIO, window_mode=WINDOW_MODE)
    N, T, C = X.shape
    n_classes = len(np.unique(y))
    print(f"Listo para entrenar: N={N} | T={T} | C={C} | clases={n_classes} | sujetos={len(np.unique(groups))}")

    ds = EEGTrials(X, y, groups)
    criterion = nn.CrossEntropyLoss()

    # preparar JSON folds
    if folds_json_path is None:
        folds_json_path = Path("folds") / f"group_folds_{N_FOLDS}splits.json"
    else:
        folds_json_path = Path(folds_json_path)
    folds_json_path.parent.mkdir(parents=True, exist_ok=True)

    unique_subs = sorted(np.unique(groups).tolist())
    subject_ids_str = [f"S{s:03d}" for s in unique_subs]

    if not folds_json_path.exists():
        if not save_folds_json:
            raise FileNotFoundError(f"Folds JSON no encontrado en {folds_json_path} y save_folds_json=False.")
        save_group_folds_json_with_indices(subject_ids_str, groups, n_splits=N_FOLDS,
                                           out_json_path=folds_json_path,
                                           created_by="Joel_Clasificador",
                                           description=folds_json_description)

    payload = load_group_folds_json(folds_json_path, expected_subject_ids=subject_ids_str, strict_check=False)
    folds = payload["folds"]

    # bucle por folds
    global_folds = []
    ft_prog_folds = []
    all_true = []
    all_pred = []

    for f in folds:
        fold = f["fold"]
        tr_idx = np.asarray(f.get("tr_idx", []), dtype=int)
        te_idx = np.asarray(f.get("te_idx", []), dtype=int)

        if tr_idx.size == 0 or te_idx.size == 0:
            print(f"Advertencia: fold {fold} sin índices tr/te válidos. Saltando.")
            continue

        # ===== Split de validación por SUJETOS dentro del set de entrenamiento =====
        gss = GroupShuffleSplit(n_splits=1, test_size=GLOBAL_VAL_SPLIT, random_state=RANDOM_STATE)
        tr_subj_idx, va_subj_idx = next(gss.split(tr_idx, groups[tr_idx], groups[tr_idx]))
        tr_sub_idx = tr_idx[tr_subj_idx]
        va_idx     = tr_idx[va_subj_idx]

        tr_loader = DataLoader(Subset(ds, tr_sub_idx), batch_size=BATCH_SIZE, shuffle=True,  drop_last=False)
        va_loader = DataLoader(Subset(ds, va_idx),     batch_size=BATCH_SIZE, shuffle=False, drop_last=False)
        te_loader = DataLoader(Subset(ds, te_idx),     batch_size=BATCH_SIZE, shuffle=False, drop_last=False)

        # ===== EEGNet =====
        model = EEGNet(n_ch=C, n_classes=n_classes, F1=8, D=2, kernel_t=64, k_sep=16,
                       pool1_t=4, pool2_t=8, dropout=0.5).to(DEVICE)

        # ===== PREENTRENAMIENTO SUPCON (opcional) =====
        if USE_SUPCON_PRETRAIN:
            X_tr_sup = X[tr_sub_idx]
            y_tr_sup = y[tr_sub_idx]
            print(f"[Fold {fold}/{N_FOLDS}] SupCon pretrain on {len(X_tr_sup)} trials...")
            supcon_pretrain(model, X_tr_sup, y_tr_sup,
                            epochs=SUPCON_EPOCHS, batch_size=SUPCON_BATCH,
                            lr=SUPCON_LR, temperature=SUPCON_TEMP, proj_dim=SUPCON_PROJ_DIM,
                            log_every=SUPCON_LOG_EVERY)

        # Optimizador para entrenamiento supervisado
        opt = optim.Adam(model.parameters(), lr=LR)

        def _acc(loader):
            return evaluate_with_preds(model, loader)[2]

        # ===== Entrenamiento con logging + EARLY STOPPING por val_acc =====
        print(f"\n[Fold {fold}/{N_FOLDS}] Entrenando modelo global con validación interna por sujetos..."
              f" (n_train={len(tr_sub_idx)} | n_val={len(va_idx)} | n_test={len(te_idx)})")

        best_state = copy.deepcopy(model.state_dict())
        best_val = -1.0
        bad = 0

        for epoch in range(1, EPOCHS_GLOBAL + 1):
            train_epoch(model, tr_loader, opt, criterion)

            if epoch % LOG_EVERY == 0:
                tr_acc = _acc(tr_loader)
                va_acc = _acc(va_loader)
                print(f"  Época {epoch:3d} | train_acc={tr_acc:.4f} | val_acc={va_acc:.4f}")

                if va_acc > best_val + 1e-4:
                    best_val = va_acc
                    best_state = copy.deepcopy(model.state_dict())
                    bad = 0
                else:
                    bad += 1
                    if bad >= GLOBAL_PATIENCE:
                        print(f"  Early stopping en época {epoch} (mejor val_acc={best_val:.4f})")
                        break

        # cargar mejor estado antes de evaluar en test
        model.load_state_dict(best_state)

        # ===== Evaluación global (inter-sujeto puro) =====
        y_true, y_pred, acc_global = evaluate_with_preds(model, te_loader)
        global_folds.append(acc_global)
        all_true.append(y_true); all_pred.append(y_pred)

        print(f"[Fold {fold}/{N_FOLDS}] Global acc={acc_global:.4f}")
        print_report(y_true, y_pred, CLASS_NAMES_4C)

        # ---------- Fine-tuning PROGRESIVO por sujeto con 4-fold CV ----------
        X_te, y_te, g_te = X[te_idx], y[te_idx], groups[te_idx]

        y_true_ft_all, y_pred_ft_all = [], []
        used_subjects = 0
        for sid in np.unique(g_te):
            idx = np.where(g_te == sid)[0]
            Xs, ys = X_te[idx], y_te[idx]

            # Seguridad
            if len(ys) < CALIB_CV_FOLDS or len(np.unique(ys)) < 2:
                continue

            y_true_subj, y_pred_subj = subject_cv_finetune_predict_progressive(
                model, Xs, ys, DEVICE, n_splits=CALIB_CV_FOLDS, n_classes=n_classes
            )
            y_true_ft_all.append(y_true_subj)
            y_pred_ft_all.append(y_pred_subj)
            used_subjects += 1

        if len(y_true_ft_all) > 0:
            y_true_ft_all = np.concatenate(y_true_ft_all)
            y_pred_ft_all = np.concatenate(y_pred_ft_all)
            acc_ft = (y_true_ft_all == y_pred_ft_all).mean()
            print(f"  Fine-tuning PROGRESIVO (por sujeto, {CALIB_CV_FOLDS}-fold CV) acc={acc_ft:.4f} | sujetos={used_subjects}")
            print(f"  Δ(FT-Global) = {acc_ft - acc_global:+.4f}")
        else:
            acc_ft = np.nan
            print("  Fine-tuning PROGRESIVO no ejecutado (sujeto(s) con muestras insuficientes).")

        ft_prog_folds.append(acc_ft)

    # ---------- resultados finales ----------
    if len(all_true) > 0:
        all_true = np.concatenate(all_true)
        all_pred = np.concatenate(all_pred)
    else:
        all_true = np.array([], dtype=int)
        all_pred = np.array([], dtype=int)

    print("\n" + "="*60)
    print("RESULTADOS FINALES")
    print("="*60)
    print("Global folds:", [f"{a:.4f}" for a in global_folds])
    if len(global_folds) > 0:
        print(f"Global mean: {np.mean(global_folds):.4f}")

    print("Fine-tune PROGRESIVO folds:", [("nan" if (a is None or np.isnan(a)) else f"{a:.4f}") for a in ft_prog_folds])
    if len(ft_prog_folds) > 0:
        print(f"Fine-tune PROGRESIVO mean: {np.nanmean(ft_prog_folds):.4f}")
        print(f"Δ(FT-Global) mean: {np.nanmean(ft_prog_folds) - np.mean(global_folds):+.4f}")

    # Matriz de confusión global (sobre todos los folds)
    if all_true.size > 0:
        plot_confusion(all_true, all_pred, CLASS_NAMES_4C,
                       title="Confusion Matrix - Global Model (All Folds)",
                       fname="confusion_global_allfolds.png")
        print("\n↳ Matriz de confusión guardada: confusion_global_allfolds.png")

    return {
        "global_folds": global_folds,
        "ft_prog_folds": ft_prog_folds,
        "all_true": all_true,
        "all_pred": all_pred,
        "folds_json_path": str(folds_json_path)
    }

# ---------- MAIN ----------
if __name__ == "__main__":
    print("🧠 INICIANDO EXPERIMENTO CON EEGNet + FINE-TUNING PROGRESIVO (por sujeto, 4-fold CV)")
    print(f"🔧 Configuración: {CLASS_SCENARIO}, {len(EXPECTED_8)} canales, {WINDOW_MODE}")
    print(f"⚙️  FT: epochs={FT_EPOCHS}, base_lr={FT_BASE_LR}, head_lr={FT_HEAD_LR}, L2SP={FT_L2SP}, patience={FT_PATIENCE}, CV={CALIB_CV_FOLDS}")
    print(f"🧲 SupCon: enabled={USE_SUPCON_PRETRAIN}, epochs={SUPCON_EPOCHS}, batch={SUPCON_BATCH}, temp={SUPCON_TEMP}")
    run_experiment()


🚀 Usando dispositivo: cuda
🧠 INICIANDO EXPERIMENTO CON EEGNet + FINE-TUNING PROGRESIVO (por sujeto, 4-fold CV)
🔧 Configuración: 4c, 8 canales, 6s
⚙️  FT: epochs=30, base_lr=5e-05, head_lr=0.001, L2SP=0.0001, patience=5, CV=4
🧲 SupCon: enabled=False, epochs=40, batch=64, temp=0.07
Sujetos elegibles: 103 → [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]...


Construyendo dataset (RAW): 100%|██████████| 103/103 [00:20<00:00,  5.13it/s]


Dataset construido: N=8652 | T=960 | C=8 | clases=4 | sujetos únicos=103
Listo para entrenar: N=8652 | T=960 | C=8 | clases=4 | sujetos=103

[Fold 1/5] Entrenando modelo global con validación interna por sujetos... (n_train=5796 | n_val=1092 | n_test=1764)
  Época   5 | train_acc=0.4545 | val_acc=0.4203
  Época  10 | train_acc=0.4827 | val_acc=0.4359
  Época  15 | train_acc=0.4841 | val_acc=0.4295
  Época  20 | train_acc=0.4990 | val_acc=0.4606
  Época  25 | train_acc=0.4928 | val_acc=0.4487
  Época  30 | train_acc=0.5093 | val_acc=0.4670
  Época  35 | train_acc=0.5047 | val_acc=0.4625
  Época  40 | train_acc=0.5000 | val_acc=0.4734
  Época  45 | train_acc=0.5060 | val_acc=0.4634
  Época  50 | train_acc=0.5091 | val_acc=0.4615
  Época  55 | train_acc=0.4945 | val_acc=0.4570
  Época  60 | train_acc=0.5166 | val_acc=0.4725
  Época  65 | train_acc=0.5185 | val_acc=0.4588
  Época  70 | train_acc=0.5155 | val_acc=0.4689
  Época  75 | train_acc=0.5207 | val_acc=0.4789
  Época  80 | train_acc

In [20]:
# -*- coding: utf-8 -*-
# EEGNet + protocolo global + fine-tuning progresivo por sujeto
# Cambios clave:
# - BN: momentum=0.99, eps=1e-3
# - Dropout: 0.25 (bloque 1), 0.5 (bloque 2)
# - Max-norm (p=2, max=2.0) en conv_depthwise, conv_sep_point, fc, out
# - Adam(lr=1e-2) + MultiStepLR([20,50], gamma=0.1) en global
# - Sin weight decay
# - 8 canales con FCz
# - Flag para apagar/encender z-score por época
# - Batch por defecto 32 (puedes probar 64)

import os, re, math, random, json, itertools, copy
from pathlib import Path
from datetime import datetime

import numpy as np
import pandas as pd
import mne

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset

from sklearn.model_selection import GroupKFold, StratifiedKFold, StratifiedShuffleSplit, GroupShuffleSplit
from sklearn.utils import check_random_state
from sklearn.metrics import confusion_matrix, classification_report

import matplotlib.pyplot as plt
from tqdm import tqdm

# =========================
# CONFIGURACIÓN GENERAL
# =========================
PROJ = Path('..').resolve().parent
DATA_RAW = PROJ / 'data' / 'raw'
CACHE_DIR = PROJ / 'data' / 'cache'
FOLDS_DIR = PROJ / 'models' / 'folds' / 'Kfold5.json'
CACHE_DIR.mkdir(parents=True, exist_ok=True)

# Dispositivo y semilla
RANDOM_STATE = 42
torch.manual_seed(RANDOM_STATE)
np.random.seed(RANDOM_STATE)
random.seed(RANDOM_STATE)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🚀 Usando dispositivo: {DEVICE}")

# Escenario de clases
CLASS_SCENARIO = '4c'

# Ventana por defecto 3s (prueba también '6s' para [-1,5]s)
WINDOW_MODE = '6s'
FS = 160.0
N_FOLDS = 5

# Entrenamiento global
BATCH_SIZE = 32          # ⇠ recomendado 32; si la GPU aguanta, prueba 64
EPOCHS_GLOBAL = 100
LR_INIT = 1e-2           # LR alto + BN + max-norm
SCHED_MILESTONES = [20, 50]
SCHED_GAMMA = 0.1

# Validación/ES global
GLOBAL_VAL_SPLIT = 0.15
GLOBAL_PATIENCE  = 10
LOG_EVERY        = 5

# Fine-tuning por sujeto (protocolo robusto)
CALIB_CV_FOLDS = 4
FT_EPOCHS = 30
FT_BASE_LR = 5e-5
FT_HEAD_LR = 1e-3
FT_L2SP = 1e-4
FT_PATIENCE = 5
FT_VAL_RATIO = 0.2

# Normalización por época canal-a-canal (z-score)
NORM_EPOCH_ZSCORE = True

# Sujetos excluidos
EXCLUDE_SUBJECTS = {38, 88, 89, 92, 100, 104}

# Runs
MI_RUNS_LR = [4, 8, 12]
MI_RUNS_OF = [6, 10, 14]
BASELINE_RUNS_EO = [1]

# Canales (8 con FCz)
EXPECTED_8 = ['C3','C4','Cz','CP3','CP4','FC3','FC4','FCz']

# =========================
# UTILIDADES DE CANALES
# =========================
def normalize_label(s: str) -> str:
    if s is None: return s
    s = s.strip()
    s = re.sub(r'[^A-Za-z0-9]', '', s)
    s = re.sub(r'([A-Za-z])0([0-9])', r'\1\2', s)
    s = re.sub(r'([A-Za-z])Z$', r'\1z', s)
    s = s.replace('fp', 'Fp').replace('FP', 'Fp')
    s = ''.join(ch.upper() if ch != 'z' else 'z' for ch in s)
    return s

def rename_channels_1010(raw: mne.io.BaseRaw):
    mapping = {}
    for ch in raw.ch_names:
        lab = normalize_label(ch)
        lab = lab[:-1] + 'z' if lab.endswith('Z') else lab
        lab = re.sub(r'([A-Z])Z$', r'\1z', lab)
        mapping[ch] = lab
    mne.rename_channels(raw.info, mapping)

def ensure_channels_order(raw: mne.io.BaseRaw, desired_channels=EXPECTED_8):
    missing = [ch for ch in desired_channels if ch not in raw.ch_names]
    if missing:
        print(f"Warning: faltan canales {missing} en archivo {getattr(raw,'filenames', [''])[0]}")
        return None
    raw.reorder_channels([ch for ch in raw.ch_names if ch in desired_channels] +
                         [ch for ch in raw.ch_names if ch not in desired_channels])
    raw.pick_channels(desired_channels, ordered=True)
    return raw

# =========================
# LECTURA DE EDF y EVENTOS
# =========================
_re_file = re.compile(r'[Ss](\d{3}).*?[Rr](\d{2})')

def parse_subject_run(path: Path):
    m = _re_file.search(str(path))
    if not m: return None, None
    return int(m.group(1)), int(m.group(2))

def run_kind(run_id:int):
    if run_id in MI_RUNS_LR: return 'LR'
    if run_id in MI_RUNS_OF: return 'OF'
    if run_id in BASELINE_RUNS_EO: return 'EO'
    return None

def read_raw_edf(path: Path):
    raw = mne.io.read_raw_edf(path, preload=True, verbose=False)
    raw.pick(mne.pick_types(raw.info, eeg=True))
    rename_channels_1010(raw)
    try:
        mont = mne.channels.make_standard_montage('standard_1020')
        raw.set_montage(mont, on_missing='ignore')
    except Exception:
        pass
    if abs(raw.info['sfreq'] - FS) > 1e-6:
        raw.resample(FS, npad="auto")
    raw = ensure_channels_order(raw, EXPECTED_8)
    if raw is None:
        return None
    # Notch a 60 Hz (o 50 Hz según tu red) con spectrum_fit
    raw.notch_filter(freqs=[60.0], picks='eeg', method='spectrum_fit', phase='zero')
    return raw

def collect_events_T1T2(raw: mne.io.BaseRaw):
    if raw.annotations is None or len(raw.annotations) == 0:
        return []
    def _norm(s): return str(s).strip().upper().replace(' ', '')
    res = []
    for onset, desc in zip(raw.annotations.onset, raw.annotations.description):
        tag = _norm(desc)
        if tag in ('T1','T2'):
            res.append((float(onset), tag))
    res.sort()
    dedup = []
    last_t1 = last_t2 = -1e9
    for t, tag in res:
        if tag == 'T1':
            if (t - last_t1) >= 0.5: dedup.append((t, tag)); last_t1 = t
        else:
            if (t - last_t2) >= 0.5: dedup.append((t, tag)); last_t2 = t
    return dedup

# =========================
# CONSTRUCCIÓN DE DATASETS
# =========================
def subjects_available():
    subs = []
    for sdir in sorted(DATA_RAW.glob('S*')):
        if not sdir.is_dir(): continue
        try:
            sid = int(sdir.name[1:])
        except: continue
        if sid in EXCLUDE_SUBJECTS: continue
        any_mi = any((sdir / f"S{sid:03d}R{r:02d}.edf").exists() for r in (MI_RUNS_LR + MI_RUNS_OF))
        if any_mi: subs.append(sid)
    return subs

def extract_trials_from_run(edf_path: Path, scenario: str, window_mode: str):
    subj, run = parse_subject_run(edf_path)
    kind = run_kind(run)
    if kind not in ('LR','OF','EO'):
        return ([], [])

    raw = read_raw_edf(edf_path)
    if raw is None:
        return ([], [])

    data = raw.get_data()
    fs = raw.info['sfreq']
    assert abs(fs - FS) < 1e-6

    out = []

    if kind in ('LR','OF'):
        events = collect_events_T1T2(raw)
        if window_mode == '3s':
            rel_start, rel_end = 0.0, 3.0
        else:
            rel_start, rel_end = -1.0, 5.0

        for onset_sec, tag in events:
            if kind == 'LR':
                if tag == 'T1': label = 'L'
                elif tag == 'T2': label = 'R'
                else: continue
            else:
                if tag == 'T1': label = 'BFISTS'
                elif tag == 'T2': label = 'BFEET'
                else: continue

            if scenario == '2c' and label not in ('L','R'):
                continue
            if scenario == '3c' and label not in ('L','R','BFISTS'):
                continue
            if scenario == '4c' and label not in ('L','R','BFISTS','BFEET'):
                continue

            s = int(round((raw.first_time + onset_sec + rel_start) * fs))
            e = int(round((raw.first_time + onset_sec + rel_end) * fs))
            if s < 0 or e > data.shape[1]:
                continue

            seg = data[:, s:e].T.astype(np.float32)
            if NORM_EPOCH_ZSCORE:
                seg = (seg - seg.mean(axis=0, keepdims=True)) / (seg.std(axis=0, keepdims=True) + 1e-6)

            if label == 'L':       y = 0
            elif label == 'R':     y = 1
            elif label == 'BFISTS':y = 2
            elif label == 'BFEET': y = 3
            else: continue

            out.append((seg, y, subj))

    elif kind == 'EO':
        return ([], raw.ch_names)

    return out, raw.ch_names

def build_dataset_all(subjects, scenario='4c', window_mode='3s'):
    X, y, groups = [], [], []
    ch_template = None

    for s in tqdm(subjects, desc="Construyendo dataset (RAW)"):
        sdir = DATA_RAW / f"S{s:03d}"
        if not sdir.exists(): continue

        trials_L, trials_R, trials_FISTS, trials_FEET = [], [], [], []

        for r in MI_RUNS_LR:
            p = sdir / f"S{s:03d}R{r:02d}.edf"
            if not p.exists(): continue
            outs, chs = extract_trials_from_run(p, scenario, window_mode)
            if ch_template is None and chs: ch_template = chs
            for seg, lab, _ in outs:
                if lab == 0: trials_L.append(seg)
                elif lab == 1: trials_R.append(seg)

        for r in MI_RUNS_OF:
            p = sdir / f"S{s:03d}R{r:02d}.edf"
            if not p.exists(): continue
            outs, chs = extract_trials_from_run(p, scenario, window_mode)
            if ch_template is None and chs: ch_template = chs
            for seg, lab, _ in outs:
                if lab == 2: trials_FISTS.append(seg)
                elif lab == 3: trials_FEET.append(seg)

        need_per_class = 21
        def pick(trials, n, rng):
            if len(trials) < n:
                idx = rng.choice(len(trials), size=n, replace=True)
                return [trials[i] for i in idx]
            rng.shuffle(trials)
            return trials[:n]

        rng = check_random_state(RANDOM_STATE + s)
        if len(trials_L)==0 or len(trials_R)==0 or len(trials_FISTS)==0 or len(trials_FEET)==0:
            continue

        Lp  = pick(trials_L,     need_per_class, rng)
        Rp  = pick(trials_R,     need_per_class, rng)
        FIp = pick(trials_FISTS, need_per_class, rng)
        FEp = pick(trials_FEET,  need_per_class, rng)

        pack = [(Lp, 0), (Rp, 1), (FIp, 2), (FEp, 3)]
        for segs, lab in pack:
            for seg in segs:
                X.append(seg)
                y.append(lab)
                groups.append(s)

    X = np.stack(X, axis=0)
    y = np.asarray(y, dtype=np.int64)
    groups = np.asarray(groups, dtype=np.int64)

    n, T, C = X.shape
    n_classes = len(np.unique(y))
    print(f"Dataset construido: N={n} | T={T} | C={C} | clases={n_classes} | sujetos únicos={len(np.unique(groups))}")
    return X, y, groups, ch_template

# =========================
# EEGNet (Lawhern et al., 2018) adaptado a (B,1,T,C)
# =========================
class EEGNet(nn.Module):
    """
    Entrada: x de forma (B, 1, T, C)
    """
    def __init__(self, n_ch: int, n_classes: int,
                 F1: int = 8, D: int = 2, kernel_t: int = 64, k_sep: int = 16,
                 pool1_t: int = 4, pool2_t: int = 8,
                 drop1_p: float = 0.25, drop2_p: float = 0.5):
        super().__init__()
        self.n_ch = n_ch
        self.n_classes = n_classes
        self.F1 = F1
        self.D = D
        self.F2 = F1 * D
        self.kernel_t = kernel_t
        self.k_sep = k_sep
        self.pool1_t = pool1_t
        self.pool2_t = pool2_t

        # Bloque 1: temporal
        self.conv_temporal = nn.Conv2d(1, F1, kernel_size=(kernel_t, 1),
                                       padding=(kernel_t // 2, 0), bias=False)
        self.bn1 = nn.BatchNorm2d(F1, momentum=0.99, eps=1e-3)
        self.act = nn.ELU()

        # Bloque 2: depthwise (espacial)
        self.conv_depthwise = nn.Conv2d(F1, self.F2, kernel_size=(1, n_ch),
                                        groups=F1, bias=False)
        self.bn2 = nn.BatchNorm2d(self.F2, momentum=0.99, eps=1e-3)
        self.pool1 = nn.AvgPool2d(kernel_size=(pool1_t, 1), stride=(pool1_t, 1))
        self.drop1 = nn.Dropout(drop1_p)

        # Bloque 3: separable temporal
        self.conv_sep_depth = nn.Conv2d(self.F2, self.F2, kernel_size=(k_sep, 1),
                                        groups=self.F2, padding=(k_sep // 2, 0), bias=False)
        self.conv_sep_point = nn.Conv2d(self.F2, self.F2, kernel_size=(1, 1), bias=False)
        self.bn3 = nn.BatchNorm2d(self.F2, momentum=0.99, eps=1e-3)
        self.pool2 = nn.AvgPool2d(kernel_size=(pool2_t, 1), stride=(pool2_t, 1))
        self.drop2 = nn.Dropout(drop2_p)

        self.flatten = nn.Flatten()

        # Cabeza dinámica
        self.fc = None
        self.out = None
        self._T_in = None

    def _build_head(self, T_in: int, device: torch.device):
        T1 = T_in // self.pool1_t
        T2 = T1 // self.pool2_t
        feat_dim = self.F2 * T2 * 1
        self.fc = nn.Linear(feat_dim, 80, bias=True).to(device)
        self.out = nn.Linear(80, self.n_classes, bias=True).to(device)
        self._T_in = T_in

    def ensure_head(self, T_in: int, device: torch.device):
        if (self.fc is None) or (self.out is None) or (self._T_in != T_in):
            self._build_head(T_in, device)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, _, T, C = x.shape
        self.ensure_head(T, x.device)

        z = self.conv_temporal(x)
        z = self.bn1(z); z = self.act(z)

        z = self.conv_depthwise(z)   # (B, F2, T, 1)
        z = self.bn2(z); z = self.act(z)
        z = self.pool1(z)
        z = self.drop1(z)

        z = self.conv_sep_depth(z)
        z = self.conv_sep_point(z)
        z = self.bn3(z); z = self.act(z)
        z = self.pool2(z)
        z = self.drop2(z)

        z = self.flatten(z)
        z = self.fc(z); z = self.act(z)
        z = self.out(z)
        return z

# =========================
# TORCH DATASET
# =========================
class EEGTrials(Dataset):
    def __init__(self, X, y, groups):
        self.X = X.astype(np.float32)
        self.y = y.astype(np.int64)
        self.g = groups.astype(np.int64)
    def __len__(self): return self.X.shape[0]
    def __getitem__(self, idx):
        x = self.X[idx]
        x = np.expand_dims(x, 0)                 # (1, T, C)
        return torch.from_numpy(x), torch.tensor(self.y[idx]), torch.tensor(self.g[idx])

CLASS_NAMES_4C = ['Left', 'Right', 'Both Fists', 'Both Feet']

# =========================
# UTIL: MAX-NORM
# =========================
@torch.no_grad()
def apply_max_norm(model, max_value=2.0, p=2.0):
    """
    Aplica max-norm (||W||_p <= max_value) a capas clave de EEGNet.
    """
    layers = []
    if hasattr(model, 'conv_depthwise'): layers.append(model.conv_depthwise)
    if hasattr(model, 'conv_sep_point'): layers.append(model.conv_sep_point)
    if hasattr(model, 'fc'):             layers.append(model.fc)
    if hasattr(model, 'out'):            layers.append(model.out)

    for layer in layers:
        if hasattr(layer, 'weight') and layer.weight is not None:
            w = layer.weight.data
            # normalizamos por filtros (dim=0)
            norms = w.view(w.size(0), -1).norm(p=p, dim=1, keepdim=True)
            desired = torch.clamp(norms, max=max_value)
            w.view(w.size(0), -1).mul_(desired / (1e-8 + norms))

# =========================
# ENTRENAMIENTO / EVALUACIÓN
# =========================
def train_epoch(model, loader, opt, criterion, maxnorm=None):
    model.train()
    for xb, yb, _ in loader:
        xb, yb = xb.to(DEVICE), yb.to(DEVICE)
        opt.zero_grad(set_to_none=True)
        logits = model(xb)
        loss = criterion(logits, yb)
        loss.backward()
        opt.step()
        if maxnorm is not None:
            apply_max_norm(model, max_value=maxnorm, p=2.0)

@torch.no_grad()
def evaluate_with_preds(model, loader):
    model.eval()
    y_true, y_pred = [], []
    for xb, yb, _ in loader:
        xb = xb.to(DEVICE)
        logits = model(xb)
        pred = logits.argmax(dim=1).cpu().numpy().tolist()
        y_pred.extend(pred)
        y_true.extend(yb.numpy().tolist())
    y_true = np.asarray(y_true, dtype=int)
    y_pred = np.asarray(y_pred, dtype=int)
    acc = (y_true == y_pred).mean()
    return y_true, y_pred, float(acc)

def plot_confusion(y_true, y_pred, classes, title, fname):
    cm = confusion_matrix(y_true, y_pred, labels=list(range(len(classes))))
    with np.errstate(invalid='ignore'):
        cm_norm = cm.astype('float') / cm.sum(axis=1, keepdims=True)
    cm_norm = np.nan_to_num(cm_norm)

    plt.figure(figsize=(6, 5))
    plt.imshow(cm_norm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title(title)
    plt.colorbar(fraction=0.046, pad=0.04)
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45, ha='right')
    plt.yticks(tick_marks, classes)

    fmt = '.2f'
    thresh = cm_norm.max() / 2.
    for i, j in itertools.product(range(cm_norm.shape[0]), range(cm_norm.shape[1])):
        plt.text(j, i, format(cm_norm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm_norm[i, j] > thresh else "black")

    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.tight_layout()
    plt.savefig(fname, dpi=150, bbox_inches='tight')
    plt.close()

def print_report(y_true, y_pred, classes):
    print(classification_report(y_true, y_pred, target_names=classes, digits=4))

# =========================
# FINE-TUNING PROGRESIVO por sujeto (con validación interna)
# =========================
def _param_groups(model, mode):
    if mode == 'out':
        train = list(model.out.parameters())
    elif mode == 'head':
        train = list(model.fc.parameters()) + list(model.out.parameters())
    elif mode == 'spatial+head':
        train = (list(model.conv_depthwise.parameters()) +
                 list(model.bn2.parameters()) +
                 list(model.conv_sep_depth.parameters()) +
                 list(model.conv_sep_point.parameters()) +
                 list(model.bn3.parameters()) +
                 list(model.fc.parameters()) +
                 list(model.out.parameters()))
    else:
        raise ValueError(mode)
    return train

def _freeze_for_mode(model, mode):
    for p in model.parameters(): p.requires_grad = False
    if mode == 'out':
        for p in model.out.parameters(): p.requires_grad = True
    elif mode == 'head':
        for p in model.fc.parameters():  p.requires_grad = True
        for p in model.out.parameters(): p.requires_grad = True
    elif mode == 'spatial+head':
        for p in model.conv_depthwise.parameters(): p.requires_grad = True
        for p in model.bn2.parameters():           p.requires_grad = True
        for p in model.conv_sep_depth.parameters():p.requires_grad = True
        for p in model.conv_sep_point.parameters():p.requires_grad = True
        for p in model.bn3.parameters():           p.requires_grad = True
        for p in model.fc.parameters():            p.requires_grad = True
        for p in model.out.parameters():           p.requires_grad = True

def _class_weights(y_np, n_classes):
    counts = np.bincount(y_np, minlength=n_classes).astype(np.float32)
    counts[counts == 0] = 1.0
    weights = counts.sum() / counts
    weights = weights / weights.mean()
    return torch.tensor(weights, dtype=torch.float32, device=DEVICE)

def _train_one_mode(model, X_cal, y_cal, n_classes, mode,
                    epochs=FT_EPOCHS, batch_size=32,
                    head_lr=FT_HEAD_LR, base_lr=FT_BASE_LR,
                    l2sp_lambda=FT_L2SP, patience=FT_PATIENCE, val_ratio=FT_VAL_RATIO):
    sss = StratifiedShuffleSplit(n_splits=1, test_size=val_ratio, random_state=RANDOM_STATE)
    (tr_idx, va_idx), = sss.split(X_cal, y_cal)
    Xtr, ytr = X_cal[tr_idx], y_cal[tr_idx]
    Xva, yva = X_cal[va_idx], y_cal[va_idx]

    ds_tr = torch.utils.data.TensorDataset(
        torch.from_numpy(Xtr).float().unsqueeze(1),
        torch.from_numpy(ytr).long()
    )
    ds_va = torch.utils.data.TensorDataset(
        torch.from_numpy(Xva).float().unsqueeze(1),
        torch.from_numpy(yva).long()
    )
    dl_tr = DataLoader(ds_tr, batch_size=batch_size, shuffle=True, drop_last=False)
    dl_va = DataLoader(ds_va, batch_size=batch_size, shuffle=False, drop_last=False)

    _freeze_for_mode(model, mode)

    if mode == 'spatial+head':
        base_params = (list(model.conv_depthwise.parameters()) +
                       list(model.bn2.parameters()) +
                       list(model.conv_sep_depth.parameters()) +
                       list(model.conv_sep_point.parameters()) +
                       list(model.bn3.parameters()))
        head_params = list(model.fc.parameters()) + list(model.out.parameters())
        train_params = base_params + head_params
        opt = optim.Adam([
            {"params": base_params, "lr": base_lr},
            {"params": head_params, "lr": head_lr},
        ])
    else:
        train_params = _param_groups(model, mode)
        opt = optim.Adam(train_params, lr=head_lr)

    ref = [p.detach().clone().to(p.device) for p in train_params]
    class_w = _class_weights(ytr, n_classes)
    crit = nn.CrossEntropyLoss(weight=class_w)

    best_state = copy.deepcopy(model.state_dict())
    best_val = float('inf')
    bad = 0

    for _ in range(epochs):
        model.train()
        for xb, yb in dl_tr:
            xb, yb = xb.to(DEVICE), yb.to(DEVICE)
            opt.zero_grad(set_to_none=True)
            logits = model(xb)
            loss = crit(logits, yb)
            reg = 0.0
            for p_cur, p_ref in zip(train_params, ref):
                reg = reg + torch.sum((p_cur - p_ref)**2)
            loss = loss + l2sp_lambda * reg
            loss.backward()
            opt.step()
            apply_max_norm(model, max_value=2.0, p=2.0)

        model.eval()
        with torch.no_grad():
            val_loss = 0.0
            nval = 0
            for xb, yb in dl_va:
                xb, yb = xb.to(DEVICE), yb.to(DEVICE)
                logits = model(xb)
                loss = crit(logits, yb)
                val_loss += loss.item() * xb.size(0)
                nval += xb.size(0)
            val_loss /= max(1, nval)

        if val_loss + 1e-7 < best_val:
            best_val = val_loss
            bad = 0
            best_state = copy.deepcopy(model.state_dict())
        else:
            bad += 1
            if bad >= patience:
                break

    model.load_state_dict(best_state)
    return model

@torch.no_grad()
def predict_numpy(model, X_np, device):
    model.eval()
    xb = torch.from_numpy(X_np).float().unsqueeze(1).to(device)
    logits = model(xb)
    return logits.argmax(dim=1).cpu().numpy()

def subject_cv_finetune_predict_progressive(model_global, Xs, ys, device,
                                            n_splits=CALIB_CV_FOLDS, n_classes=4):
    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=RANDOM_STATE)
    y_true_full = np.empty_like(ys)
    y_pred_full = np.empty_like(ys)

    for tr_idx, te_idx in skf.split(Xs, ys):
        Xcal, ycal = Xs[tr_idx], ys[tr_idx]
        Xho,  yho  = Xs[te_idx], ys[te_idx]

        m_out = copy.deepcopy(model_global)
        _train_one_mode(m_out, Xcal, ycal, n_classes, mode='out',
                        epochs=FT_EPOCHS, head_lr=FT_HEAD_LR, l2sp_lambda=FT_L2SP,
                        patience=FT_PATIENCE, val_ratio=FT_VAL_RATIO)
        yhat_out = predict_numpy(m_out, Xho, device)
        acc_out = (yhat_out == yho).mean()

        m_head = copy.deepcopy(model_global)
        _train_one_mode(m_head, Xcal, ycal, n_classes, mode='head',
                        epochs=FT_EPOCHS, head_lr=FT_HEAD_LR, l2sp_lambda=FT_L2SP,
                        patience=FT_PATIENCE, val_ratio=FT_VAL_RATIO)
        yhat_head = predict_numpy(m_head, Xho, device)
        acc_head = (yhat_head == yho).mean()

        m_sp = copy.deepcopy(model_global)
        _train_one_mode(m_sp, Xcal, ycal, n_classes, mode='spatial+head',
                        epochs=FT_EPOCHS, head_lr=FT_HEAD_LR, base_lr=FT_BASE_LR,
                        l2sp_lambda=FT_L2SP, patience=FT_PATIENCE, val_ratio=FT_VAL_RATIO)
        yhat_sp = predict_numpy(m_sp, Xho, device)
        acc_sp = (yhat_sp == yho).mean()

        best_idx = np.argmax([acc_out, acc_head, acc_sp])
        yhat_best = [yhat_out, yhat_head, yhat_sp][best_idx]

        y_true_full[te_idx] = yho
        y_pred_full[te_idx] = yhat_best

    return y_true_full, y_pred_full

# =========================
# FOLDS JSON helpers
# =========================
def save_group_folds_json_with_indices(subject_ids_str, groups_array, n_splits, out_json_path,
                                       created_by="dose_experiment", description=None):
    out_json_path = Path(out_json_path)
    unique_subjects_int = sorted(np.unique(groups_array).tolist())
    subject_ids = [f"S{sid:03d}" for sid in unique_subjects_int]

    if len(subject_ids) < n_splits:
        raise ValueError(f"n_splits={n_splits} mayor que número de sujetos={len(subject_ids)}")

    groups = np.arange(len(subject_ids))
    gkf = GroupKFold(n_splits=n_splits)

    folds = []
    fold_i = 0
    for train_idx_grp, test_idx_grp in gkf.split(groups, groups, groups):
        fold_i += 1
        train_sids = [subject_ids[int(i)] for i in train_idx_grp]
        test_sids  = [subject_ids[int(i)] for i in test_idx_grp]

        train_sids_int = [int(s[1:]) for s in train_sids]
        test_sids_int  = [int(s[1:]) for s in test_sids]

        tr_idx = np.where(np.isin(groups_array, train_sids_int))[0].tolist()
        te_idx = np.where(np.isin(groups_array, test_sids_int))[0].tolist()

        folds.append({
            "fold": int(fold_i),
            "train": train_sids,
            "test": test_sids,
            "tr_idx": tr_idx,
            "te_idx": te_idx
        })

    payload = {
        "created_at": datetime.now().isoformat(),
        "created_by": created_by,
        "description": description if description is not None else "",
        "n_splits": int(n_splits),
        "n_subjects": len(subject_ids),
        "subject_ids": subject_ids,
        "folds": folds
    }

    out_json_path.parent.mkdir(parents=True, exist_ok=True)
    with open(out_json_path, "w", encoding="utf-8") as f:
        json.dump(payload, f, indent=2, ensure_ascii=False)

    print(f"Folds JSON con índices guardado → {out_json_path}")
    return out_json_path

def load_group_folds_json(path_json, expected_subject_ids=None, strict_check=True):
    path_json = Path(path_json)
    if not path_json.exists():
        raise FileNotFoundError(f"No existe {path_json}")
    with open(path_json, "r", encoding="utf-8") as f:
        payload = json.load(f)

    subj_json = payload.get("subject_ids", [])
    if expected_subject_ids is not None:
        expected = sorted(list(expected_subject_ids))
        if subj_json != expected:
            msg = ("Los subject_ids del JSON no coinciden con expected_subject_ids.\n"
                   f"JSON has {len(subj_json)} subjects, expected {len(expected)}.\n"
                   f"First 10 JSON: {subj_json[:10]}\nFirst 10 expected: {expected[:10]}")
            if strict_check:
                raise ValueError(msg)
            else:
                print("WARNING: " + msg)
    return payload

# =========================
# EXPERIMENTO
# =========================
def run_experiment(save_folds_json=True, folds_json_path=FOLDS_DIR, folds_json_description="GroupKFold folds for comparison"):
    mne.set_log_level('WARNING')

    subs = subjects_available()
    print(f"Sujetos elegibles: {len(subs)} → {subs[:10]}{'...' if len(subs)>10 else ''}")

    X, y, groups, chs = build_dataset_all(subs, scenario=CLASS_SCENARIO, window_mode=WINDOW_MODE)
    N, T, C = X.shape
    n_classes = len(np.unique(y))
    print(f"Listo para entrenar: N={N} | T={T} | C={C} | clases={n_classes} | sujetos={len(np.unique(groups))}")

    ds = EEGTrials(X, y, groups)
    criterion = nn.CrossEntropyLoss()

    # preparar JSON folds
    if folds_json_path is None:
        folds_json_path = Path("folds") / f"group_folds_{N_FOLDS}splits.json"
    else:
        folds_json_path = Path(folds_json_path)
    folds_json_path.parent.mkdir(parents=True, exist_ok=True)

    unique_subs = sorted(np.unique(groups).tolist())
    subject_ids_str = [f"S{s:03d}" for s in unique_subs]

    if not folds_json_path.exists():
        if not save_folds_json:
            raise FileNotFoundError(f"Folds JSON no encontrado en {folds_json_path} y save_folds_json=False.")
        save_group_folds_json_with_indices(subject_ids_str, groups, n_splits=N_FOLDS,
                                           out_json_path=folds_json_path,
                                           created_by="Joel_Clasificador",
                                           description=folds_json_description)

    payload = load_group_folds_json(folds_json_path, expected_subject_ids=subject_ids_str, strict_check=False)
    folds = payload["folds"]

    # bucle por folds
    global_folds = []
    ft_prog_folds = []
    all_true = []
    all_pred = []

    for f in folds:
        fold = f["fold"]
        tr_idx = np.asarray(f.get("tr_idx", []), dtype=int)
        te_idx = np.asarray(f.get("te_idx", []), dtype=int)

        if tr_idx.size == 0 or te_idx.size == 0:
            print(f"Advertencia: fold {fold} sin índices tr/te válidos. Saltando.")
            continue

        # ===== Split de validación por sujetos =====
        gss = GroupShuffleSplit(n_splits=1, test_size=GLOBAL_VAL_SPLIT, random_state=RANDOM_STATE)
        tr_subj_idx, va_subj_idx = next(gss.split(tr_idx, groups[tr_idx], groups[tr_idx]))
        tr_sub_idx = tr_idx[tr_subj_idx]
        va_idx     = tr_idx[va_subj_idx]

        tr_loader = DataLoader(Subset(ds, tr_sub_idx), batch_size=BATCH_SIZE, shuffle=True,  drop_last=False)
        va_loader = DataLoader(Subset(ds, va_idx),     batch_size=BATCH_SIZE, shuffle=False, drop_last=False)
        te_loader = DataLoader(Subset(ds, te_idx),     batch_size=BATCH_SIZE, shuffle=False, drop_last=False)

        # ===== EEGNet =====
        model = EEGNet(n_ch=C, n_classes=n_classes, F1=8, D=2, kernel_t=64, k_sep=16,
                       pool1_t=4, pool2_t=8, drop1_p=0.25, drop2_p=0.5).to(DEVICE)

        # Adam sin weight decay + scheduler tipo paper
        opt = optim.Adam(model.parameters(), lr=LR_INIT)
        scheduler = optim.lr_scheduler.MultiStepLR(opt, milestones=SCHED_MILESTONES, gamma=SCHED_GAMMA)

        def _acc(loader):
            return evaluate_with_preds(model, loader)[2]

        # ===== Entrenamiento global con ES por val_acc + max-norm =====
        print(f"\n[Fold {fold}/{N_FOLDS}] Entrenando modelo global..."
              f" (n_train={len(tr_sub_idx)} | n_val={len(va_idx)} | n_test={len(te_idx)})")

        best_state = copy.deepcopy(model.state_dict())
        best_val = -1.0
        bad = 0

        for epoch in range(1, EPOCHS_GLOBAL + 1):
            train_epoch(model, tr_loader, opt, criterion, maxnorm=2.0)
            scheduler.step()

            if epoch % LOG_EVERY == 0 or epoch in (1, 20, 50, 100):
                tr_acc = _acc(tr_loader)
                va_acc = _acc(va_loader)
                print(f"  Época {epoch:3d} | train_acc={tr_acc:.4f} | val_acc={va_acc:.4f} | LR={opt.param_groups[0]['lr']:.5f}")

                if va_acc > best_val + 1e-4:
                    best_val = va_acc
                    best_state = copy.deepcopy(model.state_dict())
                    bad = 0
                else:
                    bad += 1
                    if bad >= GLOBAL_PATIENCE:
                        print(f"  Early stopping en época {epoch} (mejor val_acc={best_val:.4f})")
                        break

        # cargar mejor estado antes de evaluar en test
        model.load_state_dict(best_state)

        # ===== Evaluación global (inter-sujeto puro) =====
        y_true, y_pred, acc_global = evaluate_with_preds(model, te_loader)
        global_folds.append(acc_global)
        all_true.append(y_true); all_pred.append(y_pred)

        print(f"[Fold {fold}/{N_FOLDS}] Global acc={acc_global:.4f}")
        print_report(y_true, y_pred, CLASS_NAMES_4C)

        # ---------- Fine-tuning PROGRESIVO por sujeto con 4-fold CV ----------
        X_te, y_te, g_te = X[te_idx], y[te_idx], groups[te_idx]

        y_true_ft_all, y_pred_ft_all = [], []
        used_subjects = 0
        for sid in np.unique(g_te):
            idx = np.where(g_te == sid)[0]
            Xs, ys = X_te[idx], y_te[idx]

            if len(ys) < CALIB_CV_FOLDS or len(np.unique(ys)) < 2:
                continue

            y_true_subj, y_pred_subj = subject_cv_finetune_predict_progressive(
                model, Xs, ys, DEVICE, n_splits=CALIB_CV_FOLDS, n_classes=n_classes
            )
            y_true_ft_all.append(y_true_subj)
            y_pred_ft_all.append(y_pred_subj)
            used_subjects += 1

        if len(y_true_ft_all) > 0:
            y_true_ft_all = np.concatenate(y_true_ft_all)
            y_pred_ft_all = np.concatenate(y_pred_ft_all)
            acc_ft = (y_true_ft_all == y_pred_ft_all).mean()
            print(f"  Fine-tuning PROGRESIVO (por sujeto, {CALIB_CV_FOLDS}-fold CV) acc={acc_ft:.4f} | sujetos={used_subjects}")
            print(f"  Δ(FT-Global) = {acc_ft - acc_global:+.4f}")
        else:
            acc_ft = np.nan
            print("  Fine-tuning PROGRESIVO no ejecutado (sujeto(s) con muestras insuficientes).")

        ft_prog_folds.append(acc_ft)

    # ---------- resultados finales ----------
    if len(all_true) > 0:
        all_true = np.concatenate(all_true)
        all_pred = np.concatenate(all_pred)
    else:
        all_true = np.array([], dtype=int)
        all_pred = np.array([], dtype=int)

    print("\n" + "="*60)
    print("RESULTADOS FINALES")
    print("="*60)
    print("Global folds:", [f"{a:.4f}" for a in global_folds])
    if len(global_folds) > 0:
        print(f"Global mean: {np.mean(global_folds):.4f}")

    print("Fine-tune PROGRESIVO folds:", [("nan" if (a is None or np.isnan(a)) else f"{a:.4f}") for a in ft_prog_folds])
    if len(ft_prog_folds) > 0:
        print(f"Fine-tune PROGRESIVO mean: {np.nanmean(ft_prog_folds):.4f}")
        print(f"Δ(FT-Global) mean: {np.nanmean(ft_prog_folds) - np.mean(global_folds):+.4f}")

    if all_true.size > 0:
        plot_confusion(all_true, all_pred, CLASS_NAMES_4C,
                       title="Confusion Matrix - Global Model (All Folds)",
                       fname="confusion_global_allfolds.png")
        print("\n↳ Matriz de confusión guardada: confusion_global_allfolds.png")

    return {
        "global_folds": global_folds,
        "ft_prog_folds": ft_prog_folds,
        "all_true": all_true,
        "all_pred": all_pred,
        "folds_json_path": str(folds_json_path)
    }

# ---------- MAIN ----------
if __name__ == "__main__":
    print("🧠 INICIANDO EXPERIMENTO CON EEGNet + FINE-TUNING PROGRESIVO (por sujeto, 4-fold CV)")
    print(f"🔧 Configuración: {CLASS_SCENARIO}, {len(EXPECTED_8)} canales, {WINDOW_MODE}")
    print(f"⚙️  FT: epochs={FT_EPOCHS}, base_lr={FT_BASE_LR}, head_lr={FT_HEAD_LR}, L2SP={FT_L2SP}, patience={FT_PATIENCE}, CV={CALIB_CV_FOLDS}")
    run_experiment()


🚀 Usando dispositivo: cuda
🧠 INICIANDO EXPERIMENTO CON EEGNet + FINE-TUNING PROGRESIVO (por sujeto, 4-fold CV)
🔧 Configuración: 4c, 8 canales, 6s
⚙️  FT: epochs=30, base_lr=5e-05, head_lr=0.001, L2SP=0.0001, patience=5, CV=4
Sujetos elegibles: 103 → [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]...


Construyendo dataset (RAW):   0%|          | 0/103 [00:00<?, ?it/s]

Construyendo dataset (RAW): 100%|██████████| 103/103 [01:14<00:00,  1.39it/s]


Dataset construido: N=8652 | T=960 | C=8 | clases=4 | sujetos únicos=103
Listo para entrenar: N=8652 | T=960 | C=8 | clases=4 | sujetos=103

[Fold 1/5] Entrenando modelo global... (n_train=5796 | n_val=1092 | n_test=1764)
  Época   1 | train_acc=0.4275 | val_acc=0.4084 | LR=0.01000
  Época   5 | train_acc=0.4463 | val_acc=0.4112 | LR=0.01000
  Época  10 | train_acc=0.4940 | val_acc=0.4533 | LR=0.01000
  Época  15 | train_acc=0.5059 | val_acc=0.4789 | LR=0.01000
  Época  20 | train_acc=0.5059 | val_acc=0.4478 | LR=0.00100
  Época  25 | train_acc=0.5236 | val_acc=0.4762 | LR=0.00100
  Época  30 | train_acc=0.5254 | val_acc=0.4853 | LR=0.00100
  Época  35 | train_acc=0.5162 | val_acc=0.4615 | LR=0.00100
  Época  40 | train_acc=0.5264 | val_acc=0.4789 | LR=0.00100
  Época  45 | train_acc=0.5267 | val_acc=0.4716 | LR=0.00100
  Época  50 | train_acc=0.5317 | val_acc=0.4771 | LR=0.00010
  Época  55 | train_acc=0.5283 | val_acc=0.4881 | LR=0.00010
  Época  60 | train_acc=0.5202 | val_acc=0.466

In [None]:
# -*- coding: utf-8 -*-
# EEGNet + protocolo global + fine-tuning progresivo (mejoras BFISTS)
# Cambios: CE ponderada + label smoothing, cutout focalizado para OF,
# F1=16 (más capacidad), paciencia=12, SGDR, jitter+noise, curvas de entrenamiento,
# notch adaptativo según SNR y sampler balanceado sujeto+clase.

import os, re, math, random, json, itertools, copy
from pathlib import Path
from datetime import datetime

import numpy as np
import pandas as pd
import mne

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset, WeightedRandomSampler

from sklearn.model_selection import GroupKFold, StratifiedKFold, StratifiedShuffleSplit, GroupShuffleSplit
from sklearn.utils import check_random_state
from sklearn.metrics import confusion_matrix, classification_report

import matplotlib.pyplot as plt
from tqdm import tqdm

# =========================
# CONFIGURACIÓN GENERAL
# =========================
PROJ = Path('..').resolve().parent
DATA_RAW = PROJ / 'data' / 'raw'
CACHE_DIR = PROJ / 'data' / 'cache'
FOLDS_DIR = PROJ / 'models' / 'folds' / 'Kfold5.json'
CACHE_DIR.mkdir(parents=True, exist_ok=True)

# Dispositivo y semilla
RANDOM_STATE = 42
torch.manual_seed(RANDOM_STATE); np.random.seed(RANDOM_STATE); random.seed(RANDOM_STATE)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🚀 Usando dispositivo: {DEVICE}")

# Escenario y ventana
CLASS_SCENARIO = '4c'
WINDOW_MODE = '6s'   # 6s como acordamos
FS = 160.0
N_FOLDS = 5

# Entrenamiento global
BATCH_SIZE = 32
EPOCHS_GLOBAL = 100
LR_INIT = 1e-2
SGDR_T0 = 10
SGDR_Tmult = 2

# Validación/ES global
GLOBAL_VAL_SPLIT = 0.15
GLOBAL_PATIENCE  = 12   # ← más laxo
LOG_EVERY        = 5

# Fine-tuning por sujeto
CALIB_CV_FOLDS = 4
FT_EPOCHS = 30
FT_BASE_LR = 5e-5
FT_HEAD_LR = 1e-3
FT_L2SP = 1e-4
FT_PATIENCE = 5
FT_VAL_RATIO = 0.2

# Normalización por época canal-a-canal (z-score)
NORM_EPOCH_ZSCORE = True

# Sujetos excluidos
EXCLUDE_SUBJECTS = {38, 88, 89, 92, 100, 104}

# Runs
MI_RUNS_LR = [4, 8, 12]
MI_RUNS_OF = [6, 10, 14]
BASELINE_RUNS_EO = [1]

# Canales (8 con FCz)
EXPECTED_8 = ['C3','C4','Cz','CP3','CP4','FC3','FC4','FCz']

# =========================
# UTILIDADES DE CANALES
# =========================
def normalize_label(s: str) -> str:
    if s is None: return s
    s = s.strip()
    s = re.sub(r'[^A-Za-z0-9]', '', s)
    s = re.sub(r'([A-Za-z])0([0-9])', r'\1\2', s)
    s = re.sub(r'([A-Za-z])Z$', r'\1z', s)
    s = s.replace('fp', 'Fp').replace('FP', 'Fp')
    s = ''.join(ch.upper() if ch != 'z' else 'z' for ch in s)
    return s

def rename_channels_1010(raw: mne.io.BaseRaw):
    mapping = {}
    for ch in raw.ch_names:
        lab = normalize_label(ch)
        lab = lab[:-1] + 'z' if lab.endswith('Z') else lab
        lab = re.sub(r'([A-Z])Z$', r'\1z', lab)
        mapping[ch] = lab
    mne.rename_channels(raw.info, mapping)

def ensure_channels_order(raw: mne.io.BaseRaw, desired_channels=EXPECTED_8):
    missing = [ch for ch in desired_channels if ch not in raw.ch_names]
    if missing:
        print(f"Warning: faltan canales {missing} en archivo {getattr(raw,'filenames', [''])[0]}")
        return None
    raw.reorder_channels([ch for ch in raw.ch_names if ch in desired_channels] +
                         [ch for ch in raw.ch_names if ch not in desired_channels])
    raw.pick_channels(desired_channels, ordered=True)
    return raw

# =========================
# LECTURA DE EDF y EVENTOS
# =========================
_re_file = re.compile(r'[Ss](\d{3}).*?[Rr](\d{2})')

def parse_subject_run(path: Path):
    m = _re_file.search(str(path))
    if not m: return None, None
    return int(m.group(1)), int(m.group(2))

def run_kind(run_id:int):
    if run_id in MI_RUNS_LR: return 'LR'
    if run_id in MI_RUNS_OF: return 'OF'
    if run_id in BASELINE_RUNS_EO: return 'EO'
    return None

# --- Notch adaptativo: lee CSV si existe (subject, run, snr50_db, snr60_db)
_SNR_TABLE = None
def _load_snr_table():
    global _SNR_TABLE
    if _SNR_TABLE is not None:
        return _SNR_TABLE
    csv_path = PROJ / 'reports' / 'psd_mains' / 'psd_mains_summary.csv'
    if csv_path.exists():
        try:
            df = pd.read_csv(csv_path)
            _SNR_TABLE = df
        except Exception as e:
            print(f"[SNR] No se pudo leer {csv_path}: {e}")
            _SNR_TABLE = None
    return _SNR_TABLE

def _decide_notch(subject, run, th_db=10.0):
    df = _load_snr_table()
    if df is None:  # default
        return 60.0
    row = df[(df['subject']==subject) & (df['run']==run)]
    if row.empty:
        return 60.0
    snr50 = float(row['snr50_db'].iloc[0])
    snr60 = float(row['snr60_db'].iloc[0])
    if snr60 >= th_db and snr60 >= snr50: return 60.0
    if snr50 >= th_db and snr50 >  snr60: return 50.0
    return None  # no notch si no sobresale

def read_raw_edf(path: Path):
    raw = mne.io.read_raw_edf(path, preload=True, verbose=False)
    raw.pick(mne.pick_types(raw.info, eeg=True))
    rename_channels_1010(raw)
    try:
        mont = mne.channels.make_standard_montage('standard_1020')
        raw.set_montage(mont, on_missing='ignore')
    except Exception:
        pass
    if abs(raw.info['sfreq'] - FS) > 1e-6:
        raw.resample(FS, npad="auto")
    raw = ensure_channels_order(raw, EXPECTED_8)
    if raw is None:
        return None

    # Notch adaptativo (por SNR); si no hay CSV, por defecto 60 Hz
    sid, rid = parse_subject_run(path)
    notch = _decide_notch(sid, rid)
    if notch is not None:
        raw.notch_filter(freqs=[float(notch)], picks='eeg', method='spectrum_fit', phase='zero')

    return raw

def collect_events_T1T2(raw: mne.io.BaseRaw):
    if raw.annotations is None or len(raw.annotations) == 0:
        return []
    def _norm(s): return str(s).strip().upper().replace(' ', '')
    res = []
    for onset, desc in zip(raw.annotations.onset, raw.annotations.description):
        tag = _norm(desc)
        if tag in ('T1','T2'):
            res.append((float(onset), tag))
    res.sort()
    dedup = []
    last_t1 = last_t2 = -1e9
    for t, tag in res:
        if tag == 'T1':
            if (t - last_t1) >= 0.5: dedup.append((t, tag)); last_t1 = t
        else:
            if (t - last_t2) >= 0.5: dedup.append((t, tag)); last_t2 = t
    return dedup

# =========================
# CONSTRUCCIÓN DE DATASETS
# =========================
def subjects_available():
    subs = []
    for sdir in sorted(DATA_RAW.glob('S*')):
        if not sdir.is_dir(): continue
        try:
            sid = int(sdir.name[1:])
        except: continue
        if sid in EXCLUDE_SUBJECTS: continue
        any_mi = any((sdir / f"S{sid:03d}R{r:02d}.edf").exists() for r in (MI_RUNS_LR + MI_RUNS_OF))
        if any_mi: subs.append(sid)
    return subs

def extract_trials_from_run(edf_path: Path, scenario: str, window_mode: str):
    subj, run = parse_subject_run(edf_path)
    kind = run_kind(run)
    if kind not in ('LR','OF','EO'):
        return ([], [])

    raw = read_raw_edf(edf_path)
    if raw is None:
        return ([], [])

    data = raw.get_data()
    fs = raw.info['sfreq']
    assert abs(fs - FS) < 1e-6

    out = []

    if kind in ('LR','OF'):
        events = collect_events_T1T2(raw)
        if window_mode == '3s':
            rel_start, rel_end = 0.0, 3.0
        else:
            rel_start, rel_end = -1.0, 5.0

        for onset_sec, tag in events:
            if kind == 'LR':
                if tag == 'T1': label = 'L'
                elif tag == 'T2': label = 'R'
                else: continue
            else:
                if tag == 'T1': label = 'BFISTS'
                elif tag == 'T2': label = 'BFEET'
                else: continue

            if scenario == '2c' and label not in ('L','R'): continue
            if scenario == '3c' and label not in ('L','R','BFISTS'): continue
            if scenario == '4c' and label not in ('L','R','BFISTS','BFEET'): continue

            s = int(round((raw.first_time + onset_sec + rel_start) * fs))
            e = int(round((raw.first_time + onset_sec + rel_end) * fs))
            if s < 0 or e > data.shape[1]:
                continue

            seg = data[:, s:e].T.astype(np.float32)
            if NORM_EPOCH_ZSCORE:
                seg = (seg - seg.mean(axis=0, keepdims=True)) / (seg.std(axis=0, keepdims=True) + 1e-6)

            if label == 'L':       y = 0
            elif label == 'R':     y = 1
            elif label == 'BFISTS':y = 2
            elif label == 'BFEET': y = 3
            else: continue

            out.append((seg, y, subj))

    elif kind == 'EO':
        return ([], raw.ch_names)

    return out, raw.ch_names

def build_dataset_all(subjects, scenario='4c', window_mode='3s'):
    X, y, groups = [], [], []
    ch_template = None

    for s in tqdm(subjects, desc="Construyendo dataset (RAW)"):
        sdir = DATA_RAW / f"S{s:03d}"
        if not sdir.exists(): continue

        trials_L, trials_R, trials_FISTS, trials_FEET = [], [], [], []

        for r in MI_RUNS_LR:
            p = sdir / f"S{s:03d}R{r:02d}.edf"
            if not p.exists(): continue
            outs, chs = extract_trials_from_run(p, scenario, window_mode)
            if ch_template is None and chs: ch_template = chs
            for seg, lab, _ in outs:
                if lab == 0: trials_L.append(seg)
                elif lab == 1: trials_R.append(seg)

        for r in MI_RUNS_OF:
            p = sdir / f"S{s:03d}R{r:02d}.edf"
            if not p.exists(): continue
            outs, chs = extract_trials_from_run(p, scenario, window_mode)
            if ch_template is None and chs: ch_template = chs
            for seg, lab, _ in outs:
                if lab == 2: trials_FISTS.append(seg)
                elif lab == 3: trials_FEET.append(seg)

        need_per_class = 21
        def pick(trials, n, rng):
            if len(trials) < n:
                idx = rng.choice(len(trials), size=n, replace=True)
                return [trials[i] for i in idx]
            rng.shuffle(trials)
            return trials[:n]

        rng = check_random_state(RANDOM_STATE + s)
        if len(trials_L)==0 or len(trials_R)==0 or len(trials_FISTS)==0 or len(trials_FEET)==0:
            continue

        Lp  = pick(trials_L,     need_per_class, rng)
        Rp  = pick(trials_R,     need_per_class, rng)
        FIp = pick(trials_FISTS, need_per_class, rng)
        FEp = pick(trials_FEET,  need_per_class, rng)

        pack = [(Lp, 0), (Rp, 1), (FIp, 2), (FEp, 3)]
        for segs, lab in pack:
            for seg in segs:
                X.append(seg); y.append(lab); groups.append(s)

    X = np.stack(X, axis=0)
    y = np.asarray(y, dtype=np.int64)
    groups = np.asarray(groups, dtype=np.int64)

    n, T, C = X.shape
    n_classes = len(np.unique(y))
    print(f"Dataset construido: N={n} | T={T} | C={C} | clases={n_classes} | sujetos únicos={len(np.unique(groups))}")
    return X, y, groups, ch_template

# =========================
# AUGMENTS (solo train)
# =========================
def do_time_jitter(x, max_ms=50, fs=160.0):
    # x: (B,1,T,C) torch.float
    max_shift = int(round(max_ms/1000.0 * fs))
    if max_shift <= 0: return x
    B,_,T,C = x.shape
    shifts = torch.randint(low=-max_shift, high=max_shift+1, size=(B,), device=x.device)
    out = torch.empty_like(x)
    for i,s in enumerate(shifts):
        if s==0: out[i] = x[i]; continue
        if s>0:
            out[i,:,s:,:] = x[i,:,:T-s,:]
            out[i,:,:s,:] = 0
        else:
            s = -s
            out[i,:,:T-s,:] = x[i,:,s:,:]
            out[i,:,T-s:,:] = 0
    return out

def do_gaussian_noise(x, sigma=0.01):
    if sigma<=0: return x
    return x + sigma*torch.randn_like(x)

def do_temporal_cutout_masked(x, y, classes_mask={2,3}, min_ms=20, max_ms=40, fs=160.0):
    # Aplica cutout corto SOLO a clases en classes_mask (OF: BFISTS=2, BFEET=3)
    B,_,T,C = x.shape
    Lmin = int(round(min_ms/1000.0*fs))
    Lmax = int(round(max_ms/1000.0*fs))
    if Lmin<=0 or Lmax<=0 or Lmin>Lmax: return x
    out = x.clone()
    for i in range(B):
        if int(y[i].item()) in classes_mask:
            L = random.randint(Lmin, Lmax)
            if L < T:
                s = random.randint(0, T-L)
                out[i,:,s:s+L,:] = 0
    return out

# =========================
# EEGNet (Lawhern) (B,1,T,C)
# =========================
class EEGNet(nn.Module):
    """
    Entrada: x de forma (B, 1, T, C)
    """
    def __init__(self, n_ch: int, n_classes: int,
                 F1: int = 16, D: int = 2, kernel_t: int = 64, k_sep: int = 16,
                 pool1_t: int = 4, pool2_t: int = 8,
                 drop1_p: float = 0.25, drop2_p: float = 0.5):
        super().__init__()
        self.n_ch = n_ch
        self.n_classes = n_classes
        self.F1 = F1
        self.D = D
        self.F2 = F1 * D
        self.kernel_t = kernel_t
        self.k_sep = k_sep
        self.pool1_t = pool1_t
        self.pool2_t = pool2_t

        # Bloque 1: temporal
        self.conv_temporal = nn.Conv2d(1, F1, kernel_size=(kernel_t, 1),
                                       padding=(kernel_t // 2, 0), bias=False)
        self.bn1 = nn.BatchNorm2d(F1, momentum=0.99, eps=1e-3)
        self.act = nn.ELU()

        # Bloque 2: depthwise (espacial)
        self.conv_depthwise = nn.Conv2d(F1, self.F2, kernel_size=(1, n_ch),
                                        groups=F1, bias=False)
        self.bn2 = nn.BatchNorm2d(self.F2, momentum=0.99, eps=1e-3)
        self.pool1 = nn.AvgPool2d(kernel_size=(pool1_t, 1), stride=(pool1_t, 1))
        self.drop1 = nn.Dropout(drop1_p)

        # Bloque 3: separable temporal
        self.conv_sep_depth = nn.Conv2d(self.F2, self.F2, kernel_size=(k_sep, 1),
                                        groups=self.F2, padding=(k_sep // 2, 0), bias=False)
        self.conv_sep_point = nn.Conv2d(self.F2, self.F2, kernel_size=(1, 1), bias=False)
        self.bn3 = nn.BatchNorm2d(self.F2, momentum=0.99, eps=1e-3)
        self.pool2 = nn.AvgPool2d(kernel_size=(pool2_t, 1), stride=(pool2_t, 1))
        self.drop2 = nn.Dropout(drop2_p)

        self.flatten = nn.Flatten()

        # Cabeza dinámica
        self.fc = None
        self.out = None
        self._T_in = None

    def _build_head(self, T_in: int, device: torch.device):
        T1 = T_in // self.pool1_t
        T2 = T1 // self.pool2_t
        feat_dim = self.F2 * T2 * 1
        self.fc = nn.Linear(feat_dim, 80, bias=True).to(device)
        self.out = nn.Linear(80, self.n_classes, bias=True).to(device)
        self._T_in = T_in

    def ensure_head(self, T_in: int, device: torch.device):
        if (self.fc is None) or (self.out is None) or (self._T_in != T_in):
            self._build_head(T_in, device)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, _, T, C = x.shape
        self.ensure_head(T, x.device)

        z = self.conv_temporal(x)
        z = self.bn1(z); z = self.act(z)

        z = self.conv_depthwise(z)   # (B, F2, T, 1)
        z = self.bn2(z); z = self.act(z)
        z = self.pool1(z)
        z = self.drop1(z)

        z = self.conv_sep_depth(z)
        z = self.conv_sep_point(z)
        z = self.bn3(z); z = self.act(z)
        z = self.pool2(z)
        z = self.drop2(z)

        z = self.flatten(z)
        z = self.fc(z); z = self.act(z)
        z = self.out(z)
        return z

# =========================
# TORCH DATASET
# =========================
class EEGTrials(Dataset):
    def __init__(self, X, y, groups):
        self.X = X.astype(np.float32)
        self.y = y.astype(np.int64)
        self.g = groups.astype(np.int64)
    def __len__(self): return self.X.shape[0]
    def __getitem__(self, idx):
        x = self.X[idx]
        x = np.expand_dims(x, 0)                 # (1, T, C)
        return torch.from_numpy(x), torch.tensor(self.y[idx]), torch.tensor(self.g[idx])

CLASS_NAMES_4C = ['Left', 'Right', 'Both Fists', 'Both Feet']

# =========================
# UTIL: MAX-NORM
# =========================
@torch.no_grad()
def apply_max_norm(model, max_value=2.0, p=2.0):
    layers = []
    if hasattr(model, 'conv_depthwise'): layers.append(model.conv_depthwise)
    if hasattr(model, 'conv_sep_point'): layers.append(model.conv_sep_point)
    if hasattr(model, 'fc'):             layers.append(model.fc)
    if hasattr(model, 'out'):            layers.append(model.out)
    for layer in layers:
        if hasattr(layer, 'weight') and layer.weight is not None:
            w = layer.weight.data
            norms = w.view(w.size(0), -1).norm(p=p, dim=1, keepdim=True)
            desired = torch.clamp(norms, max=max_value)
            w.view(w.size(0), -1).mul_(desired / (1e-8 + norms))

# =========================
# ENTRENAMIENTO / EVALUACIÓN
# =========================
def build_weighted_sampler(y, groups):
    # pesos inversos por clase y por sujeto → balance en cada batch
    y = np.asarray(y); groups = np.asarray(groups)
    class_counts = np.bincount(y, minlength=len(np.unique(y))).astype(float)
    class_w = 1.0 / class_counts[y]
    subj_vals, subj_counts = np.unique(groups, return_counts=True)
    subj_map = {s:c for s,c in zip(subj_vals, subj_counts)}
    subj_w = np.array([1.0/subj_map[g] for g in groups], dtype=float)
    w = class_w * subj_w
    w = w / w.mean()
    w_t = torch.from_numpy(w).float()
    sampler = WeightedRandomSampler(weights=w_t, num_samples=len(w_t), replacement=True)
    return sampler

def make_class_weighted_criterion(y_indices, n_classes):
    from collections import Counter
    cnt = Counter(y_indices.tolist())
    total = sum(cnt.values())
    weights = torch.tensor(
        [total / max(1, cnt.get(c, 0)) for c in range(n_classes)],
        dtype=torch.float32, device=DEVICE
    )
    weights = weights / weights.mean()
    # CE con smoothing
    return nn.CrossEntropyLoss(weight=weights, label_smoothing=0.05)

def train_epoch(model, loader, opt, criterion, do_aug=True, fs=160.0, maxnorm=None):
    model.train()
    for xb, yb, _ in loader:
        xb, yb = xb.to(DEVICE), yb.to(DEVICE)

        # ---- AUGMENTS sin mixup ----
        if do_aug:
            xb = do_time_jitter(xb, max_ms=50, fs=fs)
            xb = do_gaussian_noise(xb, sigma=0.01)
            xb = do_temporal_cutout_masked(xb, yb, classes_mask={2,3}, min_ms=20, max_ms=40, fs=fs)

        logits = model(xb)
        loss = criterion(logits, yb)

        opt.zero_grad(set_to_none=True)
        loss.backward()
        opt.step()
        if maxnorm is not None:
            apply_max_norm(model, max_value=maxnorm, p=2.0)

@torch.no_grad()
def evaluate_with_preds(model, loader):
    model.eval()
    y_true, y_pred = [], []
    for xb, yb, _ in loader:
        xb = xb.to(DEVICE)
        logits = model(xb)
        pred = logits.argmax(dim=1).cpu().numpy().tolist()
        y_pred.extend(pred)
        y_true.extend(yb.numpy().tolist())
    y_true = np.asarray(y_true, dtype=int)
    y_pred = np.asarray(y_pred, dtype=int)
    acc = (y_true == y_pred).mean()
    return y_true, y_pred, float(acc)

def plot_confusion(y_true, y_pred, classes, title, fname):
    cm = confusion_matrix(y_true, y_pred, labels=list(range(len(classes))))
    with np.errstate(invalid='ignore'):
        cm_norm = cm.astype('float') / cm.sum(axis=1, keepdims=True)
    cm_norm = np.nan_to_num(cm_norm)

    plt.figure(figsize=(6, 5))
    plt.imshow(cm_norm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title(title)
    plt.colorbar(fraction=0.046, pad=0.04)
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45, ha='right')
    plt.yticks(tick_marks, classes)

    fmt = '.2f'
    thresh = cm_norm.max() / 2.
    for i, j in itertools.product(range(cm_norm.shape[0]), range(cm_norm.shape[1])):
        plt.text(j, i, format(cm_norm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm_norm[i, j] > thresh else "black")

    plt.ylabel('True label'); plt.xlabel('Predicted label')
    plt.tight_layout(); plt.savefig(fname, dpi=150, bbox_inches='tight'); plt.close()

def plot_training_curves(history, fname):
    plt.figure(figsize=(6,4))
    plt.plot(history['train_acc'], label='train_acc')
    plt.plot(history['val_acc'], label='val_acc')
    plt.xlabel('Epoch'); plt.ylabel('Accuracy')
    plt.title('Training curve'); plt.legend()
    plt.tight_layout(); plt.savefig(fname, dpi=150); plt.close()

def print_report(y_true, y_pred, classes):
    print(classification_report(y_true, y_pred, target_names=classes, digits=4))

# =========================
# FINE-TUNING PROGRESIVO por sujeto (con validación interna)
# =========================
def _param_groups(model, mode):
    if mode == 'out':
        train = list(model.out.parameters())
    elif mode == 'head':
        train = list(model.fc.parameters()) + list(model.out.parameters())
    elif mode == 'spatial+head':
        train = (list(model.conv_depthwise.parameters()) +
                 list(model.bn2.parameters()) +
                 list(model.conv_sep_depth.parameters()) +
                 list(model.conv_sep_point.parameters()) +
                 list(model.bn3.parameters()) +
                 list(model.fc.parameters()) +
                 list(model.out.parameters()))
    else:
        raise ValueError(mode)
    return train

def _freeze_for_mode(model, mode):
    for p in model.parameters(): p.requires_grad = False
    if mode == 'out':
        for p in model.out.parameters(): p.requires_grad = True
    elif mode == 'head':
        for p in model.fc.parameters():  p.requires_grad = True
        for p in model.out.parameters(): p.requires_grad = True
    elif mode == 'spatial+head':
        for p in model.conv_depthwise.parameters(): p.requires_grad = True
        for p in model.bn2.parameters():           p.requires_grad = True
        for p in model.conv_sep_depth.parameters():p.requires_grad = True
        for p in model.conv_sep_point.parameters():p.requires_grad = True
        for p in model.bn3.parameters():           p.requires_grad = True
        for p in model.fc.parameters():            p.requires_grad = True
        for p in model.out.parameters():           p.requires_grad = True

def _class_weights(y_np, n_classes):
    counts = np.bincount(y_np, minlength=n_classes).astype(np.float32)
    counts[counts == 0] = 1.0
    weights = counts.sum() / counts
    weights = weights / weights.mean()
    return torch.tensor(weights, dtype=torch.float32, device=DEVICE)

def _train_one_mode(model, X_cal, y_cal, n_classes, mode,
                    epochs=FT_EPOCHS, batch_size=32,
                    head_lr=FT_HEAD_LR, base_lr=FT_BASE_LR,
                    l2sp_lambda=FT_L2SP, patience=FT_PATIENCE, val_ratio=FT_VAL_RATIO):
    sss = StratifiedShuffleSplit(n_splits=1, test_size=val_ratio, random_state=RANDOM_STATE)
    (tr_idx, va_idx), = sss.split(X_cal, y_cal)
    Xtr, ytr = X_cal[tr_idx], y_cal[tr_idx]
    Xva, yva = X_cal[va_idx], y_cal[va_idx]

    ds_tr = torch.utils.data.TensorDataset(
        torch.from_numpy(Xtr).float().unsqueeze(1),
        torch.from_numpy(ytr).long()
    )
    ds_va = torch.utils.data.TensorDataset(
        torch.from_numpy(Xva).float().unsqueeze(1),
        torch.from_numpy(yva).long()
    )
    dl_tr = DataLoader(ds_tr, batch_size=batch_size, shuffle=True, drop_last=False)
    dl_va = DataLoader(ds_va, batch_size=batch_size, shuffle=False, drop_last=False)

    _freeze_for_mode(model, mode)

    if mode == 'spatial+head':
        base_params = (list(model.conv_depthwise.parameters()) +
                       list(model.bn2.parameters()) +
                       list(model.conv_sep_depth.parameters()) +
                       list(model.conv_sep_point.parameters()) +
                       list(model.bn3.parameters()))
        head_params = list(model.fc.parameters()) + list(model.out.parameters())
        train_params = base_params + head_params
        opt = optim.Adam([
            {"params": base_params, "lr": base_lr},
            {"params": head_params, "lr": head_lr},
        ])
    else:
        train_params = _param_groups(model, mode)
        opt = optim.Adam(train_params, lr=head_lr)

    ref = [p.detach().clone().to(p.device) for p in train_params]
    class_w = _class_weights(ytr, n_classes)
    crit = nn.CrossEntropyLoss(weight=class_w, label_smoothing=0.05)

    best_state = copy.deepcopy(model.state_dict())
    best_val = float('inf'); bad = 0

    for _ in range(epochs):
        model.train()
        for xb, yb in dl_tr:
            xb, yb = xb.to(DEVICE), yb.to(DEVICE)
            opt.zero_grad(set_to_none=True)
            logits = model(xb)
            loss = crit(logits, yb)
            reg = 0.0
            for p_cur, p_ref in zip(train_params, ref):
                reg = reg + torch.sum((p_cur - p_ref)**2)
            loss = loss + l2sp_lambda * reg
            loss.backward(); opt.step()
            apply_max_norm(model, max_value=2.0, p=2.0)

        model.eval()
        with torch.no_grad():
            val_loss = 0.0; nval = 0
            for xb, yb in dl_va:
                xb, yb = xb.to(DEVICE), yb.to(DEVICE)
                logits = model(xb)
                loss = crit(logits, yb)
                val_loss += loss.item() * xb.size(0); nval += xb.size(0)
            val_loss /= max(1, nval)

        if val_loss + 1e-7 < best_val:
            best_val = val_loss; bad = 0
            best_state = copy.deepcopy(model.state_dict())
        else:
            bad += 1
            if bad >= patience: break

    model.load_state_dict(best_state)
    return model

@torch.no_grad()
def predict_numpy(model, X_np, device):
    model.eval()
    xb = torch.from_numpy(X_np).float().unsqueeze(1).to(device)
    logits = model(xb)
    return logits.argmax(dim=1).cpu().numpy()

def subject_cv_finetune_predict_progressive(model_global, Xs, ys, device,
                                            n_splits=CALIB_CV_FOLDS, n_classes=4):
    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=RANDOM_STATE)
    y_true_full = np.empty_like(ys); y_pred_full = np.empty_like(ys)

    for tr_idx, te_idx in skf.split(Xs, ys):
        Xcal, ycal = Xs[tr_idx], ys[tr_idx]
        Xho,  yho  = Xs[te_idx], ys[te_idx]

        m_out = copy.deepcopy(model_global)
        _train_one_mode(m_out, Xcal, ycal, n_classes, mode='out',
                        epochs=FT_EPOCHS, head_lr=FT_HEAD_LR, l2sp_lambda=FT_L2SP,
                        patience=FT_PATIENCE, val_ratio=FT_VAL_RATIO)
        yhat_out = predict_numpy(m_out, Xho, device)

        m_head = copy.deepcopy(model_global)
        _train_one_mode(m_head, Xcal, ycal, n_classes, mode='head',
                        epochs=FT_EPOCHS, head_lr=FT_HEAD_LR, l2sp_lambda=FT_L2SP,
                        patience=FT_PATIENCE, val_ratio=FT_VAL_RATIO)
        yhat_head = predict_numpy(m_head, Xho, device)

        m_sp = copy.deepcopy(model_global)
        _train_one_mode(m_sp, Xcal, ycal, n_classes, mode='spatial+head',
                        epochs=FT_EPOCHS, head_lr=FT_HEAD_LR, base_lr=FT_BASE_LR,
                        l2sp_lambda=FT_L2SP, patience=FT_PATIENCE, val_ratio=FT_VAL_RATIO)
        yhat_sp = predict_numpy(m_sp, Xho, device)

        best_idx = np.argmax([
            (yhat_out == yho).mean(),
            (yhat_head == yho).mean(),
            (yhat_sp == yho).mean()
        ])
        yhat_best = [yhat_out, yhat_head, yhat_sp][best_idx]

        y_true_full[te_idx] = yho; y_pred_full[te_idx] = yhat_best

    return y_true_full, y_pred_full

# =========================
# FOLDS JSON helpers
# =========================
def save_group_folds_json_with_indices(subject_ids_str, groups_array, n_splits, out_json_path,
                                       created_by="dose_experiment", description=None):
    out_json_path = Path(out_json_path)
    unique_subjects_int = sorted(np.unique(groups_array).tolist())
    subject_ids = [f"S{sid:03d}" for sid in unique_subjects_int]

    if len(subject_ids) < n_splits:
        raise ValueError(f"n_splits={n_splits} mayor que número de sujetos={len(subject_ids)}")

    groups = np.arange(len(subject_ids))
    gkf = GroupKFold(n_splits=n_splits)

    folds = []
    fold_i = 0
    for train_idx_grp, test_idx_grp in gkf.split(groups, groups, groups):
        fold_i += 1
        train_sids = [subject_ids[int(i)] for i in train_idx_grp]
        test_sids  = [subject_ids[int(i)] for i in test_idx_grp]

        train_sids_int = [int(s[1:]) for s in train_sids]
        test_sids_int  = [int(s[1:]) for s in test_sids]

        tr_idx = np.where(np.isin(groups_array, train_sids_int))[0].tolist()
        te_idx = np.where(np.isin(groups_array, test_sids_int))[0].tolist()

        folds.append({
            "fold": int(fold_i),
            "train": train_sids,
            "test": test_sids,
            "tr_idx": tr_idx,
            "te_idx": te_idx
        })

    payload = {
        "created_at": datetime.now().isoformat(),
        "created_by": created_by,
        "description": description if description is not None else "",
        "n_splits": int(n_splits),
        "n_subjects": len(subject_ids),
        "subject_ids": subject_ids,
        "folds": folds
    }

    out_json_path.parent.mkdir(parents=True, exist_ok=True)
    with open(out_json_path, "w", encoding="utf-8") as f:
        json.dump(payload, f, indent=2, ensure_ascii=False)

    print(f"Folds JSON con índices guardado → {out_json_path}")
    return out_json_path

def load_group_folds_json(path_json, expected_subject_ids=None, strict_check=True):
    path_json = Path(path_json)
    if not path_json.exists():
        raise FileNotFoundError(f"No existe {path_json}")
    with open(path_json, "r", encoding="utf-8") as f:
        payload = json.load(f)

    subj_json = payload.get("subject_ids", [])
    if expected_subject_ids is not None:
        expected = sorted(list(expected_subject_ids))
        if subj_json != expected:
            msg = ("Los subject_ids del JSON no coinciden con expected_subject_ids.\n"
                   f"JSON has {len(subj_json)} subjects, expected {len(expected)}.\n"
                   f"First 10 JSON: {subj_json[:10]}\nFirst 10 expected: {expected[:10]}")
            if strict_check:
                raise ValueError(msg)
            else:
                print("WARNING: " + msg)
    return payload

# =========================
# EXPERIMENTO
# =========================
def run_experiment(save_folds_json=True, folds_json_path=FOLDS_DIR, folds_json_description="GroupKFold folds for comparison"):
    mne.set_log_level('WARNING')

    subs = subjects_available()
    print(f"Sujetos elegibles: {len(subs)} → {subs[:10]}{'...' if len(subs)>10 else ''}")

    X, y, groups, chs = build_dataset_all(subs, scenario=CLASS_SCENARIO, window_mode=WINDOW_MODE)
    N, T, C = X.shape
    n_classes = len(np.unique(y))
    print(f"Listo para entrenar: N={N} | T={T} | C={C} | clases={n_classes} | sujetos={len(np.unique(groups))}")

    ds = EEGTrials(X, y, groups)

    # preparar JSON folds
    if folds_json_path is None:
        folds_json_path = Path("folds") / f"group_folds_{N_FOLDS}splits.json"
    else:
        folds_json_path = Path(folds_json_path)
    folds_json_path.parent.mkdir(parents=True, exist_ok=True)

    unique_subs = sorted(np.unique(groups).tolist())
    subject_ids_str = [f"S{s:03d}" for s in unique_subs]

    if not folds_json_path.exists():
        if not save_folds_json:
            raise FileNotFoundError(f"Folds JSON no encontrado en {folds_json_path} y save_folds_json=False.")
        save_group_folds_json_with_indices(subject_ids_str, groups, n_splits=N_FOLDS,
                                           out_json_path=folds_json_path,
                                           created_by="Joel_Clasificador",
                                           description=folds_json_description)

    payload = load_group_folds_json(folds_json_path, expected_subject_ids=subject_ids_str, strict_check=False)
    folds = payload["folds"]

    # bucle por folds
    global_folds = []
    ft_prog_folds = []
    all_true = []
    all_pred = []

    for f in folds:
        fold = f["fold"]
        tr_idx = np.asarray(f.get("tr_idx", []), dtype=int)
        te_idx = np.asarray(f.get("te_idx", []), dtype=int)

        if tr_idx.size == 0 or te_idx.size == 0:
            print(f"Advertencia: fold {fold} sin índices tr/te válidos. Saltando.")
            continue

        # ===== Split de validación por sujetos =====
        gss = GroupShuffleSplit(n_splits=1, test_size=GLOBAL_VAL_SPLIT, random_state=RANDOM_STATE)
        tr_subj_idx, va_subj_idx = next(gss.split(tr_idx, groups[tr_idx], groups[tr_idx]))
        tr_sub_idx = tr_idx[tr_subj_idx]
        va_idx     = tr_idx[va_subj_idx]

        # Sampler balanceado para train
        sampler = build_weighted_sampler(y[tr_sub_idx], groups[tr_sub_idx])

        tr_loader = DataLoader(Subset(ds, tr_sub_idx), batch_size=BATCH_SIZE, sampler=sampler, drop_last=False)
        va_loader = DataLoader(Subset(ds, va_idx),     batch_size=BATCH_SIZE, shuffle=False, drop_last=False)
        te_loader = DataLoader(Subset(ds, te_idx),     batch_size=BATCH_SIZE, shuffle=False, drop_last=False)

        # ===== Modelo =====
        model = EEGNet(n_ch=C, n_classes=n_classes,
                       F1=16, D=2, kernel_t=64, k_sep=16,
                       pool1_t=4, pool2_t=8, drop1_p=0.25, drop2_p=0.5).to(DEVICE)

        # Opt y scheduler SGDR
        opt = optim.Adam(model.parameters(), lr=LR_INIT)
        scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(opt, T_0=SGDR_T0, T_mult=SGDR_Tmult)

        # Criterio ponderado (en función del split de train del fold)
        criterion = make_class_weighted_criterion(torch.tensor(y[tr_sub_idx]), n_classes)

        # Métrica rápida
        def _acc(loader):
            return evaluate_with_preds(model, loader)[2]

        # Historia para curvas
        history = {'train_acc': [], 'val_acc': []}

        # ===== Entrenamiento global =====
        print(f"\n[Fold {fold}/{N_FOLDS}] Entrenando modelo global..."
              f" (n_train={len(tr_sub_idx)} | n_val={len(va_idx)} | n_test={len(te_idx)})")

        best_state = copy.deepcopy(model.state_dict())
        best_val = -1.0
        bad = 0

        for epoch in range(1, EPOCHS_GLOBAL + 1):
            train_epoch(model, tr_loader, opt, criterion,
                        do_aug=True, fs=FS, maxnorm=2.0)
            scheduler.step(epoch-1 + 1e-8)  # tick suave

            # eval
            tr_acc = _acc(tr_loader)
            va_acc = _acc(va_loader)
            history['train_acc'].append(tr_acc)
            history['val_acc'].append(va_acc)

            if (epoch % LOG_EVERY == 0) or epoch in (1, 10, 20, 50, 100):
                cur_lr = opt.param_groups[0]['lr']
                print(f"  Época {epoch:3d} | train_acc={tr_acc:.4f} | val_acc={va_acc:.4f} | LR={cur_lr:.5f}")

            if va_acc > best_val + 1e-4:
                best_val = va_acc
                best_state = copy.deepcopy(model.state_dict())
                bad = 0
            else:
                bad += 1
                if bad >= GLOBAL_PATIENCE:
                    print(f"  Early stopping en época {epoch} (mejor val_acc={best_val:.4f})")
                    break

        # guardar curva de entrenamiento
        curve_path = f"training_curve_fold{fold}.png"
        plot_training_curves(history, curve_path)
        print(f"↳ Curva de entrenamiento guardada: {curve_path}")

        # cargar mejor estado antes de evaluar en test
        model.load_state_dict(best_state)

        # ===== Evaluación global (inter-sujeto puro) =====
        y_true, y_pred, acc_global = evaluate_with_preds(model, te_loader)
        global_folds.append(acc_global)
        all_true.append(y_true); all_pred.append(y_pred)

        print(f"[Fold {fold}/{N_FOLDS}] Global acc={acc_global:.4f}")
        print_report(y_true, y_pred, CLASS_NAMES_4C)

        # ---------- Fine-tuning PROGRESIVO por sujeto con 4-fold CV ----------
        X_te, y_te, g_te = X[te_idx], y[te_idx], groups[te_idx]

        y_true_ft_all, y_pred_ft_all = [], []
        used_subjects = 0
        for sid in np.unique(g_te):
            idx = np.where(g_te == sid)[0]
            Xs, ys = X_te[idx], y_te[idx]

            if len(ys) < CALIB_CV_FOLDS or len(np.unique(ys)) < 2:
                continue

            y_true_subj, y_pred_subj = subject_cv_finetune_predict_progressive(
                model, Xs, ys, DEVICE, n_splits=CALIB_CV_FOLDS, n_classes=n_classes
            )
            y_true_ft_all.append(y_true_subj)
            y_pred_ft_all.append(y_pred_subj)
            used_subjects += 1

        if len(y_true_ft_all) > 0:
            y_true_ft_all = np.concatenate(y_true_ft_all)
            y_pred_ft_all = np.concatenate(y_pred_ft_all)
            acc_ft = (y_true_ft_all == y_pred_ft_all).mean()
            print(f"  Fine-tuning PROGRESIVO (por sujeto, {CALIB_CV_FOLDS}-fold CV) acc={acc_ft:.4f} | sujetos={used_subjects}")
            print(f"  Δ(FT-Global) = {acc_ft - acc_global:+.4f}")
        else:
            acc_ft = np.nan
            print("  Fine-tuning PROGRESIVO no ejecutado (sujeto(s) con muestras insuficientes).")

        ft_prog_folds.append(acc_ft)

    # ---------- resultados finales ----------
    if len(all_true) > 0:
        all_true = np.concatenate(all_true)
        all_pred = np.concatenate(all_pred)
    else:
        all_true = np.array([], dtype=int)
        all_pred = np.array([], dtype=int)

    print("\n" + "="*60)
    print("RESULTADOS FINALES")
    print("="*60)
    print("Global folds:", [f"{a:.4f}" for a in global_folds])
    if len(global_folds) > 0:
        print(f"Global mean: {np.mean(global_folds):.4f}")

    print("Fine-tune PROGRESIVO folds:", [("nan" if (a is None or np.isnan(a)) else f"{a:.4f}") for a in ft_prog_folds])
    if len(ft_prog_folds) > 0:
        print(f"Fine-tune PROGRESIVO mean: {np.nanmean(ft_prog_folds):.4f}")
        print(f"Δ(FT-Global) mean: {np.nanmean(ft_prog_folds) - np.mean(global_folds):+.4f}")

    if all_true.size > 0:
        plot_confusion(all_true, all_pred, CLASS_NAMES_4C,
                       title="Confusion Matrix - Global Model (All Folds)",
                       fname="confusion_global_allfolds.png")
        print("\n↳ Matriz de confusión guardada: confusion_global_allfolds.png")

    return {
        "global_folds": global_folds,
        "ft_prog_folds": ft_prog_folds,
        "all_true": all_true,
        "all_pred": all_pred,
        "folds_json_path": str(folds_json_path)
    }

# ---------- MAIN ----------
if __name__ == "__main__":
    print("🧠 INICIANDO EXPERIMENTO CON EEGNet (capacidad↑) + CE ponderada + smoothing + aug OF focalizado")
    print(f"🔧 Configuración: {CLASS_SCENARIO}, {len(EXPECTED_8)} canales, {WINDOW_MODE}")
    print(f"⚙️  FT: epochs={FT_EPOCHS}, base_lr={FT_BASE_LR}, head_lr={FT_HEAD_LR}, L2SP={FT_L2SP}, patience={FT_PATIENCE}, CV={CALIB_CV_FOLDS}")
    run_experiment()


🚀 Usando dispositivo: cuda
🧠 INICIANDO EXPERIMENTO CON EEGNet (capacidad↑) + CE ponderada + smoothing + aug OF focalizado
🔧 Configuración: 4c, 8 canales, 6s
⚙️  FT: epochs=30, base_lr=5e-05, head_lr=0.001, L2SP=0.0001, patience=5, CV=4
Sujetos elegibles: 103 → [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]...


Construyendo dataset (RAW): 100%|██████████| 103/103 [00:37<00:00,  2.73it/s]


Dataset construido: N=8652 | T=960 | C=8 | clases=4 | sujetos únicos=103
Listo para entrenar: N=8652 | T=960 | C=8 | clases=4 | sujetos=103

[Fold 1/5] Entrenando modelo global... (n_train=5796 | n_val=1092 | n_test=1764)
  Época   1 | train_acc=0.4707 | val_acc=0.4441 | LR=0.01000
  Época   5 | train_acc=0.5000 | val_acc=0.4496 | LR=0.00655
  Época  10 | train_acc=0.5502 | val_acc=0.4844 | LR=0.00024
  Época  15 | train_acc=0.5200 | val_acc=0.4734 | LR=0.00905
  Época  20 | train_acc=0.5148 | val_acc=0.4689 | LR=0.00578
  Early stopping en época 22 (mejor val_acc=0.4844)
↳ Curva de entrenamiento guardada: training_curve_fold1.png
[Fold 1/5] Global acc=0.4649
              precision    recall  f1-score   support

        Left     0.6081    0.4785    0.5355       441
       Right     0.5249    0.5261    0.5255       441
  Both Fists     0.3494    0.3946    0.3706       441
   Both Feet     0.4256    0.4603    0.4423       441

    accuracy                         0.4649      1764
   mac

In [25]:
# -*- coding: utf-8 -*-
# EEGNet + protocolo global + fine-tuning progresivo por sujeto (augments + SGDR + tweaks)
# Cambios clave:
# 1) Capacidad ↑: F1=24, fc=128
# 2) Loss ponderada por clase (+20% a Both Fists) con soporte para mixup (WeightedSoftCrossEntropy)
# 3) Cutout focalizado (30–60 ms) solo para clases OF (2,3)
# 4) SGDR con T0=6, Tmult=2
# 5) TTA (n=5, jitter ±25 ms) en evaluación

import os, re, math, random, json, itertools, copy
from pathlib import Path
from datetime import datetime

import numpy as np
import pandas as pd
import mne

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset, WeightedRandomSampler

from sklearn.model_selection import GroupKFold, StratifiedKFold, StratifiedShuffleSplit, GroupShuffleSplit
from sklearn.utils import check_random_state
from sklearn.metrics import confusion_matrix, classification_report

import matplotlib.pyplot as plt
from tqdm import tqdm

# =========================
# CONFIGURACIÓN GENERAL
# =========================
PROJ = Path('..').resolve().parent
DATA_RAW = PROJ / 'data' / 'raw'
CACHE_DIR = PROJ / 'data' / 'cache'
FOLDS_DIR = PROJ / 'models' / 'folds' / 'Kfold5.json'
CACHE_DIR.mkdir(parents=True, exist_ok=True)

# Dispositivo y semilla
RANDOM_STATE = 42
torch.manual_seed(RANDOM_STATE); np.random.seed(RANDOM_STATE); random.seed(RANDOM_STATE)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🚀 Usando dispositivo: {DEVICE}")

# Escenario y ventana
CLASS_SCENARIO = '4c'
WINDOW_MODE = '6s'   # mantienes 6s
FS = 160.0
N_FOLDS = 5

# Entrenamiento global
BATCH_SIZE = 64
EPOCHS_GLOBAL = 100
LR_INIT = 1e-2
SGDR_T0 = 6      # ↓ ciclos cortos al inicio
SGDR_Tmult = 2

# Validación/ES global
GLOBAL_VAL_SPLIT = 0.15
GLOBAL_PATIENCE  = 10
LOG_EVERY        = 5

# Fine-tuning por sujeto (protocolo robusto)
CALIB_CV_FOLDS = 4
FT_EPOCHS = 30
FT_BASE_LR = 5e-5
FT_HEAD_LR = 1e-3
FT_L2SP = 1e-4
FT_PATIENCE = 5
FT_VAL_RATIO = 0.2

# Normalización por época canal-a-canal (z-score)
NORM_EPOCH_ZSCORE = True

# Sujetos excluidos
EXCLUDE_SUBJECTS = {38, 88, 89, 92, 100, 104}

# Runs
MI_RUNS_LR = [4, 8, 12]
MI_RUNS_OF = [6, 10, 14]
BASELINE_RUNS_EO = [1]

# Canales (8 con FCz)
EXPECTED_8 = ['C3','C4','Cz','CP3','CP4','FC3','FC4','FCz']

# =========================
# UTILIDADES DE CANALES
# =========================
def normalize_label(s: str) -> str:
    if s is None: return s
    s = s.strip()
    s = re.sub(r'[^A-Za-z0-9]', '', s)
    s = re.sub(r'([A-Za-z])0([0-9])', r'\1\2', s)
    s = re.sub(r'([A-Za-z])Z$', r'\1z', s)
    s = s.replace('fp', 'Fp').replace('FP', 'Fp')
    s = ''.join(ch.upper() if ch != 'z' else 'z' for ch in s)
    return s

def rename_channels_1010(raw: mne.io.BaseRaw):
    mapping = {}
    for ch in raw.ch_names:
        lab = normalize_label(ch)
        lab = lab[:-1] + 'z' if lab.endswith('Z') else lab
        lab = re.sub(r'([A-Z])Z$', r'\1z', lab)
        mapping[ch] = lab
    mne.rename_channels(raw.info, mapping)

def ensure_channels_order(raw: mne.io.BaseRaw, desired_channels=EXPECTED_8):
    missing = [ch for ch in desired_channels if ch not in raw.ch_names]
    if missing:
        print(f"Warning: faltan canales {missing} en archivo {getattr(raw,'filenames', [''])[0]}")
        return None
    raw.reorder_channels([ch for ch in raw.ch_names if ch in desired_channels] +
                         [ch for ch in raw.ch_names if ch not in desired_channels])
    raw.pick_channels(desired_channels, ordered=True)
    return raw

# =========================
# LECTURA DE EDF y EVENTOS
# =========================
_re_file = re.compile(r'[Ss](\d{3}).*?[Rr](\d{2})')

def parse_subject_run(path: Path):
    m = _re_file.search(str(path))
    if not m: return None, None
    return int(m.group(1)), int(m.group(2))

def run_kind(run_id:int):
    if run_id in MI_RUNS_LR: return 'LR'
    if run_id in MI_RUNS_OF: return 'OF'
    if run_id in BASELINE_RUNS_EO: return 'EO'
    return None

# --- Notch adaptativo: lee CSV si existe (subject, run, snr50_db, snr60_db)
_SNR_TABLE = None
def _load_snr_table():
    global _SNR_TABLE
    if _SNR_TABLE is not None:
        return _SNR_TABLE
    csv_path = PROJ / 'reports' / 'psd_mains' / 'psd_mains_summary.csv'
    if csv_path.exists():
        try:
            df = pd.read_csv(csv_path)
            _SNR_TABLE = df
        except Exception as e:
            print(f"[SNR] No se pudo leer {csv_path}: {e}")
            _SNR_TABLE = None
    return _SNR_TABLE

def _decide_notch(subject, run, th_db=10.0):
    df = _load_snr_table()
    if df is None:  # default
        return 60.0
    row = df[(df['subject']==subject) & (df['run']==run)]
    if row.empty:
        return 60.0
    snr50 = float(row['snr50_db'].iloc[0]); snr60 = float(row['snr60_db'].iloc[0])
    if snr60 >= th_db and snr60 >= snr50: return 60.0
    if snr50 >= th_db and snr50 >  snr60: return 50.0
    return None  # no notch si no sobresale

def read_raw_edf(path: Path):
    raw = mne.io.read_raw_edf(path, preload=True, verbose=False)
    raw.pick(mne.pick_types(raw.info, eeg=True))
    rename_channels_1010(raw)
    try:
        mont = mne.channels.make_standard_montage('standard_1020')
        raw.set_montage(mont, on_missing='ignore')
    except Exception:
        pass
    if abs(raw.info['sfreq'] - FS) > 1e-6:
        raw.resample(FS, npad="auto")
    raw = ensure_channels_order(raw, EXPECTED_8)
    if raw is None:
        return None

    # Notch adaptativo (por SNR); si no hay CSV, por defecto 60 Hz
    sid, rid = parse_subject_run(path)
    notch = _decide_notch(sid, rid)
    if notch is not None:
        raw.notch_filter(freqs=[float(notch)], picks='eeg', method='spectrum_fit', phase='zero')

    return raw

def collect_events_T1T2(raw: mne.io.BaseRaw):
    if raw.annotations is None or len(raw.annotations) == 0:
        return []
    def _norm(s): return str(s).strip().upper().replace(' ', '')
    res = []
    for onset, desc in zip(raw.annotations.onset, raw.annotations.description):
        tag = _norm(desc)
        if tag in ('T1','T2'):
            res.append((float(onset), tag))
    res.sort()
    dedup = []
    last_t1 = last_t2 = -1e9
    for t, tag in res:
        if tag == 'T1':
            if (t - last_t1) >= 0.5: dedup.append((t, tag)); last_t1 = t
        else:
            if (t - last_t2) >= 0.5: dedup.append((t, tag)); last_t2 = t
    return dedup

# =========================
# CONSTRUCCIÓN DE DATASETS
# =========================
def subjects_available():
    subs = []
    for sdir in sorted(DATA_RAW.glob('S*')):
        if not sdir.is_dir(): continue
        try: sid = int(sdir.name[1:])
        except: continue
        if sid in EXCLUDE_SUBJECTS: continue
        any_mi = any((sdir / f"S{sid:03d}R{r:02d}.edf").exists() for r in (MI_RUNS_LR + MI_RUNS_OF))
        if any_mi: subs.append(sid)
    return subs

def extract_trials_from_run(edf_path: Path, scenario: str, window_mode: str):
    subj, run = parse_subject_run(edf_path)
    kind = run_kind(run)
    if kind not in ('LR','OF','EO'):
        return ([], [])

    raw = read_raw_edf(edf_path)
    if raw is None:
        return ([], [])

    data = raw.get_data()
    fs = raw.info['sfreq']
    assert abs(fs - FS) < 1e-6

    out = []

    if kind in ('LR','OF'):
        events = collect_events_T1T2(raw)
        if window_mode == '3s':
            rel_start, rel_end = 0.0, 3.0
        else:
            rel_start, rel_end = -1.0, 5.0

        for onset_sec, tag in events:
            if kind == 'LR':
                if tag == 'T1': label = 'L'
                elif tag == 'T2': label = 'R'
                else: continue
            else:
                if tag == 'T1': label = 'BFISTS'
                elif tag == 'T2': label = 'BFEET'
                else: continue

            if scenario == '2c' and label not in ('L','R'): continue
            if scenario == '3c' and label not in ('L','R','BFISTS'): continue
            if scenario == '4c' and label not in ('L','R','BFISTS','BFEET'): continue

            s = int(round((raw.first_time + onset_sec + rel_start) * fs))
            e = int(round((raw.first_time + onset_sec + rel_end) * fs))
            if s < 0 or e > data.shape[1]:
                continue

            seg = data[:, s:e].T.astype(np.float32)
            if NORM_EPOCH_ZSCORE:
                seg = (seg - seg.mean(axis=0, keepdims=True)) / (seg.std(axis=0, keepdims=True) + 1e-6)

            if label == 'L':       y = 0
            elif label == 'R':     y = 1
            elif label == 'BFISTS':y = 2
            elif label == 'BFEET': y = 3
            else: continue

            out.append((seg, y, subj))

    elif kind == 'EO':
        return ([], raw.ch_names)

    return out, raw.ch_names

def build_dataset_all(subjects, scenario='4c', window_mode='3s'):
    X, y, groups = [], [], []
    ch_template = None

    for s in tqdm(subjects, desc="Construyendo dataset (RAW)"):
        sdir = DATA_RAW / f"S{s:03d}"
        if not sdir.exists(): continue

        trials_L, trials_R, trials_FISTS, trials_FEET = [], [], [], []

        for r in MI_RUNS_LR:
            p = sdir / f"S{s:03d}R{r:02d}.edf"
            if not p.exists(): continue
            outs, chs = extract_trials_from_run(p, scenario, window_mode)
            if ch_template is None and chs: ch_template = chs
            for seg, lab, _ in outs:
                if lab == 0: trials_L.append(seg)
                elif lab == 1: trials_R.append(seg)

        for r in MI_RUNS_OF:
            p = sdir / f"S{s:03d}R{r:02d}.edf"
            if not p.exists(): continue
            outs, chs = extract_trials_from_run(p, scenario, window_mode)
            if ch_template is None and chs: ch_template = chs
            for seg, lab, _ in outs:
                if lab == 2: trials_FISTS.append(seg)
                elif lab == 3: trials_FEET.append(seg)

        need_per_class = 21
        def pick(trials, n, rng):
            if len(trials) < n:
                idx = rng.choice(len(trials), size=n, replace=True)
                return [trials[i] for i in idx]
            rng.shuffle(trials)
            return trials[:n]

        rng = check_random_state(RANDOM_STATE + s)
        if len(trials_L)==0 or len(trials_R)==0 or len(trials_FISTS)==0 or len(trials_FEET)==0:
            continue

        Lp  = pick(trials_L,     need_per_class, rng)
        Rp  = pick(trials_R,     need_per_class, rng)
        FIp = pick(trials_FISTS, need_per_class, rng)
        FEp = pick(trials_FEET,  need_per_class, rng)

        pack = [(Lp, 0), (Rp, 1), (FIp, 2), (FEp, 3)]
        for segs, lab in pack:
            for seg in segs:
                X.append(seg); y.append(lab); groups.append(s)

    X = np.stack(X, axis=0)
    y = np.asarray(y, dtype=np.int64)
    groups = np.asarray(groups, dtype=np.int64)

    n, T, C = X.shape
    n_classes = len(np.unique(y))
    print(f"Dataset construido: N={n} | T={T} | C={C} | clases={n_classes} | sujetos únicos={len(np.unique(groups))}")
    return X, y, groups, ch_template

# =========================
# AUGMENTS (solo train)
# =========================
def do_time_jitter(x, max_ms=50, fs=160.0):
    # x: (B,1,T,C) torch.float
    max_shift = int(round(max_ms/1000.0 * fs))
    if max_shift <= 0: return x
    B,_,T,C = x.shape
    shifts = torch.randint(low=-max_shift, high=max_shift+1, size=(B,), device=x.device)
    out = torch.empty_like(x)
    for i,s in enumerate(shifts):
        if s==0: out[i] = x[i]; continue
        if s>0:
            out[i,:,s:,:] = x[i,:,:T-s,:]
            out[i,:,:s,:] = 0
        else:
            s = -s
            out[i,:,:T-s,:] = x[i,:,s:,:]
            out[i,:,T-s:,:] = 0
    return out

def do_gaussian_noise(x, sigma=0.01):
    if sigma<=0: return x
    return x + sigma*torch.randn_like(x)

def do_temporal_cutout(x, min_ms=100, max_ms=150, fs=160.0):
    B,_,T,C = x.shape
    Lmin = int(round(min_ms/1000.0*fs))
    Lmax = int(round(max_ms/1000.0*fs))
    if Lmin<=0 or Lmax<=0 or Lmin>Lmax: return x
    out = x.clone()
    for i in range(B):
        L = random.randint(Lmin, Lmax)
        if L>=T: continue
        s = random.randint(0, T-L)
        out[i,:,s:s+L,:] = 0
    return out

def do_temporal_cutout_masked(x, y, classes_mask={2,3}, min_ms=30, max_ms=60, fs=160.0):
    # Aplica cutout solo si y pertenece a classes_mask
    B,_,T,C = x.shape
    if B == 0: return x
    Lmin = int(round(min_ms/1000.0*fs))
    Lmax = int(round(max_ms/1000.0*fs))
    if Lmin<=0 or Lmax<=0 or Lmin>Lmax: return x
    out = x.clone()
    y_np = y.detach().cpu().numpy()
    for i in range(B):
        if int(y_np[i]) not in classes_mask: 
            continue
        L = random.randint(Lmin, Lmax)
        if L>=T: continue
        s = random.randint(0, T-L)
        out[i,:,s:s+L,:] = 0
    return out

def mixup_batch(x, y, n_classes, alpha=0.2):
    if alpha<=0:
        y_onehot = torch.nn.functional.one_hot(y, num_classes=n_classes).float()
        return x, y_onehot, 1.0
    lam = np.random.beta(alpha, alpha)
    perm = torch.randperm(x.size(0), device=x.device)
    x_mix = lam*x + (1-lam)*x[perm]
    y_a = torch.nn.functional.one_hot(y, num_classes=n_classes).float()
    y_b = y_a[perm]
    y_mix = lam*y_a + (1-lam)*y_b
    return x_mix, y_mix, lam

class WeightedSoftCrossEntropy(nn.Module):
    """
    Cross-entropy suave (targets en probas, p.ej. mixup) con pesos por clase.
    """
    def __init__(self, class_weights=None, label_smoothing=0.0):
        super().__init__()
        self.register_buffer('w', None if class_weights is None else class_weights.clone().float())
        self.ls = float(label_smoothing)

    def forward(self, logits, target_probs):
        # label smoothing: mezcla target_probs con uniforme
        if self.ls > 0:
            K = logits.size(1)
            target_probs = (1-self.ls)*target_probs + self.ls*(1.0/K)
        logp = torch.log_softmax(logits, dim=1)  # (B,C)
        loss_per_class = -(target_probs * logp)  # (B,C)
        if self.w is not None:
            loss_per_class = loss_per_class * self.w.unsqueeze(0)  # pesar por clase
        loss = loss_per_class.sum(dim=1).mean()
        return loss

# =========================
# EEGNet (Lawhern et al., 2018) adaptado a (B,1,T,C) + ChannelDropout
# =========================
class ChannelDropout(nn.Module):
    def __init__(self, p=0.1):
        super().__init__()
        self.p = p
    def forward(self, x):
        if not self.training or self.p<=0: return x
        B,_,T,C = x.shape
        mask = (torch.rand(B,1,1,C, device=x.device) > self.p).float()
        return x * mask

class EEGNet(nn.Module):
    """
    Entrada: x de forma (B, 1, T, C)
    """
    def __init__(self, n_ch: int, n_classes: int,
                 F1: int = 24, D: int = 2, kernel_t: int = 64, k_sep: int = 16,
                 pool1_t: int = 4, pool2_t: int = 6,
                 drop1_p: float = 0.35, drop2_p: float = 0.6,
                 chdrop_p: float = 0.1):
        super().__init__()
        self.n_ch = n_ch
        self.n_classes = n_classes
        self.F1 = F1
        self.D = D
        self.F2 = F1 * D
        self.kernel_t = kernel_t
        self.k_sep = k_sep
        self.pool1_t = pool1_t
        self.pool2_t = pool2_t

        self.chdrop = ChannelDropout(p=chdrop_p)

        # Bloque 1: temporal
        self.conv_temporal = nn.Conv2d(1, F1, kernel_size=(kernel_t, 1),
                                       padding=(kernel_t // 2, 0), bias=False)
        self.bn1 = nn.BatchNorm2d(F1, momentum=0.99, eps=1e-3)
        self.act = nn.ELU()

        # Bloque 2: depthwise (espacial)
        self.conv_depthwise = nn.Conv2d(F1, self.F2, kernel_size=(1, n_ch),
                                        groups=F1, bias=False)
        self.bn2 = nn.BatchNorm2d(self.F2, momentum=0.99, eps=1e-3)
        self.pool1 = nn.AvgPool2d(kernel_size=(pool1_t, 1), stride=(pool1_t, 1))
        self.drop1 = nn.Dropout(drop1_p)

        # Bloque 3: separable temporal
        self.conv_sep_depth = nn.Conv2d(self.F2, self.F2, kernel_size=(k_sep, 1),
                                        groups=self.F2, padding=(k_sep // 2, 0), bias=False)
        self.conv_sep_point = nn.Conv2d(self.F2, self.F2, kernel_size=(1, 1), bias=False)
        self.bn3 = nn.BatchNorm2d(self.F2, momentum=0.99, eps=1e-3)
        self.pool2 = nn.AvgPool2d(kernel_size=(pool2_t, 1), stride=(pool2_t, 1))
        self.drop2 = nn.Dropout(drop2_p)

        self.flatten = nn.Flatten()

        # Cabeza dinámica
        self.fc = None
        self.out = None
        self._T_in = None

    def _build_head(self, T_in: int, device: torch.device):
        T1 = T_in // self.pool1_t
        T2 = T1 // self.pool2_t
        feat_dim = self.F2 * T2 * 1
        self.fc  = nn.Linear(feat_dim, 128, bias=True).to(device)
        self.out = nn.Linear(128, self.n_classes, bias=True).to(device)
        self._T_in = T_in

    def ensure_head(self, T_in: int, device: torch.device):
        if (self.fc is None) or (self.out is None) or (self._T_in != T_in):
            self._build_head(T_in, device)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, _, T, C = x.shape
        self.ensure_head(T, x.device)

        x = self.chdrop(x)  # ChannelDropout

        z = self.conv_temporal(x)
        z = self.bn1(z); z = self.act(z)

        z = self.conv_depthwise(z)   # (B, F2, T, 1)
        z = self.bn2(z); z = self.act(z)
        z = self.pool1(z)
        z = self.drop1(z)

        z = self.conv_sep_depth(z)
        z = self.conv_sep_point(z)
        z = self.bn3(z); z = self.act(z)
        z = self.pool2(z)
        z = self.drop2(z)

        z = self.flatten(z)
        z = self.fc(z); z = self.act(z)
        z = self.out(z)
        return z

# =========================
# TORCH DATASET
# =========================
class EEGTrials(Dataset):
    def __init__(self, X, y, groups):
        self.X = X.astype(np.float32)
        self.y = y.astype(np.int64)
        self.g = groups.astype(np.int64)
    def __len__(self): return self.X.shape[0]
    def __getitem__(self, idx):
        x = self.X[idx]
        x = np.expand_dims(x, 0)                 # (1, T, C)
        return torch.from_numpy(x), torch.tensor(self.y[idx]), torch.tensor(self.g[idx])

CLASS_NAMES_4C = ['Left', 'Right', 'Both Fists', 'Both Feet']

# =========================
# UTIL: MAX-NORM
# =========================
@torch.no_grad()
def apply_max_norm(model, max_value=2.0, p=2.0):
    layers = []
    if hasattr(model, 'conv_depthwise'): layers.append(model.conv_depthwise)
    if hasattr(model, 'conv_sep_point'): layers.append(model.conv_sep_point)
    if hasattr(model, 'fc'):             layers.append(model.fc)
    if hasattr(model, 'out'):            layers.append(model.out)
    for layer in layers:
        if hasattr(layer, 'weight') and layer.weight is not None:
            w = layer.weight.data
            norms = w.view(w.size(0), -1).norm(p=p, dim=1, keepdim=True)
            desired = torch.clamp(norms, max=max_value)
            w.view(w.size(0), -1).mul_(desired / (1e-8 + norms))

# =========================
# ENTRENAMIENTO / EVALUACIÓN
# =========================
def build_weighted_sampler(y, groups):
    # pesos inversos por clase y por sujeto → promueve balance en cada batch
    y = np.asarray(y); groups = np.asarray(groups)
    class_counts = np.bincount(y, minlength=len(np.unique(y))).astype(float)
    class_w = 1.0 / class_counts[y]
    subj_vals, subj_counts = np.unique(groups, return_counts=True)
    subj_map = {s:c for s,c in zip(subj_vals, subj_counts)}
    subj_w = np.array([1.0/subj_map[g] for g in groups], dtype=float)
    w = class_w * subj_w
    w = w / w.mean()
    w_t = torch.from_numpy(w).float()
    sampler = WeightedRandomSampler(weights=w_t, num_samples=len(w_t), replacement=True)
    return sampler

def make_class_weight_tensor(y_indices, n_classes, boost_bfists=1.20):
    # pesos inversos por clase, normalizados; clase 2 (*Both Fists*) con boost
    counts = np.bincount(y_indices, minlength=n_classes).astype(np.float32)
    counts[counts == 0] = 1.0
    w = counts.sum() / counts
    w = w / w.mean()
    w[2] *= boost_bfists
    w = w / w.mean()
    return torch.tensor(w, dtype=torch.float32, device=DEVICE)

def train_epoch(model, loader, opt, criterion_soft, n_classes,
                do_aug=True, fs=160.0, maxnorm=None):
    model.train()
    for xb, yb, _ in loader:
        xb, yb = xb.to(DEVICE), yb.to(DEVICE)

        # ---- AUGMENTS (orden: jitter → noise → cutout focalizado → mixup) ----
        if do_aug:
            xb = do_time_jitter(xb, max_ms=50, fs=fs)
            xb = do_gaussian_noise(xb, sigma=0.01)
            xb = do_temporal_cutout_masked(xb, yb, classes_mask={2,3}, min_ms=30, max_ms=60, fs=fs)
            xb, yt, _ = mixup_batch(xb, yb, n_classes=n_classes, alpha=0.2)
            logits = model(xb)
            loss = criterion_soft(logits, yt)
        else:
            logits = model(xb)
            # yb como índices → convertir a one-hot para criterio suave
            yt = torch.nn.functional.one_hot(yb, num_classes=n_classes).float()
            loss = criterion_soft(logits, yt)

        opt.zero_grad(set_to_none=True)
        loss.backward()
        opt.step()
        if maxnorm is not None:
            apply_max_norm(model, max_value=maxnorm, p=2.0)

@torch.no_grad()
def _predict_tta(model, xb, n=5, fs=160.0):
    outs = []
    for _ in range(n):
        xj = do_time_jitter(xb, max_ms=25, fs=fs)
        outs.append(model(xj))
    return torch.stack(outs, dim=0).mean(dim=0)

@torch.no_grad()
def evaluate_with_preds(model, loader, use_tta=True, tta_n=5):
    model.eval()
    y_true, y_pred = [], []
    for xb, yb, _ in loader:
        xb = xb.to(DEVICE)
        if use_tta:
            logits = _predict_tta(model, xb, n=tta_n, fs=FS)
        else:
            logits = model(xb)
        pred = logits.argmax(dim=1).cpu().numpy().tolist()
        y_pred.extend(pred)
        y_true.extend(yb.numpy().tolist())
    y_true = np.asarray(y_true, dtype=int)
    y_pred = np.asarray(y_pred, dtype=int)
    acc = (y_true == y_pred).mean()
    return y_true, y_pred, float(acc)

def plot_confusion(y_true, y_pred, classes, title, fname):
    cm = confusion_matrix(y_true, y_pred, labels=list(range(len(classes))))
    with np.errstate(invalid='ignore'):
        cm_norm = cm.astype('float') / cm.sum(axis=1, keepdims=True)
    cm_norm = np.nan_to_num(cm_norm)

    plt.figure(figsize=(6, 5))
    plt.imshow(cm_norm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title(title)
    plt.colorbar(fraction=0.046, pad=0.04)
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45, ha='right')
    plt.yticks(tick_marks, classes)

    fmt = '.2f'
    thresh = cm_norm.max() / 2.
    for i, j in itertools.product(range(cm_norm.shape[0]), range(cm_norm.shape[1])):
        plt.text(j, i, format(cm_norm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm_norm[i, j] > thresh else "black")

    plt.ylabel('True label'); plt.xlabel('Predicted label')
    plt.tight_layout(); plt.savefig(fname, dpi=150, bbox_inches='tight'); plt.close()

def plot_training_curves(history, fname):
    plt.figure(figsize=(6,4))
    plt.plot(history['train_acc'], label='train_acc')
    plt.plot(history['val_acc'], label='val_acc')
    plt.xlabel('Epoch'); plt.ylabel('Accuracy')
    plt.title('Training curve'); plt.legend()
    plt.tight_layout(); plt.savefig(fname, dpi=150); plt.close()

def print_report(y_true, y_pred, classes):
    print(classification_report(y_true, y_pred, target_names=classes, digits=4))

# =========================
# FINE-TUNING PROGRESIVO por sujeto (con validación interna)
# =========================
def _param_groups(model, mode):
    if mode == 'out':
        train = list(model.out.parameters())
    elif mode == 'head':
        train = list(model.fc.parameters()) + list(model.out.parameters())
    elif mode == 'spatial+head':
        train = (list(model.conv_depthwise.parameters()) +
                 list(model.bn2.parameters()) +
                 list(model.conv_sep_depth.parameters()) +
                 list(model.conv_sep_point.parameters()) +
                 list(model.bn3.parameters()) +
                 list(model.fc.parameters()) +
                 list(model.out.parameters()))
    else:
        raise ValueError(mode)
    return train

def _freeze_for_mode(model, mode):
    for p in model.parameters(): p.requires_grad = False
    if mode == 'out':
        for p in model.out.parameters(): p.requires_grad = True
    elif mode == 'head':
        for p in model.fc.parameters():  p.requires_grad = True
        for p in model.out.parameters(): p.requires_grad = True
    elif mode == 'spatial+head':
        for p in model.conv_depthwise.parameters(): p.requires_grad = True
        for p in model.bn2.parameters():           p.requires_grad = True
        for p in model.conv_sep_depth.parameters():p.requires_grad = True
        for p in model.conv_sep_point.parameters():p.requires_grad = True
        for p in model.bn3.parameters():           p.requires_grad = True
        for p in model.fc.parameters():            p.requires_grad = True
        for p in model.out.parameters():           p.requires_grad = True

def _class_weights(y_np, n_classes):
    counts = np.bincount(y_np, minlength=n_classes).astype(np.float32)
    counts[counts == 0] = 1.0
    weights = counts.sum() / counts
    weights = weights / weights.mean()
    return torch.tensor(weights, dtype=torch.float32, device=DEVICE)

def _train_one_mode(model, X_cal, y_cal, n_classes, mode,
                    epochs=FT_EPOCHS, batch_size=32,
                    head_lr=FT_HEAD_LR, base_lr=FT_BASE_LR,
                    l2sp_lambda=FT_L2SP, patience=FT_PATIENCE, val_ratio=FT_VAL_RATIO):
    sss = StratifiedShuffleSplit(n_splits=1, test_size=val_ratio, random_state=RANDOM_STATE)
    (tr_idx, va_idx), = sss.split(X_cal, y_cal)
    Xtr, ytr = X_cal[tr_idx], y_cal[tr_idx]
    Xva, yva = X_cal[va_idx], y_cal[va_idx]

    ds_tr = torch.utils.data.TensorDataset(
        torch.from_numpy(Xtr).float().unsqueeze(1),
        torch.from_numpy(ytr).long()
    )
    ds_va = torch.utils.data.TensorDataset(
        torch.from_numpy(Xva).float().unsqueeze(1),
        torch.from_numpy(yva).long()
    )
    dl_tr = DataLoader(ds_tr, batch_size=batch_size, shuffle=True, drop_last=False)
    dl_va = DataLoader(ds_va, batch_size=batch_size, shuffle=False, drop_last=False)

    _freeze_for_mode(model, mode)

    if mode == 'spatial+head':
        base_params = (list(model.conv_depthwise.parameters()) +
                       list(model.bn2.parameters()) +
                       list(model.conv_sep_depth.parameters()) +
                       list(model.conv_sep_point.parameters()) +
                       list(model.bn3.parameters()))
        head_params = list(model.fc.parameters()) + list(model.out.parameters())
        train_params = base_params + head_params
        opt = optim.Adam([
            {"params": base_params, "lr": base_lr},
            {"params": head_params, "lr": head_lr},
        ])
    else:
        train_params = _param_groups(model, mode)
        opt = optim.Adam(train_params, lr=head_lr)

    ref = [p.detach().clone().to(p.device) for p in train_params]
    class_w = _class_weights(ytr, n_classes)
    crit = nn.CrossEntropyLoss(weight=class_w)

    best_state = copy.deepcopy(model.state_dict())
    best_val = float('inf'); bad = 0

    for _ in range(epochs):
        model.train()
        for xb, yb in dl_tr:
            xb, yb = xb.to(DEVICE), yb.to(DEVICE)
            opt.zero_grad(set_to_none=True)
            logits = model(xb)
            loss = crit(logits, yb)
            reg = 0.0
            for p_cur, p_ref in zip(train_params, ref):
                reg = reg + torch.sum((p_cur - p_ref)**2)
            loss = loss + l2sp_lambda * reg
            loss.backward(); opt.step()
            apply_max_norm(model, max_value=2.0, p=2.0)

        model.eval()
        with torch.no_grad():
            val_loss = 0.0; nval = 0
            for xb, yb in dl_va:
                xb, yb = xb.to(DEVICE), yb.to(DEVICE)
                logits = model(xb)
                loss = crit(logits, yb)
                val_loss += loss.item() * xb.size(0); nval += xb.size(0)
            val_loss /= max(1, nval)

        if val_loss + 1e-7 < best_val:
            best_val = val_loss; bad = 0
            best_state = copy.deepcopy(model.state_dict())
        else:
            bad += 1
            if bad >= patience: break

    model.load_state_dict(best_state)
    return model

@torch.no_grad()
def predict_numpy(model, X_np, device):
    model.eval()
    xb = torch.from_numpy(X_np).float().unsqueeze(1).to(device)
    logits = model(xb)
    return logits.argmax(dim=1).cpu().numpy()

def subject_cv_finetune_predict_progressive(model_global, Xs, ys, device,
                                            n_splits=CALIB_CV_FOLDS, n_classes=4):
    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=RANDOM_STATE)
    y_true_full = np.empty_like(ys); y_pred_full = np.empty_like(ys)

    for tr_idx, te_idx in skf.split(Xs, ys):
        Xcal, ycal = Xs[tr_idx], ys[tr_idx]
        Xho,  yho  = Xs[te_idx], ys[te_idx]

        m_out = copy.deepcopy(model_global)
        _train_one_mode(m_out, Xcal, ycal, n_classes, mode='out',
                        epochs=FT_EPOCHS, head_lr=FT_HEAD_LR, l2sp_lambda=FT_L2SP,
                        patience=FT_PATIENCE, val_ratio=FT_VAL_RATIO)
        yhat_out = predict_numpy(m_out, Xho, device); acc_out = (yhat_out == yho).mean()

        m_head = copy.deepcopy(model_global)
        _train_one_mode(m_head, Xcal, ycal, n_classes, mode='head',
                        epochs=FT_EPOCHS, head_lr=FT_HEAD_LR, l2sp_lambda=FT_L2SP,
                        patience=FT_PATIENCE, val_ratio=FT_VAL_RATIO)
        yhat_head = predict_numpy(m_head, Xho, device); acc_head = (yhat_head == yho).mean()

        m_sp = copy.deepcopy(model_global)
        _train_one_mode(m_sp, Xcal, ycal, n_classes, mode='spatial+head',
                        epochs=FT_EPOCHS, head_lr=FT_HEAD_LR, base_lr=FT_BASE_LR,
                        l2sp_lambda=FT_L2SP, patience=FT_PATIENCE, val_ratio=FT_VAL_RATIO)
        yhat_sp = predict_numpy(m_sp, Xho, device); acc_sp = (yhat_sp == yho).mean()

        best_idx = np.argmax([acc_out, acc_head, acc_sp])
        yhat_best = [yhat_out, yhat_head, yhat_sp][best_idx]

        y_true_full[te_idx] = yho; y_pred_full[te_idx] = yhat_best

    return y_true_full, y_pred_full

# =========================
# FOLDS JSON helpers
# =========================
def save_group_folds_json_with_indices(subject_ids_str, groups_array, n_splits, out_json_path,
                                       created_by="dose_experiment", description=None):
    out_json_path = Path(out_json_path)
    unique_subjects_int = sorted(np.unique(groups_array).tolist())
    subject_ids = [f"S{sid:03d}" for sid in unique_subjects_int]

    if len(subject_ids) < n_splits:
        raise ValueError(f"n_splits={n_splits} mayor que número de sujetos={len(subject_ids)}")

    groups = np.arange(len(subject_ids))
    gkf = GroupKFold(n_splits=n_splits)

    folds = []
    fold_i = 0
    for train_idx_grp, test_idx_grp in gkf.split(groups, groups, groups):
        fold_i += 1
        train_sids = [subject_ids[int(i)] for i in train_idx_grp]
        test_sids  = [subject_ids[int(i)] for i in test_idx_grp]

        train_sids_int = [int(s[1:]) for s in train_sids]
        test_sids_int  = [int(s[1:]) for s in test_sids]

        tr_idx = np.where(np.isin(groups_array, train_sids_int))[0].tolist()
        te_idx = np.where(np.isin(groups_array, test_sids_int))[0].tolist()

        folds.append({
            "fold": int(fold_i),
            "train": train_sids,
            "test": test_sids,
            "tr_idx": tr_idx,
            "te_idx": te_idx
        })

    payload = {
        "created_at": datetime.now().isoformat(),
        "created_by": created_by,
        "description": description if description is not None else "",
        "n_splits": int(n_splits),
        "n_subjects": len(subject_ids),
        "subject_ids": subject_ids,
        "folds": folds
    }

    out_json_path.parent.mkdir(parents=True, exist_ok=True)
    with open(out_json_path, "w", encoding="utf-8") as f:
        json.dump(payload, f, indent=2, ensure_ascii=False)

    print(f"Folds JSON con índices guardado → {out_json_path}")
    return out_json_path

def load_group_folds_json(path_json, expected_subject_ids=None, strict_check=True):
    path_json = Path(path_json)
    if not path_json.exists():
        raise FileNotFoundError(f"No existe {path_json}")
    with open(path_json, "r", encoding="utf-8") as f:
        payload = json.load(f)

    subj_json = payload.get("subject_ids", [])
    if expected_subject_ids is not None:
        expected = sorted(list(expected_subject_ids))
        if subj_json != expected:
            msg = ("Los subject_ids del JSON no coinciden con expected_subject_ids.\n"
                   f"JSON has {len(subj_json)} subjects, expected {len(expected)}.\n"
                   f"First 10 JSON: {subj_json[:10]}\nFirst 10 expected: {expected[:10]}")
            if strict_check:
                raise ValueError(msg)
            else:
                print("WARNING: " + msg)
    return payload

# =========================
# EXPERIMENTO
# =========================
def run_experiment(save_folds_json=True, folds_json_path=FOLDS_DIR, folds_json_description="GroupKFold folds for comparison"):
    mne.set_log_level('WARNING')

    subs = subjects_available()
    print(f"Sujetos elegibles: {len(subs)} → {subs[:10]}{'...' if len(subs)>10 else ''}")

    X, y, groups, chs = build_dataset_all(subs, scenario=CLASS_SCENARIO, window_mode=WINDOW_MODE)
    N, T, C = X.shape
    n_classes = len(np.unique(y))
    print(f"Listo para entrenar: N={N} | T={T} | C={C} | clases={n_classes} | sujetos={len(np.unique(groups))}")

    ds = EEGTrials(X, y, groups)

    # preparar JSON folds
    if folds_json_path is None:
        folds_json_path = Path("folds") / f"group_folds_{N_FOLDS}splits.json"
    else:
        folds_json_path = Path(folds_json_path)
    folds_json_path.parent.mkdir(parents=True, exist_ok=True)

    unique_subs = sorted(np.unique(groups).tolist())
    subject_ids_str = [f"S{s:03d}" for s in unique_subs]

    if not folds_json_path.exists():
        if not save_folds_json:
            raise FileNotFoundError(f"Folds JSON no encontrado en {folds_json_path} y save_folds_json=False.")
        save_group_folds_json_with_indices(subject_ids_str, groups, n_splits=N_FOLDS,
                                           out_json_path=folds_json_path,
                                           created_by="Joel_Clasificador",
                                           description=folds_json_description)

    payload = load_group_folds_json(folds_json_path, expected_subject_ids=subject_ids_str, strict_check=False)
    folds = payload["folds"]

    # bucle por folds
    global_folds = []
    ft_prog_folds = []
    all_true = []
    all_pred = []

    for f in folds:
        fold = f["fold"]
        tr_idx = np.asarray(f.get("tr_idx", []), dtype=int)
        te_idx = np.asarray(f.get("te_idx", []), dtype=int)

        if tr_idx.size == 0 or te_idx.size == 0:
            print(f"Advertencia: fold {fold} sin índices tr/te válidos. Saltando.")
            continue

        # ===== Split de validación por sujetos =====
        gss = GroupShuffleSplit(n_splits=1, test_size=GLOBAL_VAL_SPLIT, random_state=RANDOM_STATE)
        tr_subj_idx, va_subj_idx = next(gss.split(tr_idx, groups[tr_idx], groups[tr_idx]))
        tr_sub_idx = tr_idx[tr_subj_idx]
        va_idx     = tr_idx[va_subj_idx]

        # Sampler balanceado para train
        sampler = build_weighted_sampler(y[tr_sub_idx], groups[tr_sub_idx])

        tr_loader = DataLoader(Subset(ds, tr_sub_idx), batch_size=BATCH_SIZE, sampler=sampler, drop_last=False)
        va_loader = DataLoader(Subset(ds, va_idx),     batch_size=BATCH_SIZE, shuffle=False, drop_last=False)
        te_loader = DataLoader(Subset(ds, te_idx),     batch_size=BATCH_SIZE, shuffle=False, drop_last=False)

        # ===== EEGNet =====
        model = EEGNet(n_ch=C, n_classes=n_classes,
                       F1=24, D=2, kernel_t=64, k_sep=16,
                       pool1_t=4, pool2_t=6, drop1_p=0.35, drop2_p=0.6,
                       chdrop_p=0.1).to(DEVICE)

        # Opt y scheduler SGDR
        opt = optim.Adam(model.parameters(), lr=LR_INIT)
        scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(opt, T_0=SGDR_T0, T_mult=SGDR_Tmult)

        # Métrica rápida
        def _acc(loader):
            return evaluate_with_preds(model, loader, use_tta=True, tta_n=5)[2]

        # Historia para curvas
        history = {'train_acc': [], 'val_acc': []}

        # ---- Criterio suave con ponderación por clase (+20% BFISTS) ----
        class_weights = make_class_weight_tensor(y[tr_sub_idx], n_classes, boost_bfists=1.20)
        criterion_soft = WeightedSoftCrossEntropy(class_weights, label_smoothing=0.05)

        # ===== Entrenamiento global =====
        print(f"\n[Fold {fold}/{N_FOLDS}] Entrenando modelo global..."
              f" (n_train={len(tr_sub_idx)} | n_val={len(va_idx)} | n_test={len(te_idx)})")

        best_state = copy.deepcopy(model.state_dict())
        best_val = -1.0
        bad = 0

        for epoch in range(1, EPOCHS_GLOBAL + 1):
            train_epoch(model, tr_loader, opt, criterion_soft, n_classes=n_classes,
                        do_aug=True, fs=FS, maxnorm=2.0)
            scheduler.step(epoch-1 + 1e-8)  # tick suave

            # eval
            tr_acc = _acc(tr_loader)
            va_acc = _acc(va_loader)
            history['train_acc'].append(tr_acc)
            history['val_acc'].append(va_acc)

            if (epoch % LOG_EVERY == 0) or epoch in (1, 10, 20, 50, 100):
                cur_lr = opt.param_groups[0]['lr']
                print(f"  Época {epoch:3d} | train_acc={tr_acc:.4f} | val_acc={va_acc:.4f} | LR={cur_lr:.5f}")

            if va_acc > best_val + 1e-4:
                best_val = va_acc
                best_state = copy.deepcopy(model.state_dict())
                bad = 0
            else:
                bad += 1
                if bad >= GLOBAL_PATIENCE:
                    print(f"  Early stopping en época {epoch} (mejor val_acc={best_val:.4f})")
                    break

        # guardar curva de entrenamiento
        curve_path = f"training_curve_fold{fold}.png"
        plot_training_curves(history, curve_path)
        print(f"↳ Curva de entrenamiento guardada: {curve_path}")

        # cargar mejor estado antes de evaluar en test
        model.load_state_dict(best_state)

        # ===== Evaluación global (inter-sujeto puro) =====
        y_true, y_pred, acc_global = evaluate_with_preds(model, te_loader, use_tta=True, tta_n=5)
        global_folds.append(acc_global)
        all_true.append(y_true); all_pred.append(y_pred)

        print(f"[Fold {fold}/{N_FOLDS}] Global acc={acc_global:.4f}")
        print_report(y_true, y_pred, CLASS_NAMES_4C)

        # ---------- Fine-tuning PROGRESIVO por sujeto con 4-fold CV ----------
        X_te, y_te, g_te = X[te_idx], y[te_idx], groups[te_idx]

        y_true_ft_all, y_pred_ft_all = [], []
        used_subjects = 0
        for sid in np.unique(g_te):
            idx = np.where(g_te == sid)[0]
            Xs, ys = X_te[idx], y_te[idx]

            if len(ys) < CALIB_CV_FOLDS or len(np.unique(ys)) < 2:
                continue

            y_true_subj, y_pred_subj = subject_cv_finetune_predict_progressive(
                model, Xs, ys, DEVICE, n_splits=CALIB_CV_FOLDS, n_classes=n_classes
            )
            y_true_ft_all.append(y_true_subj)
            y_pred_ft_all.append(y_pred_subj)
            used_subjects += 1

        if len(y_true_ft_all) > 0:
            y_true_ft_all = np.concatenate(y_true_ft_all)
            y_pred_ft_all = np.concatenate(y_pred_ft_all)
            acc_ft = (y_true_ft_all == y_pred_ft_all).mean()
            print(f"  Fine-tuning PROGRESIVO (por sujeto, {CALIB_CV_FOLDS}-fold CV) acc={acc_ft:.4f} | sujetos={used_subjects}")
            print(f"  Δ(FT-Global) = {acc_ft - acc_global:+.4f}")
        else:
            acc_ft = np.nan
            print("  Fine-tuning PROGRESIVO no ejecutado (sujeto(s) con muestras insuficientes).")

        ft_prog_folds.append(acc_ft)

    # ---------- resultados finales ----------
    if len(all_true) > 0:
        all_true = np.concatenate(all_true)
        all_pred = np.concatenate(all_pred)
    else:
        all_true = np.array([], dtype=int)
        all_pred = np.array([], dtype=int)

    print("\n" + "="*60)
    print("RESULTADOS FINALES")
    print("="*60)
    print("Global folds:", [f"{a:.4f}" for a in global_folds])
    if len(global_folds) > 0:
        print(f"Global mean: {np.mean(global_folds):.4f}")

    print("Fine-tune PROGRESIVO folds:", [("nan" if (a is None or np.isnan(a)) else f"{a:.4f}") for a in ft_prog_folds])
    if len(ft_prog_folds) > 0:
        print(f"Fine-tune PROGRESIVO mean: {np.nanmean(ft_prog_folds):.4f}")
        print(f"Δ(FT-Global) mean: {np.nanmean(ft_prog_folds) - np.mean(global_folds):+.4f}")

    if all_true.size > 0:
        plot_confusion(all_true, all_pred, CLASS_NAMES_4C,
                       title="Confusion Matrix - Global Model (All Folds)",
                       fname="confusion_global_allfolds.png")
        print("\n↳ Matriz de confusión guardada: confusion_global_allfolds.png")

    return {
        "global_folds": global_folds,
        "ft_prog_folds": ft_prog_folds,
        "all_true": all_true,
        "all_pred": all_pred,
        "folds_json_path": str(folds_json_path)
    }

# ---------- MAIN ----------
if __name__ == "__main__":
    print("🧠 INICIANDO EXPERIMENTO CON EEGNet + FINE-TUNING PROGRESIVO (augments + SGDR + tweaks)")
    print(f"🔧 Configuración: {CLASS_SCENARIO}, {len(EXPECTED_8)} canales, {WINDOW_MODE}")
    print(f"⚙️  FT: epochs={FT_EPOCHS}, base_lr={FT_BASE_LR}, head_lr={FT_HEAD_LR}, L2SP={FT_L2SP}, patience={FT_PATIENCE}, CV={CALIB_CV_FOLDS}")
    run_experiment()


🚀 Usando dispositivo: cuda
🧠 INICIANDO EXPERIMENTO CON EEGNet + FINE-TUNING PROGRESIVO (augments + SGDR + tweaks)
🔧 Configuración: 4c, 8 canales, 6s
⚙️  FT: epochs=30, base_lr=5e-05, head_lr=0.001, L2SP=0.0001, patience=5, CV=4
Sujetos elegibles: 103 → [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]...


Construyendo dataset (RAW):   0%|          | 0/103 [00:00<?, ?it/s]

Construyendo dataset (RAW): 100%|██████████| 103/103 [00:37<00:00,  2.72it/s]


Dataset construido: N=8652 | T=960 | C=8 | clases=4 | sujetos únicos=103
Listo para entrenar: N=8652 | T=960 | C=8 | clases=4 | sujetos=103

[Fold 1/5] Entrenando modelo global... (n_train=5796 | n_val=1092 | n_test=1764)
  Época   1 | train_acc=0.4605 | val_acc=0.4304 | LR=0.01000
  Época   5 | train_acc=0.5138 | val_acc=0.4643 | LR=0.00250
  Época  10 | train_acc=0.5226 | val_acc=0.4744 | LR=0.00854
  Época  15 | train_acc=0.5514 | val_acc=0.4872 | LR=0.00250
  Época  20 | train_acc=0.5371 | val_acc=0.4698 | LR=0.00996
  Época  25 | train_acc=0.5229 | val_acc=0.4799 | LR=0.00854
  Early stopping en época 26 (mejor val_acc=0.4881)
↳ Curva de entrenamiento guardada: training_curve_fold1.png
[Fold 1/5] Global acc=0.4575
              precision    recall  f1-score   support

        Left     0.5516    0.4603    0.5019       441
       Right     0.5642    0.4785    0.5178       441
  Both Fists     0.3540    0.5057    0.4164       441
   Both Feet     0.4337    0.3855    0.4082       441


In [26]:
# -*- coding: utf-8 -*-
# EEGNet + protocolo global + fine-tuning progresivo por sujeto (augments + SGDR + tweaks++)

import os, re, math, random, json, itertools, copy
from pathlib import Path
from datetime import datetime

import numpy as np
import pandas as pd
import mne

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset, WeightedRandomSampler

from sklearn.model_selection import GroupKFold, StratifiedKFold, StratifiedShuffleSplit, GroupShuffleSplit
from sklearn.utils import check_random_state
from sklearn.metrics import confusion_matrix, classification_report

import matplotlib.pyplot as plt
from tqdm import tqdm

# =========================
# CONFIGURACIÓN GENERAL
# =========================
PROJ = Path('..').resolve().parent
DATA_RAW = PROJ / 'data' / 'raw'
CACHE_DIR = PROJ / 'data' / 'cache'
FOLDS_DIR = PROJ / 'models' / 'folds' / 'Kfold5.json'
CACHE_DIR.mkdir(parents=True, exist_ok=True)

# Dispositivo y semilla
RANDOM_STATE = 42
torch.manual_seed(RANDOM_STATE); np.random.seed(RANDOM_STATE); random.seed(RANDOM_STATE)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🚀 Usando dispositivo: {DEVICE}")

# Escenario y ventana
CLASS_SCENARIO = '4c'
WINDOW_MODE = '6s'   # 6s
FS = 160.0
N_FOLDS = 5

# Entrenamiento global
BATCH_SIZE = 64
EPOCHS_GLOBAL = 100
LR_INIT = 1e-2
# LR_INIT = 1.5e-2  # ← opcional para probar con batch=64
SGDR_T0 = 6
SGDR_Tmult = 2
SGDR_ETA_MIN = 1e-4  # nuevo: no caer tan bajo entre reinicios

# Validación/ES global
GLOBAL_VAL_SPLIT = 0.15
GLOBAL_PATIENCE  = 15      # ↑ paciencia
MIN_EPOCHS_BEFORE_ES = 12  # no cortar antes de completar ~2 ciclos de T0=6
LOG_EVERY        = 5

# Fine-tuning por sujeto
CALIB_CV_FOLDS = 4
FT_EPOCHS = 30
FT_BASE_LR = 5e-5
FT_HEAD_LR = 1e-3
FT_L2SP = 1e-4
FT_PATIENCE = 5
FT_VAL_RATIO = 0.2

# Normalización por época canal-a-canal (z-score)
NORM_EPOCH_ZSCORE = True

# Sujetos excluidos
EXCLUDE_SUBJECTS = {38, 88, 89, 92, 100, 104}

# Runs
MI_RUNS_LR = [4, 8, 12]
MI_RUNS_OF = [6, 10, 14]
BASELINE_RUNS_EO = [1]

# Canales (8 con FCz)
EXPECTED_8 = ['C3','C4','Cz','CP3','CP4','FC3','FC4','FCz']

# Augments
MIXUP_ALPHA = 0.25         # ↑ respecto a 0.20
CUTOUT_MASKED_MS = (30,60) # igual
TIME_JITTER_MS = 50
TTA_N = 7                  # ↑ TTA en evaluación
TTA_JITTER_MS = 25

# =========================
# UTILIDADES DE CANALES
# =========================
def normalize_label(s: str) -> str:
    if s is None: return s
    s = s.strip()
    s = re.sub(r'[^A-Za-z0-9]', '', s)
    s = re.sub(r'([A-Za-z])0([0-9])', r'\1\2', s)
    s = re.sub(r'([A-Za-z])Z$', r'\1z', s)
    s = s.replace('fp', 'Fp').replace('FP', 'Fp')
    s = ''.join(ch.upper() if ch != 'z' else 'z' for ch in s)
    return s

def rename_channels_1010(raw: mne.io.BaseRaw):
    mapping = {}
    for ch in raw.ch_names:
        lab = normalize_label(ch)
        lab = lab[:-1] + 'z' if lab.endswith('Z') else lab
        lab = re.sub(r'([A-Z])Z$', r'\1z', lab)
        mapping[ch] = lab
    mne.rename_channels(raw.info, mapping)

def ensure_channels_order(raw: mne.io.BaseRaw, desired_channels=EXPECTED_8):
    missing = [ch for ch in desired_channels if ch not in raw.ch_names]
    if missing:
        print(f"Warning: faltan canales {missing} en archivo {getattr(raw,'filenames', [''])[0]}")
        return None
    raw.reorder_channels([ch for ch in raw.ch_names if ch in desired_channels] +
                         [ch for ch in raw.ch_names if ch not in desired_channels])
    raw.pick_channels(desired_channels, ordered=True)
    return raw

# =========================
# LECTURA DE EDF y EVENTOS
# =========================
_re_file = re.compile(r'[Ss](\d{3}).*?[Rr](\d{2})')

def parse_subject_run(path: Path):
    m = _re_file.search(str(path))
    if not m: return None, None
    return int(m.group(1)), int(m.group(2))

def run_kind(run_id:int):
    if run_id in MI_RUNS_LR: return 'LR'
    if run_id in MI_RUNS_OF: return 'OF'
    if run_id in BASELINE_RUNS_EO: return 'EO'
    return None

# --- Notch adaptativo: lee CSV si existe (subject, run, snr50_db, snr60_db)
_SNR_TABLE = None
def _load_snr_table():
    global _SNR_TABLE
    if _SNR_TABLE is not None:
        return _SNR_TABLE
    csv_path = PROJ / 'reports' / 'psd_mains' / 'psd_mains_summary.csv'
    if csv_path.exists():
        try:
            df = pd.read_csv(csv_path)
            _SNR_TABLE = df
        except Exception as e:
            print(f"[SNR] No se pudo leer {csv_path}: {e}")
            _SNR_TABLE = None
    return _SNR_TABLE

def _decide_notch(subject, run, th_db=10.0):
    df = _load_snr_table()
    if df is None:
        return 60.0
    row = df[(df['subject']==subject) & (df['run']==run)]
    if row.empty:
        return 60.0
    snr50 = float(row['snr50_db'].iloc[0]); snr60 = float(row['snr60_db'].iloc[0])
    if snr60 >= th_db and snr60 >= snr50: return 60.0
    if snr50 >= th_db and snr50 >  snr60: return 50.0
    return None  # no notch

def read_raw_edf(path: Path):
    raw = mne.io.read_raw_edf(path, preload=True, verbose=False)
    raw.pick(mne.pick_types(raw.info, eeg=True))
    rename_channels_1010(raw)
    try:
        mont = mne.channels.make_standard_montage('standard_1020')
        raw.set_montage(mont, on_missing='ignore')
    except Exception:
        pass
    if abs(raw.info['sfreq'] - FS) > 1e-6:
        raw.resample(FS, npad="auto")
    raw = ensure_channels_order(raw, EXPECTED_8)
    if raw is None:
        return None

    sid, rid = parse_subject_run(path)
    notch = _decide_notch(sid, rid)
    if notch is not None:
        raw.notch_filter(freqs=[float(notch)], picks='eeg', method='spectrum_fit', phase='zero')

    return raw

def collect_events_T1T2(raw: mne.io.BaseRaw):
    if raw.annotations is None or len(raw.annotations) == 0:
        return []
    def _norm(s): return str(s).strip().upper().replace(' ', '')
    res = []
    for onset, desc in zip(raw.annotations.onset, raw.annotations.description):
        tag = _norm(desc)
        if tag in ('T1','T2'):
            res.append((float(onset), tag))
    res.sort()
    dedup = []
    last_t1 = last_t2 = -1e9
    for t, tag in res:
        if tag == 'T1':
            if (t - last_t1) >= 0.5: dedup.append((t, tag)); last_t1 = t
        else:
            if (t - last_t2) >= 0.5: dedup.append((t, tag)); last_t2 = t
    return dedup

# =========================
# CONSTRUCCIÓN DE DATASETS
# =========================
def subjects_available():
    subs = []
    for sdir in sorted(DATA_RAW.glob('S*')):
        if not sdir.is_dir(): continue
        try: sid = int(sdir.name[1:])
        except: continue
        if sid in EXCLUDE_SUBJECTS: continue
        any_mi = any((sdir / f"S{sid:03d}R{r:02d}.edf").exists() for r in (MI_RUNS_LR + MI_RUNS_OF))
        if any_mi: subs.append(sid)
    return subs

def extract_trials_from_run(edf_path: Path, scenario: str, window_mode: str):
    subj, run = parse_subject_run(edf_path)
    kind = run_kind(run)
    if kind not in ('LR','OF','EO'):
        return ([], [])

    raw = read_raw_edf(edf_path)
    if raw is None:
        return ([], [])

    data = raw.get_data()
    fs = raw.info['sfreq']
    assert abs(fs - FS) < 1e-6

    out = []

    if kind in ('LR','OF'):
        events = collect_events_T1T2(raw)
        if window_mode == '3s':
            rel_start, rel_end = 0.0, 3.0
        else:
            rel_start, rel_end = -1.0, 5.0

        for onset_sec, tag in events:
            if kind == 'LR':
                if tag == 'T1': label = 'L'
                elif tag == 'T2': label = 'R'
                else: continue
            else:
                if tag == 'T1': label = 'BFISTS'
                elif tag == 'T2': label = 'BFEET'
                else: continue

            if scenario == '2c' and label not in ('L','R'): continue
            if scenario == '3c' and label not in ('L','R','BFISTS'): continue
            if scenario == '4c' and label not in ('L','R','BFISTS','BFEET'): continue

            s = int(round((raw.first_time + onset_sec + rel_start) * fs))
            e = int(round((raw.first_time + onset_sec + rel_end) * fs))
            if s < 0 or e > data.shape[1]:
                continue

            seg = data[:, s:e].T.astype(np.float32)
            if NORM_EPOCH_ZSCORE:
                seg = (seg - seg.mean(axis=0, keepdims=True)) / (seg.std(axis=0, keepdims=True) + 1e-6)

            if label == 'L':       y = 0
            elif label == 'R':     y = 1
            elif label == 'BFISTS':y = 2
            elif label == 'BFEET': y = 3
            else: continue

            out.append((seg, y, subj))

    elif kind == 'EO':
        return ([], raw.ch_names)

    return out, raw.ch_names

def build_dataset_all(subjects, scenario='4c', window_mode='3s'):
    X, y, groups = [], [], []
    ch_template = None

    for s in tqdm(subjects, desc="Construyendo dataset (RAW)"):
        sdir = DATA_RAW / f"S{s:03d}"
        if not sdir.exists(): continue

        trials_L, trials_R, trials_FISTS, trials_FEET = [], [], [], []

        for r in MI_RUNS_LR:
            p = sdir / f"S{s:03d}R{r:02d}.edf"
            if not p.exists(): continue
            outs, chs = extract_trials_from_run(p, scenario, window_mode)
            if ch_template is None and chs: ch_template = chs
            for seg, lab, _ in outs:
                if lab == 0: trials_L.append(seg)
                elif lab == 1: trials_R.append(seg)

        for r in MI_RUNS_OF:
            p = sdir / f"S{s:03d}R{r:02d}.edf"
            if not p.exists(): continue
            outs, chs = extract_trials_from_run(p, scenario, window_mode)
            if ch_template is None and chs: ch_template = chs
            for seg, lab, _ in outs:
                if lab == 2: trials_FISTS.append(seg)
                elif lab == 3: trials_FEET.append(seg)

        need_per_class = 21
        def pick(trials, n, rng):
            if len(trials) < n:
                idx = rng.choice(len(trials), size=n, replace=True)
                return [trials[i] for i in idx]
            rng.shuffle(trials)
            return trials[:n]

        rng = check_random_state(RANDOM_STATE + s)
        if len(trials_L)==0 or len(trials_R)==0 or len(trials_FISTS)==0 or len(trials_FEET)==0:
            continue

        Lp  = pick(trials_L,     need_per_class, rng)
        Rp  = pick(trials_R,     need_per_class, rng)
        FIp = pick(trials_FISTS, need_per_class, rng)
        FEp = pick(trials_FEET,  need_per_class, rng)

        pack = [(Lp, 0), (Rp, 1), (FIp, 2), (FEp, 3)]
        for segs, lab in pack:
            for seg in segs:
                X.append(seg); y.append(lab); groups.append(s)

    X = np.stack(X, axis=0)
    y = np.asarray(y, dtype=np.int64)
    groups = np.asarray(groups, dtype=np.int64)

    n, T, C = X.shape
    n_classes = len(np.unique(y))
    print(f"Dataset construido: N={n} | T={T} | C={C} | clases={n_classes} | sujetos únicos={len(np.unique(groups))}")
    return X, y, groups, ch_template

# =========================
# AUGMENTS (solo train)
# =========================
def do_time_jitter(x, max_ms=50, fs=160.0):
    max_shift = int(round(max_ms/1000.0 * fs))
    if max_shift <= 0: return x
    B,_,T,C = x.shape
    shifts = torch.randint(low=-max_shift, high=max_shift+1, size=(B,), device=x.device)
    out = torch.empty_like(x)
    for i,s in enumerate(shifts):
        if s==0: out[i] = x[i]; continue
        if s>0:
            out[i,:,s:,:] = x[i,:,:T-s,:]
            out[i,:,:s,:] = 0
        else:
            s = -s
            out[i,:,:T-s,:] = x[i,:,s:,:]
            out[i,:,T-s:,:] = 0
    return out

def do_gaussian_noise(x, sigma=0.01):
    if sigma<=0: return x
    return x + sigma*torch.randn_like(x)

def do_temporal_cutout_masked(x, y, classes_mask={2,3}, min_ms=30, max_ms=60, fs=160.0):
    B,_,T,C = x.shape
    if B == 0: return x
    Lmin = int(round(min_ms/1000.0*fs))
    Lmax = int(round(max_ms/1000.0*fs))
    if Lmin<=0 or Lmax<=0 or Lmin>Lmax: return x
    out = x.clone()
    y_np = y.detach().cpu().numpy()
    for i in range(B):
        if int(y_np[i]) not in classes_mask:
            continue
        L = random.randint(Lmin, Lmax)
        if L>=T: continue
        s = random.randint(0, T-L)
        out[i,:,s:s+L,:] = 0
    return out

def mixup_batch(x, y, n_classes, alpha=0.25):
    if alpha<=0:
        y_onehot = torch.nn.functional.one_hot(y, num_classes=n_classes).float()
        return x, y_onehot, 1.0
    lam = np.random.beta(alpha, alpha)
    perm = torch.randperm(x.size(0), device=x.device)
    x_mix = lam*x + (1-lam)*x[perm]
    y_a = torch.nn.functional.one_hot(y, num_classes=n_classes).float()
    y_b = y_a[perm]
    y_mix = lam*y_a + (1-lam)*y_b
    return x_mix, y_mix, lam

class WeightedSoftCrossEntropy(nn.Module):
    def __init__(self, class_weights=None, label_smoothing=0.10):  # ↑ smoothing
        super().__init__()
        self.register_buffer('w', None if class_weights is None else class_weights.clone().float())
        self.ls = float(label_smoothing)

    def forward(self, logits, target_probs):
        if self.ls > 0:
            K = logits.size(1)
            target_probs = (1-self.ls)*target_probs + self.ls*(1.0/K)
        logp = torch.log_softmax(logits, dim=1)
        loss_per_class = -(target_probs * logp)
        if self.w is not None:
            loss_per_class = loss_per_class * self.w.unsqueeze(0)
        loss = loss_per_class.sum(dim=1).mean()
        return loss

# =========================
# EEGNet (Lawhern et al., 2018) adaptado a (B,1,T,C) + ChannelDropout
# =========================
class ChannelDropout(nn.Module):
    def __init__(self, p=0.15):  # ↑ 0.10 → 0.15
        super().__init__()
        self.p = p
    def forward(self, x):
        if not self.training or self.p<=0: return x
        B,_,T,C = x.shape
        mask = (torch.rand(B,1,1,C, device=x.device) > self.p).float()
        return x * mask

class EEGNet(nn.Module):
    def __init__(self, n_ch: int, n_classes: int,
                 F1: int = 24, D: int = 2, kernel_t: int = 64, k_sep: int = 16,
                 pool1_t: int = 4, pool2_t: int = 6,
                 drop1_p: float = 0.40, drop2_p: float = 0.60,  # drop1 ↑
                 chdrop_p: float = 0.15):
        super().__init__()
        self.n_ch = n_ch
        self.n_classes = n_classes
        self.F1 = F1
        self.D = D
        self.F2 = F1 * D
        self.kernel_t = kernel_t
        self.k_sep = k_sep
        self.pool1_t = pool1_t
        self.pool2_t = pool2_t

        self.chdrop = ChannelDropout(p=chdrop_p)

        self.conv_temporal = nn.Conv2d(1, F1, kernel_size=(kernel_t, 1),
                                       padding=(kernel_t // 2, 0), bias=False)
        self.bn1 = nn.BatchNorm2d(F1, momentum=0.99, eps=1e-3)
        self.act = nn.ELU()

        self.conv_depthwise = nn.Conv2d(F1, self.F2, kernel_size=(1, n_ch),
                                        groups=F1, bias=False)
        self.bn2 = nn.BatchNorm2d(self.F2, momentum=0.99, eps=1e-3)
        self.pool1 = nn.AvgPool2d(kernel_size=(pool1_t, 1), stride=(pool1_t, 1))
        self.drop1 = nn.Dropout(drop1_p)

        self.conv_sep_depth = nn.Conv2d(self.F2, self.F2, kernel_size=(k_sep, 1),
                                        groups=self.F2, padding=(k_sep // 2, 0), bias=False)
        self.conv_sep_point = nn.Conv2d(self.F2, self.F2, kernel_size=(1, 1), bias=False)
        self.bn3 = nn.BatchNorm2d(self.F2, momentum=0.99, eps=1e-3)
        self.pool2 = nn.AvgPool2d(kernel_size=(pool2_t, 1), stride=(pool2_t, 1))
        self.drop2 = nn.Dropout(drop2_p)

        self.flatten = nn.Flatten()
        self.fc = None
        self.out = None
        self._T_in = None

    def _build_head(self, T_in: int, device: torch.device):
        T1 = T_in // self.pool1_t
        T2 = T1 // self.pool2_t
        feat_dim = self.F2 * T2 * 1
        self.fc  = nn.Linear(feat_dim, 128, bias=True).to(device)
        self.out = nn.Linear(128, self.n_classes, bias=True).to(device)
        self._T_in = T_in

    def ensure_head(self, T_in: int, device: torch.device):
        if (self.fc is None) or (self.out is None) or (self._T_in != T_in):
            self._build_head(T_in, device)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, _, T, C = x.shape
        self.ensure_head(T, x.device)

        x = self.chdrop(x)

        z = self.conv_temporal(x)
        z = self.bn1(z); z = self.act(z)

        z = self.conv_depthwise(z)
        z = self.bn2(z); z = self.act(z)
        z = self.pool1(z)
        z = self.drop1(z)

        z = self.conv_sep_depth(z)
        z = self.conv_sep_point(z)
        z = self.bn3(z); z = self.act(z)
        z = self.pool2(z)
        z = self.drop2(z)

        z = self.flatten(z)
        z = self.fc(z); z = self.act(z)
        z = self.out(z)
        return z

# =========================
# TORCH DATASET
# =========================
class EEGTrials(Dataset):
    def __init__(self, X, y, groups):
        self.X = X.astype(np.float32)
        self.y = y.astype(np.int64)
        self.g = groups.astype(np.int64)
    def __len__(self): return self.X.shape[0]
    def __getitem__(self, idx):
        x = self.X[idx]
        x = np.expand_dims(x, 0)                 # (1, T, C)
        return torch.from_numpy(x), torch.tensor(self.y[idx]), torch.tensor(self.g[idx])

CLASS_NAMES_4C = ['Left', 'Right', 'Both Fists', 'Both Feet']

# =========================
# UTIL: MAX-NORM
# =========================
@torch.no_grad()
def apply_max_norm(model, max_value=2.0, p=2.0):
    layers = []
    if hasattr(model, 'conv_depthwise'): layers.append(model.conv_depthwise)
    if hasattr(model, 'conv_sep_point'): layers.append(model.conv_sep_point)
    if hasattr(model, 'fc'):             layers.append(model.fc)
    if hasattr(model, 'out'):            layers.append(model.out)
    for layer in layers:
        if hasattr(layer, 'weight') and layer.weight is not None:
            w = layer.weight.data
            norms = w.view(w.size(0), -1).norm(p=p, dim=1, keepdim=True)
            desired = torch.clamp(norms, max=max_value)
            w.view(w.size(0), -1).mul_(desired / (1e-8 + norms))

# =========================
# ENTRENAMIENTO / EVALUACIÓN
# =========================
def build_weighted_sampler(y, groups):
    y = np.asarray(y); groups = np.asarray(groups)
    class_counts = np.bincount(y, minlength=len(np.unique(y))).astype(float)
    class_w = 1.0 / class_counts[y]
    subj_vals, subj_counts = np.unique(groups, return_counts=True)
    subj_map = {s:c for s,c in zip(subj_vals, subj_counts)}
    subj_w = np.array([1.0/subj_map[g] for g in groups], dtype=float)
    w = class_w * subj_w
    w = w / w.mean()
    w_t = torch.from_numpy(w).float()
    sampler = WeightedRandomSampler(weights=w_t, num_samples=len(w_t), replacement=True)
    return sampler

def make_class_weight_tensor(y_indices, n_classes, boost_bfists=1.20, boost_bfeet=1.10):
    counts = np.bincount(y_indices, minlength=n_classes).astype(np.float32)
    counts[counts == 0] = 1.0
    w = counts.sum() / counts
    w = w / w.mean()
    w[2] *= boost_bfists  # Both Fists
    w[3] *= boost_bfeet   # Both Feet (nuevo)
    w = w / w.mean()
    return torch.tensor(w, dtype=torch.float32, device=DEVICE)

def train_epoch(model, loader, opt, criterion_soft, n_classes,
                do_aug=True, fs=160.0, maxnorm=None):
    model.train()
    for xb, yb, _ in loader:
        xb, yb = xb.to(DEVICE), yb.to(DEVICE)

        if do_aug:
            xb = do_time_jitter(xb, max_ms=TIME_JITTER_MS, fs=fs)
            xb = do_gaussian_noise(xb, sigma=0.01)
            xb = do_temporal_cutout_masked(xb, yb, classes_mask={2,3},
                                           min_ms=CUTOUT_MASKED_MS[0], max_ms=CUTOUT_MASKED_MS[1], fs=fs)
            xb, yt, _ = mixup_batch(xb, yb, n_classes=n_classes, alpha=MIXUP_ALPHA)
            logits = model(xb)
            loss = criterion_soft(logits, yt)
        else:
            logits = model(xb)
            yt = torch.nn.functional.one_hot(yb, num_classes=n_classes).float()
            loss = criterion_soft(logits, yt)

        opt.zero_grad(set_to_none=True)
        loss.backward()
        opt.step()
        if maxnorm is not None:
            apply_max_norm(model, max_value=maxnorm, p=2.0)

@torch.no_grad()
def _predict_tta(model, xb, n=TTA_N, fs=160.0):
    outs = []
    for _ in range(n):
        xj = do_time_jitter(xb, max_ms=TTA_JITTER_MS, fs=fs)
        outs.append(model(xj))
    return torch.stack(outs, dim=0).mean(dim=0)

@torch.no_grad()
def evaluate_with_preds(model, loader, use_tta=True, tta_n=TTA_N):
    model.eval()
    y_true, y_pred = [], []
    for xb, yb, _ in loader:
        xb = xb.to(DEVICE)
        if use_tta:
            logits = _predict_tta(model, xb, n=tta_n, fs=FS)
        else:
            logits = model(xb)
        pred = logits.argmax(dim=1).cpu().numpy().tolist()
        y_pred.extend(pred)
        y_true.extend(yb.numpy().tolist())
    y_true = np.asarray(y_true, dtype=int)
    y_pred = np.asarray(y_pred, dtype=int)
    acc = (y_true == y_pred).mean()
    return y_true, y_pred, float(acc)

def plot_confusion(y_true, y_pred, classes, title, fname):
    cm = confusion_matrix(y_true, y_pred, labels=list(range(len(classes))))
    with np.errstate(invalid='ignore'):
        cm_norm = cm.astype('float') / cm.sum(axis=1, keepdims=True)
    cm_norm = np.nan_to_num(cm_norm)

    plt.figure(figsize=(6, 5))
    plt.imshow(cm_norm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title(title)
    plt.colorbar(fraction=0.046, pad=0.04)
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45, ha='right')
    plt.yticks(tick_marks, classes)

    fmt = '.2f'
    thresh = cm_norm.max() / 2.
    for i, j in itertools.product(range(cm_norm.shape[0]), range(cm_norm.shape[1])):
        plt.text(j, i, format(cm_norm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm_norm[i, j] > thresh else "black")

    plt.ylabel('True label'); plt.xlabel('Predicted label')
    plt.tight_layout(); plt.savefig(fname, dpi=150, bbox_inches='tight'); plt.close()

def plot_training_curves(history, fname):
    plt.figure(figsize=(6,4))
    plt.plot(history['train_acc'], label='train_acc')
    plt.plot(history['val_acc'], label='val_acc')
    plt.xlabel('Epoch'); plt.ylabel('Accuracy')
    plt.title('Training curve'); plt.legend()
    plt.tight_layout(); plt.savefig(fname, dpi=150); plt.close()

def print_report(y_true, y_pred, classes):
    print(classification_report(y_true, y_pred, target_names=classes, digits=4))

# =========================
# FINE-TUNING PROGRESIVO por sujeto (con validación interna)
# =========================
def _param_groups(model, mode):
    if mode == 'out':
        train = list(model.out.parameters())
    elif mode == 'head':
        train = list(model.fc.parameters()) + list(model.out.parameters())
    elif mode == 'spatial+head':
        train = (list(model.conv_depthwise.parameters()) +
                 list(model.bn2.parameters()) +
                 list(model.conv_sep_depth.parameters()) +
                 list(model.conv_sep_point.parameters()) +
                 list(model.bn3.parameters()) +
                 list(model.fc.parameters()) +
                 list(model.out.parameters()))
    else:
        raise ValueError(mode)
    return train

def _freeze_for_mode(model, mode):
    for p in model.parameters(): p.requires_grad = False
    if mode == 'out':
        for p in model.out.parameters(): p.requires_grad = True
    elif mode == 'head':
        for p in model.fc.parameters():  p.requires_grad = True
        for p in model.out.parameters(): p.requires_grad = True
    elif mode == 'spatial+head':
        for p in model.conv_depthwise.parameters(): p.requires_grad = True
        for p in model.bn2.parameters():           p.requires_grad = True
        for p in model.conv_sep_depth.parameters():p.requires_grad = True
        for p in model.conv_sep_point.parameters():p.requires_grad = True
        for p in model.bn3.parameters():           p.requires_grad = True
        for p in model.fc.parameters():            p.requires_grad = True
        for p in model.out.parameters():           p.requires_grad = True

def _class_weights(y_np, n_classes):
    counts = np.bincount(y_np, minlength=n_classes).astype(np.float32)
    counts[counts == 0] = 1.0
    weights = counts.sum() / counts
    weights = weights / weights.mean()
    return torch.tensor(weights, dtype=torch.float32, device=DEVICE)

def _train_one_mode(model, X_cal, y_cal, n_classes, mode,
                    epochs=FT_EPOCHS, batch_size=32,
                    head_lr=FT_HEAD_LR, base_lr=FT_BASE_LR,
                    l2sp_lambda=FT_L2SP, patience=FT_PATIENCE, val_ratio=FT_VAL_RATIO):
    sss = StratifiedShuffleSplit(n_splits=1, test_size=val_ratio, random_state=RANDOM_STATE)
    (tr_idx, va_idx), = sss.split(X_cal, y_cal)
    Xtr, ytr = X_cal[tr_idx], y_cal[tr_idx]
    Xva, yva = X_cal[va_idx], y_cal[va_idx]

    ds_tr = torch.utils.data.TensorDataset(
        torch.from_numpy(Xtr).float().unsqueeze(1),
        torch.from_numpy(ytr).long()
    )
    ds_va = torch.utils.data.TensorDataset(
        torch.from_numpy(Xva).float().unsqueeze(1),
        torch.from_numpy(yva).long()
    )
    dl_tr = DataLoader(ds_tr, batch_size=batch_size, shuffle=True, drop_last=False)
    dl_va = DataLoader(ds_va, batch_size=batch_size, shuffle=False, drop_last=False)

    _freeze_for_mode(model, mode)

    if mode == 'spatial+head':
        base_params = (list(model.conv_depthwise.parameters()) +
                       list(model.bn2.parameters()) +
                       list(model.conv_sep_depth.parameters()) +
                       list(model.conv_sep_point.parameters()) +
                       list(model.bn3.parameters()))
        head_params = list(model.fc.parameters()) + list(model.out.parameters())
        train_params = base_params + head_params
        opt = optim.Adam([
            {"params": base_params, "lr": base_lr},
            {"params": head_params, "lr": head_lr},
        ])
    else:
        train_params = _param_groups(model, mode)
        opt = optim.Adam(train_params, lr=head_lr)

    ref = [p.detach().clone().to(p.device) for p in train_params]
    class_w = _class_weights(ytr, n_classes)
    crit = nn.CrossEntropyLoss(weight=class_w)

    best_state = copy.deepcopy(model.state_dict())
    best_val = float('inf'); bad = 0

    for _ in range(epochs):
        model.train()
        for xb, yb in dl_tr:
            xb, yb = xb.to(DEVICE), yb.to(DEVICE)
            opt.zero_grad(set_to_none=True)
            logits = model(xb)
            loss = crit(logits, yb)
            reg = 0.0
            for p_cur, p_ref in zip(train_params, ref):
                reg = reg + torch.sum((p_cur - p_ref)**2)
            loss = loss + l2sp_lambda * reg
            loss.backward(); opt.step()
            apply_max_norm(model, max_value=2.0, p=2.0)

        model.eval()
        with torch.no_grad():
            val_loss = 0.0; nval = 0
            for xb, yb in dl_va:
                xb, yb = xb.to(DEVICE), yb.to(DEVICE)
                logits = model(xb)
                loss = crit(logits, yb)
                val_loss += loss.item() * xb.size(0); nval += xb.size(0)
            val_loss /= max(1, nval)

        if val_loss + 1e-7 < best_val:
            best_val = val_loss; bad = 0
            best_state = copy.deepcopy(model.state_dict())
        else:
            bad += 1
            if bad >= patience: break

    model.load_state_dict(best_state)
    return model

@torch.no_grad()
def predict_numpy(model, X_np, device):
    model.eval()
    xb = torch.from_numpy(X_np).float().unsqueeze(1).to(device)
    logits = model(xb)
    return logits.argmax(dim=1).cpu().numpy()

def subject_cv_finetune_predict_progressive(model_global, Xs, ys, device,
                                            n_splits=CALIB_CV_FOLDS, n_classes=4):
    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=RANDOM_STATE)
    y_true_full = np.empty_like(ys); y_pred_full = np.empty_like(ys)

    for tr_idx, te_idx in skf.split(Xs, ys):
        Xcal, ycal = Xs[tr_idx], ys[tr_idx]
        Xho,  yho  = Xs[te_idx], ys[te_idx]

        m_out = copy.deepcopy(model_global)
        _train_one_mode(m_out, Xcal, ycal, n_classes, mode='out',
                        epochs=FT_EPOCHS, head_lr=FT_HEAD_LR, l2sp_lambda=FT_L2SP,
                        patience=FT_PATIENCE, val_ratio=FT_VAL_RATIO)
        yhat_out = predict_numpy(m_out, Xho, device)

        m_head = copy.deepcopy(model_global)
        _train_one_mode(m_head, Xcal, ycal, n_classes, mode='head',
                        epochs=FT_EPOCHS, head_lr=FT_HEAD_LR, l2sp_lambda=FT_L2SP,
                        patience=FT_PATIENCE, val_ratio=FT_VAL_RATIO)
        yhat_head = predict_numpy(m_head, Xho, device)

        m_sp = copy.deepcopy(model_global)
        _train_one_mode(m_sp, Xcal, ycal, n_classes, mode='spatial+head',
                        epochs=FT_EPOCHS, head_lr=FT_HEAD_LR, base_lr=FT_BASE_LR,
                        l2sp_lambda=FT_L2SP, patience=FT_PATIENCE, val_ratio=FT_VAL_RATIO)
        yhat_sp = predict_numpy(m_sp, Xho, device)

        accs = [ (yhat_out == yho).mean(), (yhat_head == yho).mean(), (yhat_sp == yho).mean() ]
        best_idx = int(np.argmax(accs))
        yhat_best = [yhat_out, yhat_head, yhat_sp][best_idx]

        y_true_full[te_idx] = yho; y_pred_full[te_idx] = yhat_best

    return y_true_full, y_pred_full

# =========================
# FOLDS JSON helpers
# =========================
def save_group_folds_json_with_indices(subject_ids_str, groups_array, n_splits, out_json_path,
                                       created_by="dose_experiment", description=None):
    out_json_path = Path(out_json_path)
    unique_subjects_int = sorted(np.unique(groups_array).tolist())
    subject_ids = [f"S{sid:03d}" for sid in unique_subjects_int]

    if len(subject_ids) < n_splits:
        raise ValueError(f"n_splits={n_splits} mayor que número de sujetos={len(subject_ids)}")

    groups = np.arange(len(subject_ids))
    gkf = GroupKFold(n_splits=n_splits)

    folds = []
    fold_i = 0
    for train_idx_grp, test_idx_grp in gkf.split(groups, groups, groups):
        fold_i += 1
        train_sids = [subject_ids[int(i)] for i in train_idx_grp]
        test_sids  = [subject_ids[int(i)] for i in test_idx_grp]

        train_sids_int = [int(s[1:]) for s in train_sids]
        test_sids_int  = [int(s[1:]) for s in test_sids]

        tr_idx = np.where(np.isin(groups_array, train_sids_int))[0].tolist()
        te_idx = np.where(np.isin(groups_array, test_sids_int))[0].tolist()

        folds.append({
            "fold": int(fold_i),
            "train": train_sids,
            "test": test_sids,
            "tr_idx": tr_idx,
            "te_idx": te_idx
        })

    payload = {
        "created_at": datetime.now().isoformat(),
        "created_by": created_by,
        "description": description if description is not None else "",
        "n_splits": int(n_splits),
        "n_subjects": len(subject_ids),
        "subject_ids": subject_ids,
        "folds": folds
    }

    out_json_path.parent.mkdir(parents=True, exist_ok=True)
    with open(out_json_path, "w", encoding="utf-8") as f:
        json.dump(payload, f, indent=2, ensure_ascii=False)

    print(f"Folds JSON con índices guardado → {out_json_path}")
    return out_json_path

def load_group_folds_json(path_json, expected_subject_ids=None, strict_check=True):
    path_json = Path(path_json)
    if not path_json.exists():
        raise FileNotFoundError(f"No existe {path_json}")
    with open(path_json, "r", encoding="utf-8") as f:
        payload = json.load(f)

    subj_json = payload.get("subject_ids", [])
    if expected_subject_ids is not None:
        expected = sorted(list(expected_subject_ids))
        if subj_json != expected:
            msg = ("Los subject_ids del JSON no coinciden con expected_subject_ids.\n"
                   f"JSON has {len(subj_json)} subjects, expected {len(expected)}.\n"
                   f"First 10 JSON: {subj_json[:10]}\nFirst 10 expected: {expected[:10]}")
            if strict_check:
                raise ValueError(msg)
            else:
                print("WARNING: " + msg)
    return payload

# =========================
# EXPERIMENTO
# =========================
def run_experiment(save_folds_json=True, folds_json_path=FOLDS_DIR, folds_json_description="GroupKFold folds for comparison"):
    mne.set_log_level('WARNING')

    subs = subjects_available()
    print(f"Sujetos elegibles: {len(subs)} → {subs[:10]}{'...' if len(subs)>10 else ''}")

    X, y, groups, chs = build_dataset_all(subs, scenario=CLASS_SCENARIO, window_mode=WINDOW_MODE)
    N, T, C = X.shape
    n_classes = len(np.unique(y))
    print(f"Listo para entrenar: N={N} | T={T} | C={C} | clases={n_classes} | sujetos={len(np.unique(groups))}")

    ds = EEGTrials(X, y, groups)

    # preparar JSON folds
    if folds_json_path is None:
        folds_json_path = Path("folds") / f"group_folds_{N_FOLDS}splits.json"
    else:
        folds_json_path = Path(folds_json_path)
    folds_json_path.parent.mkdir(parents=True, exist_ok=True)

    unique_subs = sorted(np.unique(groups).tolist())
    subject_ids_str = [f"S{s:03d}" for s in unique_subs]

    if not folds_json_path.exists():
        if not save_folds_json:
            raise FileNotFoundError(f"Folds JSON no encontrado en {folds_json_path} y save_folds_json=False.")
        save_group_folds_json_with_indices(subject_ids_str, groups, n_splits=N_FOLDS,
                                           out_json_path=folds_json_path,
                                           created_by="Joel_Clasificador",
                                           description=folds_json_description)

    payload = load_group_folds_json(folds_json_path, expected_subject_ids=subject_ids_str, strict_check=False)
    folds = payload["folds"]

    # bucle por folds
    global_folds = []
    ft_prog_folds = []
    all_true = []
    all_pred = []

    for f in folds:
        fold = f["fold"]
        tr_idx = np.asarray(f.get("tr_idx", []), dtype=int)
        te_idx = np.asarray(f.get("te_idx", []), dtype=int)

        if tr_idx.size == 0 or te_idx.size == 0:
            print(f"Advertencia: fold {fold} sin índices tr/te válidos. Saltando.")
            continue

        # ===== Split de validación por sujetos =====
        gss = GroupShuffleSplit(n_splits=1, test_size=GLOBAL_VAL_SPLIT, random_state=RANDOM_STATE)
        tr_subj_idx, va_subj_idx = next(gss.split(tr_idx, groups[tr_idx], groups[tr_idx]))
        tr_sub_idx = tr_idx[tr_subj_idx]
        va_idx     = tr_idx[va_subj_idx]

        # Sampler balanceado para train
        sampler = build_weighted_sampler(y[tr_sub_idx], groups[tr_sub_idx])

        tr_loader = DataLoader(Subset(ds, tr_sub_idx), batch_size=BATCH_SIZE, sampler=sampler, drop_last=False)
        va_loader = DataLoader(Subset(ds, va_idx),     batch_size=BATCH_SIZE, shuffle=False, drop_last=False)
        te_loader = DataLoader(Subset(ds, te_idx),     batch_size=BATCH_SIZE, shuffle=False, drop_last=False)

        # ===== EEGNet =====
        model = EEGNet(n_ch=C, n_classes=n_classes,
                       F1=24, D=2, kernel_t=64, k_sep=16,
                       pool1_t=4, pool2_t=6, drop1_p=0.40, drop2_p=0.60,
                       chdrop_p=0.15).to(DEVICE)

        # Opt y scheduler SGDR (con eta_min)
        opt = optim.Adam(model.parameters(), lr=LR_INIT)
        scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
            opt, T_0=SGDR_T0, T_mult=SGDR_Tmult, eta_min=SGDR_ETA_MIN
        )

        # Métrica rápida
        def _acc(loader):
            return evaluate_with_preds(model, loader, use_tta=True, tta_n=TTA_N)[2]

        # Historia para curvas
        history = {'train_acc': [], 'val_acc': []}

        # ---- Criterio suave con ponderación por clase (+20% Fists, +10% Feet) ----
        class_weights = make_class_weight_tensor(y[tr_sub_idx], n_classes, boost_bfists=1.20, boost_bfeet=1.10)
        criterion_soft = WeightedSoftCrossEntropy(class_weights, label_smoothing=0.10)

        # ===== Entrenamiento global =====
        print(f"\n[Fold {fold}/{N_FOLDS}] Entrenando modelo global..."
              f" (n_train={len(tr_sub_idx)} | n_val={len(va_idx)} | n_test={len(te_idx)})")

        best_state = copy.deepcopy(model.state_dict())
        best_val = -1.0
        bad = 0

        for epoch in range(1, EPOCHS_GLOBAL + 1):
            train_epoch(model, tr_loader, opt, criterion_soft, n_classes=n_classes,
                        do_aug=True, fs=FS, maxnorm=2.0)
            scheduler.step(epoch-1 + 1e-8)

            # eval
            tr_acc = _acc(tr_loader)
            va_acc = _acc(va_loader)
            history['train_acc'].append(tr_acc)
            history['val_acc'].append(va_acc)

            if (epoch % LOG_EVERY == 0) or epoch in (1, 10, 20, 50, 100):
                cur_lr = opt.param_groups[0]['lr']
                print(f"  Época {epoch:3d} | train_acc={tr_acc:.4f} | val_acc={va_acc:.4f} | LR={cur_lr:.5f}")

            # early stopping "amable" con SGDR
            improved = (va_acc > best_val + 1e-4)
            if improved:
                best_val = va_acc
                best_state = copy.deepcopy(model.state_dict())
                bad = 0
            else:
                bad += 1
                if (epoch >= MIN_EPOCHS_BEFORE_ES) and (bad >= GLOBAL_PATIENCE):
                    print(f"  Early stopping en época {epoch} (mejor val_acc={best_val:.4f})")
                    break

        # guardar curva de entrenamiento
        curve_path = f"training_curve_fold{fold}.png"
        plot_training_curves(history, curve_path)
        print(f"↳ Curva de entrenamiento guardada: {curve_path}")

        # cargar mejor estado antes de evaluar en test
        model.load_state_dict(best_state)

        # ===== Evaluación global =====
        y_true, y_pred, acc_global = evaluate_with_preds(model, te_loader, use_tta=True, tta_n=TTA_N)
        global_folds.append(acc_global)
        all_true.append(y_true); all_pred.append(y_pred)

        print(f"[Fold {fold}/{N_FOLDS}] Global acc={acc_global:.4f}")
        print_report(y_true, y_pred, CLASS_NAMES_4C)

        # ---------- Fine-tuning PROGRESIVO por sujeto con 4-fold CV ----------
        X_te, y_te, g_te = X[te_idx], y[te_idx], groups[te_idx]

        y_true_ft_all, y_pred_ft_all = [], []
        used_subjects = 0
        for sid in np.unique(g_te):
            idx = np.where(g_te == sid)[0]
            Xs, ys = X_te[idx], y_te[idx]

            if len(ys) < CALIB_CV_FOLDS or len(np.unique(ys)) < 2:
                continue

            y_true_subj, y_pred_subj = subject_cv_finetune_predict_progressive(
                model, Xs, ys, DEVICE, n_splits=CALIB_CV_FOLDS, n_classes=n_classes
            )
            y_true_ft_all.append(y_true_subj)
            y_pred_ft_all.append(y_pred_subj)
            used_subjects += 1

        if len(y_true_ft_all) > 0:
            y_true_ft_all = np.concatenate(y_true_ft_all)
            y_pred_ft_all = np.concatenate(y_pred_ft_all)
            acc_ft = (y_true_ft_all == y_pred_ft_all).mean()
            print(f"  Fine-tuning PROGRESIVO (por sujeto, {CALIB_CV_FOLDS}-fold CV) acc={acc_ft:.4f} | sujetos={used_subjects}")
            print(f"  Δ(FT-Global) = {acc_ft - acc_global:+.4f}")
        else:
            acc_ft = np.nan
            print("  Fine-tuning PROGRESIVO no ejecutado (sujeto(s) con muestras insuficientes).")

        ft_prog_folds.append(acc_ft)

    # ---------- resultados finales ----------
    if len(all_true) > 0:
        all_true = np.concatenate(all_true)
        all_pred = np.concatenate(all_pred)
    else:
        all_true = np.array([], dtype=int)
        all_pred = np.array([], dtype=int)

    print("\n" + "="*60)
    print("RESULTADOS FINALES")
    print("="*60)
    print("Global folds:", [f"{a:.4f}" for a in global_folds])
    if len(global_folds) > 0:
        print(f"Global mean: {np.mean(global_folds):.4f}")

    print("Fine-tune PROGRESIVO folds:", [("nan" if (a is None or np.isnan(a)) else f"{a:.4f}") for a in ft_prog_folds])
    if len(ft_prog_folds) > 0:
        print(f"Fine-tune PROGRESIVO mean: {np.nanmean(ft_prog_folds):.4f}")
        print(f"Δ(FT-Global) mean: {np.nanmean(ft_prog_folds) - np.mean(global_folds):+.4f}")

    if all_true.size > 0:
        plot_confusion(all_true, all_pred, CLASS_NAMES_4C,
                       title="Confusion Matrix - Global Model (All Folds)",
                       fname="confusion_global_allfolds.png")
        print("\n↳ Matriz de confusión guardada: confusion_global_allfolds.png")

    return {
        "global_folds": global_folds,
        "ft_prog_folds": ft_prog_folds,
        "all_true": all_true,
        "all_pred": all_pred,
        "folds_json_path": str(folds_json_path)
    }

# ---------- MAIN ----------
if __name__ == "__main__":
    print("🧠 INICIANDO EXPERIMENTO CON EEGNet + FINE-TUNING PROGRESIVO (augments + SGDR + tweaks++)")
    print(f"🔧 Configuración: {CLASS_SCENARIO}, {len(EXPECTED_8)} canales, {WINDOW_MODE}")
    print(f"⚙️  FT: epochs={FT_EPOCHS}, base_lr={FT_BASE_LR}, head_lr={FT_HEAD_LR}, L2SP={FT_L2SP}, patience={FT_PATIENCE}, CV={CALIB_CV_FOLDS}")
    run_experiment()


🚀 Usando dispositivo: cuda
🧠 INICIANDO EXPERIMENTO CON EEGNet + FINE-TUNING PROGRESIVO (augments + SGDR + tweaks++)
🔧 Configuración: 4c, 8 canales, 6s
⚙️  FT: epochs=30, base_lr=5e-05, head_lr=0.001, L2SP=0.0001, patience=5, CV=4
Sujetos elegibles: 103 → [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]...


Construyendo dataset (RAW): 100%|██████████| 103/103 [00:37<00:00,  2.71it/s]


Dataset construido: N=8652 | T=960 | C=8 | clases=4 | sujetos únicos=103
Listo para entrenar: N=8652 | T=960 | C=8 | clases=4 | sujetos=103

[Fold 1/5] Entrenando modelo global... (n_train=5796 | n_val=1092 | n_test=1764)
  Época   1 | train_acc=0.4441 | val_acc=0.4176 | LR=0.01000
  Época   5 | train_acc=0.5048 | val_acc=0.4597 | LR=0.00257
  Época  10 | train_acc=0.5097 | val_acc=0.4579 | LR=0.00855
  Época  15 | train_acc=0.5409 | val_acc=0.4771 | LR=0.00257
  Época  20 | train_acc=0.5228 | val_acc=0.4789 | LR=0.00996
  Época  25 | train_acc=0.5223 | val_acc=0.4679 | LR=0.00855
  Early stopping en época 26 (mejor val_acc=0.4799)
↳ Curva de entrenamiento guardada: training_curve_fold1.png
[Fold 1/5] Global acc=0.4467
              precision    recall  f1-score   support

        Left     0.5895    0.4331    0.4993       441
       Right     0.5511    0.4036    0.4660       441
  Both Fists     0.3551    0.4444    0.3948       441
   Both Feet     0.3947    0.5057    0.4433       441


In [27]:
# -*- coding: utf-8 -*-
# EEGNet + protocolo global + fine-tuning progresivo por sujeto (augments + SGDR + tweaks)
# Variant-B:
# - F1=24 (igual), fc=256
# - kernel_t=80 (↑ contexto temporal), pool2_t=5 (↓ submuestreo)
# - Dropout: drop1=0.40, drop2=0.60, chdrop=0.15
# - Loss ponderada por clase (+20% a Both Fists) con soporte mixup (WeightedSoftCrossEntropy)
# - Mixup alpha=0.30
# - SGDR con T0=6, Tmult=2 + warmup lineal 2 épocas
# - TTA (n=5, jitter ±25 ms) en evaluación

import os, re, math, random, json, itertools, copy
from pathlib import Path
from datetime import datetime

import numpy as np
import pandas as pd
import mne

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset, WeightedRandomSampler

from sklearn.model_selection import GroupKFold, StratifiedKFold, StratifiedShuffleSplit, GroupShuffleSplit
from sklearn.utils import check_random_state
from sklearn.metrics import confusion_matrix, classification_report

import matplotlib.pyplot as plt
from tqdm import tqdm

# =========================
# CONFIGURACIÓN GENERAL
# =========================
PROJ = Path('..').resolve().parent
DATA_RAW = PROJ / 'data' / 'raw'
CACHE_DIR = PROJ / 'data' / 'cache'
FOLDS_DIR = PROJ / 'models' / 'folds' / 'Kfold5.json'
CACHE_DIR.mkdir(parents=True, exist_ok=True)

# Dispositivo y semilla
RANDOM_STATE = 42
torch.manual_seed(RANDOM_STATE); np.random.seed(RANDOM_STATE); random.seed(RANDOM_STATE)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🚀 Usando dispositivo: {DEVICE}")

# Escenario y ventana
CLASS_SCENARIO = '4c'
WINDOW_MODE = '6s'   # 6s
FS = 160.0
N_FOLDS = 5

# Entrenamiento global
BATCH_SIZE = 64
EPOCHS_GLOBAL = 100
LR_INIT = 1e-2
SGDR_T0 = 6
SGDR_Tmult = 2
WARMUP_EPOCHS = 2  # <-- warmup lineal antes de SGDR

# Validación/ES global
GLOBAL_VAL_SPLIT = 0.15
GLOBAL_PATIENCE  = 10
LOG_EVERY        = 5

# Fine-tuning por sujeto (protocolo robusto)
CALIB_CV_FOLDS = 4
FT_EPOCHS = 30
FT_BASE_LR = 5e-5
FT_HEAD_LR = 1e-3
FT_L2SP = 1e-4
FT_PATIENCE = 5
FT_VAL_RATIO = 0.2

# Normalización por época canal-a-canal (z-score)
NORM_EPOCH_ZSCORE = True

# Sujetos excluidos
EXCLUDE_SUBJECTS = {38, 88, 89, 92, 100, 104}

# Runs
MI_RUNS_LR = [4, 8, 12]
MI_RUNS_OF = [6, 10, 14]
BASELINE_RUNS_EO = [1]

# Canales (8 con FCz)
EXPECTED_8 = ['C3','C4','Cz','CP3','CP4','FC3','FC4','FCz']

# Augments
MIXUP_ALPHA = 0.30  # Variant-B (antes 0.2–0.25)

# =========================
# UTILIDADES DE CANALES
# =========================
def normalize_label(s: str) -> str:
    if s is None: return s
    s = s.strip()
    s = re.sub(r'[^A-Za-z0-9]', '', s)
    s = re.sub(r'([A-Za-z])0([0-9])', r'\1\2', s)
    s = re.sub(r'([A-Za-z])Z$', r'\1z', s)
    s = s.replace('fp', 'Fp').replace('FP', 'Fp')
    s = ''.join(ch.upper() if ch != 'z' else 'z' for ch in s)
    return s

def rename_channels_1010(raw: mne.io.BaseRaw):
    mapping = {}
    for ch in raw.ch_names:
        lab = normalize_label(ch)
        lab = lab[:-1] + 'z' if lab.endswith('Z') else lab
        lab = re.sub(r'([A-Z])Z$', r'\1z', lab)
        mapping[ch] = lab
    mne.rename_channels(raw.info, mapping)

def ensure_channels_order(raw: mne.io.BaseRaw, desired_channels=EXPECTED_8):
    missing = [ch for ch in desired_channels if ch not in raw.ch_names]
    if missing:
        print(f"Warning: faltan canales {missing} en archivo {getattr(raw,'filenames', [''])[0]}")
        return None
    raw.reorder_channels([ch for ch in raw.ch_names if ch in desired_channels] +
                         [ch for ch in raw.ch_names if ch not in desired_channels])
    raw.pick_channels(desired_channels, ordered=True)
    return raw

# =========================
# LECTURA DE EDF y EVENTOS
# =========================
_re_file = re.compile(r'[Ss](\d{3}).*?[Rr](\d{2})')

def parse_subject_run(path: Path):
    m = _re_file.search(str(path))
    if not m: return None, None
    return int(m.group(1)), int(m.group(2))

def run_kind(run_id:int):
    if run_id in MI_RUNS_LR: return 'LR'
    if run_id in MI_RUNS_OF: return 'OF'
    if run_id in BASELINE_RUNS_EO: return 'EO'
    return None

# --- Notch adaptativo: lee CSV si existe (subject, run, snr50_db, snr60_db)
_SNR_TABLE = None
def _load_snr_table():
    global _SNR_TABLE
    if _SNR_TABLE is not None:
        return _SNR_TABLE
    csv_path = PROJ / 'reports' / 'psd_mains' / 'psd_mains_summary.csv'
    if csv_path.exists():
        try:
            df = pd.read_csv(csv_path)
            _SNR_TABLE = df
        except Exception as e:
            print(f"[SNR] No se pudo leer {csv_path}: {e}")
            _SNR_TABLE = None
    return _SNR_TABLE

def _decide_notch(subject, run, th_db=10.0):
    df = _load_snr_table()
    if df is None:  # default
        return 60.0
    row = df[(df['subject']==subject) & (df['run']==run)]
    if row.empty:
        return 60.0
    snr50 = float(row['snr50_db'].iloc[0]); snr60 = float(row['snr60_db'].iloc[0])
    if snr60 >= th_db and snr60 >= snr50: return 60.0
    if snr50 >= th_db and snr50 >  snr60: return 50.0
    return None  # no notch si no sobresale

def read_raw_edf(path: Path):
    raw = mne.io.read_raw_edf(path, preload=True, verbose=False)
    raw.pick(mne.pick_types(raw.info, eeg=True))
    rename_channels_1010(raw)
    try:
        mont = mne.channels.make_standard_montage('standard_1020')
        raw.set_montage(mont, on_missing='ignore')
    except Exception:
        pass
    if abs(raw.info['sfreq'] - FS) > 1e-6:
        raw.resample(FS, npad="auto")
    raw = ensure_channels_order(raw, EXPECTED_8)
    if raw is None:
        return None

    # Notch adaptativo (por SNR); si no hay CSV, por defecto 60 Hz
    sid, rid = parse_subject_run(path)
    notch = _decide_notch(sid, rid)
    if notch is not None:
        raw.notch_filter(freqs=[float(notch)], picks='eeg', method='spectrum_fit', phase='zero')

    return raw

def collect_events_T1T2(raw: mne.io.BaseRaw):
    if raw.annotations is None or len(raw.annotations) == 0:
        return []
    def _norm(s): return str(s).strip().upper().replace(' ', '')
    res = []
    for onset, desc in zip(raw.annotations.onset, raw.annotations.description):
        tag = _norm(desc)
        if tag in ('T1','T2'):
            res.append((float(onset), tag))
    res.sort()
    dedup = []
    last_t1 = last_t2 = -1e9
    for t, tag in res:
        if tag == 'T1':
            if (t - last_t1) >= 0.5: dedup.append((t, tag)); last_t1 = t
        else:
            if (t - last_t2) >= 0.5: dedup.append((t, tag)); last_t2 = t
    return dedup

# =========================
# CONSTRUCCIÓN DE DATASETS
# =========================
def subjects_available():
    subs = []
    for sdir in sorted(DATA_RAW.glob('S*')):
        if not sdir.is_dir(): continue
        try: sid = int(sdir.name[1:])
        except: continue
        if sid in EXCLUDE_SUBJECTS: continue
        any_mi = any((sdir / f"S{sid:03d}R{r:02d}.edf").exists() for r in (MI_RUNS_LR + MI_RUNS_OF))
        if any_mi: subs.append(sid)
    return subs

def extract_trials_from_run(edf_path: Path, scenario: str, window_mode: str):
    subj, run = parse_subject_run(edf_path)
    kind = run_kind(run)
    if kind not in ('LR','OF','EO'):
        return ([], [])

    raw = read_raw_edf(edf_path)
    if raw is None:
        return ([], [])

    data = raw.get_data()
    fs = raw.info['sfreq']
    assert abs(fs - FS) < 1e-6

    out = []

    if kind in ('LR','OF'):
        events = collect_events_T1T2(raw)
        if window_mode == '3s':
            rel_start, rel_end = 0.0, 3.0
        else:
            rel_start, rel_end = -1.0, 5.0

        for onset_sec, tag in events:
            if kind == 'LR':
                if tag == 'T1': label = 'L'
                elif tag == 'T2': label = 'R'
                else: continue
            else:
                if tag == 'T1': label = 'BFISTS'
                elif tag == 'T2': label = 'BFEET'
                else: continue

            if scenario == '2c' and label not in ('L','R'): continue
            if scenario == '3c' and label not in ('L','R','BFISTS'): continue
            if scenario == '4c' and label not in ('L','R','BFISTS','BFEET'): continue

            s = int(round((raw.first_time + onset_sec + rel_start) * fs))
            e = int(round((raw.first_time + onset_sec + rel_end) * fs))
            if s < 0 or e > data.shape[1]:
                continue

            seg = data[:, s:e].T.astype(np.float32)
            if NORM_EPOCH_ZSCORE:
                seg = (seg - seg.mean(axis=0, keepdims=True)) / (seg.std(axis=0, keepdims=True) + 1e-6)

            if label == 'L':       y = 0
            elif label == 'R':     y = 1
            elif label == 'BFISTS':y = 2
            elif label == 'BFEET': y = 3
            else: continue

            out.append((seg, y, subj))

    elif kind == 'EO':
        return ([], raw.ch_names)

    return out, raw.ch_names

def build_dataset_all(subjects, scenario='4c', window_mode='3s'):
    X, y, groups = [], [], []
    ch_template = None

    for s in tqdm(subjects, desc="Construyendo dataset (RAW)"):
        sdir = DATA_RAW / f"S{s:03d}"
        if not sdir.exists(): continue

        trials_L, trials_R, trials_FISTS, trials_FEET = [], [], [], []

        for r in MI_RUNS_LR:
            p = sdir / f"S{s:03d}R{r:02d}.edf"
            if not p.exists(): continue
            outs, chs = extract_trials_from_run(p, scenario, window_mode)
            if ch_template is None and chs: ch_template = chs
            for seg, lab, _ in outs:
                if lab == 0: trials_L.append(seg)
                elif lab == 1: trials_R.append(seg)

        for r in MI_RUNS_OF:
            p = sdir / f"S{s:03d}R{r:02d}.edf"
            if not p.exists(): continue
            outs, chs = extract_trials_from_run(p, scenario, window_mode)
            if ch_template is None and chs: ch_template = chs
            for seg, lab, _ in outs:
                if lab == 2: trials_FISTS.append(seg)
                elif lab == 3: trials_FEET.append(seg)

        need_per_class = 21
        def pick(trials, n, rng):
            if len(trials) < n:
                idx = rng.choice(len(trials), size=n, replace=True)
                return [trials[i] for i in idx]
            rng.shuffle(trials)
            return trials[:n]

        rng = check_random_state(RANDOM_STATE + s)
        if len(trials_L)==0 or len(trials_R)==0 or len(trials_FISTS)==0 or len(trials_FEET)==0:
            continue

        Lp  = pick(trials_L,     need_per_class, rng)
        Rp  = pick(trials_R,     need_per_class, rng)
        FIp = pick(trials_FISTS, need_per_class, rng)
        FEp = pick(trials_FEET,  need_per_class, rng)

        pack = [(Lp, 0), (Rp, 1), (FIp, 2), (FEp, 3)]
        for segs, lab in pack:
            for seg in segs:
                X.append(seg); y.append(lab); groups.append(s)

    X = np.stack(X, axis=0)
    y = np.asarray(y, dtype=np.int64)
    groups = np.asarray(groups, dtype=np.int64)

    n, T, C = X.shape
    n_classes = len(np.unique(y))
    print(f"Dataset construido: N={n} | T={T} | C={C} | clases={n_classes} | sujetos únicos={len(np.unique(groups))}")
    return X, y, groups, ch_template

# =========================
# AUGMENTS (solo train)
# =========================
def do_time_jitter(x, max_ms=50, fs=160.0):
    # x: (B,1,T,C) torch.float
    max_shift = int(round(max_ms/1000.0 * fs))
    if max_shift <= 0: return x
    B,_,T,C = x.shape
    shifts = torch.randint(low=-max_shift, high=max_shift+1, size=(B,), device=x.device)
    out = torch.empty_like(x)
    for i,s in enumerate(shifts):
        if s==0: out[i] = x[i]; continue
        if s>0:
            out[i,:,s:,:] = x[i,:,:T-s,:]
            out[i,:,:s,:] = 0
        else:
            s = -s
            out[i,:,:T-s,:] = x[i,:,s:,:]
            out[i,:,T-s:,:] = 0
    return out

def do_gaussian_noise(x, sigma=0.01):
    if sigma<=0: return x
    return x + sigma*torch.randn_like(x)

def do_temporal_cutout(x, min_ms=100, max_ms=150, fs=160.0):
    B,_,T,C = x.shape
    Lmin = int(round(min_ms/1000.0*fs))
    Lmax = int(round(max_ms/1000.0*fs))
    if Lmin<=0 or Lmax<=0 or Lmin>Lmax: return x
    out = x.clone()
    for i in range(B):
        L = random.randint(Lmin, Lmax)
        if L>=T: continue
        s = random.randint(0, T-L)
        out[i,:,s:s+L,:] = 0
    return out

def do_temporal_cutout_masked(x, y, classes_mask={2,3}, min_ms=30, max_ms=60, fs=160.0):
    # Aplica cutout solo si y pertenece a classes_mask
    B,_,T,C = x.shape
    if B == 0: return x
    Lmin = int(round(min_ms/1000.0*fs))
    Lmax = int(round(max_ms/1000.0*fs))
    if Lmin<=0 or Lmax<=0 or Lmin>Lmax: return x
    out = x.clone()
    y_np = y.detach().cpu().numpy()
    for i in range(B):
        if int(y_np[i]) not in classes_mask:
            continue
        L = random.randint(Lmin, Lmax)
        if L>=T: continue
        s = random.randint(0, T-L)
        out[i,:,s:s+L,:] = 0
    return out

def mixup_batch(x, y, n_classes, alpha=MIXUP_ALPHA):
    if alpha<=0:
        y_onehot = torch.nn.functional.one_hot(y, num_classes=n_classes).float()
        return x, y_onehot, 1.0
    lam = np.random.beta(alpha, alpha)
    perm = torch.randperm(x.size(0), device=x.device)
    x_mix = lam*x + (1-lam)*x[perm]
    y_a = torch.nn.functional.one_hot(y, num_classes=n_classes).float()
    y_b = y_a[perm]
    y_mix = lam*y_a + (1-lam)*y_b
    return x_mix, y_mix, lam

class WeightedSoftCrossEntropy(nn.Module):
    """
    Cross-entropy suave (targets en probas, p.ej. mixup) con pesos por clase.
    """
    def __init__(self, class_weights=None, label_smoothing=0.0):
        super().__init__()
        self.register_buffer('w', None if class_weights is None else class_weights.clone().float())
        self.ls = float(label_smoothing)

    def forward(self, logits, target_probs):
        # label smoothing: mezcla target_probs con uniforme
        if self.ls > 0:
            K = logits.size(1)
            target_probs = (1-self.ls)*target_probs + self.ls*(1.0/K)
        logp = torch.log_softmax(logits, dim=1)  # (B,C)
        loss_per_class = -(target_probs * logp)  # (B,C)
        if self.w is not None:
            loss_per_class = loss_per_class * self.w.unsqueeze(0)  # pesar por clase
        loss = loss_per_class.sum(dim=1).mean()
        return loss

# =========================
# EEGNet (Lawhern et al., 2018) adaptado a (B,1,T,C) + ChannelDropout
# =========================
class ChannelDropout(nn.Module):
    def __init__(self, p=0.15):  # Variant-B: 0.15
        super().__init__()
        self.p = p
    def forward(self, x):
        if not self.training or self.p<=0: return x
        B,_,T,C = x.shape
        mask = (torch.rand(B,1,1,C, device=x.device) > self.p).float()
        return x * mask

class EEGNet(nn.Module):
    """
    Entrada: x de forma (B, 1, T, C)
    """
    def __init__(self, n_ch: int, n_classes: int,
                 F1: int = 24, D: int = 2, kernel_t: int = 80, k_sep: int = 16,
                 pool1_t: int = 4, pool2_t: int = 5,
                 drop1_p: float = 0.40, drop2_p: float = 0.60,
                 chdrop_p: float = 0.15):
        super().__init__()
        self.n_ch = n_ch
        self.n_classes = n_classes
        self.F1 = F1
        self.D = D
        self.F2 = F1 * D
        self.kernel_t = kernel_t
        self.k_sep = k_sep
        self.pool1_t = pool1_t
        self.pool2_t = pool2_t

        self.chdrop = ChannelDropout(p=chdrop_p)

        # Bloque 1: temporal
        self.conv_temporal = nn.Conv2d(1, F1, kernel_size=(kernel_t, 1),
                                       padding=(kernel_t // 2, 0), bias=False)
        self.bn1 = nn.BatchNorm2d(F1, momentum=0.99, eps=1e-3)
        self.act = nn.ELU()

        # Bloque 2: depthwise (espacial)
        self.conv_depthwise = nn.Conv2d(F1, self.F2, kernel_size=(1, n_ch),
                                        groups=F1, bias=False)
        self.bn2 = nn.BatchNorm2d(self.F2, momentum=0.99, eps=1e-3)
        self.pool1 = nn.AvgPool2d(kernel_size=(pool1_t, 1), stride=(pool1_t, 1))
        self.drop1 = nn.Dropout(drop1_p)

        # Bloque 3: separable temporal
        self.conv_sep_depth = nn.Conv2d(self.F2, self.F2, kernel_size=(k_sep, 1),
                                        groups=self.F2, padding=(k_sep // 2, 0), bias=False)
        self.conv_sep_point = nn.Conv2d(self.F2, self.F2, kernel_size=(1, 1), bias=False)
        self.bn3 = nn.BatchNorm2d(self.F2, momentum=0.99, eps=1e-3)
        self.pool2 = nn.AvgPool2d(kernel_size=(pool2_t, 1), stride=(pool2_t, 1))
        self.drop2 = nn.Dropout(drop2_p)

        self.flatten = nn.Flatten()

        # Cabeza dinámica
        self.fc = None
        self.out = None
        self._T_in = None

    def _build_head(self, T_in: int, device: torch.device):
        T1 = T_in // self.pool1_t
        T2 = T1 // self.pool2_t
        feat_dim = self.F2 * T2 * 1
        self.fc  = nn.Linear(feat_dim, 256, bias=True).to(device)  # Variant-B: 256
        self.out = nn.Linear(256, self.n_classes, bias=True).to(device)
        self._T_in = T_in

    def ensure_head(self, T_in: int, device: torch.device):
        if (self.fc is None) or (self.out is None) or (self._T_in != T_in):
            self._build_head(T_in, device)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, _, T, C = x.shape
        self.ensure_head(T, x.device)

        x = self.chdrop(x)  # ChannelDropout

        z = self.conv_temporal(x)
        z = self.bn1(z); z = self.act(z)

        z = self.conv_depthwise(z)   # (B, F2, T, 1)
        z = self.bn2(z); z = self.act(z)
        z = self.pool1(z)
        z = self.drop1(z)

        z = self.conv_sep_depth(z)
        z = self.conv_sep_point(z)
        z = self.bn3(z); z = self.act(z)
        z = self.pool2(z)
        z = self.drop2(z)

        z = self.flatten(z)
        z = self.fc(z); z = self.act(z)
        z = self.out(z)
        return z

# =========================
# TORCH DATASET
# =========================
class EEGTrials(Dataset):
    def __init__(self, X, y, groups):
        self.X = X.astype(np.float32)
        self.y = y.astype(np.int64)
        self.g = groups.astype(np.int64)
    def __len__(self): return self.X.shape[0]
    def __getitem__(self, idx):
        x = self.X[idx]
        x = np.expand_dims(x, 0)                 # (1, T, C)
        return torch.from_numpy(x), torch.tensor(self.y[idx]), torch.tensor(self.g[idx])

CLASS_NAMES_4C = ['Left', 'Right', 'Both Fists', 'Both Feet']

# =========================
# UTIL: MAX-NORM
# =========================
@torch.no_grad()
def apply_max_norm(model, max_value=2.0, p=2.0):
    layers = []
    if hasattr(model, 'conv_depthwise'): layers.append(model.conv_depthwise)
    if hasattr(model, 'conv_sep_point'): layers.append(model.conv_sep_point)
    if hasattr(model, 'fc'):             layers.append(model.fc)
    if hasattr(model, 'out'):            layers.append(model.out)
    for layer in layers:
        if hasattr(layer, 'weight') and layer.weight is not None:
            w = layer.weight.data
            norms = w.view(w.size(0), -1).norm(p=p, dim=1, keepdim=True)
            desired = torch.clamp(norms, max=max_value)
            w.view(w.size(0), -1).mul_(desired / (1e-8 + norms))

# =========================
# ENTRENAMIENTO / EVALUACIÓN
# =========================
def build_weighted_sampler(y, groups):
    # pesos inversos por clase y por sujeto → promueve balance en cada batch
    y = np.asarray(y); groups = np.asarray(groups)
    class_counts = np.bincount(y, minlength=len(np.unique(y))).astype(float)
    class_w = 1.0 / class_counts[y]
    subj_vals, subj_counts = np.unique(groups, return_counts=True)
    subj_map = {s:c for s,c in zip(subj_vals, subj_counts)}
    subj_w = np.array([1.0/subj_map[g] for g in groups], dtype=float)
    w = class_w * subj_w
    w = w / w.mean()
    w_t = torch.from_numpy(w).float()
    sampler = WeightedRandomSampler(weights=w_t, num_samples=len(w_t), replacement=True)
    return sampler

def make_class_weight_tensor(y_indices, n_classes, boost_bfists=1.20):
    # pesos inversos por clase, normalizados; clase 2 (*Both Fists*) con boost
    counts = np.bincount(y_indices, minlength=n_classes).astype(np.float32)
    counts[counts == 0] = 1.0
    w = counts.sum() / counts
    w = w / w.mean()
    w[2] *= boost_bfists
    w = w / w.mean()
    return torch.tensor(w, dtype=torch.float32, device=DEVICE)

def train_epoch(model, loader, opt, criterion_soft, n_classes,
                do_aug=True, fs=160.0, maxnorm=None):
    model.train()
    for xb, yb, _ in loader:
        xb, yb = xb.to(DEVICE), yb.to(DEVICE)

        # ---- AUGMENTS (orden: jitter → noise → cutout focalizado → mixup) ----
        if do_aug:
            xb = do_time_jitter(xb, max_ms=50, fs=fs)
            xb = do_gaussian_noise(xb, sigma=0.01)
            xb = do_temporal_cutout_masked(xb, yb, classes_mask={2,3}, min_ms=30, max_ms=60, fs=fs)
            xb, yt, _ = mixup_batch(xb, yb, n_classes=n_classes, alpha=MIXUP_ALPHA)
            logits = model(xb)
            loss = criterion_soft(logits, yt)
        else:
            logits = model(xb)
            # yb como índices → convertir a one-hot para criterio suave
            yt = torch.nn.functional.one_hot(yb, num_classes=n_classes).float()
            loss = criterion_soft(logits, yt)

        opt.zero_grad(set_to_none=True)
        loss.backward()
        opt.step()
        if maxnorm is not None:
            apply_max_norm(model, max_value=maxnorm, p=2.0)

@torch.no_grad()
def _predict_tta(model, xb, n=5, fs=160.0):
    outs = []
    for _ in range(n):
        xj = do_time_jitter(xb, max_ms=25, fs=fs)
        outs.append(model(xj))
    return torch.stack(outs, dim=0).mean(dim=0)

@torch.no_grad()
def evaluate_with_preds(model, loader, use_tta=True, tta_n=5):
    model.eval()
    y_true, y_pred = [], []
    for xb, yb, _ in loader:
        xb = xb.to(DEVICE)
        if use_tta:
            logits = _predict_tta(model, xb, n=tta_n, fs=FS)
        else:
            logits = model(xb)
        pred = logits.argmax(dim=1).cpu().numpy().tolist()
        y_pred.extend(pred)
        y_true.extend(yb.numpy().tolist())
    y_true = np.asarray(y_true, dtype=int)
    y_pred = np.asarray(y_pred, dtype=int)
    acc = (y_true == y_pred).mean()
    return y_true, y_pred, float(acc)

def plot_confusion(y_true, y_pred, classes, title, fname):
    cm = confusion_matrix(y_true, y_pred, labels=list(range(len(classes))))
    with np.errstate(invalid='ignore'):
        cm_norm = cm.astype('float') / cm.sum(axis=1, keepdims=True)
    cm_norm = np.nan_to_num(cm_norm)

    plt.figure(figsize=(6, 5))
    plt.imshow(cm_norm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title(title)
    plt.colorbar(fraction=0.046, pad=0.04)
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45, ha='right')
    plt.yticks(tick_marks, classes)

    fmt = '.2f'
    thresh = cm_norm.max() / 2.
    for i, j in itertools.product(range(cm_norm.shape[0]), range(cm_norm.shape[1])):
        plt.text(j, i, format(cm_norm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm_norm[i, j] > thresh else "black")

    plt.ylabel('True label'); plt.xlabel('Predicted label')
    plt.tight_layout(); plt.savefig(fname, dpi=150, bbox_inches='tight'); plt.close()

def plot_training_curves(history, fname):
    plt.figure(figsize=(6,4))
    plt.plot(history['train_acc'], label='train_acc')
    plt.plot(history['val_acc'], label='val_acc')
    plt.xlabel('Epoch'); plt.ylabel('Accuracy')
    plt.title('Training curve'); plt.legend()
    plt.tight_layout(); plt.savefig(fname, dpi=150); plt.close()

def print_report(y_true, y_pred, classes):
    print(classification_report(y_true, y_pred, target_names=classes, digits=4))

# =========================
# FINE-TUNING PROGRESIVO por sujeto (con validación interna)
# =========================
def _param_groups(model, mode):
    if mode == 'out':
        train = list(model.out.parameters())
    elif mode == 'head':
        train = list(model.fc.parameters()) + list(model.out.parameters())
    elif mode == 'spatial+head':
        train = (list(model.conv_depthwise.parameters()) +
                 list(model.bn2.parameters()) +
                 list(model.conv_sep_depth.parameters()) +
                 list(model.conv_sep_point.parameters()) +
                 list(model.bn3.parameters()) +
                 list(model.fc.parameters()) +
                 list(model.out.parameters()))
    else:
        raise ValueError(mode)
    return train

def _freeze_for_mode(model, mode):
    for p in model.parameters(): p.requires_grad = False
    if mode == 'out':
        for p in model.out.parameters(): p.requires_grad = True
    elif mode == 'head':
        for p in model.fc.parameters():  p.requires_grad = True
        for p in model.out.parameters(): p.requires_grad = True
    elif mode == 'spatial+head':
        for p in model.conv_depthwise.parameters(): p.requires_grad = True
        for p in model.bn2.parameters():           p.requires_grad = True
        for p in model.conv_sep_depth.parameters():p.requires_grad = True
        for p in model.conv_sep_point.parameters():p.requires_grad = True
        for p in model.bn3.parameters():           p.requires_grad = True
        for p in model.fc.parameters():            p.requires_grad = True
        for p in model.out.parameters():           p.requires_grad = True

def _class_weights(y_np, n_classes):
    counts = np.bincount(y_np, minlength=n_classes).astype(np.float32)
    counts[counts == 0] = 1.0
    weights = counts.sum() / counts
    weights = weights / weights.mean()
    return torch.tensor(weights, dtype=torch.float32, device=DEVICE)

def _train_one_mode(model, X_cal, y_cal, n_classes, mode,
                    epochs=FT_EPOCHS, batch_size=32,
                    head_lr=FT_HEAD_LR, base_lr=FT_BASE_LR,
                    l2sp_lambda=FT_L2SP, patience=FT_PATIENCE, val_ratio=FT_VAL_RATIO):
    sss = StratifiedShuffleSplit(n_splits=1, test_size=val_ratio, random_state=RANDOM_STATE)
    (tr_idx, va_idx), = sss.split(X_cal, y_cal)
    Xtr, ytr = X_cal[tr_idx], y_cal[tr_idx]
    Xva, yva = X_cal[va_idx], y_cal[va_idx]

    ds_tr = torch.utils.data.TensorDataset(
        torch.from_numpy(Xtr).float().unsqueeze(1),
        torch.from_numpy(ytr).long()
    )
    ds_va = torch.utils.data.TensorDataset(
        torch.from_numpy(Xva).float().unsqueeze(1),
        torch.from_numpy(yva).long()
    )
    dl_tr = DataLoader(ds_tr, batch_size=batch_size, shuffle=True, drop_last=False)
    dl_va = DataLoader(ds_va, batch_size=batch_size, shuffle=False, drop_last=False)

    _freeze_for_mode(model, mode)

    if mode == 'spatial+head':
        base_params = (list(model.conv_depthwise.parameters()) +
                       list(model.bn2.parameters()) +
                       list(model.conv_sep_depth.parameters()) +
                       list(model.conv_sep_point.parameters()) +
                       list(model.bn3.parameters()))
        head_params = list(model.fc.parameters()) + list(model.out.parameters())
        train_params = base_params + head_params
        opt = optim.Adam([
            {"params": base_params, "lr": base_lr},
            {"params": head_params, "lr": head_lr},
        ])
    else:
        train_params = _param_groups(model, mode)
        opt = optim.Adam(train_params, lr=head_lr)

    ref = [p.detach().clone().to(p.device) for p in train_params]
    class_w = _class_weights(ytr, n_classes)
    crit = nn.CrossEntropyLoss(weight=class_w)

    best_state = copy.deepcopy(model.state_dict())
    best_val = float('inf'); bad = 0

    for _ in range(epochs):
        model.train()
        for xb, yb in dl_tr:
            xb, yb = xb.to(DEVICE), yb.to(DEVICE)
            opt.zero_grad(set_to_none=True)
            logits = model(xb)
            loss = crit(logits, yb)
            reg = 0.0
            for p_cur, p_ref in zip(train_params, ref):
                reg = reg + torch.sum((p_cur - p_ref)**2)
            loss = loss + l2sp_lambda * reg
            loss.backward(); opt.step()
            apply_max_norm(model, max_value=2.0, p=2.0)

        model.eval()
        with torch.no_grad():
            val_loss = 0.0; nval = 0
            for xb, yb in dl_va:
                xb, yb = xb.to(DEVICE), yb.to(DEVICE)
                logits = model(xb)
                loss = crit(logits, yb)
                val_loss += loss.item() * xb.size(0); nval += xb.size(0)
            val_loss /= max(1, nval)

        if val_loss + 1e-7 < best_val:
            best_val = val_loss; bad = 0
            best_state = copy.deepcopy(model.state_dict())
        else:
            bad += 1
            if bad >= patience: break

    model.load_state_dict(best_state)
    return model

@torch.no_grad()
def predict_numpy(model, X_np, device):
    model.eval()
    xb = torch.from_numpy(X_np).float().unsqueeze(1).to(device)
    logits = model(xb)
    return logits.argmax(dim=1).cpu().numpy()

def subject_cv_finetune_predict_progressive(model_global, Xs, ys, device,
                                            n_splits=CALIB_CV_FOLDS, n_classes=4):
    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=RANDOM_STATE)
    y_true_full = np.empty_like(ys); y_pred_full = np.empty_like(ys)

    for tr_idx, te_idx in skf.split(Xs, ys):
        Xcal, ycal = Xs[tr_idx], ys[tr_idx]
        Xho,  yho  = Xs[te_idx], ys[te_idx]

        m_out = copy.deepcopy(model_global)
        _train_one_mode(m_out, Xcal, ycal, n_classes, mode='out',
                        epochs=FT_EPOCHS, head_lr=FT_HEAD_LR, l2sp_lambda=FT_L2SP,
                        patience=FT_PATIENCE, val_ratio=FT_VAL_RATIO)
        yhat_out = predict_numpy(m_out, Xho, device); acc_out = (yhat_out == yho).mean()

        m_head = copy.deepcopy(model_global)
        _train_one_mode(m_head, Xcal, ycal, n_classes, mode='head',
                        epochs=FT_EPOCHS, head_lr=FT_HEAD_LR, l2sp_lambda=FT_L2SP,
                        patience=FT_PATIENCE, val_ratio=FT_VAL_RATIO)
        yhat_head = predict_numpy(m_head, Xho, device); acc_head = (yhat_head == yho).mean()

        m_sp = copy.deepcopy(model_global)
        _train_one_mode(m_sp, Xcal, ycal, n_classes, mode='spatial+head',
                        epochs=FT_EPOCHS, head_lr=FT_HEAD_LR, base_lr=FT_BASE_LR,
                        l2sp_lambda=FT_L2SP, patience=FT_PATIENCE, val_ratio=FT_VAL_RATIO)
        yhat_sp = predict_numpy(m_sp, Xho, device); acc_sp = (yhat_sp == yho).mean()

        best_idx = np.argmax([acc_out, acc_head, acc_sp])
        yhat_best = [yhat_out, yhat_head, yhat_sp][best_idx]

        y_true_full[te_idx] = yho; y_pred_full[te_idx] = yhat_best

    return y_true_full, y_pred_full

# =========================
# FOLDS JSON helpers
# =========================
def save_group_folds_json_with_indices(subject_ids_str, groups_array, n_splits, out_json_path,
                                       created_by="dose_experiment", description=None):
    out_json_path = Path(out_json_path)
    unique_subjects_int = sorted(np.unique(groups_array).tolist())
    subject_ids = [f"S{sid:03d}" for sid in unique_subjects_int]

    if len(subject_ids) < n_splits:
        raise ValueError(f"n_splits={n_splits} mayor que número de sujetos={len(subject_ids)}")

    groups = np.arange(len(subject_ids))
    gkf = GroupKFold(n_splits=n_splits)

    folds = []
    fold_i = 0
    for train_idx_grp, test_idx_grp in gkf.split(groups, groups, groups):
        fold_i += 1
        train_sids = [subject_ids[int(i)] for i in train_idx_grp]
        test_sids  = [subject_ids[int(i)] for i in test_idx_grp]

        train_sids_int = [int(s[1:]) for s in train_sids]
        test_sids_int  = [int(s[1:]) for s in test_sids]

        tr_idx = np.where(np.isin(groups_array, train_sids_int))[0].tolist()
        te_idx = np.where(np.isin(groups_array, test_sids_int))[0].tolist()

        folds.append({
            "fold": int(fold_i),
            "train": train_sids,
            "test": test_sids,
            "tr_idx": tr_idx,
            "te_idx": te_idx
        })

    payload = {
        "created_at": datetime.now().isoformat(),
        "created_by": created_by,
        "description": description if description is not None else "",
        "n_splits": int(n_splits),
        "n_subjects": len(subject_ids),
        "subject_ids": subject_ids,
        "folds": folds
    }

    out_json_path.parent.mkdir(parents=True, exist_ok=True)
    with open(out_json_path, "w", encoding="utf-8") as f:
        json.dump(payload, f, indent=2, ensure_ascii=False)

    print(f"Folds JSON con índices guardado → {out_json_path}")
    return out_json_path

def load_group_folds_json(path_json, expected_subject_ids=None, strict_check=True):
    path_json = Path(path_json)
    if not path_json.exists():
        raise FileNotFoundError(f"No existe {path_json}")
    with open(path_json, "r", encoding="utf-8") as f:
        payload = json.load(f)

    subj_json = payload.get("subject_ids", [])
    if expected_subject_ids is not None:
        expected = sorted(list(expected_subject_ids))
        if subj_json != expected:
            msg = ("Los subject_ids del JSON no coinciden con expected_subject_ids.\n"
                   f"JSON has {len(subj_json)} subjects, expected {len(expected)}.\n"
                   f"First 10 JSON: {subj_json[:10]}\nFirst 10 expected: {expected[:10]}")
            if strict_check:
                raise ValueError(msg)
            else:
                print("WARNING: " + msg)
    return payload

# =========================
# EXPERIMENTO
# =========================
def run_experiment(save_folds_json=True, folds_json_path=FOLDS_DIR, folds_json_description="GroupKFold folds for comparison"):
    mne.set_log_level('WARNING')

    subs = subjects_available()
    print(f"Sujetos elegibles: {len(subs)} → {subs[:10]}{'...' if len(subs)>10 else ''}")

    X, y, groups, chs = build_dataset_all(subs, scenario=CLASS_SCENARIO, window_mode=WINDOW_MODE)
    N, T, C = X.shape
    n_classes = len(np.unique(y))
    print(f"Listo para entrenar: N={N} | T={T} | C={C} | clases={n_classes} | sujetos={len(np.unique(groups))}")

    ds = EEGTrials(X, y, groups)

    # preparar JSON folds
    if folds_json_path is None:
        folds_json_path = Path("folds") / f"group_folds_{N_FOLDS}splits.json"
    else:
        folds_json_path = Path(folds_json_path)
    folds_json_path.parent.mkdir(parents=True, exist_ok=True)

    unique_subs = sorted(np.unique(groups).tolist())
    subject_ids_str = [f"S{s:03d}" for s in unique_subs]

    if not folds_json_path.exists():
        if not save_folds_json:
            raise FileNotFoundError(f"Folds JSON no encontrado en {folds_json_path} y save_folds_json=False.")
        save_group_folds_json_with_indices(subject_ids_str, groups, n_splits=N_FOLDS,
                                           out_json_path=folds_json_path,
                                           created_by="Joel_Clasificador",
                                           description=folds_json_description)

    payload = load_group_folds_json(folds_json_path, expected_subject_ids=subject_ids_str, strict_check=False)
    folds = payload["folds"]

    # bucle por folds
    global_folds = []
    ft_prog_folds = []
    all_true = []
    all_pred = []

    for f in folds:
        fold = f["fold"]
        tr_idx = np.asarray(f.get("tr_idx", []), dtype=int)
        te_idx = np.asarray(f.get("te_idx", []), dtype=int)

        if tr_idx.size == 0 or te_idx.size == 0:
            print(f"Advertencia: fold {fold} sin índices tr/te válidos. Saltando.")
            continue

        # ===== Split de validación por sujetos =====
        gss = GroupShuffleSplit(n_splits=1, test_size=GLOBAL_VAL_SPLIT, random_state=RANDOM_STATE)
        tr_subj_idx, va_subj_idx = next(gss.split(tr_idx, groups[tr_idx], groups[tr_idx]))
        tr_sub_idx = tr_idx[tr_subj_idx]
        va_idx     = tr_idx[va_subj_idx]

        # Sampler balanceado para train
        sampler = build_weighted_sampler(y[tr_sub_idx], groups[tr_sub_idx])

        tr_loader = DataLoader(Subset(ds, tr_sub_idx), batch_size=BATCH_SIZE, sampler=sampler, drop_last=False)
        va_loader = DataLoader(Subset(ds, va_idx),     batch_size=BATCH_SIZE, shuffle=False, drop_last=False)
        te_loader = DataLoader(Subset(ds, te_idx),     batch_size=BATCH_SIZE, shuffle=False, drop_last=False)

        # ===== EEGNet =====
        model = EEGNet(n_ch=C, n_classes=n_classes,
                       F1=24, D=2, kernel_t=80, k_sep=16,
                       pool1_t=4, pool2_t=5, drop1_p=0.40, drop2_p=0.60,
                       chdrop_p=0.15).to(DEVICE)

        # Opt y scheduler SGDR
        opt = optim.Adam(model.parameters(), lr=LR_INIT)
        scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(opt, T_0=SGDR_T0, T_mult=SGDR_Tmult)

        # Métrica rápida
        def _acc(loader):
            return evaluate_with_preds(model, loader, use_tta=True, tta_n=5)[2]

        # Historia para curvas
        history = {'train_acc': [], 'val_acc': []}

        # ---- Criterio suave con ponderación por clase (+20% BFISTS) ----
        class_weights = make_class_weight_tensor(y[tr_sub_idx], n_classes, boost_bfists=1.20)
        criterion_soft = WeightedSoftCrossEntropy(class_weights, label_smoothing=0.05)

        # ===== Entrenamiento global =====
        print(f"\n[Fold {fold}/{N_FOLDS}] Entrenando modelo global..."
              f" (n_train={len(tr_sub_idx)} | n_val={len(va_idx)} | n_test={len(te_idx)})")

        best_state = copy.deepcopy(model.state_dict())
        best_val = -1.0
        bad = 0

        for epoch in range(1, EPOCHS_GLOBAL + 1):
            train_epoch(model, tr_loader, opt, criterion_soft, n_classes=n_classes,
                        do_aug=True, fs=FS, maxnorm=2.0)

            # ---- Warmup + SGDR
            cur_epoch = epoch - 1 + 1e-8
            if epoch <= WARMUP_EPOCHS:
                # warmup lineal de 0 → LR_INIT
                for pg in opt.param_groups:
                    pg['lr'] = LR_INIT * epoch / float(WARMUP_EPOCHS)
            else:
                scheduler.step(cur_epoch)

            # eval
            tr_acc = _acc(tr_loader)
            va_acc = _acc(va_loader)
            history['train_acc'].append(tr_acc)
            history['val_acc'].append(va_acc)

            if (epoch % LOG_EVERY == 0) or epoch in (1, 10, 20, 50, 100):
                cur_lr = opt.param_groups[0]['lr']
                print(f"  Época {epoch:3d} | train_acc={tr_acc:.4f} | val_acc={va_acc:.4f} | LR={cur_lr:.5f}")

            if va_acc > best_val + 1e-4:
                best_val = va_acc
                best_state = copy.deepcopy(model.state_dict())
                bad = 0
            else:
                bad += 1
                if bad >= GLOBAL_PATIENCE:
                    print(f"  Early stopping en época {epoch} (mejor val_acc={best_val:.4f})")
                    break

        # guardar curva de entrenamiento
        curve_path = f"training_curve_fold{fold}.png"
        plot_training_curves(history, curve_path)
        print(f"↳ Curva de entrenamiento guardada: {curve_path}")

        # cargar mejor estado antes de evaluar en test
        model.load_state_dict(best_state)

        # ===== Evaluación global (inter-sujeto puro) =====
        y_true, y_pred, acc_global = evaluate_with_preds(model, te_loader, use_tta=True, tta_n=5)
        global_folds.append(acc_global)
        all_true.append(y_true); all_pred.append(y_pred)

        print(f"[Fold {fold}/{N_FOLDS}] Global acc={acc_global:.4f}")
        print_report(y_true, y_pred, CLASS_NAMES_4C)

        # ---------- Fine-tuning PROGRESIVO por sujeto con 4-fold CV ----------
        X_te, y_te, g_te = X[te_idx], y[te_idx], groups[te_idx]

        y_true_ft_all, y_pred_ft_all = [], []
        used_subjects = 0
        for sid in np.unique(g_te):
            idx = np.where(g_te == sid)[0]
            Xs, ys = X_te[idx], y_te[idx]

            if len(ys) < CALIB_CV_FOLDS or len(np.unique(ys)) < 2:
                continue

            y_true_subj, y_pred_subj = subject_cv_finetune_predict_progressive(
                model, Xs, ys, DEVICE, n_splits=CALIB_CV_FOLDS, n_classes=n_classes
            )
            y_true_ft_all.append(y_true_subj)
            y_pred_ft_all.append(y_pred_subj)
            used_subjects += 1

        if len(y_true_ft_all) > 0:
            y_true_ft_all = np.concatenate(y_true_ft_all)
            y_pred_ft_all = np.concatenate(y_pred_ft_all)
            acc_ft = (y_true_ft_all == y_pred_ft_all).mean()
            print(f"  Fine-tuning PROGRESIVO (por sujeto, {CALIB_CV_FOLDS}-fold CV) acc={acc_ft:.4f} | sujetos={used_subjects}")
            print(f"  Δ(FT-Global) = {acc_ft - acc_global:+.4f}")
        else:
            acc_ft = np.nan
            print("  Fine-tuning PROGRESIVO no ejecutado (sujeto(s) con muestras insuficientes).")

        ft_prog_folds.append(acc_ft)

    # ---------- resultados finales ----------
    if len(all_true) > 0:
        all_true = np.concatenate(all_true)
        all_pred = np.concatenate(all_pred)
    else:
        all_true = np.array([], dtype=int)
        all_pred = np.array([], dtype=int)

    print("\n" + "="*60)
    print("RESULTADOS FINALES")
    print("="*60)
    print("Global folds:", [f"{a:.4f}" for a in global_folds])
    if len(global_folds) > 0:
        print(f"Global mean: {np.mean(global_folds):.4f}")

    print("Fine-tune PROGRESIVO folds:", [("nan" if (a is None or np.isnan(a)) else f"{a:.4f}") for a in ft_prog_folds])
    if len(ft_prog_folds) > 0:
        print(f"Fine-tune PROGRESIVO mean: {np.nanmean(ft_prog_folds):.4f}")
        print(f"Δ(FT-Global) mean: {np.nanmean(ft_prog_folds) - np.mean(global_folds):+.4f}")

    if all_true.size > 0:
        plot_confusion(all_true, all_pred, CLASS_NAMES_4C,
                       title="Confusion Matrix - Global Model (All Folds)",
                       fname="confusion_global_allfolds.png")
        print("\n↳ Matriz de confusión guardada: confusion_global_allfolds.png")

    return {
        "global_folds": global_folds,
        "ft_prog_folds": ft_prog_folds,
        "all_true": all_true,
        "all_pred": all_pred,
        "folds_json_path": str(folds_json_path)
    }

# ---------- MAIN ----------
if __name__ == "__main__":
    print("🧠 INICIANDO EXPERIMENTO CON EEGNet + FINE-TUNING PROGRESIVO (Variant-B)")
    print(f"🔧 Configuración: {CLASS_SCENARIO}, {len(EXPECTED_8)} canales, {WINDOW_MODE}")
    print(f"⚙️  FT: epochs={FT_EPOCHS}, base_lr={FT_BASE_LR}, head_lr={FT_HEAD_LR}, L2SP={FT_L2SP}, patience={FT_PATIENCE}, CV={CALIB_CV_FOLDS}")
    run_experiment()


🚀 Usando dispositivo: cuda
🧠 INICIANDO EXPERIMENTO CON EEGNet + FINE-TUNING PROGRESIVO (Variant-B)
🔧 Configuración: 4c, 8 canales, 6s
⚙️  FT: epochs=30, base_lr=5e-05, head_lr=0.001, L2SP=0.0001, patience=5, CV=4
Sujetos elegibles: 103 → [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]...


Construyendo dataset (RAW): 100%|██████████| 103/103 [00:38<00:00,  2.69it/s]


Dataset construido: N=8652 | T=960 | C=8 | clases=4 | sujetos únicos=103
Listo para entrenar: N=8652 | T=960 | C=8 | clases=4 | sujetos=103

[Fold 1/5] Entrenando modelo global... (n_train=5796 | n_val=1092 | n_test=1764)
  Época   1 | train_acc=0.4651 | val_acc=0.4469 | LR=0.00500
  Época   5 | train_acc=0.5104 | val_acc=0.4643 | LR=0.00250
  Época  10 | train_acc=0.5119 | val_acc=0.4753 | LR=0.00854
  Época  15 | train_acc=0.5400 | val_acc=0.4936 | LR=0.00250
  Época  20 | train_acc=0.5376 | val_acc=0.4698 | LR=0.00996
  Early stopping en época 24 (mejor val_acc=0.5101)
↳ Curva de entrenamiento guardada: training_curve_fold1.png
[Fold 1/5] Global acc=0.4756
              precision    recall  f1-score   support

        Left     0.5802    0.4263    0.4915       441
       Right     0.5622    0.4921    0.5248       441
  Both Fists     0.3862    0.5079    0.4388       441
   Both Feet     0.4430    0.4762    0.4590       441

    accuracy                         0.4756      1764
   mac