In [7]:
# -*- coding: utf-8 -*-
# CNN + Transformer (PyTorch) para PhysioNet/BCI2000 MI — IMAGINERÍA (4 clases) con 8 canales.
# Incluye reproducibilidad (seed=42) y determinismo CUDA para cuBLAS/cuDNN en notebook.

# =========================
# Reproducibilidad (poner ANTES de importar torch)
# =========================
import os
os.environ['PYTHONHASHSEED'] = '42'
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'  # determinismo cuBLAS

import re, json, random
from pathlib import Path
from glob import glob
import numpy as np
import mne
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix

# =========================
# REPRODUCIBILIDAD (seed + determinismo)
# =========================
RANDOM_STATE = 42

def seed_everything(seed: int = 42):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    # determinismo cuDNN
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Evitar TF32 (puede romper determinismo en algunas GPU)
    try:
        torch.backends.cuda.matmul.allow_tf32 = False
        torch.backends.cudnn.allow_tf32 = False
    except Exception:
        pass
    # reforzar determinismo cuando sea posible
    try:
        torch.use_deterministic_algorithms(True)
    except Exception:
        pass

def seed_worker(worker_id: int):
    worker_seed = seed_everything.__defaults__[0] + worker_id if seed_everything.__defaults__ else 42 + worker_id
    np.random.seed(worker_seed)
    random.seed(worker_seed)

seed_everything(RANDOM_STATE)

# =========================
# CONFIG (edita a tu gusto)
# =========================
PROJ = Path('..').resolve().parent
DATA_RAW = PROJ / 'data' / 'raw'                     # .../S###/S###R##.edf
FOLDS_JSON = PROJ / 'models' / 'folds' / 'Kfold5.json'

FOLD = 1
EPOCHS = 40
BATCH_SIZE = 64
LR = 1e-3
RESAMPLE_HZ = None              # None para mantener original
DO_NOTCH = True                 # 60 Hz
DO_BANDPASS = False             # band-pass (4–38) desactivado por defecto
BP_LO, BP_HI = 4.0, 38.0
DO_CAR = False                   # re-referencia promedio sobre 8 canales

# Ventana temporal (seg)
TMIN, TMAX = 0.5, 4.5

EXCLUDE_SUBJECTS = {38, 88, 89, 92, 100, 104}
EXPECTED_8 = ['C3','C4','Cz','CP3','CP4','FC3','FC4','FCz']
CLASS_NAMES = ['left', 'right', 'both_fists', 'both_feet']

# Runs IMAGINERÍA
IMAGERY_RUNS_LR = {4, 8, 12}    # T1=left, T2=right
IMAGERY_RUNS_BF = {6, 10, 14}   # T1=both_fists, T2=both_feet

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🚀 Device: {DEVICE}")

# =========================
# UTILIDADES
# =========================
def normalize_ch_name(name: str) -> str:
    """Normaliza: quita no-alfanum, pasa a upper. 'Fc3.' -> 'FC3', 'Cz..' -> 'CZ'."""
    s = re.sub(r'[^A-Za-z0-9]', '', name)
    return s.upper()

NORMALIZED_TARGETS = [normalize_ch_name(c) for c in EXPECTED_8]

def pick_8_channels(raw: mne.io.BaseRaw) -> mne.io.BaseRaw:
    """Mapea canales del EDF a EXPECTED_8 aunque tengan puntos/minúsculas."""
    chs = raw.info['ch_names']
    norm_map = {normalize_ch_name(ch): ch for ch in chs}
    picked = []
    for target_norm, target_orig in zip(NORMALIZED_TARGETS, EXPECTED_8):
        if target_norm in norm_map:
            picked.append(norm_map[target_norm])
        else:
            raise RuntimeError(f"Canal requerido '{target_orig}' no encontrado. Disponibles: {chs}")
    return raw.pick(picks=picked)

def list_subject_imagery_edfs(subject_id: str) -> list:
    """Devuelve lista de EDFs IMAGERY para un sujeto S###: R04,R06,R08,R10,R12,R14."""
    subj_dir = DATA_RAW / subject_id
    edfs = []
    for r in [4, 6, 8, 10, 12, 14]:
        pattern = str(subj_dir / f"{subject_id}R{r:02d}.edf")
        edfs.extend(glob(pattern))
    return sorted(edfs)

def load_subject_epochs(subject_id: str, resample_hz: int, do_notch: bool, do_bandpass: bool,
                        do_car: bool, bp_lo: float, bp_hi: float):
    """Carga EDFs de imaginería, aplica preproc y devuelve (X[N,8,T], y[N], sfreq)."""
    edfs = list_subject_imagery_edfs(subject_id)
    if len(edfs) == 0:
        raise FileNotFoundError(f"No imagery EDF files for {subject_id} under {DATA_RAW}")

    X_list, y_list, sfreq_list = [], [], []

    for edf_path in edfs:
        m = re.search(r"R(\d{2})", Path(edf_path).name)
        run = int(m.group(1)) if m else -1

        raw = mne.io.read_raw_edf(edf_path, preload=True, verbose='ERROR')

        # Selección de 8 canales
        raw = pick_8_channels(raw)

        # --- PREPROC OPCIONAL ---
        if do_notch:
            raw.notch_filter(freqs=[60.0], picks='all', verbose='ERROR')
        if do_bandpass:
            raw.filter(l_freq=bp_lo, h_freq=bp_hi, picks='all', verbose='ERROR')
        if do_car:
            raw.set_eeg_reference('average', projection=False, verbose='ERROR')

        # Resample
        if resample_hz is not None and resample_hz > 0:
            raw.resample(resample_hz)
        sfreq = raw.info['sfreq']

        # Eventos (T0/T1/T2)
        events, event_id = mne.events_from_annotations(raw, verbose='ERROR')

        # Mantener solo T1/T2 (T0 ignorado)
        keep = {k: v for k, v in event_id.items() if k in {'T1', 'T2'}}
        if len(keep) == 0:
            continue

        epochs = mne.Epochs(raw, events=events, event_id=keep, tmin=TMIN, tmax=TMAX,
                            baseline=None, preload=True, verbose='ERROR')
        X = epochs.get_data()  # (n_epochs, 8, T)

        # Construir y según RUN
        ev_codes = epochs.events[:, 2]
        inv = {v: k for k, v in keep.items()}  # id -> 'T1'/'T2'
        y_run = []
        for code in ev_codes:
            lab = inv[code]
            if run in IMAGERY_RUNS_LR:
                y_run.append(0 if lab == 'T1' else 1)  # left/right
            elif run in IMAGERY_RUNS_BF:
                y_run.append(2 if lab == 'T1' else 3)  # both_fists/feet
            else:
                y_run.append(-1)
        y_run = np.array(y_run, dtype=int)
        keep_mask = y_run >= 0
        X = X[keep_mask]
        y = y_run[keep_mask]

        if len(y) == 0:
            continue

        X_list.append(X)
        y_list.append(y)
        sfreq_list.append(sfreq)

    if len(X_list) == 0:
        return np.empty((0, 8, 1)), np.empty((0,), dtype=int), None

    X_all = np.concatenate(X_list, axis=0)
    y_all = np.concatenate(y_list, axis=0)

    if len(set([round(s) for s in sfreq_list])) != 1:
        raise RuntimeError(f"Inconsistent sampling rates: {sfreq_list}")

    return X_all, y_all, sfreq_list[0]

def load_fold_subjects(folds_json: Path, fold: int):
    with open(folds_json, 'r') as f:
        data = json.load(f)
    for item in data.get('folds', []):
        if int(item.get('fold', -1)) == int(fold):
            train_sub = list(item.get('train', []))
            test_sub  = list(item.get('test', []))
            return train_sub, test_sub
    raise ValueError(f"Fold {fold} not found in {folds_json}")

def subject_id_to_int(s: str) -> int:
    m = re.match(r'[Ss](\d+)', s)
    return int(m.group(1)) if m else -1

def class_count_summary(y, name):
    bc = np.bincount(y, minlength=4)
    print(f"{name} counts:", dict(zip(CLASS_NAMES, bc)), "| total =", bc.sum())

# =========================
# MODELO: CNN + Transformer (con dropout)
# =========================
class DepthwiseSeparableConv(nn.Module):
    def __init__(self, in_ch, out_ch, k, s=1, p=0):
        super().__init__()
        self.dw = nn.Conv1d(in_ch, in_ch, kernel_size=k, stride=s, padding=p, groups=in_ch, bias=False)
        self.pw = nn.Conv1d(in_ch, out_ch, kernel_size=1, bias=False)
        self.bn = nn.BatchNorm1d(out_ch)
        self.act = nn.ELU()
    def forward(self, x):
        x = self.dw(x); x = self.pw(x); x = self.bn(x)
        return self.act(x)

class EEGCNNTransformer(nn.Module):
    def __init__(self, n_ch=8, n_cls=4, d_model=128, n_heads=4, n_layers=2, p_drop=0.2):
        super().__init__()
        self.conv_t = nn.Sequential(
            nn.Conv1d(n_ch, 32, kernel_size=129, stride=2, padding=64, bias=False),
            nn.BatchNorm1d(32), nn.ELU(),
            DepthwiseSeparableConv(32, 64, k=31, s=2, p=15),
            DepthwiseSeparableConv(64, 128, k=15, s=2, p=7),
        )
        self.proj = nn.Conv1d(128, d_model, kernel_size=1, bias=False)
        self.dropout = nn.Dropout(p=p_drop)
        self.pos_encoding = None
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=n_heads, dim_feedforward=2*d_model,
            batch_first=True, activation='gelu', dropout=0.1, norm_first=True
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
        self.cls = nn.Parameter(torch.zeros(1, 1, d_model))
        nn.init.normal_(self.cls, std=0.02)
        self.head = nn.Sequential(nn.LayerNorm(d_model), nn.Linear(d_model, n_cls))

    def _positional_encoding(self, L, d):
        pos = torch.arange(0, L, dtype=torch.float32).unsqueeze(1)
        i   = torch.arange(0, d, dtype=torch.float32).unsqueeze(0)
        angle = pos / torch.pow(10000, (2 * (i//2)) / d)
        pe = torch.zeros(L, d, dtype=torch.float32)
        pe[:, 0::2] = torch.sin(angle[:, 0::2])
        pe[:, 1::2] = torch.cos(angle[:, 1::2])
        return pe

    def forward(self, x):
        # x: (B, C, T)
        z = self.conv_t(x)           # (B, 128, T')
        z = self.proj(z)             # (B, d_model, T')
        z = self.dropout(z)          # regularización extra
        z = z.transpose(1, 2)        # (B, T', d_model)
        B, L, D = z.shape
        if (self.pos_encoding is None) or (self.pos_encoding.shape[0] != L) or (self.pos_encoding.shape[1] != D):
            self.pos_encoding = self._positional_encoding(L, D).to(z.device)
        z = z + self.pos_encoding[None, :, :]
        cls_tok = self.cls.expand(B, -1, -1)  # (B,1,D)
        z = torch.cat([cls_tok, z], dim=1)    # (B, 1+L, D)
        z = self.encoder(z)                   # (B, 1+L, D)
        cls = z[:, 0, :]
        return self.head(cls)

# =========================
# ENTRENAMIENTO / EVAL
# =========================
def standardize_per_channel(train_X, test_X):
    """Estandariza por canal usando SOLO estadísticas del train. Entradas (N,C,T)."""
    C = train_X.shape[1]
    train_X = train_X.astype(np.float32)
    test_X  = test_X.astype(np.float32)
    for c in range(C):
        mu = train_X[:, c, :].mean()
        sd = train_X[:, c, :].std()
        sd = sd if sd > 1e-6 else 1.0
        train_X[:, c, :] = (train_X[:, c, :] - mu) / sd
        test_X[:, c, :]  = (test_X[:, c, :] - mu) / sd
    return train_X, test_X

def train_one_fold(fold:int, resample_hz:int, do_notch:bool, do_bandpass:bool,
                   bp_lo:float, bp_hi:float, epochs:int, batch_size:int, lr:float,
                   device:str=torch.device('cuda' if torch.cuda.is_available() else 'cpu')):

    # --- sujetos fold ---
    train_sub, test_sub = load_fold_subjects(FOLDS_JSON, fold)
    train_sub = [s for s in train_sub if subject_id_to_int(s) not in EXCLUDE_SUBJECTS]
    test_sub  = [s for s in test_sub  if subject_id_to_int(s) not in EXCLUDE_SUBJECTS]

    # --- carga datos ---
    X_tr_list, y_tr_list, X_te_list, y_te_list = [], [], [], []
    sfreq = None

    for sid in tqdm(train_sub, desc=f"Cargando train fold{fold}"):
        Xs, ys, sf = load_subject_epochs(sid, resample_hz, do_notch, do_bandpass, DO_CAR, bp_lo, bp_hi)
        if len(ys) == 0: continue
        X_tr_list.append(Xs); y_tr_list.append(ys); sfreq = sf if sfreq is None else sfreq

    for sid in tqdm(test_sub, desc=f"Cargando test fold{fold}"):
        Xs, ys, sf = load_subject_epochs(sid, resample_hz, do_notch, do_bandpass, DO_CAR, bp_lo, bp_hi)
        if len(ys) == 0: continue
        X_te_list.append(Xs); y_te_list.append(ys); sfreq = sf if sfreq is None else sfreq

    if len(X_tr_list) == 0 or len(X_te_list) == 0:
        raise RuntimeError("Datos insuficientes tras carga de sujetos.")

    X_tr = np.concatenate(X_tr_list, axis=0); y_tr = np.concatenate(y_tr_list, axis=0)
    X_te = np.concatenate(X_te_list, axis=0); y_te = np.concatenate(y_te_list, axis=0)

    # --- diagnósticos de desbalance ---
    print("train counts:", dict(zip(CLASS_NAMES, np.bincount(y_tr, minlength=4))), "| total =", len(y_tr))
    print("test  counts:", dict(zip(CLASS_NAMES, np.bincount(y_te, minlength=4))), "| total =", len(y_te))

    # --- estandarización sin leakage ---
    X_tr, X_te = standardize_per_channel(X_tr, X_te)

    # --- DataLoaders (reproducibles) ---
    g = torch.Generator(device="cpu"); g.manual_seed(RANDOM_STATE)
    tr_ld = DataLoader(
        TensorDataset(torch.tensor(X_tr), torch.tensor(y_tr).long()),
        batch_size=batch_size, shuffle=True, drop_last=False,
        generator=g, worker_init_fn=seed_worker
    )
    te_ld = DataLoader(
        TensorDataset(torch.tensor(X_te), torch.tensor(y_te).long()),
        batch_size=batch_size, shuffle=False, drop_last=False,
        generator=g, worker_init_fn=seed_worker
    )

    # --- Modelo, optimizador, pérdida ---
    model = EEGCNNTransformer(n_ch=8, n_cls=4, d_model=128, n_heads=4, n_layers=2, p_drop=0.2).to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-2)
    crit = torch.nn.CrossEntropyLoss()

    best_acc, best_state = 0.0, None

    for ep in range(1, epochs+1):
        model.train()
        tr_loss, n_seen = 0.0, 0
        for xb, yb in tr_ld:
            xb = xb.to(device); yb = yb.to(device)
            opt.zero_grad()
            logits = model(xb)
            loss = crit(logits, yb)
            loss.backward()
            opt.step()
            tr_loss += loss.item() * len(yb); n_seen += len(yb)
        tr_loss /= max(1, n_seen)

        # --- Eval ---
        model.eval()
        preds, gts = [], []
        with torch.no_grad():
            for xb, yb in te_ld:
                xb = xb.to(device)
                p = model(xb).argmax(dim=1).cpu().numpy()
                preds.append(p); gts.append(yb.numpy())
        preds = np.concatenate(preds); gts = np.concatenate(gts)
        acc = accuracy_score(gts, preds)
        f1m = f1_score(gts, preds, average='macro')

        if acc > best_acc:
            best_acc = acc
            best_state = {k: v.cpu() for k, v in model.state_dict().items()}

        print(f"Epoch {ep:03d} | train_loss={tr_loss:.4f} | val_acc={acc:.4f} | val_f1m={f1m:.4f}")

    # cargar el mejor
    if best_state is not None:
        model.load_state_dict(best_state)

    # --- reporte final ---
    model.eval()
    preds, gts = [], []
    with torch.no_grad():
        for xb, yb in te_ld:
            xb = xb.to(device)
            p = model(xb).argmax(dim=1).cpu().numpy()
            preds.append(p); gts.append(yb.numpy())
    preds = np.concatenate(preds); gts = np.concatenate(gts)

    acc = accuracy_score(gts, preds)
    f1m = f1_score(gts, preds, average='macro')
    cm = confusion_matrix(gts, preds, labels=[0,1,2,3])

    print(f"\n=== RESULTADOS (fold {fold}) ===")
    print(f"Acc: {acc:.4f} | Macro-F1: {f1m:.4f}")
    print("Confusion matrix (rows=true, cols=pred):")
    print(cm)

    return acc, f1m, cm, model

# =========================
# EJECUCIÓN (Notebook)
# =========================
acc, f1m, cm, model = train_one_fold(
    fold=FOLD,
    resample_hz=RESAMPLE_HZ,
    do_notch=DO_NOTCH,
    do_bandpass=DO_BANDPASS,
    bp_lo=BP_LO,
    bp_hi=BP_HI,
    epochs=EPOCHS,
    batch_size=BATCH_SIZE,
    lr=LR,
)


🚀 Device: cuda


Cargando train fold1: 100%|██████████| 82/82 [00:12<00:00,  6.56it/s]
Cargando test fold1: 100%|██████████| 21/21 [00:03<00:00,  6.55it/s]


train counts: {'left': np.int64(1767), 'right': np.int64(1748), 'both_fists': np.int64(1752), 'both_feet': np.int64(1761)} | total = 7028
test  counts: {'left': np.int64(448), 'right': np.int64(446), 'both_fists': np.int64(442), 'both_feet': np.int64(450)} | total = 1786




Epoch 001 | train_loss=1.3710 | val_acc=0.3533 | val_f1m=0.3243
Epoch 002 | train_loss=1.2879 | val_acc=0.3897 | val_f1m=0.3910
Epoch 003 | train_loss=1.2595 | val_acc=0.3763 | val_f1m=0.3440
Epoch 004 | train_loss=1.2336 | val_acc=0.3757 | val_f1m=0.3722
Epoch 005 | train_loss=1.2297 | val_acc=0.4149 | val_f1m=0.4140
Epoch 006 | train_loss=1.2192 | val_acc=0.4009 | val_f1m=0.3996
Epoch 007 | train_loss=1.2103 | val_acc=0.3992 | val_f1m=0.3996
Epoch 008 | train_loss=1.2072 | val_acc=0.3852 | val_f1m=0.3780
Epoch 009 | train_loss=1.1913 | val_acc=0.4065 | val_f1m=0.4018
Epoch 010 | train_loss=1.1832 | val_acc=0.3852 | val_f1m=0.3849
Epoch 011 | train_loss=1.1808 | val_acc=0.4071 | val_f1m=0.4083
Epoch 012 | train_loss=1.1699 | val_acc=0.3852 | val_f1m=0.3858
Epoch 013 | train_loss=1.1525 | val_acc=0.3824 | val_f1m=0.3767
Epoch 014 | train_loss=1.1547 | val_acc=0.3886 | val_f1m=0.3849
Epoch 015 | train_loss=1.1468 | val_acc=0.3757 | val_f1m=0.3682
Epoch 016 | train_loss=1.1294 | val_acc=

In [21]:
# -*- coding: utf-8 -*-
# CNN + Transformer (PyTorch) para PhysioNet/BCI2000 MI — IMAGINERÍA (4 clases) con 8 canales.
# Incluye: class-weights + label smoothing, warmup+cosine+early stopping, sampler balanceado,
# augmentations ligeras, TTA estilo EEGNet y reproducibilidad fuerte (seed=42).
# Ventana: 6 s (TMIN=-1.0, TMAX=5.0). Z-score por época activable con ZSCORE_PER_EPOCH.
# Imprime métricas por fold (classification_report) y un runner de 5 folds tipo EEGNet.

# =========================
# Reproducibilidad (poner ANTES de importar torch)
# =========================
import os
os.environ['PYTHONHASHSEED'] = '42'
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'  # determinismo cuBLAS

import re, json, random
from pathlib import Path
from glob import glob
import numpy as np
import mne
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset, WeightedRandomSampler
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, classification_report

# =========================
# REPRODUCIBILIDAD (seed + determinismo)
# =========================
RANDOM_STATE = 42

def seed_everything(seed: int = 42):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    # determinismo cuDNN
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # evitar TF32 (puede romper determinismo)
    try:
        torch.backends.cuda.matmul.allow_tf32 = False
        torch.backends.cudnn.allow_tf32 = False
    except Exception:
        pass
    # reforzar determinismo
    try:
        torch.use_deterministic_algorithms(True)
    except Exception:
        pass

def seed_worker(worker_id: int):
    worker_seed = RANDOM_STATE + worker_id
    np.random.seed(worker_seed)
    random.seed(worker_seed)

seed_everything(RANDOM_STATE)

# =========================
# CONFIG (edita a tu gusto)
# =========================
PROJ = Path('..').resolve().parent
DATA_RAW = PROJ / 'data' / 'raw'                     # .../S###/S###R##.edf
FOLDS_JSON = PROJ / 'models' / 'folds' / 'Kfold5.json'

FOLD = 5
EPOCHS = 60
BATCH_SIZE = 64
LR = 1e-3
RESAMPLE_HZ = None          # None: mantiene 160 Hz original
DO_NOTCH = True             # ganador
DO_BANDPASS = False         # ganador
BP_LO, BP_HI = 4.0, 38.0
DO_CAR = False              # ganador

# Normalización
ZSCORE_PER_EPOCH = False     # <- activa/desactiva z-score por época (como EEGNet)

# Hiperparámetros del modelo
D_MODEL = 128              # (si quieres probar más capacidad: 160/192)
N_HEADS = 4
N_LAYERS = 2
P_DROP = 0.2               # sube a 0.3 si ves sobreajuste

# Ventana temporal (seg) — igual a EEGNet (6 s)
TMIN, TMAX = -1.0, 5.0

# Sliding windows / TTA en evaluación
SW_MODE = 'tta'   # 'none' | 'subwin' | 'tta'
SW_ENABLE = True  # si quieres apagar todo, pon False

# Para TTA de desplazamientos cortos ±25 ms (EEGNet)
TTA_SHIFTS_S = [-0.025, -0.0125, 0.0, 0.0125, 0.025]

# (solo si usas 'subwin'; no se usa en 'tta' EEGNet)
SW_LEN   = 2.0
SW_STRIDE = 0.5

EXCLUDE_SUBJECTS = {38, 88, 89, 92, 100, 104}
EXPECTED_8 = ['C3','C4','Cz','CP3','CP4','FC3','FC4','FCz']
CLASS_NAMES = ['left', 'right', 'both_fists', 'both_feet']

# Runs IMAGINERÍA
IMAGERY_RUNS_LR = {4, 8, 12}    # T1=left, T2=right
IMAGERY_RUNS_BF = {6, 10, 14}   # T1=both_fists, T2=both_feet

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🚀 Device: {DEVICE}")

# =========================
# UTILIDADES
# =========================
def normalize_ch_name(name: str) -> str:
    s = re.sub(r'[^A-Za-z0-9]', '', name)
    return s.upper()

NORMALIZED_TARGETS = [normalize_ch_name(c) for c in EXPECTED_8]

def pick_8_channels(raw: mne.io.BaseRaw) -> mne.io.BaseRaw:
    chs = raw.info['ch_names']
    norm_map = {normalize_ch_name(ch): ch for ch in chs}
    picked = []
    for target_norm, target_orig in zip(NORMALIZED_TARGETS, EXPECTED_8):
        if target_norm in norm_map:
            picked.append(norm_map[target_norm])
        else:
            raise RuntimeError(f"Canal requerido '{target_orig}' no encontrado. Disponibles: {chs}")
    return raw.pick(picks=picked)

def list_subject_imagery_edfs(subject_id: str) -> list:
    subj_dir = DATA_RAW / subject_id
    edfs = []
    for r in [4, 6, 8, 10, 12, 14]:
        pattern = str(subj_dir / f"{subject_id}R{r:02d}.edf")
        edfs.extend(glob(pattern))
    return sorted(edfs)

def load_subject_epochs(subject_id: str, resample_hz: int, do_notch: bool, do_bandpass: bool,
                        do_car: bool, bp_lo: float, bp_hi: float):
    edfs = list_subject_imagery_edfs(subject_id)
    if len(edfs) == 0:
        raise FileNotFoundError(f"No imagery EDF files for {subject_id} under {DATA_RAW}")

    X_list, y_list, sfreq_list = [], [], []

    for edf_path in edfs:
        m = re.search(r"R(\d{2})", Path(edf_path).name)
        run = int(m.group(1)) if m else -1

        raw = mne.io.read_raw_edf(edf_path, preload=True, verbose='ERROR')

        # Selección de 8 canales
        raw = pick_8_channels(raw)

        # --- PREPROC OPCIONAL ---
        if do_notch:
            raw.notch_filter(freqs=[60.0], picks='all', verbose='ERROR')
        if do_bandpass:
            raw.filter(l_freq=bp_lo, h_freq=bp_hi, picks='all', verbose='ERROR')
        if do_car:
            raw.set_eeg_reference('average', projection=False, verbose='ERROR')

        # Resample
        if resample_hz is not None and resample_hz > 0:
            raw.resample(resample_hz)
        sfreq = raw.info['sfreq']

        # Eventos (T0/T1/T2)
        events, event_id = mne.events_from_annotations(raw, verbose='ERROR')

        # Mantener solo T1/T2
        keep = {k: v for k, v in event_id.items() if k in {'T1', 'T2'}}
        if len(keep) == 0:
            continue

        epochs = mne.Epochs(raw, events=events, event_id=keep, tmin=TMIN, tmax=TMAX,
                            baseline=None, preload=True, verbose='ERROR')
        X = epochs.get_data()  # (n_epochs, 8, T)

        # ---- Z-SCORE POR ÉPOCA (opcional, como EEGNet) ----
        if ZSCORE_PER_EPOCH:
            X = X.astype(np.float32)
            eps = 1e-6
            mu = X.mean(axis=2, keepdims=True)              # (N,8,1)
            sd = X.std(axis=2, keepdims=True) + eps         # (N,8,1)
            X = (X - mu) / sd

        # Construir y según RUN
        ev_codes = epochs.events[:, 2]
        inv = {v: k for k, v in keep.items()}  # id -> 'T1'/'T2'
        y_run = []
        for code in ev_codes:
            lab = inv[code]
            if run in IMAGERY_RUNS_LR:
                y_run.append(0 if lab == 'T1' else 1)  # left/right
            elif run in IMAGERY_RUNS_BF:
                y_run.append(2 if lab == 'T1' else 3)  # both_fists/feet
            else:
                y_run.append(-1)
        y_run = np.array(y_run, dtype=int)
        keep_mask = y_run >= 0
        X = X[keep_mask]
        y = y_run[keep_mask]

        if len(y) == 0:
            continue

        X_list.append(X)
        y_list.append(y)
        sfreq_list.append(sfreq)

    if len(X_list) == 0:
        return np.empty((0, 8, 1)), np.empty((0,), dtype=int), None

    X_all = np.concatenate(X_list, axis=0)
    y_all = np.concatenate(y_list, axis=0)

    if len(set([round(s) for s in sfreq_list])) != 1:
        raise RuntimeError(f"Inconsistent sampling rates: {sfreq_list}")

    return X_all, y_all, sfreq_list[0]

def load_fold_subjects(folds_json: Path, fold: int):
    with open(folds_json, 'r') as f:
        data = json.load(f)
    for item in data.get('folds', []):
        if int(item.get('fold', -1)) == int(fold):
            train_sub = list(item.get('train', []))
            test_sub  = list(item.get('test', []))
            return train_sub, test_sub
    raise ValueError(f"Fold {fold} not found in {folds_json}")

def subject_id_to_int(s: str) -> int:
    m = re.match(r'[Ss](\d+)', s)
    return int(m.group(1)) if m else -1

def class_count_summary(y, name):
    bc = np.bincount(y, minlength=4)
    print(f"{name} counts:", dict(zip(CLASS_NAMES, bc)), "| total =", bc.sum())

# =========================
# MODELO: CNN + Transformer
# =========================
class DepthwiseSeparableConv(nn.Module):
    def __init__(self, in_ch, out_ch, k, s=1, p=0):
        super().__init__()
        self.dw = nn.Conv1d(in_ch, in_ch, kernel_size=k, stride=s, padding=p, groups=in_ch, bias=False)
        self.pw = nn.Conv1d(in_ch, out_ch, kernel_size=1, bias=False)
        self.bn = nn.BatchNorm1d(out_ch)
        self.act = nn.ELU()
    def forward(self, x):
        x = self.dw(x); x = self.pw(x); x = self.bn(x)
        return self.act(x)

class EEGCNNTransformer(nn.Module):
    def __init__(self, n_ch=8, n_cls=4, d_model=128, n_heads=4, n_layers=2, p_drop=0.2):
        super().__init__()
        self.conv_t = nn.Sequential(
            nn.Conv1d(n_ch, 32, kernel_size=129, stride=2, padding=64, bias=False),
            nn.BatchNorm1d(32), nn.ELU(),
            DepthwiseSeparableConv(32, 64, k=31, s=2, p=15),
            DepthwiseSeparableConv(64, 128, k=15, s=2, p=7),
        )
        self.proj = nn.Conv1d(128, d_model, kernel_size=1, bias=False)
        self.dropout = nn.Dropout(p=p_drop)
        self.pos_encoding = None
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=n_heads, dim_feedforward=2*d_model,
            batch_first=True, activation='gelu', dropout=0.1, norm_first=True
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
        self.cls = nn.Parameter(torch.zeros(1, 1, d_model))
        nn.init.normal_(self.cls, std=0.02)
        self.head = nn.Sequential(nn.LayerNorm(d_model), nn.Linear(d_model, n_cls))

    def _positional_encoding(self, L, d):
        pos = torch.arange(0, L, dtype=torch.float32).unsqueeze(1)
        i   = torch.arange(0, d, dtype=torch.float32).unsqueeze(0)
        angle = pos / torch.pow(10000, (2 * (i//2)) / d)
        pe = torch.zeros(L, d, dtype=torch.float32)
        pe[:, 0::2] = torch.sin(angle[:, 0::2])
        pe[:, 1::2] = torch.cos(angle[:, 1::2])
        return pe

    def forward(self, x):
        z = self.conv_t(x)           # (B, 128, T')
        z = self.proj(z)             # (B, d_model, T')
        z = self.dropout(z)
        z = z.transpose(1, 2)        # (B, T', d_model)
        B, L, D = z.shape
        if (self.pos_encoding is None) or (self.pos_encoding.shape[0] != L) or (self.pos_encoding.shape[1] != D):
            self.pos_encoding = self._positional_encoding(L, D).to(z.device)
        z = z + self.pos_encoding[None, :, :]
        cls_tok = self.cls.expand(B, -1, -1)  # (B,1,D)
        z = torch.cat([cls_tok, z], dim=1)    # (B, 1+L, D)
        z = self.encoder(z)                   # (B, 1+L, D)
        cls = z[:, 0, :]
        return self.head(cls)

# =========================
# ENTRENAMIENTO / EVAL
# =========================
def standardize_per_channel(train_X, test_X):
    C = train_X.shape[1]
    train_X = train_X.astype(np.float32)
    test_X  = test_X.astype(np.float32)
    for c in range(C):
        mu = train_X[:, c, :].mean()
        sd = train_X[:, c, :].std()
        sd = sd if sd > 1e-6 else 1.0
        train_X[:, c, :] = (train_X[:, c, :] - mu) / sd
        test_X[:, c, :]  = (test_X[:, c, :] - mu) / sd
    return train_X, test_X

# --- Augmentations ligeras ---
def augment_batch(xb, p_jitter=0.25, p_noise=0.25, p_chdrop=0.15,
                  max_jitter_frac=0.02, noise_std=0.02):
    """
    xb: torch.Tensor (B,C,T) (se asume ya normalizado)
    - jitter temporal (roll)
    - ruido gaussiano
    - channel dropout (apagar 1 canal aleatorio)
    """
    B, C, T = xb.shape
    if np.random.rand() < p_jitter:
        max_shift = int(max(1, T*max_jitter_frac))
        shifts = torch.randint(low=-max_shift, high=max_shift+1, size=(B,), device=xb.device)
        for i in range(B):
            xb[i] = torch.roll(xb[i], shifts=int(shifts[i].item()), dims=-1)
    if np.random.rand() < p_noise:
        xb = xb + noise_std*torch.randn_like(xb)
    if np.random.rand() < p_chdrop:
        k = 1
        for i in range(B):
            idx = torch.randperm(C, device=xb.device)[:k]
            xb[i, idx, :] = 0.0
    return xb

def subwindow_logits(model, X, sfreq, sw_len, sw_stride, device):
    """
    Promedia logits sobre sub-ventanas de longitud sw_len (seg).
    Entrada:  X (N,C,T) estandarizado
    Salida:   (N,4)
    """
    model.eval()
    wl = int(round(sw_len * sfreq))
    st = int(round(sw_stride * sfreq))
    wl = max(1, min(wl, X.shape[-1]))
    st = max(1, st)

    out = []
    with torch.no_grad():
        for i in range(X.shape[0]):
            x = X[i]  # (C,T)
            acc = []
            for s in range(0, max(1, X.shape[-1]-wl+1), st):
                seg = x[:, s:s+wl]
                if seg.shape[-1] < wl:
                    pad = wl - seg.shape[-1]
                    seg = np.pad(seg, ((0,0),(0,pad)), mode='edge')
                xb = torch.tensor(seg[None, ...], dtype=torch.float32, device=device)
                logit = model(xb).detach().cpu().numpy()[0]  # (4,)
                acc.append(logit)
            acc = np.mean(np.stack(acc, axis=0), axis=0) if len(acc) else np.zeros(4, dtype=np.float32)
            out.append(acc)
    return np.stack(out, axis=0)  # (N,4)

def time_shift_tta_logits(model, X, sfreq, shifts_s, device):
    """
    Test-Time Augmentation (EEGNet-like): 5 desplazamientos cortos ±25ms manteniendo longitud T.
    Entrada:  X (N,C,T) estandarizado
    Salida:   (N,4)
    """
    model.eval()
    T = X.shape[-1]
    out = []
    with torch.no_grad():
        for i in range(X.shape[0]):
            x0 = X[i]  # (C,T)
            acc = []
            for sh in shifts_s:
                shift = int(round(sh * sfreq))
                if shift == 0:
                    x = x0
                elif shift > 0:
                    x = np.pad(x0[:, shift:], ((0,0),(0,shift)), mode='edge')[:, :T]
                else:
                    shift = -shift
                    x = np.pad(x0[:, :-shift], ((0,0),(shift,0)), mode='edge')[:, :T]
                xb = torch.tensor(x[None, ...], dtype=torch.float32, device=device)
                logit = model(xb).detach().cpu().numpy()[0]  # (4,)
                acc.append(logit)
            acc = np.mean(np.stack(acc, axis=0), axis=0)
            out.append(acc)
    return np.stack(out, axis=0)  # (N,4)

def load_fold_subject_ids_and_counts(fold:int):
    train_sub, test_sub = load_fold_subjects(FOLDS_JSON, fold)
    train_sub = [s for s in train_sub if subject_id_to_int(s) not in EXCLUDE_SUBJECTS]
    test_sub  = [s for s in test_sub  if subject_id_to_int(s) not in EXCLUDE_SUBJECTS]
    return train_sub, test_sub

def train_one_fold(fold:int, resample_hz:int, do_notch:bool, do_bandpass:bool,
                   bp_lo:float, bp_hi:float, epochs:int, batch_size:int, lr:float,
                   device:str=torch.device('cuda' if torch.cuda.is_available() else 'cpu')):

    # --- sujetos fold ---
    train_sub, test_sub = load_fold_subject_ids_and_counts(fold)

    # --- carga datos ---
    X_tr_list, y_tr_list, X_te_list, y_te_list = [], [], [], []
    sfreq = None

    for sid in tqdm(train_sub, desc=f"Cargando train fold{fold}"):
        Xs, ys, sf = load_subject_epochs(sid, resample_hz, do_notch, do_bandpass, DO_CAR, bp_lo, bp_hi)
        if len(ys) == 0: continue
        X_tr_list.append(Xs); y_tr_list.append(ys); sfreq = sf if sfreq is None else sfreq

    for sid in tqdm(test_sub, desc=f"Cargando test fold{fold}"):
        Xs, ys, sf = load_subject_epochs(sid, resample_hz, do_notch, do_bandpass, DO_CAR, bp_lo, bp_hi)
        if len(ys) == 0: continue
        X_te_list.append(Xs); y_te_list.append(ys); sfreq = sf if sfreq is None else sfreq

    if len(X_tr_list) == 0 or len(X_te_list) == 0:
        raise RuntimeError("Datos insuficientes tras carga de sujetos.")

    X_tr = np.concatenate(X_tr_list, axis=0); y_tr = np.concatenate(y_tr_list, axis=0)
    X_te = np.concatenate(X_te_list, axis=0); y_te = np.concatenate(y_te_list, axis=0)

    # --- diagnósticos de desbalance ---
    # (y cabecera al estilo EEGNet con n_train / n_test)
    print(f"[Fold {fold}/5] Entrenando modelo global... (n_train={len(y_tr)} | n_test={len(y_te)})")
    print("train counts:", dict(zip(CLASS_NAMES, np.bincount(y_tr, minlength=4))), "| total =", len(y_tr))
    print("test  counts:", dict(zip(CLASS_NAMES, np.bincount(y_te, minlength=4))), "| total =", len(y_te))

    # --- Normalización ---
    if ZSCORE_PER_EPOCH:
        # Ya se normalizó por época dentro de load_subject_epochs → NO hacer z-score global
        X_tr_std, X_te_std = X_tr, X_te
    else:
        # Sin z-score por época → aplicar z-score global por canal usando SOLO train
        X_tr_std, X_te_std = standardize_per_channel(X_tr, X_te)

    # --- DataLoaders (sampler balanceado + reproducible) ---
    train_ds = TensorDataset(torch.tensor(X_tr_std), torch.tensor(y_tr).long())
    class_counts = np.bincount(y_tr, minlength=4)
    class_weights_vec = class_counts.sum() / (4.0 * np.maximum(class_counts, 1))
    sample_weights = class_weights_vec[train_ds.tensors[1].numpy()]
    sampler = WeightedRandomSampler(
        weights=torch.tensor(sample_weights, dtype=torch.float32),
        num_samples=len(train_ds), replacement=True,
        generator=torch.Generator().manual_seed(RANDOM_STATE)
    )
    tr_ld = DataLoader(train_ds, batch_size=batch_size, sampler=sampler,
                       drop_last=False, worker_init_fn=seed_worker)

    te_ld = DataLoader(
        TensorDataset(torch.tensor(X_te_std), torch.tensor(y_te).long()),
        batch_size=batch_size, shuffle=False, drop_last=False,
        worker_init_fn=seed_worker
    )

    # --- Modelo ---
    model = EEGCNNTransformer(
        n_ch=8, n_cls=4, d_model=D_MODEL, n_heads=N_HEADS, n_layers=N_LAYERS, p_drop=P_DROP
    ).to(device)

    # --- Optimizador ---
    opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-2)

    # (1) Class-weights + label smoothing
    class_weights = torch.tensor(class_weights_vec, dtype=torch.float32, device=device)
    crit = torch.nn.CrossEntropyLoss(weight=class_weights, label_smoothing=0.05)

    # (2) Warmup + Cosine LR + Early stopping por macro-F1
    from torch.optim.lr_scheduler import LambdaLR
    total_epochs = epochs
    warmup_epochs = max(1, min(5, epochs//10))
    def lr_lambda(current_epoch):
        if current_epoch < warmup_epochs:
            return (current_epoch + 1) / warmup_epochs
        progress = (current_epoch - warmup_epochs) / max(1, (total_epochs - warmup_epochs))
        progress = min(1.0, max(0.0, progress))
        return 0.5 * (1.0 + np.cos(np.pi * progress))
    scheduler = LambdaLR(opt, lr_lambda=lr_lambda)

    best_f1, best_state, patience, wait = 0.0, None, 10, 0

    for ep in range(1, epochs+1):
        model.train()
        tr_loss, n_seen = 0.0, 0
        for xb, yb in tr_ld:
            # augs en GPU (se asume input ya normalizado)
            xb = xb.to(device)
            xb = augment_batch(xb)
            yb = yb.to(device)

            opt.zero_grad()
            logits = model(xb)
            loss = crit(logits, yb)
            loss.backward()
            opt.step()
            tr_loss += loss.item() * len(yb); n_seen += len(yb)
        tr_loss /= max(1, n_seen)

        # --- Eval en test set (mismo patrón que tu script actual) ---
        model.eval()
        preds, gts = [], []
        with torch.no_grad():
            for xb, yb in te_ld:
                xb = xb.to(device)
                p = model(xb).argmax(dim=1).cpu().numpy()
                preds.append(p); gts.append(yb.numpy())
        preds = np.concatenate(preds); gts = np.concatenate(gts)
        acc = accuracy_score(gts, preds)
        f1m = f1_score(gts, preds, average='macro')

        improved = f1m > best_f1 + 1e-4
        if improved:
            best_f1 = f1m
            best_state = {k: v.cpu() for k, v in model.state_dict().items()}
            wait = 0
        else:
            wait += 1

        scheduler.step()
        # Log estilo EEGNet (usa acc/f1m del split de test como "val_*" porque no tenemos val separado)
        print(f"  Época {ep:3d} | train_loss={tr_loss:.4f} | val_acc={acc:.4f} | val_f1m={f1m:.4f} | LR={scheduler.get_last_lr()[0]:.6f}")

        if wait >= patience:
            print(f"  Early stopping en época {ep} (mejor val_f1m={best_f1:.4f})")
            break

    # cargar el mejor
    if best_state is not None:
        model.load_state_dict(best_state)

    # --- reporte final con TTA/subwin/none ---
    model.eval()
    sfreq_used = RESAMPLE_HZ if RESAMPLE_HZ is not None else int(round(X_te_std.shape[-1] / (TMAX - TMIN)))

    if (not SW_ENABLE) or SW_MODE == 'none':
        preds, gts = [], []
        with torch.no_grad():
            for xb, yb in te_ld:
                xb = xb.to(device)
                p = model(xb).argmax(dim=1).cpu().numpy()
                preds.append(p); gts.append(yb.numpy())
        preds = np.concatenate(preds); gts = np.concatenate(gts)

    elif SW_MODE == 'subwin':
        logits = subwindow_logits(model, X_te_std, sfreq_used, SW_LEN, SW_STRIDE, device)  # (N,4)
        preds = logits.argmax(axis=1)
        gts = y_te

    elif SW_MODE == 'tta':
        logits = time_shift_tta_logits(model, X_te_std, sfreq_used, TTA_SHIFTS_S, device)  # (N,4)
        preds = logits.argmax(axis=1)
        gts = y_te

    else:
        raise ValueError(f"SW_MODE desconocido: {SW_MODE}")

    acc = accuracy_score(gts, preds)
    f1m = f1_score(gts, preds, average='macro')
    cm = confusion_matrix(gts, preds, labels=[0,1,2,3])

    # ---- Salida estilo EEGNet ----
    print(f"[Fold {fold}/5] Global acc={acc:.4f}")
    print("\n" + classification_report(gts, preds, target_names=[
        'Left','Right','Both Fists','Both Feet'
    ], digits=4))

    print(f"=== RESULTADOS (fold {fold}) ===")
    print(f"Acc: {acc:.4f} | Macro-F1: {f1m:.4f}")
    print("Confusion matrix (rows=true, cols=pred):")
    print(cm)

    return acc, f1m, cm, model

# =========================
# Runner multi-fold con resumen (igual estilo EEGNet)
# =========================
def run_all_folds(n_folds=5):
    accs, f1s = [], []
    for f in range(1, n_folds+1):
        print("\n" + "="*12 + f"  FOLD {f}  " + "="*12)
        acc, f1m, cm, _ = train_one_fold(
            fold=f,
            resample_hz=RESAMPLE_HZ,
            do_notch=DO_NOTCH,
            do_bandpass=DO_BANDPASS,
            bp_lo=BP_LO,
            bp_hi=BP_HI,
            epochs=EPOCHS,
            batch_size=BATCH_SIZE,
            lr=LR,
        )
        accs.append(acc); f1s.append(f1m)
    if accs:
        print("\n" + "="*60)
        print("RESULTADOS FINALES (Transformer)")
        print("="*60)
        print("Acc folds:", [f"{a:.4f}" for a in accs])
        print("F1m folds:", [f"{a:.4f}" for a in f1s])
        print(f"Acc mean: {np.mean(accs):.4f} | std: {np.std(accs):.4f}")
        print(f"F1m mean: {np.mean(f1s):.4f} | std: {np.std(f1s):.4f}")

# =========================
# EJECUCIÓN
# =========================
if __name__ == "__main__":
    # Opción A: correr todos los folds 1..5 con reporte por fold y resumen
    run_all_folds(n_folds=5)

    # Opción B: correr un solo fold
    # acc, f1m, cm, model = train_one_fold(
    #     fold=FOLD,
    #     resample_hz=RESAMPLE_HZ,
    #     do_notch=DO_NOTCH,
    #     do_bandpass=DO_BANDPASS,
    #     bp_lo=BP_LO,
    #     bp_hi=BP_HI,
    #     epochs=EPOCHS,
    #     batch_size=BATCH_SIZE,
    #     lr=LR,
    # )


🚀 Device: cuda



Cargando train fold1: 100%|██████████| 82/82 [00:09<00:00,  8.35it/s]
Cargando test fold1: 100%|██████████| 21/21 [00:02<00:00,  8.35it/s]


[Fold 1/5] Entrenando modelo global... (n_train=6888 | n_test=1764)
train counts: {'left': np.int64(1722), 'right': np.int64(1722), 'both_fists': np.int64(1722), 'both_feet': np.int64(1722)} | total = 6888
test  counts: {'left': np.int64(441), 'right': np.int64(441), 'both_fists': np.int64(441), 'both_feet': np.int64(441)} | total = 1764




  Época   1 | train_loss=1.3483 | val_acc=0.3656 | val_f1m=0.3498 | LR=0.000400
  Época   2 | train_loss=1.2134 | val_acc=0.4427 | val_f1m=0.4359 | LR=0.000600
  Época   3 | train_loss=1.1752 | val_acc=0.4598 | val_f1m=0.4615 | LR=0.000800
  Época   4 | train_loss=1.1564 | val_acc=0.4495 | val_f1m=0.4522 | LR=0.001000
  Época   5 | train_loss=1.1524 | val_acc=0.4507 | val_f1m=0.4365 | LR=0.001000
  Época   6 | train_loss=1.1156 | val_acc=0.4717 | val_f1m=0.4685 | LR=0.000999
  Época   7 | train_loss=1.1040 | val_acc=0.4745 | val_f1m=0.4572 | LR=0.000997
  Época   8 | train_loss=1.1060 | val_acc=0.4705 | val_f1m=0.4688 | LR=0.000993
  Época   9 | train_loss=1.1008 | val_acc=0.4422 | val_f1m=0.4248 | LR=0.000987
  Época  10 | train_loss=1.0739 | val_acc=0.4586 | val_f1m=0.4385 | LR=0.000980
  Época  11 | train_loss=1.0749 | val_acc=0.4830 | val_f1m=0.4818 | LR=0.000971
  Época  12 | train_loss=1.0540 | val_acc=0.4615 | val_f1m=0.4579 | LR=0.000961
  Época  13 | train_loss=1.0441 | val_ac

Cargando train fold2: 100%|██████████| 82/82 [00:09<00:00,  8.31it/s]
Cargando test fold2: 100%|██████████| 21/21 [00:02<00:00,  8.19it/s]


[Fold 2/5] Entrenando modelo global... (n_train=6888 | n_test=1764)
train counts: {'left': np.int64(1722), 'right': np.int64(1722), 'both_fists': np.int64(1722), 'both_feet': np.int64(1722)} | total = 6888
test  counts: {'left': np.int64(441), 'right': np.int64(441), 'both_fists': np.int64(441), 'both_feet': np.int64(441)} | total = 1764




  Época   1 | train_loss=1.3581 | val_acc=0.4286 | val_f1m=0.4178 | LR=0.000400
  Época   2 | train_loss=1.2361 | val_acc=0.4864 | val_f1m=0.4735 | LR=0.000600
  Época   3 | train_loss=1.2014 | val_acc=0.5249 | val_f1m=0.5194 | LR=0.000800
  Época   4 | train_loss=1.1905 | val_acc=0.5283 | val_f1m=0.5318 | LR=0.001000
  Época   5 | train_loss=1.1873 | val_acc=0.5340 | val_f1m=0.5350 | LR=0.001000
  Época   6 | train_loss=1.1505 | val_acc=0.5368 | val_f1m=0.5404 | LR=0.000999
  Época   7 | train_loss=1.1486 | val_acc=0.5459 | val_f1m=0.5367 | LR=0.000997
  Época   8 | train_loss=1.1518 | val_acc=0.5193 | val_f1m=0.4985 | LR=0.000993
  Época   9 | train_loss=1.1339 | val_acc=0.5329 | val_f1m=0.5064 | LR=0.000987
  Época  10 | train_loss=1.1026 | val_acc=0.5482 | val_f1m=0.5315 | LR=0.000980
  Época  11 | train_loss=1.1050 | val_acc=0.5465 | val_f1m=0.5389 | LR=0.000971
  Época  12 | train_loss=1.0913 | val_acc=0.5527 | val_f1m=0.5528 | LR=0.000961
  Época  13 | train_loss=1.0917 | val_ac

Cargando train fold3: 100%|██████████| 82/82 [00:09<00:00,  8.30it/s]
Cargando test fold3: 100%|██████████| 21/21 [00:02<00:00,  8.33it/s]


[Fold 3/5] Entrenando modelo global... (n_train=6888 | n_test=1764)
train counts: {'left': np.int64(1722), 'right': np.int64(1722), 'both_fists': np.int64(1722), 'both_feet': np.int64(1722)} | total = 6888
test  counts: {'left': np.int64(441), 'right': np.int64(441), 'both_fists': np.int64(441), 'both_feet': np.int64(441)} | total = 1764




  Época   1 | train_loss=1.3863 | val_acc=0.3254 | val_f1m=0.2574 | LR=0.000400
  Época   2 | train_loss=1.2550 | val_acc=0.4677 | val_f1m=0.4658 | LR=0.000600
  Época   3 | train_loss=1.1853 | val_acc=0.4773 | val_f1m=0.4771 | LR=0.000800
  Época   4 | train_loss=1.1739 | val_acc=0.4575 | val_f1m=0.4612 | LR=0.001000
  Época   5 | train_loss=1.1491 | val_acc=0.4864 | val_f1m=0.4779 | LR=0.001000
  Época   6 | train_loss=1.1223 | val_acc=0.4881 | val_f1m=0.4881 | LR=0.000999
  Época   7 | train_loss=1.1163 | val_acc=0.4683 | val_f1m=0.4501 | LR=0.000997
  Época   8 | train_loss=1.1067 | val_acc=0.4688 | val_f1m=0.4621 | LR=0.000993
  Época   9 | train_loss=1.0928 | val_acc=0.4813 | val_f1m=0.4665 | LR=0.000987
  Época  10 | train_loss=1.0775 | val_acc=0.4813 | val_f1m=0.4804 | LR=0.000980
  Época  11 | train_loss=1.0863 | val_acc=0.4717 | val_f1m=0.4595 | LR=0.000971
  Época  12 | train_loss=1.0623 | val_acc=0.4586 | val_f1m=0.4511 | LR=0.000961
  Época  13 | train_loss=1.0616 | val_ac

Cargando train fold4: 100%|██████████| 83/83 [00:10<00:00,  8.27it/s]
Cargando test fold4: 100%|██████████| 20/20 [00:02<00:00,  8.26it/s]


[Fold 4/5] Entrenando modelo global... (n_train=6972 | n_test=1680)
train counts: {'left': np.int64(1743), 'right': np.int64(1743), 'both_fists': np.int64(1743), 'both_feet': np.int64(1743)} | total = 6972
test  counts: {'left': np.int64(420), 'right': np.int64(420), 'both_fists': np.int64(420), 'both_feet': np.int64(420)} | total = 1680




  Época   1 | train_loss=1.3840 | val_acc=0.4030 | val_f1m=0.4010 | LR=0.000400
  Época   2 | train_loss=1.2280 | val_acc=0.4929 | val_f1m=0.4864 | LR=0.000600
  Época   3 | train_loss=1.1767 | val_acc=0.4726 | val_f1m=0.4500 | LR=0.000800
  Época   4 | train_loss=1.1629 | val_acc=0.4994 | val_f1m=0.4880 | LR=0.001000
  Época   5 | train_loss=1.1573 | val_acc=0.5018 | val_f1m=0.4922 | LR=0.001000
  Época   6 | train_loss=1.1326 | val_acc=0.5018 | val_f1m=0.4944 | LR=0.000999
  Época   7 | train_loss=1.1316 | val_acc=0.5071 | val_f1m=0.5104 | LR=0.000997
  Época   8 | train_loss=1.1127 | val_acc=0.5131 | val_f1m=0.5082 | LR=0.000993
  Época   9 | train_loss=1.1066 | val_acc=0.5119 | val_f1m=0.5065 | LR=0.000987
  Época  10 | train_loss=1.0880 | val_acc=0.4940 | val_f1m=0.4961 | LR=0.000980
  Época  11 | train_loss=1.0813 | val_acc=0.5095 | val_f1m=0.5083 | LR=0.000971
  Época  12 | train_loss=1.0809 | val_acc=0.5244 | val_f1m=0.5223 | LR=0.000961
  Época  13 | train_loss=1.0706 | val_ac

Cargando train fold5: 100%|██████████| 83/83 [00:09<00:00,  8.40it/s]
Cargando test fold5: 100%|██████████| 20/20 [00:02<00:00,  8.30it/s]


[Fold 5/5] Entrenando modelo global... (n_train=6972 | n_test=1680)
train counts: {'left': np.int64(1743), 'right': np.int64(1743), 'both_fists': np.int64(1743), 'both_feet': np.int64(1743)} | total = 6972
test  counts: {'left': np.int64(420), 'right': np.int64(420), 'both_fists': np.int64(420), 'both_feet': np.int64(420)} | total = 1680




  Época   1 | train_loss=1.3911 | val_acc=0.3518 | val_f1m=0.3082 | LR=0.000400
  Época   2 | train_loss=1.2738 | val_acc=0.5071 | val_f1m=0.5038 | LR=0.000600
  Época   3 | train_loss=1.1954 | val_acc=0.5208 | val_f1m=0.5012 | LR=0.000800
  Época   4 | train_loss=1.1863 | val_acc=0.5393 | val_f1m=0.5232 | LR=0.001000
  Época   5 | train_loss=1.1657 | val_acc=0.5232 | val_f1m=0.5089 | LR=0.001000
  Época   6 | train_loss=1.1465 | val_acc=0.5399 | val_f1m=0.5326 | LR=0.000999
  Época   7 | train_loss=1.1450 | val_acc=0.5054 | val_f1m=0.4860 | LR=0.000997
  Época   8 | train_loss=1.1193 | val_acc=0.5214 | val_f1m=0.5132 | LR=0.000993
  Época   9 | train_loss=1.1186 | val_acc=0.5488 | val_f1m=0.5431 | LR=0.000987
  Época  10 | train_loss=1.1034 | val_acc=0.5333 | val_f1m=0.5356 | LR=0.000980
  Época  11 | train_loss=1.0964 | val_acc=0.5339 | val_f1m=0.5147 | LR=0.000971
  Época  12 | train_loss=1.0925 | val_acc=0.5423 | val_f1m=0.5327 | LR=0.000961
  Época  13 | train_loss=1.0821 | val_ac

In [None]:
# -*- coding: utf-8 -*-
# CNN+Transformer para MI (4 clases, 8 canales) con:
# (1) Focal Loss
# (4) Aumentos más fuertes y estables
# (6) LR warmup+cosine, early stopping por F1 macro
# Cambios propuestos:
# - GroupNorm en conv (en lugar de BN)
# - EMA de pesos para val/test
# - Balanced sampler por sujeto/clase (WeightedRandomSampler)
# - Dropout antes del encoder = 0.3
# - Cosine min_factor=0.2
# - Ensemble TTA + Subwindows en test (opcional)
# - Val estratificada por sujeto (opcional)

import os
os.environ['PYTHONHASHSEED'] = '42'
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'

import re, json, random
from pathlib import Path
from glob import glob
import numpy as np
import mne
from tqdm import tqdm
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset, WeightedRandomSampler
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, classification_report
from sklearn.model_selection import StratifiedShuffleSplit

# =========================
# REPRODUCIBILIDAD
# =========================
RANDOM_STATE = 42
def seed_everything(seed: int = 42):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    try:
        torch.backends.cuda.matmul.allow_tf32 = False
        torch.backends.cudnn.allow_tf32 = False
    except Exception:
        pass
    try:
        torch.use_deterministic_algorithms(True)
    except Exception:
        pass

def seed_worker(worker_id: int):
    worker_seed = RANDOM_STATE + worker_id
    np.random.seed(worker_seed)
    random.seed(worker_seed)

seed_everything(RANDOM_STATE)

# =========================
# CONFIG
# =========================
PROJ = Path('..').resolve().parent
DATA_RAW = PROJ / 'data' / 'raw'
FOLDS_JSON = PROJ / 'models' / 'folds' / 'Kfold5.json'

EPOCHS = 60
BATCH_SIZE = 64        # múltiplo de 4
BASE_LR = 1e-3
WARMUP_EPOCHS = 8
PATIENCE = 12

# Split de validación por fold (por sujetos)
VAL_SUBJECT_FRAC = 0.18  # ≈ 18% de sujetos del train → val
VAL_STRAT_SUBJECT = True # <- ACTIVADO: estratifica por etiqueta dominante de cada sujeto

# Prepro
RESAMPLE_HZ = None
DO_NOTCH = True
DO_BANDPASS = False       # dejamos apagado para respetar tu setup original
BP_LO, BP_HI = 4.0, 38.0
DO_CAR = False
ZSCORE_PER_EPOCH = False

# Modelo
D_MODEL = 128
N_HEADS = 4
N_LAYERS = 2
P_DROP = 0.2            # dropout en conv stack
P_DROP_ENCODER = 0.3    # dropout antes del encoder (↑ respecto a 0.2)

# Ventana temporal
TMIN, TMAX = -1.0, 5.0

# TTA / SUBWINDOW en TEST
SW_MODE = 'tta'   # 'none'|'subwin'|'tta'
SW_ENABLE = True
TTA_SHIFTS_S = [-0.05, -0.025, 0.0, 0.025, 0.05]
SW_LEN, SW_STRIDE = 4.5, 2.0
COMBINE_TTA_AND_SUBWIN = False  # Promedia logits TTA y Subwindow (50/50) si True

# Sampler balanceado
USE_WEIGHTED_SAMPLER = True

# EMA
USE_EMA = True
EMA_DECAY = 0.999  # 0.999 ~ suave y efectivo

EXCLUDE_SUBJECTS = {38, 88, 89, 92, 100, 104}
EXPECTED_8 = ['C3','C4','Cz','CP3','CP4','FC3','FC4','FCz']
CLASS_NAMES = ['left', 'right', 'both_fists', 'both_feet']

IMAGERY_RUNS_LR = {4, 8, 12}
IMAGERY_RUNS_BF = {6, 10, 14}

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🚀 Usando dispositivo: {DEVICE}")
print("🧠 INICIANDO EXPERIMENTO CON CNN+Transformer (K-Fold por sujeto como EEGNet)")
print(f"🔧 Configuración: 4c, 8 canales, 6s | EPOCHS={EPOCHS}, BATCH={BATCH_SIZE}, LR={BASE_LR} | ZSCORE_PER_EPOCH={ZSCORE_PER_EPOCH}")

# =========================
# UTILIDADES I/O
# =========================
def normalize_ch_name(name: str) -> str:
    s = re.sub(r'[^A-Za-z0-9]', '', name)
    return s.upper()

NORMALIZED_TARGETS = [normalize_ch_name(c) for c in EXPECTED_8]

def pick_8_channels(raw: mne.io.BaseRaw) -> mne.io.BaseRaw:
    chs = raw.info['ch_names']
    norm_map = {normalize_ch_name(ch): ch for ch in chs}
    picked = []
    for target_norm, target_orig in zip(NORMALIZED_TARGETS, EXPECTED_8):
        if target_norm in norm_map:
            picked.append(norm_map[target_norm])
        else:
            raise RuntimeError(f"Canal requerido '{target_orig}' no encontrado. Disponibles: {chs}")
    return raw.pick(picks=picked)

def list_subject_imagery_edfs(subject_id: str) -> list:
    subj_dir = DATA_RAW / subject_id
    edfs = []
    for r in [4, 6, 8, 10, 12, 14]:
        edfs.extend(glob(str(subj_dir / f"{subject_id}R{r:02d}.edf")))
    return sorted(edfs)

def subject_id_to_int(s: str) -> int:
    m = re.match(r'[Ss](\d+)', s)
    return int(m.group(1)) if m else -1

def load_subject_epochs(subject_id: str, resample_hz: int, do_notch: bool, do_bandpass: bool,
                        do_car: bool, bp_lo: float, bp_hi: float):
    edfs = list_subject_imagery_edfs(subject_id)
    if len(edfs) == 0:
        return np.empty((0,8,1), dtype=np.float32), np.empty((0,), dtype=int), None

    X_list, y_list, sfreq_list = [], [], []

    for edf_path in edfs:
        m = re.search(r"R(\d{2})", Path(edf_path).name)
        run = int(m.group(1)) if m else -1

        raw = mne.io.read_raw_edf(edf_path, preload=True, verbose='ERROR')
        raw = pick_8_channels(raw)

        if do_notch:
            raw.notch_filter(freqs=[60.0], picks='all', verbose='ERROR')
        if do_bandpass:
            raw.filter(l_freq=bp_lo, h_freq=bp_hi, picks='all', verbose='ERROR')
        if do_car:
            raw.set_eeg_reference('average', projection=False, verbose='ERROR')

        if resample_hz is not None and resample_hz > 0:
            raw.resample(resample_hz)
        sfreq = raw.info['sfreq']

        events, event_id = mne.events_from_annotations(raw, verbose='ERROR')
        keep = {k: v for k, v in event_id.items() if k in {'T1', 'T2'}}
        if len(keep) == 0:
            continue

        epochs = mne.Epochs(raw, events=events, event_id=keep, tmin=TMIN, tmax=TMAX,
                            baseline=None, preload=True, verbose='ERROR')
        X = epochs.get_data()

        if ZSCORE_PER_EPOCH:
            X = X.astype(np.float32)
            eps = 1e-6
            mu = X.mean(axis=2, keepdims=True)
            sd = X.std(axis=2, keepdims=True) + eps
            X = (X - mu) / sd

        ev_codes = epochs.events[:, 2]
        inv = {v: k for k, v in keep.items()}
        y_run = []
        for code in ev_codes:
            lab = inv[code]
            if run in IMAGERY_RUNS_LR:
                y_run.append(0 if lab == 'T1' else 1)
            elif run in IMAGERY_RUNS_BF:
                y_run.append(2 if lab == 'T1' else 3)
            else:
                y_run.append(-1)
        y_run = np.array(y_run, dtype=int)
        mask = y_run >= 0
        if mask.sum() == 0:
            continue

        X_list.append(X[mask])
        y_list.append(y_run[mask])
        sfreq_list.append(sfreq)

    if len(X_list) == 0:
        return np.empty((0,8,1), dtype=np.float32), np.empty((0,), dtype=int), None

    X_all = np.concatenate(X_list, axis=0).astype(np.float32)
    y_all = np.concatenate(y_list, axis=0).astype(int)

    if len(set([int(round(s)) for s in sfreq_list])) != 1:
        raise RuntimeError(f"Sampling rates inconsistentes: {sfreq_list}")

    return X_all, y_all, sfreq_list[0]

def load_fold_subjects(folds_json: Path, fold: int):
    with open(folds_json, 'r') as f:
        data = json.load(f)
    for item in data.get('folds', []):
        if int(item.get('fold', -1)) == int(fold):
            return list(item.get('train', [])), list(item.get('test', []))
    raise ValueError(f"Fold {fold} not found in {folds_json}")

def standardize_per_channel(train_X, other_X):
    C = train_X.shape[1]
    train_X = train_X.astype(np.float32)
    other_X = other_X.astype(np.float32)
    for c in range(C):
        mu = train_X[:, c, :].mean()
        sd = train_X[:, c, :].std()
        sd = sd if sd > 1e-6 else 1.0
        train_X[:, c, :] = (train_X[:, c, :] - mu) / sd
        other_X[:, c, :]  = (other_X[:, c, :] - mu) / sd
    return train_X, other_X

# =========================
# MODELO (GroupNorm en conv)
# =========================
def make_gn(num_channels, num_groups=8):
    # Ajusta grupos para que dividan a num_channels
    g = min(num_groups, num_channels)
    while num_channels % g != 0 and g > 1:
        g -= 1
    return nn.GroupNorm(g, num_channels)

class DepthwiseSeparableConv(nn.Module):
    def __init__(self, in_ch, out_ch, k, s=1, p=0, p_drop=0.2):
        super().__init__()
        self.dw = nn.Conv1d(in_ch, in_ch, kernel_size=k, stride=s, padding=p, groups=in_ch, bias=False)
        self.pw = nn.Conv1d(in_ch, out_ch, kernel_size=1, bias=False)
        self.norm = make_gn(out_ch)
        self.act = nn.ELU()
        self.dropout = nn.Dropout(p=p_drop)
    def forward(self, x):
        x = self.dw(x); x = self.pw(x); x = self.norm(x)
        x = self.act(x); x = self.dropout(x)
        return x

class EEGCNNTransformer(nn.Module):
    def __init__(self, n_ch=8, n_cls=4, d_model=128, n_heads=4, n_layers=2,
                 p_drop=0.2, p_drop_encoder=0.3):
        super().__init__()
        self.conv_t = nn.Sequential(
            nn.Conv1d(n_ch, 32, kernel_size=129, stride=2, padding=64, bias=False),
            make_gn(32),
            nn.ELU(),
            nn.Dropout(p=p_drop),
            DepthwiseSeparableConv(32, 64, k=31, s=2, p=15, p_drop=p_drop),
            DepthwiseSeparableConv(64, 128, k=15, s=2, p=7,  p_drop=p_drop),
        )
        self.proj = nn.Conv1d(128, d_model, kernel_size=1, bias=False)
        self.dropout = nn.Dropout(p=p_drop_encoder)  # ↑ 0.3
        self.pos_encoding = None
        enc = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=n_heads, dim_feedforward=2*d_model,
            batch_first=True, activation='gelu', dropout=0.1, norm_first=True
        )
        self.encoder = nn.TransformerEncoder(enc, num_layers=n_layers)
        self.cls = nn.Parameter(torch.zeros(1, 1, d_model))
        nn.init.normal_(self.cls, std=0.02)
        self.head = nn.Sequential(nn.LayerNorm(d_model), nn.Linear(d_model, n_cls))

    def _positional_encoding(self, L, d):
        pos = torch.arange(0, L, dtype=torch.float32).unsqueeze(1)
        i   = torch.arange(0, d, dtype=torch.float32).unsqueeze(0)
        angle = pos / torch.pow(10000, (2 * (i//2)) / d)
        pe = torch.zeros(L, d, dtype=torch.float32)
        pe[:, 0::2] = torch.sin(angle[:, 0::2])
        pe[:, 1::2] = torch.cos(angle[:, 1::2])
        return pe

    def forward(self, x):
        z = self.conv_t(x)           # (B, 128, T')
        z = self.proj(z)             # (B, d_model, T')
        z = self.dropout(z)
        z = z.transpose(1, 2)        # (B, T', d_model)
        B, L, D = z.shape
        if (self.pos_encoding is None) or (self.pos_encoding.shape[0] != L) or (self.pos_encoding.shape[1] != D):
            self.pos_encoding = self._positional_encoding(L, D).to(z.device)
        z = z + self.pos_encoding[None, :, :]
        cls_tok = self.cls.expand(B, -1, -1)
        z = torch.cat([cls_tok, z], dim=1)
        z = self.encoder(z)
        cls = z[:, 0, :]
        return self.head(cls)

# =========================
# (1) FOCAL LOSS
# =========================
class FocalLoss(nn.Module):
    def __init__(self, alpha: torch.Tensor, gamma: float = 1.5, reduction: str = 'mean'):
        super().__init__()
        self.alpha = alpha / alpha.sum()
        self.gamma = gamma
        self.reduction = reduction
    def forward(self, logits, target):
        logp = nn.functional.log_softmax(logits, dim=-1)      # (B,C)
        p = logp.exp()
        idx = torch.arange(target.shape[0], device=logits.device)
        pt = p[idx, target]
        logpt = logp[idx, target]
        at = self.alpha[target]
        loss = - at * ((1 - pt) ** self.gamma) * logpt
        if self.reduction == 'mean': return loss.mean()
        if self.reduction == 'sum':  return loss.sum()
        return loss

# =========================
# (4) AUGMENTS MÁS FUERTES (pero estables)
# =========================
def augment_batch(
    xb,
    p_jitter=0.35, p_noise=0.35, p_chdrop=0.15,
    max_jitter_frac=0.03, noise_std=0.03, max_chdrop=1
):
    """
    - jitter: shift temporal pequeño (±3% de la longitud)
    - noise: ruido gaussiano leve
    - chdrop: apagado de 1 canal aleatorio
    """
    B, C, T = xb.shape
    if np.random.rand() < p_jitter:
        max_shift = int(max(1, T*max_jitter_frac))
        shifts = torch.randint(low=-max_shift, high=max_shift+1, size=(B,), device=xb.device)
        for i in range(B):
            xb[i] = torch.roll(xb[i], shifts=int(shifts[i].item()), dims=-1)
    if np.random.rand() < p_noise:
        xb = xb + noise_std*torch.randn_like(xb)
    if np.random.rand() < p_chdrop and max_chdrop > 0:
        k = min(max_chdrop, C)
        for i in range(B):
            idx = torch.randperm(C, device=xb.device)[:k]
            xb[i, idx, :] = 0.0
    return xb

# =========================
# EMA de pesos
# =========================
class ModelEMA:
    def __init__(self, model: nn.Module, decay: float = 0.999, device=None):
        self.ema = self._clone(model).to(device if device is not None else next(model.parameters()).device)
        self.decay = decay
        self._updates = 0
        self.update(model, force=True)

    def _clone(self, model):
        ema = type(model)()
        ema.load_state_dict(model.state_dict())
        for p in ema.parameters():
            p.requires_grad_(False)
        return ema

    @torch.no_grad()
    def update(self, model: nn.Module, force: bool = False):
        d = self.decay
        if self._updates < 200:
            # warmup EMA decay para primeras iteraciones
            d = 0.0 + (self._updates / 200.0) * self.decay
        msd = model.state_dict()
        esd = self.ema.state_dict()
        for k in esd.keys():
            if esd[k].dtype.is_floating_point:
                esd[k].mul_(d).add_(msd[k].detach(), alpha=1.0 - d)
            else:
                esd[k] = msd[k]
        self._updates += 1

# =========================
# INFERENCIA TTA / SUBWINDOW
# =========================
def subwindow_logits(model, X, sfreq, sw_len, sw_stride, device):
    model.eval()
    wl = int(round(sw_len * sfreq))
    st = int(round(sw_stride * sfreq))
    wl = max(1, min(wl, X.shape[-1])); st = max(1, st)
    out = []
    with torch.no_grad():
        for i in range(X.shape[0]):
            x = X[i]; acc = []
            for s in range(0, max(1, X.shape[-1]-wl+1), st):
                seg = x[:, s:s+wl]
                if seg.shape[-1] < wl:
                    pad = wl - seg.shape[-1]
                    seg = np.pad(seg, ((0,0),(0,pad)), mode='edge')
                xb = torch.tensor(seg[None, ...], dtype=torch.float32, device=device)
                logit = model(xb).detach().cpu().numpy()[0]
                acc.append(logit)
            acc = np.mean(np.stack(acc, axis=0), axis=0) if len(acc) else np.zeros(4, dtype=np.float32)
            out.append(acc)
    return np.stack(out, axis=0)

def time_shift_tta_logits(model, X, sfreq, shifts_s, device):
    model.eval()
    T = X.shape[-1]; out = []
    with torch.no_grad():
        for i in range(X.shape[0]):
            x0 = X[i]; acc = []
            for sh in shifts_s:
                shift = int(round(sh * sfreq))
                if shift == 0:
                    x = x0
                elif shift > 0:
                    x = np.pad(x0[:, shift:], ((0,0),(0,shift)), mode='edge')[:, :T]
                else:
                    shift = -shift
                    x = np.pad(x0[:, :-shift], ((0,0),(shift,0)), mode='edge')[:, :T]
                xb = torch.tensor(x[None, ...], dtype=torch.float32, device=device)
                logit = model(xb).detach().cpu().numpy()[0]
                acc.append(logit)
            out.append(np.mean(np.stack(acc, axis=0), axis=0))
    return np.stack(out, axis=0)

# =========================
# Utilidades splits estratificados por sujeto
# =========================
def build_subject_label_map(subject_ids):
    """
    Estima una etiqueta dominante por sujeto (argmax de su histograma de clases).
    Se usa sólo para estratificar la selección de sujetos de validación.
    """
    y_dom_list = []
    for sid in subject_ids:
        Xs, ys, _ = load_subject_epochs(sid, RESAMPLE_HZ, DO_NOTCH, DO_BANDPASS, DO_CAR, BP_LO, BP_HI)
        if len(ys) == 0:
            y_dom_list.append(-1)
            continue
        binc = np.bincount(ys, minlength=4)
        y_dom = int(np.argmax(binc))
        y_dom_list.append(y_dom)
    return np.array(y_dom_list, dtype=int)

# =========================
# TRAIN/EVAL por FOLD (con train/val/test)
# =========================
def train_one_fold(fold:int, device):
    # --- sujetos por fold ---
    def load_fold_subjects_local(folds_json: Path, fold: int):
        return load_fold_subjects(folds_json, fold)

    train_sub, test_sub = load_fold_subjects_local(FOLDS_JSON, fold)
    train_sub = [s for s in train_sub if subject_id_to_int(s) not in EXCLUDE_SUBJECTS]
    test_sub  = [s for s in test_sub  if subject_id_to_int(s) not in EXCLUDE_SUBJECTS]

    # Split de validación por sujetos (determinista y opcionalmente estratificado)
    rng = np.random.RandomState(RANDOM_STATE + fold)
    tr_subjects = sorted(train_sub)

    if VAL_STRAT_SUBJECT and len(tr_subjects) > 1:
        y_dom = build_subject_label_map(tr_subjects)
        # Reemplaza -1 (sujetos sin trials) por la moda para no romper el split
        if np.any(y_dom < 0):
            mask = y_dom >= 0
            moda = int(np.bincount(y_dom[mask]).argmax()) if mask.sum() > 0 else 0
            y_dom[~mask] = moda
        n_val_subj = max(1, int(round(len(tr_subjects) * VAL_SUBJECT_FRAC)))
        sss = StratifiedShuffleSplit(n_splits=1, test_size=n_val_subj, random_state=RANDOM_STATE + fold)
        # sss.split(X, y) — X no importa, pasamos índice
        idx = np.arange(len(tr_subjects))
        _, val_idx = next(sss.split(idx, y_dom))
        val_subjects = sorted([tr_subjects[i] for i in val_idx])
        train_subjects = [s for s in tr_subjects if s not in val_subjects]
    else:
        tr_subjects_shuf = tr_subjects.copy()
        rng.shuffle(tr_subjects_shuf)
        n_val_subj = max(1, int(round(len(tr_subjects_shuf) * VAL_SUBJECT_FRAC)))
        val_subjects = sorted(tr_subjects_shuf[:n_val_subj])
        train_subjects = sorted(tr_subjects_shuf[n_val_subj:])

    # Carga TRAIN/VAL/TEST
    X_tr_list, y_tr_list, sub_tr_list = [], [], []
    X_val_list, y_val_list, sub_val_list = [], [], []
    X_te_list, y_te_list, sub_te_list = [], [], []
    sfreq = None

    for sid in tqdm(train_subjects, desc=f"Cargando train fold{fold}"):
        Xs, ys, sf = load_subject_epochs(sid, RESAMPLE_HZ, DO_NOTCH, DO_BANDPASS, DO_CAR, BP_LO, BP_HI)
        if len(ys) == 0: continue
        X_tr_list.append(Xs); y_tr_list.append(ys)
        sub_tr_list.append(np.full_like(ys, fill_value=subject_id_to_int(sid)))
        sfreq = sf if sfreq is None else sfreq

    for sid in tqdm(val_subjects, desc=f"Cargando val fold{fold}"):
        Xs, ys, sf = load_subject_epochs(sid, RESAMPLE_HZ, DO_NOTCH, DO_BANDPASS, DO_CAR, BP_LO, BP_HI)
        if len(ys) == 0: continue
        X_val_list.append(Xs); y_val_list.append(ys)
        sub_val_list.append(np.full_like(ys, fill_value=subject_id_to_int(sid)))
        sfreq = sf if sfreq is None else sfreq

    for sid in tqdm(test_sub, desc=f"Cargando test fold{fold}"):
        Xs, ys, sf = load_subject_epochs(sid, RESAMPLE_HZ, DO_NOTCH, DO_BANDPASS, DO_CAR, BP_LO, BP_HI)
        if len(ys) == 0: continue
        X_te_list.append(Xs); y_te_list.append(ys)
        sub_te_list.append(np.full_like(ys, fill_value=subject_id_to_int(sid)))
        sfreq = sf if sfreq is None else sfreq

    # Concatenar
    X_tr = np.concatenate(X_tr_list, axis=0); y_tr = np.concatenate(y_tr_list, axis=0)
    sub_tr = np.concatenate(sub_tr_list, axis=0)
    X_val = np.concatenate(X_val_list, axis=0); y_val = np.concatenate(y_val_list, axis=0)
    sub_val = np.concatenate(sub_val_list, axis=0)
    X_te = np.concatenate(X_te_list, axis=0); y_te = np.concatenate(y_te_list, axis=0)
    sub_te = np.concatenate(sub_te_list, axis=0)

    print(f"[Fold {fold}/5] Entrenando modelo global... (n_train={len(y_tr)} | n_val={len(y_val)} | n_test={len(y_te)})")

    # Normalización por canal (fit en TRAIN y aplicar a VAL/TEST)
    if ZSCORE_PER_EPOCH:
        X_tr_std, X_val_std, X_te_std = X_tr, X_val, X_te
    else:
        X_tr_std, X_val_std = standardize_per_channel(X_tr, X_val)
        _,        X_te_std  = standardize_per_channel(X_tr, X_te)  # reutiliza stats de train

    # Datasets
    tr_ds  = TensorDataset(torch.tensor(X_tr_std),  torch.tensor(y_tr).long(),  torch.tensor(sub_tr).long())
    val_ds = TensorDataset(torch.tensor(X_val_std), torch.tensor(y_val).long(), torch.tensor(sub_val).long())
    te_ds  = TensorDataset(torch.tensor(X_te_std),  torch.tensor(y_te).long(),  torch.tensor(sub_te).long())

    # Weighted sampler por sujeto/clase
    def make_weighted_sampler(dataset: TensorDataset):
        Xb, yb, sb = dataset.tensors
        yb_np = yb.numpy()
        sb_np = sb.numpy()
        # frec por sujeto
        uniq_s, cnt_s = np.unique(sb_np, return_counts=True)
        map_s = {s:c for s,c in zip(uniq_s, cnt_s)}
        # frec por (sujeto,clase)
        key = sb_np.astype(np.int64) * 10 + yb_np.astype(np.int64)  # asumiendo clases <10
        uniq_k, cnt_k = np.unique(key, return_counts=True)
        map_k = {k:c for k,c in zip(uniq_k, cnt_k)}
        # peso = 1 / (frec_sujeto * frec_clase_en_sujeto)
        w = []
        for s, y in zip(sb_np, yb_np):
            k = int(s)*10 + int(y)
            ws = map_s[int(s)]
            wk = map_k[k]
            w.append(1.0 / (float(ws) * float(wk)))
        w = np.array(w, dtype=np.float64)
        w = w / (w.mean() + 1e-12)
        sampler = WeightedRandomSampler(weights=torch.tensor(w, dtype=torch.double),
                                        num_samples=len(yb_np), replacement=True)
        return sampler

    if USE_WEIGHTED_SAMPLER:
        tr_sampler = make_weighted_sampler(tr_ds)
        tr_ld  = DataLoader(tr_ds, batch_size=BATCH_SIZE, sampler=tr_sampler, drop_last=False, worker_init_fn=seed_worker)
    else:
        tr_ld  = DataLoader(tr_ds, batch_size=BATCH_SIZE, shuffle=True,  drop_last=False, worker_init_fn=seed_worker)

    val_ld = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, worker_init_fn=seed_worker)
    te_ld  = DataLoader(te_ds,  batch_size=BATCH_SIZE, shuffle=False, drop_last=False, worker_init_fn=seed_worker)

    # Modelo
    model = EEGCNNTransformer(n_ch=8, n_cls=4, d_model=D_MODEL, n_heads=N_HEADS,
                              n_layers=N_LAYERS, p_drop=P_DROP, p_drop_encoder=P_DROP_ENCODER).to(device)

    # Optimizador + (1) Focal Loss (ligero upweight a both_fists)
    opt = torch.optim.AdamW(model.parameters(), lr=BASE_LR, weight_decay=1e-2)

    class_counts = np.bincount(y_tr, minlength=4).astype(np.float32)
    inv = class_counts.sum() / (4.0 * np.maximum(class_counts, 1.0))
    alpha = torch.tensor(inv, dtype=torch.float32, device=device)
    # Aumentamos un poco el peso de both_fists para mejorar su recall
    alpha_mean = alpha.mean().item()
    alpha[2] = 1.5 * alpha_mean
    crit = FocalLoss(alpha=alpha, gamma=1.5, reduction='mean')

    # (6) Warmup+Cosine (min_factor 0.1)
    from torch.optim.lr_scheduler import LambdaLR
    total_epochs = EPOCHS
    warmup_epochs = max(1, int(WARMUP_EPOCHS))
    min_factor = 0.1
    def lr_lambda(current_epoch):
        if current_epoch < warmup_epochs:
            return (current_epoch + 1) / warmup_epochs
        progress = (current_epoch - warmup_epochs) / max(1, (total_epochs - warmup_epochs))
        progress = min(1.0, max(0.0, progress))
        return min_factor + 0.5 * (1.0 - min_factor) * (1.0 + np.cos(np.pi * progress))
    scheduler = LambdaLR(opt, lr_lambda=lr_lambda)

    # EMA
    if USE_EMA:
        ema = ModelEMA(model, decay=EMA_DECAY, device=device)
    else:
        ema = None

    # Entrenamiento con early stopping por F1 macro (en VAL; usando EMA si está activo)
    best_f1, best_state, wait = 0.0, None, 0
    hist = {"ep": [], "tr_loss": [], "tr_acc": [], "val_acc": [], "val_f1m": [], "lr": []}

    def evaluate_on(loader, use_ema=True):
        mdl = ema.ema if (ema is not None and use_ema) else model
        mdl.eval()
        preds, gts = [], []
        with torch.no_grad():
            for xb, yb, _sb in loader:
                xb = xb.to(device)
                p = mdl(xb).argmax(dim=1).cpu().numpy()
                preds.append(p); gts.append(yb.numpy())
        preds = np.concatenate(preds); gts = np.concatenate(gts)
        acc = accuracy_score(gts, preds)
        f1m = f1_score(gts, preds, average='macro')
        return acc, f1m

    for ep in range(1, EPOCHS+1):
        # ---- Train ----
        model.train()
        tr_loss, n_seen, tr_correct = 0.0, 0, 0
        for xb, yb, _sb in tr_ld:
            xb = xb.to(device); yb = yb.to(device)
            xb = augment_batch(xb)  # (4) aumentos más fuertes
            opt.zero_grad()
            logits = model(xb)
            loss = crit(logits, yb)
            loss.backward()
            opt.step()
            if ema is not None:
                ema.update(model)

            tr_loss += loss.item() * len(yb)
            n_seen += len(yb)
            tr_correct += (logits.argmax(1) == yb).sum().item()
        tr_loss /= max(1, n_seen)
        tr_acc = tr_correct / max(1, n_seen)

        # ---- Val (con EMA si aplica) ----
        acc, f1m = evaluate_on(val_ld, use_ema=True)

        hist["ep"].append(ep)
        hist["tr_loss"].append(tr_loss)
        hist["tr_acc"].append(tr_acc)
        hist["val_acc"].append(acc)
        hist["val_f1m"].append(f1m)
        hist["lr"].append(scheduler.get_last_lr()[0])

        print(f"  Época {ep:3d} | train_loss={tr_loss:.4f} | train_acc={tr_acc:.4f} | val_acc={acc:.4f} | val_f1m={f1m:.4f} | LR={scheduler.get_last_lr()[0]:.6f}")

        improved = f1m > best_f1 + 1e-4
        if improved:
            best_f1 = f1m
            # guardamos el estado EMA si existe; si no, el del modelo
            ref_model = ema.ema if ema is not None else model
            best_state = {k: v.detach().cpu() for k, v in ref_model.state_dict().items()}
            wait = 0
        else:
            wait += 1

        scheduler.step()
        if wait >= PATIENCE:
            print(f"  Early stopping en época {ep} (mejor val_f1m={best_f1:.4f})")
            break

    # Cargamos mejor estado guardado (EMA si estaba activo)
    if best_state is not None:
        (ema.ema if (ema is not None) else model).load_state_dict(best_state)

    # ---- Guardar curva ----
    fig = plt.figure(figsize=(8,4.5))
    ax1 = plt.gca()
    ax1.plot(hist["ep"], hist["tr_loss"], label="train_loss")
    ax1.plot(hist["ep"], hist["val_f1m"], label="val_f1m")
    ax1.plot(hist["ep"], hist["val_acc"], label="val_acc")
    ax1.set_xlabel("Época"); ax1.set_title(f"Fold {fold} — Curva de entrenamiento")
    ax1.legend(); ax1.grid(True, alpha=0.3)
    out_png = f"training_curve_fold{fold}.png"
    plt.tight_layout()
    plt.savefig(out_png, dpi=140)
    plt.close(fig)
    print(f"↳ Curva de entrenamiento guardada: {out_png}")

    # ---- Evaluación final en TEST (con TTA/subwindows y ensemble si aplica) ----
    eval_model = ema.ema if (ema is not None) else model
    eval_model.eval()

    sfreq_used = RESAMPLE_HZ
    if sfreq_used is None:
        sfreq_used = int(round(X_te_std.shape[-1] / (TMAX - TMIN)))

    if (not SW_ENABLE) or SW_MODE == 'none':
        preds, gts = [], []
        with torch.no_grad():
            for xb, yb, _sb in te_ld:
                xb = xb.to(device)
                p = eval_model(xb).argmax(dim=1).cpu().numpy()
                preds.append(p); gts.append(yb.numpy())
        preds = np.concatenate(preds); gts = np.concatenate(gts)
    elif SW_MODE in ('subwin', 'tta'):
        logits_tta = None
        logits_sw  = None
        if SW_MODE == 'subwin':
            logits_sw = subwindow_logits(eval_model, X_te_std, sfreq_used, SW_LEN, SW_STRIDE, device)
        elif SW_MODE == 'tta':
            logits_tta = time_shift_tta_logits(eval_model, X_te_std, sfreq_used, TTA_SHIFTS_S, device)

        if COMBINE_TTA_AND_SUBWIN:
            if logits_tta is None:
                logits_tta = time_shift_tta_logits(eval_model, X_te_std, sfreq_used, TTA_SHIFTS_S, device)
            if logits_sw is None:
                logits_sw  = subwindow_logits(eval_model, X_te_std, sfreq_used, SW_LEN, SW_STRIDE, device)
            logits = 0.5 * logits_tta + 0.5 * logits_sw
        else:
            logits = logits_tta if logits_tta is not None else logits_sw

        preds = logits.argmax(axis=1); gts = y_te
    else:
        raise ValueError(f"SW_MODE desconocido: {SW_MODE}")

    acc = accuracy_score(gts, preds)
    f1m = f1_score(gts, preds, average='macro')
    print(f"[Fold {fold}/5] Global acc={acc:.4f}\n")
    print(classification_report(gts, preds, target_names=[c.replace('_',' ') for c in CLASS_NAMES], digits=4))
    print("Confusion matrix (rows=true, cols=pred):")
    print(confusion_matrix(gts, preds, labels=[0,1,2,3]))

    return acc, f1m

# =========================
# LOOP 5 FOLDS + RESUMEN
# =========================
if __name__ == "__main__":
    acc_folds, f1_folds = [], []
    for fold in range(1, 6):
        acc, f1m = train_one_fold(fold, DEVICE)
        acc_folds.append(f"{acc:.4f}")
        f1_folds.append(f"{f1m:.4f}")

    acc_mean = float(np.mean([float(a) for a in acc_folds]))
    f1_mean  = float(np.mean([float(f) for f in f1_folds]))

    print("\n============================================================")
    print("RESULTADOS FINALES")
    print("============================================================")
    print(f"Global folds (ACC): {acc_folds}")
    print(f"Global mean ACC: {acc_mean:.4f}")
    print(f"F1 folds (MACRO): {f1_folds}")
    print(f"F1 mean (MACRO): {f1_mean:.4f}")


🚀 Usando dispositivo: cuda
🧠 INICIANDO EXPERIMENTO CON CNN+Transformer (K-Fold por sujeto como EEGNet)
🔧 Configuración: 4c, 8 canales, 6s | EPOCHS=60, BATCH=64, LR=0.001 | ZSCORE_PER_EPOCH=False


Cargando train fold1: 100%|██████████| 67/67 [00:08<00:00,  8.35it/s]
Cargando val fold1: 100%|██████████| 15/15 [00:01<00:00,  8.37it/s]
Cargando test fold1: 100%|██████████| 21/21 [00:02<00:00,  8.27it/s]


[Fold 1/5] Entrenando modelo global... (n_train=5628 | n_val=1260 | n_test=1764)




  Época   1 | train_loss=0.2225 | train_acc=0.2747 | val_acc=0.2984 | val_f1m=0.2443 | LR=0.000125
  Época   2 | train_loss=0.2016 | train_acc=0.3625 | val_acc=0.4056 | val_f1m=0.3910 | LR=0.000250
  Época   3 | train_loss=0.1802 | train_acc=0.4655 | val_acc=0.4143 | val_f1m=0.4060 | LR=0.000375
  Época   4 | train_loss=0.1765 | train_acc=0.4730 | val_acc=0.4175 | val_f1m=0.4095 | LR=0.000500
  Época   5 | train_loss=0.1699 | train_acc=0.5023 | val_acc=0.4230 | val_f1m=0.4166 | LR=0.000625
  Época   6 | train_loss=0.1664 | train_acc=0.5263 | val_acc=0.4317 | val_f1m=0.4279 | LR=0.000750
  Época   7 | train_loss=0.1640 | train_acc=0.5258 | val_acc=0.4365 | val_f1m=0.4340 | LR=0.000875
  Época   8 | train_loss=0.1614 | train_acc=0.5339 | val_acc=0.4421 | val_f1m=0.4401 | LR=0.001000
  Época   9 | train_loss=0.1648 | train_acc=0.5171 | val_acc=0.4429 | val_f1m=0.4412 | LR=0.001000
  Época  10 | train_loss=0.1569 | train_acc=0.5437 | val_acc=0.4437 | val_f1m=0.4417 | LR=0.000999
  Época  1

Cargando train fold2: 100%|██████████| 67/67 [00:07<00:00,  8.41it/s]
Cargando val fold2: 100%|██████████| 15/15 [00:01<00:00,  8.50it/s]
Cargando test fold2: 100%|██████████| 21/21 [00:02<00:00,  8.48it/s]


[Fold 2/5] Entrenando modelo global... (n_train=5628 | n_val=1260 | n_test=1764)




  Época   1 | train_loss=0.2221 | train_acc=0.2790 | val_acc=0.3063 | val_f1m=0.2355 | LR=0.000125
  Época   2 | train_loss=0.1983 | train_acc=0.3882 | val_acc=0.4302 | val_f1m=0.4300 | LR=0.000250
  Época   3 | train_loss=0.1829 | train_acc=0.4662 | val_acc=0.4484 | val_f1m=0.4476 | LR=0.000375
  Época   4 | train_loss=0.1810 | train_acc=0.4739 | val_acc=0.4516 | val_f1m=0.4514 | LR=0.000500
  Época   5 | train_loss=0.1767 | train_acc=0.4803 | val_acc=0.4532 | val_f1m=0.4523 | LR=0.000625
  Época   6 | train_loss=0.1763 | train_acc=0.4742 | val_acc=0.4698 | val_f1m=0.4689 | LR=0.000750
  Época   7 | train_loss=0.1725 | train_acc=0.4927 | val_acc=0.4722 | val_f1m=0.4715 | LR=0.000875
  Época   8 | train_loss=0.1758 | train_acc=0.4869 | val_acc=0.4698 | val_f1m=0.4685 | LR=0.001000
  Época   9 | train_loss=0.1687 | train_acc=0.5084 | val_acc=0.4770 | val_f1m=0.4744 | LR=0.001000
  Época  10 | train_loss=0.1666 | train_acc=0.5171 | val_acc=0.4810 | val_f1m=0.4783 | LR=0.000999
  Época  1

Cargando train fold3: 100%|██████████| 67/67 [00:07<00:00,  8.47it/s]
Cargando val fold3: 100%|██████████| 15/15 [00:01<00:00,  8.47it/s]
Cargando test fold3: 100%|██████████| 21/21 [00:02<00:00,  8.41it/s]


[Fold 3/5] Entrenando modelo global... (n_train=5628 | n_val=1260 | n_test=1764)




  Época   1 | train_loss=0.2247 | train_acc=0.2742 | val_acc=0.2976 | val_f1m=0.2007 | LR=0.000125
  Época   2 | train_loss=0.2040 | train_acc=0.3666 | val_acc=0.4468 | val_f1m=0.4450 | LR=0.000250
  Época   3 | train_loss=0.1837 | train_acc=0.4536 | val_acc=0.4579 | val_f1m=0.4595 | LR=0.000375
  Época   4 | train_loss=0.1757 | train_acc=0.4890 | val_acc=0.4603 | val_f1m=0.4628 | LR=0.000500
  Época   5 | train_loss=0.1704 | train_acc=0.5011 | val_acc=0.4619 | val_f1m=0.4654 | LR=0.000625
  Época   6 | train_loss=0.1704 | train_acc=0.5030 | val_acc=0.4714 | val_f1m=0.4752 | LR=0.000750
  Época   7 | train_loss=0.1684 | train_acc=0.5126 | val_acc=0.4722 | val_f1m=0.4753 | LR=0.000875
  Época   8 | train_loss=0.1651 | train_acc=0.5229 | val_acc=0.4810 | val_f1m=0.4851 | LR=0.001000
  Época   9 | train_loss=0.1606 | train_acc=0.5444 | val_acc=0.4817 | val_f1m=0.4865 | LR=0.001000
  Época  10 | train_loss=0.1603 | train_acc=0.5396 | val_acc=0.4881 | val_f1m=0.4926 | LR=0.000999
  Época  1

Cargando train fold4: 100%|██████████| 68/68 [00:08<00:00,  8.42it/s]
Cargando val fold4: 100%|██████████| 15/15 [00:01<00:00,  8.44it/s]
Cargando test fold4: 100%|██████████| 20/20 [00:02<00:00,  8.38it/s]


[Fold 4/5] Entrenando modelo global... (n_train=5712 | n_val=1260 | n_test=1680)




  Época   1 | train_loss=0.2243 | train_acc=0.2635 | val_acc=0.3310 | val_f1m=0.2866 | LR=0.000125
  Época   2 | train_loss=0.2013 | train_acc=0.3853 | val_acc=0.4270 | val_f1m=0.4305 | LR=0.000250
  Época   3 | train_loss=0.1904 | train_acc=0.4268 | val_acc=0.4397 | val_f1m=0.4423 | LR=0.000375
  Época   4 | train_loss=0.1795 | train_acc=0.4758 | val_acc=0.4429 | val_f1m=0.4452 | LR=0.000500
  Época   5 | train_loss=0.1781 | train_acc=0.4855 | val_acc=0.4579 | val_f1m=0.4601 | LR=0.000625
  Época   6 | train_loss=0.1710 | train_acc=0.4979 | val_acc=0.4587 | val_f1m=0.4610 | LR=0.000750
  Época   7 | train_loss=0.1742 | train_acc=0.4911 | val_acc=0.4754 | val_f1m=0.4777 | LR=0.000875
  Época   8 | train_loss=0.1716 | train_acc=0.5004 | val_acc=0.4810 | val_f1m=0.4834 | LR=0.001000
  Época   9 | train_loss=0.1642 | train_acc=0.5165 | val_acc=0.4857 | val_f1m=0.4877 | LR=0.001000
  Época  10 | train_loss=0.1656 | train_acc=0.5114 | val_acc=0.4865 | val_f1m=0.4880 | LR=0.000999
  Época  1

Cargando train fold5: 100%|██████████| 68/68 [00:08<00:00,  8.42it/s]
Cargando val fold5: 100%|██████████| 15/15 [00:01<00:00,  8.49it/s]
Cargando test fold5: 100%|██████████| 20/20 [00:02<00:00,  8.42it/s]


[Fold 5/5] Entrenando modelo global... (n_train=5712 | n_val=1260 | n_test=1680)




  Época   1 | train_loss=0.2241 | train_acc=0.2742 | val_acc=0.3183 | val_f1m=0.2769 | LR=0.000125
  Época   2 | train_loss=0.2039 | train_acc=0.3699 | val_acc=0.4238 | val_f1m=0.4264 | LR=0.000250
  Época   3 | train_loss=0.1856 | train_acc=0.4508 | val_acc=0.4206 | val_f1m=0.4102 | LR=0.000375
  Época   4 | train_loss=0.1807 | train_acc=0.4624 | val_acc=0.4238 | val_f1m=0.4147 | LR=0.000500
  Época   5 | train_loss=0.1752 | train_acc=0.4933 | val_acc=0.4254 | val_f1m=0.4166 | LR=0.000625
  Época   6 | train_loss=0.1775 | train_acc=0.4855 | val_acc=0.4278 | val_f1m=0.4172 | LR=0.000750
  Época   7 | train_loss=0.1736 | train_acc=0.4956 | val_acc=0.4373 | val_f1m=0.4288 | LR=0.000875
  Época   8 | train_loss=0.1700 | train_acc=0.4986 | val_acc=0.4413 | val_f1m=0.4329 | LR=0.001000
  Época   9 | train_loss=0.1673 | train_acc=0.5130 | val_acc=0.4317 | val_f1m=0.4234 | LR=0.001000
  Época  10 | train_loss=0.1647 | train_acc=0.5191 | val_acc=0.4333 | val_f1m=0.4257 | LR=0.000999
  Época  1

In [8]:
# -*- coding: utf-8 -*-
# CNN+Transformer para MI (4 clases, 8 canales) con:
# (1) Focal Loss
# (4) Aumentos más fuertes y estables
# (6) LR warmup+cosine, early stopping por F1 macro
# Cambios aplicados (2,3,4,5,6,7):
# 2) WeightedRandomSampler templado (a=0.5, b=1.0)
# 3) Focal alpha: boost también para both_feet
# 4) TransformerEncoderLayer(norm_first=False)
# 5) Gradient clipping (max_norm=1.0)
# 6) EMA más lenta (EMA_DECAY=0.9995; warmup decay a 1000 updates)
# 7) Cosine LR min_factor=0.1

import os
os.environ['PYTHONHASHSEED'] = '42'
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'

import re, json, random
from pathlib import Path
from glob import glob
import numpy as np
import mne
from tqdm import tqdm
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset, WeightedRandomSampler
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, classification_report
from sklearn.model_selection import StratifiedShuffleSplit

# =========================
# REPRODUCIBILIDAD
# =========================
RANDOM_STATE = 42
def seed_everything(seed: int = 42):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    try:
        torch.backends.cuda.matmul.allow_tf32 = False
        torch.backends.cudnn.allow_tf32 = False
    except Exception:
        pass
    try:
        torch.use_deterministic_algorithms(True)
    except Exception:
        pass

def seed_worker(worker_id: int):
    worker_seed = RANDOM_STATE + worker_id
    np.random.seed(worker_seed)
    random.seed(worker_seed)

seed_everything(RANDOM_STATE)

# =========================
# CONFIG
# =========================
PROJ = Path('..').resolve().parent
DATA_RAW = PROJ / 'data' / 'raw'
FOLDS_JSON = PROJ / 'models' / 'folds' / 'Kfold5.json'

EPOCHS = 60
BATCH_SIZE = 64        # múltiplo de 4
BASE_LR = 1e-3
WARMUP_EPOCHS = 8
PATIENCE = 12

# Split de validación por fold (por sujetos)
VAL_SUBJECT_FRAC = 0.18  # ≈ 18% de sujetos del train → val
VAL_STRAT_SUBJECT = True # estratifica por etiqueta dominante de cada sujeto

# Prepro
RESAMPLE_HZ = None
DO_NOTCH = True
DO_BANDPASS = False       # respetamos tu setup original
BP_LO, BP_HI = 4.0, 38.0
DO_CAR = False
ZSCORE_PER_EPOCH = False

# Modelo
D_MODEL = 128
N_HEADS = 4
N_LAYERS = 2
P_DROP = 0.2            # dropout en conv stack
P_DROP_ENCODER = 0.3    # dropout antes del encoder

# Ventana temporal
TMIN, TMAX = -1.0, 5.0

# TTA / SUBWINDOW en TEST
SW_MODE = 'tta'   # 'none'|'subwin'|'tta'
SW_ENABLE = True
TTA_SHIFTS_S = [-0.05, -0.025, 0.0, 0.025, 0.05]
SW_LEN, SW_STRIDE = 4.5, 2.0
COMBINE_TTA_AND_SUBWIN = False  # solicitado (no cambio 1)

# Sampler balanceado
USE_WEIGHTED_SAMPLER = True

# EMA
USE_EMA = True
EMA_DECAY = 0.9995  # (6) más lenta

EXCLUDE_SUBJECTS = {38, 88, 89, 92, 100, 104}
EXPECTED_8 = ['C3','C4','Cz','CP3','CP4','FC3','FC4','FCz']
CLASS_NAMES = ['left', 'right', 'both_fists', 'both_feet']

IMAGERY_RUNS_LR = {4, 8, 12}
IMAGERY_RUNS_BF = {6, 10, 14}

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🚀 Usando dispositivo: {DEVICE}")
print("🧠 INICIANDO EXPERIMENTO CON CNN+Transformer (K-Fold por sujeto como EEGNet)")
print(f"🔧 Configuración: 4c, 8 canales, 6s | EPOCHS={EPOCHS}, BATCH={BATCH_SIZE}, LR={BASE_LR} | ZSCORE_PER_EPOCH={ZSCORE_PER_EPOCH}")

# =========================
# UTILIDADES I/O
# =========================
def normalize_ch_name(name: str) -> str:
    s = re.sub(r'[^A-Za-z0-9]', '', name)
    return s.upper()

NORMALIZED_TARGETS = [normalize_ch_name(c) for c in EXPECTED_8]

def pick_8_channels(raw: mne.io.BaseRaw) -> mne.io.BaseRaw:
    chs = raw.info['ch_names']
    norm_map = {normalize_ch_name(ch): ch for ch in chs}
    picked = []
    for target_norm, target_orig in zip(NORMALIZED_TARGETS, EXPECTED_8):
        if target_norm in norm_map:
            picked.append(norm_map[target_norm])
        else:
            raise RuntimeError(f"Canal requerido '{target_orig}' no encontrado. Disponibles: {chs}")
    return raw.pick(picks=picked)

def list_subject_imagery_edfs(subject_id: str) -> list:
    subj_dir = DATA_RAW / subject_id
    edfs = []
    for r in [4, 6, 8, 10, 12, 14]:
        edfs.extend(glob(str(subj_dir / f"{subject_id}R{r:02d}.edf")))
    return sorted(edfs)

def subject_id_to_int(s: str) -> int:
    m = re.match(r'[Ss](\d+)', s)
    return int(m.group(1)) if m else -1

def load_subject_epochs(subject_id: str, resample_hz: int, do_notch: bool, do_bandpass: bool,
                        do_car: bool, bp_lo: float, bp_hi: float):
    edfs = list_subject_imagery_edfs(subject_id)
    if len(edfs) == 0:
        return np.empty((0,8,1), dtype=np.float32), np.empty((0,), dtype=int), None

    X_list, y_list, sfreq_list = [], [], []

    for edf_path in edfs:
        m = re.search(r"R(\d{2})", Path(edf_path).name)
        run = int(m.group(1)) if m else -1

        raw = mne.io.read_raw_edf(edf_path, preload=True, verbose='ERROR')
        raw = pick_8_channels(raw)

        if do_notch:
            raw.notch_filter(freqs=[60.0], picks='all', verbose='ERROR')
        if do_bandpass:
            raw.filter(l_freq=bp_lo, h_freq=bp_hi, picks='all', verbose='ERROR')
        if do_car:
            raw.set_eeg_reference('average', projection=False, verbose='ERROR')

        if resample_hz is not None and resample_hz > 0:
            raw.resample(resample_hz)
        sfreq = raw.info['sfreq']

        events, event_id = mne.events_from_annotations(raw, verbose='ERROR')
        keep = {k: v for k, v in event_id.items() if k in {'T1', 'T2'}}
        if len(keep) == 0:
            continue

        epochs = mne.Epochs(raw, events=events, event_id=keep, tmin=TMIN, tmax=TMAX,
                            baseline=None, preload=True, verbose='ERROR')
        X = epochs.get_data()

        if ZSCORE_PER_EPOCH:
            X = X.astype(np.float32)
            eps = 1e-6
            mu = X.mean(axis=2, keepdims=True)
            sd = X.std(axis=2, keepdims=True) + eps
            X = (X - mu) / sd

        ev_codes = epochs.events[:, 2]
        inv = {v: k for k, v in keep.items()}
        y_run = []
        for code in ev_codes:
            lab = inv[code]
            if run in IMAGERY_RUNS_LR:
                y_run.append(0 if lab == 'T1' else 1)
            elif run in IMAGERY_RUNS_BF:
                y_run.append(2 if lab == 'T1' else 3)
            else:
                y_run.append(-1)
        y_run = np.array(y_run, dtype=int)
        mask = y_run >= 0
        if mask.sum() == 0:
            continue

        X_list.append(X[mask])
        y_list.append(y_run[mask])
        sfreq_list.append(sfreq)

    if len(X_list) == 0:
        return np.empty((0,8,1), dtype=np.float32), np.empty((0,), dtype=int), None

    X_all = np.concatenate(X_list, axis=0).astype(np.float32)
    y_all = np.concatenate(y_list, axis=0).astype(int)

    if len(set([int(round(s)) for s in sfreq_list])) != 1:
        raise RuntimeError(f"Sampling rates inconsistentes: {sfreq_list}")

    return X_all, y_all, sfreq_list[0]

def load_fold_subjects(folds_json: Path, fold: int):
    with open(folds_json, 'r') as f:
        data = json.load(f)
    for item in data.get('folds', []):
        if int(item.get('fold', -1)) == int(fold):
            return list(item.get('train', [])), list(item.get('test', []))
    raise ValueError(f"Fold {fold} not found in {folds_json}")

def standardize_per_channel(train_X, other_X):
    C = train_X.shape[1]
    train_X = train_X.astype(np.float32)
    other_X = other_X.astype(np.float32)
    for c in range(C):
        mu = train_X[:, c, :].mean()
        sd = train_X[:, c, :].std()
        sd = sd if sd > 1e-6 else 1.0
        train_X[:, c, :] = (train_X[:, c, :] - mu) / sd
        other_X[:, c, :]  = (other_X[:, c, :] - mu) / sd
    return train_X, other_X

# =========================
# MODELO (GroupNorm en conv)
# =========================
def make_gn(num_channels, num_groups=8):
    g = min(num_groups, num_channels)
    while num_channels % g != 0 and g > 1:
        g -= 1
    return nn.GroupNorm(g, num_channels)

class DepthwiseSeparableConv(nn.Module):
    def __init__(self, in_ch, out_ch, k, s=1, p=0, p_drop=0.2):
        super().__init__()
        self.dw = nn.Conv1d(in_ch, in_ch, kernel_size=k, stride=s, padding=p, groups=in_ch, bias=False)
        self.pw = nn.Conv1d(in_ch, out_ch, kernel_size=1, bias=False)
        self.norm = make_gn(out_ch)
        self.act = nn.ELU()
        self.dropout = nn.Dropout(p=p_drop)
    def forward(self, x):
        x = self.dw(x); x = self.pw(x); x = self.norm(x)
        x = self.act(x); x = self.dropout(x)
        return x

class EEGCNNTransformer(nn.Module):
    def __init__(self, n_ch=8, n_cls=4, d_model=128, n_heads=4, n_layers=2,
                 p_drop=0.2, p_drop_encoder=0.3):
        super().__init__()
        self.conv_t = nn.Sequential(
            nn.Conv1d(n_ch, 32, kernel_size=129, stride=2, padding=64, bias=False),
            make_gn(32),
            nn.ELU(),
            nn.Dropout(p=p_drop),
            DepthwiseSeparableConv(32, 64, k=31, s=2, p=15, p_drop=p_drop),
            DepthwiseSeparableConv(64, 128, k=15, s=2, p=7,  p_drop=p_drop),
        )
        self.proj = nn.Conv1d(128, d_model, kernel_size=1, bias=False)
        self.dropout = nn.Dropout(p=p_drop_encoder)
        self.pos_encoding = None
        enc = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=n_heads, dim_feedforward=2*d_model,
            batch_first=True, activation='gelu', dropout=0.1, norm_first=True  # (4)
        )
        self.encoder = nn.TransformerEncoder(enc, num_layers=n_layers)
        self.cls = nn.Parameter(torch.zeros(1, 1, d_model))
        nn.init.normal_(self.cls, std=0.02)
        self.head = nn.Sequential(nn.LayerNorm(d_model), nn.Linear(d_model, n_cls))

    def _positional_encoding(self, L, d):
        pos = torch.arange(0, L, dtype=torch.float32).unsqueeze(1)
        i   = torch.arange(0, d, dtype=torch.float32).unsqueeze(0)
        angle = pos / torch.pow(10000, (2 * (i//2)) / d)
        pe = torch.zeros(L, d, dtype=torch.float32)
        pe[:, 0::2] = torch.sin(angle[:, 0::2])
        pe[:, 1::2] = torch.cos(angle[:, 1::2])
        return pe

    def forward(self, x):
        z = self.conv_t(x)           # (B, 128, T')
        z = self.proj(z)             # (B, d_model, T')
        z = self.dropout(z)
        z = z.transpose(1, 2)        # (B, T', d_model)
        B, L, D = z.shape
        if (self.pos_encoding is None) or (self.pos_encoding.shape[0] != L) or (self.pos_encoding.shape[1] != D):
            self.pos_encoding = self._positional_encoding(L, D).to(z.device)
        z = z + self.pos_encoding[None, :, :]
        cls_tok = self.cls.expand(B, -1, -1)
        z = torch.cat([cls_tok, z], dim=1)
        z = self.encoder(z)
        cls = z[:, 0, :]
        return self.head(cls)

# =========================
# (1) FOCAL LOSS
# =========================
class FocalLoss(nn.Module):
    def __init__(self, alpha: torch.Tensor, gamma: float = 1.5, reduction: str = 'mean'):
        super().__init__()
        self.alpha = alpha / alpha.sum()
        self.gamma = gamma
        self.reduction = reduction
    def forward(self, logits, target):
        logp = nn.functional.log_softmax(logits, dim=-1)      # (B,C)
        p = logp.exp()
        idx = torch.arange(target.shape[0], device=logits.device)
        pt = p[idx, target]
        logpt = logp[idx, target]
        at = self.alpha[target]
        loss = - at * ((1 - pt) ** self.gamma) * logpt
        if self.reduction == 'mean': return loss.mean()
        if self.reduction == 'sum':  return loss.sum()
        return loss

# =========================
# (4) AUGMENTS MÁS FUERTES (pero estables)
# =========================
def augment_batch(
    xb,
    p_jitter=0.35, p_noise=0.35, p_chdrop=0.15,
    max_jitter_frac=0.03, noise_std=0.03, max_chdrop=1
):
    B, C, T = xb.shape
    if np.random.rand() < p_jitter:
        max_shift = int(max(1, T*max_jitter_frac))
        shifts = torch.randint(low=-max_shift, high=max_shift+1, size=(B,), device=xb.device)
        for i in range(B):
            xb[i] = torch.roll(xb[i], shifts=int(shifts[i].item()), dims=-1)
    if np.random.rand() < p_noise:
        xb = xb + noise_std*torch.randn_like(xb)
    if np.random.rand() < p_chdrop and max_chdrop > 0:
        k = min(max_chdrop, C)
        for i in range(B):
            idx = torch.randperm(C, device=xb.device)[:k]
            xb[i, idx, :] = 0.0
    return xb

# =========================
# EMA de pesos
# =========================
class ModelEMA:
    def __init__(self, model: nn.Module, decay: float = 0.9995, device=None):
        self.ema = self._clone(model).to(device if device is not None else next(model.parameters()).device)
        self.decay = decay
        self._updates = 0
        self.update(model, force=True)

    def _clone(self, model):
        ema = type(model)()
        ema.load_state_dict(model.state_dict())
        for p in ema.parameters():
            p.requires_grad_(False)
        return ema

    @torch.no_grad()
    def update(self, model: nn.Module, force: bool = False):
        d = self.decay
        # (6) warmup del decay extendido a 1000 updates
        if self._updates < 1000:
            d = (self._updates / 1000.0) * self.decay
        msd = model.state_dict()
        esd = self.ema.state_dict()
        for k in esd.keys():
            if esd[k].dtype.is_floating_point:
                esd[k].mul_(d).add_(msd[k].detach(), alpha=1.0 - d)
            else:
                esd[k] = msd[k]
        self._updates += 1

# =========================
# INFERENCIA TTA / SUBWINDOW
# =========================
def subwindow_logits(model, X, sfreq, sw_len, sw_stride, device):
    model.eval()
    wl = int(round(sw_len * sfreq))
    st = int(round(sw_stride * sfreq))
    wl = max(1, min(wl, X.shape[-1])); st = max(1, st)
    out = []
    with torch.no_grad():
        for i in range(X.shape[0]):
            x = X[i]; acc = []
            for s in range(0, max(1, X.shape[-1]-wl+1), st):
                seg = x[:, s:s+wl]
                if seg.shape[-1] < wl:
                    pad = wl - seg.shape[-1]
                    seg = np.pad(seg, ((0,0),(0,pad)), mode='edge')
                xb = torch.tensor(seg[None, ...], dtype=torch.float32, device=device)
                logit = model(xb).detach().cpu().numpy()[0]
                acc.append(logit)
            acc = np.mean(np.stack(acc, axis=0), axis=0) if len(acc) else np.zeros(4, dtype=np.float32)
            out.append(acc)
    return np.stack(out, axis=0)

def time_shift_tta_logits(model, X, sfreq, shifts_s, device):
    model.eval()
    T = X.shape[-1]; out = []
    with torch.no_grad():
        for i in range(X.shape[0]):
            x0 = X[i]; acc = []
            for sh in shifts_s:
                shift = int(round(sh * sfreq))
                if shift == 0:
                    x = x0
                elif shift > 0:
                    x = np.pad(x0[:, shift:], ((0,0),(0,shift)), mode='edge')[:, :T]
                else:
                    shift = -shift
                    x = np.pad(x0[:, :-shift], ((0,0),(shift,0)), mode='edge')[:, :T]
                xb = torch.tensor(x[None, ...], dtype=torch.float32, device=device)
                logit = model(xb).detach().cpu().numpy()[0]
                acc.append(logit)
            out.append(np.mean(np.stack(acc, axis=0), axis=0))
    return np.stack(out, axis=0)

# =========================
# Utilidades splits estratificados por sujeto
# =========================
def build_subject_label_map(subject_ids):
    y_dom_list = []
    for sid in subject_ids:
        Xs, ys, _ = load_subject_epochs(sid, RESAMPLE_HZ, DO_NOTCH, DO_BANDPASS, DO_CAR, BP_LO, BP_HI)
        if len(ys) == 0:
            y_dom_list.append(-1)
            continue
        binc = np.bincount(ys, minlength=4)
        y_dom = int(np.argmax(binc))
        y_dom_list.append(y_dom)
    return np.array(y_dom_list, dtype=int)

# =========================
# TRAIN/EVAL por FOLD (con train/val/test)
# =========================
def train_one_fold(fold:int, device):
    def load_fold_subjects_local(folds_json: Path, fold: int):
        return load_fold_subjects(folds_json, fold)

    train_sub, test_sub = load_fold_subjects_local(FOLDS_JSON, fold)
    train_sub = [s for s in train_sub if subject_id_to_int(s) not in EXCLUDE_SUBJECTS]
    test_sub  = [s for s in test_sub  if subject_id_to_int(s) not in EXCLUDE_SUBJECTS]

    # Split de validación por sujetos (determinista y opcionalmente estratificado)
    rng = np.random.RandomState(RANDOM_STATE + fold)
    tr_subjects = sorted(train_sub)

    if VAL_STRAT_SUBJECT and len(tr_subjects) > 1:
        y_dom = build_subject_label_map(tr_subjects)
        if np.any(y_dom < 0):
            mask = y_dom >= 0
            moda = int(np.bincount(y_dom[mask]).argmax()) if mask.sum() > 0 else 0
            y_dom[~mask] = moda
        n_val_subj = max(1, int(round(len(tr_subjects) * VAL_SUBJECT_FRAC)))
        sss = StratifiedShuffleSplit(n_splits=1, test_size=n_val_subj, random_state=RANDOM_STATE + fold)
        idx = np.arange(len(tr_subjects))
        _, val_idx = next(sss.split(idx, y_dom))
        val_subjects = sorted([tr_subjects[i] for i in val_idx])
        train_subjects = [s for s in tr_subjects if s not in val_subjects]
    else:
        tr_subjects_shuf = tr_subjects.copy()
        rng.shuffle(tr_subjects_shuf)
        n_val_subj = max(1, int(round(len(tr_subjects_shuf) * VAL_SUBJECT_FRAC)))
        val_subjects = sorted(tr_subjects_shuf[:n_val_subj])
        train_subjects = sorted(tr_subjects_shuf[n_val_subj:])

    # Carga TRAIN/VAL/TEST
    X_tr_list, y_tr_list, sub_tr_list = [], [], []
    X_val_list, y_val_list, sub_val_list = [], [], []
    X_te_list, y_te_list, sub_te_list = [], [], []
    sfreq = None

    for sid in tqdm(train_subjects, desc=f"Cargando train fold{fold}"):
        Xs, ys, sf = load_subject_epochs(sid, RESAMPLE_HZ, DO_NOTCH, DO_BANDPASS, DO_CAR, BP_LO, BP_HI)
        if len(ys) == 0: continue
        X_tr_list.append(Xs); y_tr_list.append(ys)
        sub_tr_list.append(np.full_like(ys, fill_value=subject_id_to_int(sid)))
        sfreq = sf if sfreq is None else sfreq

    for sid in tqdm(val_subjects, desc=f"Cargando val fold{fold}"):
        Xs, ys, sf = load_subject_epochs(sid, RESAMPLE_HZ, DO_NOTCH, DO_BANDPASS, DO_CAR, BP_LO, BP_HI)
        if len(ys) == 0: continue
        X_val_list.append(Xs); y_val_list.append(ys)
        sub_val_list.append(np.full_like(ys, fill_value=subject_id_to_int(sid)))
        sfreq = sf if sfreq is None else sfreq

    for sid in tqdm(test_sub, desc=f"Cargando test fold{fold}"):
        Xs, ys, sf = load_subject_epochs(sid, RESAMPLE_HZ, DO_NOTCH, DO_BANDPASS, DO_CAR, BP_LO, BP_HI)
        if len(ys) == 0: continue
        X_te_list.append(Xs); y_te_list.append(ys)
        sub_te_list.append(np.full_like(ys, fill_value=subject_id_to_int(sid)))
        sfreq = sf if sfreq is None else sfreq

    # Concatenar
    X_tr = np.concatenate(X_tr_list, axis=0); y_tr = np.concatenate(y_tr_list, axis=0)
    sub_tr = np.concatenate(sub_tr_list, axis=0)
    X_val = np.concatenate(X_val_list, axis=0); y_val = np.concatenate(y_val_list, axis=0)
    sub_val = np.concatenate(sub_val_list, axis=0)
    X_te = np.concatenate(X_te_list, axis=0); y_te = np.concatenate(y_te_list, axis=0)
    sub_te = np.concatenate(sub_te_list, axis=0)

    print(f"[Fold {fold}/5] Entrenando modelo global... (n_train={len(y_tr)} | n_val={len(y_val)} | n_test={len(y_te)})")

    # Normalización por canal (fit en TRAIN y aplicar a VAL/TEST)
    if ZSCORE_PER_EPOCH:
        X_tr_std, X_val_std, X_te_std = X_tr, X_val, X_te
    else:
        X_tr_std, X_val_std = standardize_per_channel(X_tr, X_val)
        _,        X_te_std  = standardize_per_channel(X_tr, X_te)

    # Datasets
    tr_ds  = TensorDataset(torch.tensor(X_tr_std),  torch.tensor(y_tr).long(),  torch.tensor(sub_tr).long())
    val_ds = TensorDataset(torch.tensor(X_val_std), torch.tensor(y_val).long(), torch.tensor(sub_val).long())
    te_ds  = TensorDataset(torch.tensor(X_te_std),  torch.tensor(y_te).long(),  torch.tensor(sub_te).long())

    # (2) Weighted sampler templado: w = (1/ws)^a * (1/wk)^b
    def make_weighted_sampler(dataset: TensorDataset):
        _Xb, yb, sb = dataset.tensors
        yb_np = yb.numpy()
        sb_np = sb.numpy()
        uniq_s, cnt_s = np.unique(sb_np, return_counts=True)
        map_s = {s:c for s,c in zip(uniq_s, cnt_s)}
        key = sb_np.astype(np.int64) * 10 + yb_np.astype(np.int64)
        uniq_k, cnt_k = np.unique(key, return_counts=True)
        map_k = {k:c for k,c in zip(uniq_k, cnt_k)}
        a, b = 0.8, 1.0
        w = []
        for s, y in zip(sb_np, yb_np):
            k = int(s)*10 + int(y)
            ws = float(map_s[int(s)])
            wk = float(map_k[k])
            w.append((ws ** (-a)) * (wk ** (-b)))
        w = np.array(w, dtype=np.float64)
        w = w / (w.mean() + 1e-12)
        sampler = WeightedRandomSampler(weights=torch.tensor(w, dtype=torch.double),
                                        num_samples=len(yb_np), replacement=True)
        return sampler

    if USE_WEIGHTED_SAMPLER:
        tr_sampler = make_weighted_sampler(tr_ds)
        tr_ld  = DataLoader(tr_ds, batch_size=BATCH_SIZE, sampler=tr_sampler, drop_last=False, worker_init_fn=seed_worker)
    else:
        tr_ld  = DataLoader(tr_ds, batch_size=BATCH_SIZE, shuffle=True,  drop_last=False, worker_init_fn=seed_worker)

    val_ld = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, worker_init_fn=seed_worker)
    te_ld  = DataLoader(te_ds,  batch_size=BATCH_SIZE, shuffle=False, drop_last=False, worker_init_fn=seed_worker)

    # Modelo
    model = EEGCNNTransformer(n_ch=8, n_cls=4, d_model=D_MODEL, n_heads=N_HEADS,
                              n_layers=N_LAYERS, p_drop=P_DROP, p_drop_encoder=P_DROP_ENCODER).to(device)

    # Optimizador + Focal Loss (3) alpha tweaking
    opt = torch.optim.AdamW(model.parameters(), lr=BASE_LR, weight_decay=1e-2)

    class_counts = np.bincount(y_tr, minlength=4).astype(np.float32)
    inv = class_counts.sum() / (4.0 * np.maximum(class_counts, 1.0))
    alpha = torch.tensor(inv, dtype=torch.float32, device=device)
    alpha_mean = alpha.mean().item()
    alpha[2] = 1.5 * alpha_mean      # both_fists (ya estaba)
    alpha[3] = 1.15 * alpha_mean      # (3) pequeño empujón a both_feet
    crit = FocalLoss(alpha=alpha, gamma=1.5, reduction='mean')

    # (7) Warmup+Cosine (min_factor=0.1)
    from torch.optim.lr_scheduler import LambdaLR
    total_epochs = EPOCHS
    warmup_epochs = max(1, int(WARMUP_EPOCHS))
    min_factor = 0.1
    def lr_lambda(current_epoch):
        if current_epoch < warmup_epochs:
            return (current_epoch + 1) / warmup_epochs
        progress = (current_epoch - warmup_epochs) / max(1, (total_epochs - warmup_epochs))
        progress = min(1.0, max(0.0, progress))
        return min_factor + 0.5 * (1.0 - min_factor) * (1.0 + np.cos(np.pi * progress))
    scheduler = LambdaLR(opt, lr_lambda=lr_lambda)

    # EMA
    if USE_EMA:
        ema = ModelEMA(model, decay=EMA_DECAY, device=device)
    else:
        ema = None

    # Entrenamiento con early stopping por F1 macro (EMA en val/test si aplica)
    best_f1, best_state, wait = 0.0, None, 0
    hist = {"ep": [], "tr_loss": [], "tr_acc": [], "val_acc": [], "val_f1m": [], "lr": []}

    def evaluate_on(loader, use_ema=True):
        mdl = ema.ema if (ema is not None and use_ema) else model
        mdl.eval()
        preds, gts = [], []
        with torch.no_grad():
            for xb, yb, _sb in loader:
                xb = xb.to(device)
                p = mdl(xb).argmax(dim=1).cpu().numpy()
                preds.append(p); gts.append(yb.numpy())
        preds = np.concatenate(preds); gts = np.concatenate(gts)
        acc = accuracy_score(gts, preds)
        f1m = f1_score(gts, preds, average='macro')
        return acc, f1m

    for ep in range(1, EPOCHS+1):
        # ---- Train ----
        model.train()
        tr_loss, n_seen, tr_correct = 0.0, 0, 0
        for xb, yb, _sb in tr_ld:
            xb = xb.to(device); yb = yb.to(device)
            xb = augment_batch(xb)
            opt.zero_grad()
            logits = model(xb)
            loss = crit(logits, yb)
            loss.backward()
            # (5) Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            opt.step()
            if ema is not None:
                ema.update(model)

            tr_loss += loss.item() * len(yb)
            n_seen += len(yb)
            tr_correct += (logits.argmax(1) == yb).sum().item()
        tr_loss /= max(1, n_seen)
        tr_acc = tr_correct / max(1, n_seen)

        # ---- Val (EMA si aplica) ----
        acc, f1m = evaluate_on(val_ld, use_ema=True)

        hist["ep"].append(ep)
        hist["tr_loss"].append(tr_loss)
        hist["tr_acc"].append(tr_acc)
        hist["val_acc"].append(acc)
        hist["val_f1m"].append(f1m)
        hist["lr"].append(scheduler.get_last_lr()[0])

        print(f"  Época {ep:3d} | train_loss={tr_loss:.4f} | train_acc={tr_acc:.4f} | val_acc={acc:.4f} | val_f1m={f1m:.4f} | LR={scheduler.get_last_lr()[0]:.6f}")

        improved = f1m > best_f1 + 1e-4
        if improved:
            best_f1 = f1m
            ref_model = ema.ema if ema is not None else model
            best_state = {k: v.detach().cpu() for k, v in ref_model.state_dict().items()}
            wait = 0
        else:
            wait += 1

        scheduler.step()
        if wait >= PATIENCE:
            print(f"  Early stopping en época {ep} (mejor val_f1m={best_f1:.4f})")
            break

    # Cargamos mejor estado guardado (EMA si estaba activo)
    if best_state is not None:
        (ema.ema if (ema is not None) else model).load_state_dict(best_state)

    # ---- Guardar curva ----
    fig = plt.figure(figsize=(8,4.5))
    ax1 = plt.gca()
    ax1.plot(hist["ep"], hist["tr_loss"], label="train_loss")
    ax1.plot(hist["ep"], hist["val_f1m"], label="val_f1m")
    ax1.plot(hist["ep"], hist["val_acc"], label="val_acc")
    ax1.set_xlabel("Época"); ax1.set_title(f"Fold {fold} — Curva de entrenamiento")
    ax1.legend(); ax1.grid(True, alpha=0.3)
    out_png = f"training_curve_fold{fold}.png"
    plt.tight_layout()
    plt.savefig(out_png, dpi=140)
    plt.close(fig)
    print(f"↳ Curva de entrenamiento guardada: {out_png}")

    # ---- Evaluación final en TEST ----
    eval_model = ema.ema if (ema is not None) else model
    eval_model.eval()

    sfreq_used = RESAMPLE_HZ
    if sfreq_used is None:
        sfreq_used = int(round(X_te_std.shape[-1] / (TMAX - TMIN)))

    if (not SW_ENABLE) or SW_MODE == 'none':
        preds, gts = [], []
        with torch.no_grad():
            for xb, yb, _sb in te_ld:
                xb = xb.to(device)
                p = eval_model(xb).argmax(dim=1).cpu().numpy()
                preds.append(p); gts.append(yb.numpy())
        preds = np.concatenate(preds); gts = np.concatenate(gts)
    elif SW_MODE in ('subwin', 'tta'):
        logits_tta = None
        logits_sw  = None
        if SW_MODE == 'subwin':
            logits_sw = subwindow_logits(eval_model, X_te_std, sfreq_used, SW_LEN, SW_STRIDE, device)
        elif SW_MODE == 'tta':
            logits_tta = time_shift_tta_logits(eval_model, X_te_std, sfreq_used, TTA_SHIFTS_S, device)

        if COMBINE_TTA_AND_SUBWIN:
            if logits_tta is None:
                logits_tta = time_shift_tta_logits(eval_model, X_te_std, sfreq_used, TTA_SHIFTS_S, device)
            if logits_sw is None:
                logits_sw  = subwindow_logits(eval_model, X_te_std, sfreq_used, SW_LEN, SW_STRIDE, device)
            logits = 0.5 * logits_tta + 0.5 * logits_sw
        else:
            logits = logits_tta if logits_tta is not None else logits_sw

        preds = logits.argmax(axis=1); gts = y_te
    else:
        raise ValueError(f"SW_MODE desconocido: {SW_MODE}")

    acc = accuracy_score(gts, preds)
    f1m = f1_score(gts, preds, average='macro')
    print(f"[Fold {fold}/5] Global acc={acc:.4f}\n")
    print(classification_report(gts, preds, target_names=[c.replace('_',' ') for c in CLASS_NAMES], digits=4))
    print("Confusion matrix (rows=true, cols=pred):")
    print(confusion_matrix(gts, preds, labels=[0,1,2,3]))

    return acc, f1m

# =========================
# LOOP 5 FOLDS + RESUMEN
# =========================
if __name__ == "__main__":
    acc_folds, f1_folds = [], []
    for fold in range(1, 6):
        acc, f1m = train_one_fold(fold, DEVICE)
        acc_folds.append(f"{acc:.4f}")
        f1_folds.append(f"{f1m:.4f}")

    acc_mean = float(np.mean([float(a) for a in acc_folds]))
    f1_mean  = float(np.mean([float(f) for f in f1_folds]))

    print("\n============================================================")
    print("RESULTADOS FINALES")
    print("============================================================")
    print(f"Global folds (ACC): {acc_folds}")
    print(f"Global mean ACC: {acc_mean:.4f}")
    print(f"F1 folds (MACRO): {f1_folds}")
    print(f"F1 mean (MACRO): {f1_mean:.4f}")


🚀 Usando dispositivo: cuda
🧠 INICIANDO EXPERIMENTO CON CNN+Transformer (K-Fold por sujeto como EEGNet)
🔧 Configuración: 4c, 8 canales, 6s | EPOCHS=60, BATCH=64, LR=0.001 | ZSCORE_PER_EPOCH=False


Cargando train fold1: 100%|██████████| 67/67 [00:08<00:00,  8.34it/s]
Cargando val fold1: 100%|██████████| 15/15 [00:01<00:00,  8.28it/s]
Cargando test fold1: 100%|██████████| 21/21 [00:02<00:00,  8.29it/s]


[Fold 1/5] Entrenando modelo global... (n_train=5628 | n_val=1260 | n_test=1764)




  Época   1 | train_loss=0.2228 | train_acc=0.2822 | val_acc=0.2968 | val_f1m=0.2339 | LR=0.000125
  Época   2 | train_loss=0.2013 | train_acc=0.3701 | val_acc=0.4087 | val_f1m=0.3803 | LR=0.000250
  Época   3 | train_loss=0.1803 | train_acc=0.4726 | val_acc=0.4302 | val_f1m=0.4253 | LR=0.000375
  Época   4 | train_loss=0.1769 | train_acc=0.4812 | val_acc=0.4500 | val_f1m=0.4505 | LR=0.000500
  Época   5 | train_loss=0.1705 | train_acc=0.4993 | val_acc=0.4651 | val_f1m=0.4590 | LR=0.000625
  Época   6 | train_loss=0.1670 | train_acc=0.5258 | val_acc=0.4643 | val_f1m=0.4612 | LR=0.000750
  Época   7 | train_loss=0.1638 | train_acc=0.5352 | val_acc=0.4452 | val_f1m=0.4363 | LR=0.000875
  Época   8 | train_loss=0.1619 | train_acc=0.5288 | val_acc=0.4532 | val_f1m=0.4557 | LR=0.001000
  Época   9 | train_loss=0.1647 | train_acc=0.5135 | val_acc=0.4413 | val_f1m=0.4447 | LR=0.001000
  Época  10 | train_loss=0.1573 | train_acc=0.5439 | val_acc=0.4508 | val_f1m=0.4501 | LR=0.000999
  Época  1

Cargando train fold2: 100%|██████████| 67/67 [00:08<00:00,  8.12it/s]
Cargando val fold2: 100%|██████████| 15/15 [00:01<00:00,  8.18it/s]
Cargando test fold2: 100%|██████████| 21/21 [00:02<00:00,  8.36it/s]


[Fold 2/5] Entrenando modelo global... (n_train=5628 | n_val=1260 | n_test=1764)




  Época   1 | train_loss=0.2254 | train_acc=0.2774 | val_acc=0.2960 | val_f1m=0.1950 | LR=0.000125
  Época   2 | train_loss=0.2034 | train_acc=0.3651 | val_acc=0.4381 | val_f1m=0.4369 | LR=0.000250
  Época   3 | train_loss=0.1904 | train_acc=0.4314 | val_acc=0.4651 | val_f1m=0.4496 | LR=0.000375
  Época   4 | train_loss=0.1830 | train_acc=0.4645 | val_acc=0.4635 | val_f1m=0.4623 | LR=0.000500
  Época   5 | train_loss=0.1794 | train_acc=0.4627 | val_acc=0.4651 | val_f1m=0.4664 | LR=0.000625
  Época   6 | train_loss=0.1770 | train_acc=0.4789 | val_acc=0.4444 | val_f1m=0.4288 | LR=0.000750
  Época   7 | train_loss=0.1771 | train_acc=0.4815 | val_acc=0.4794 | val_f1m=0.4826 | LR=0.000875
  Época   8 | train_loss=0.1734 | train_acc=0.4931 | val_acc=0.4484 | val_f1m=0.4509 | LR=0.001000
  Época   9 | train_loss=0.1707 | train_acc=0.5002 | val_acc=0.4778 | val_f1m=0.4829 | LR=0.001000
  Época  10 | train_loss=0.1664 | train_acc=0.5155 | val_acc=0.4849 | val_f1m=0.4854 | LR=0.000999
  Época  1

Cargando train fold3: 100%|██████████| 67/67 [00:08<00:00,  8.34it/s]
Cargando val fold3: 100%|██████████| 15/15 [00:01<00:00,  8.41it/s]
Cargando test fold3: 100%|██████████| 21/21 [00:02<00:00,  8.41it/s]


[Fold 3/5] Entrenando modelo global... (n_train=5628 | n_val=1260 | n_test=1764)




  Época   1 | train_loss=0.2213 | train_acc=0.2937 | val_acc=0.3230 | val_f1m=0.2394 | LR=0.000125
  Época   2 | train_loss=0.1972 | train_acc=0.3898 | val_acc=0.4627 | val_f1m=0.4549 | LR=0.000250
  Época   3 | train_loss=0.1835 | train_acc=0.4595 | val_acc=0.4698 | val_f1m=0.4589 | LR=0.000375
  Época   4 | train_loss=0.1791 | train_acc=0.4829 | val_acc=0.4778 | val_f1m=0.4768 | LR=0.000500
  Época   5 | train_loss=0.1750 | train_acc=0.4982 | val_acc=0.4770 | val_f1m=0.4783 | LR=0.000625
  Época   6 | train_loss=0.1718 | train_acc=0.5034 | val_acc=0.4667 | val_f1m=0.4606 | LR=0.000750
  Época   7 | train_loss=0.1682 | train_acc=0.5114 | val_acc=0.4738 | val_f1m=0.4708 | LR=0.000875
  Época   8 | train_loss=0.1677 | train_acc=0.5203 | val_acc=0.4905 | val_f1m=0.4921 | LR=0.001000
  Época   9 | train_loss=0.1653 | train_acc=0.5183 | val_acc=0.5008 | val_f1m=0.5042 | LR=0.001000
  Época  10 | train_loss=0.1617 | train_acc=0.5295 | val_acc=0.4968 | val_f1m=0.5005 | LR=0.000999
  Época  1

Cargando train fold4: 100%|██████████| 68/68 [00:08<00:00,  8.32it/s]
Cargando val fold4: 100%|██████████| 15/15 [00:01<00:00,  8.34it/s]
Cargando test fold4: 100%|██████████| 20/20 [00:02<00:00,  8.43it/s]


[Fold 4/5] Entrenando modelo global... (n_train=5712 | n_val=1260 | n_test=1680)




  Época   1 | train_loss=0.2248 | train_acc=0.2799 | val_acc=0.3206 | val_f1m=0.2356 | LR=0.000125
  Época   2 | train_loss=0.2045 | train_acc=0.3641 | val_acc=0.4484 | val_f1m=0.4439 | LR=0.000250
  Época   3 | train_loss=0.1881 | train_acc=0.4396 | val_acc=0.4611 | val_f1m=0.4624 | LR=0.000375
  Época   4 | train_loss=0.1807 | train_acc=0.4718 | val_acc=0.4738 | val_f1m=0.4721 | LR=0.000500
  Época   5 | train_loss=0.1742 | train_acc=0.4937 | val_acc=0.4722 | val_f1m=0.4751 | LR=0.000625
  Época   6 | train_loss=0.1775 | train_acc=0.4795 | val_acc=0.4857 | val_f1m=0.4847 | LR=0.000750
  Época   7 | train_loss=0.1746 | train_acc=0.4977 | val_acc=0.4651 | val_f1m=0.4615 | LR=0.000875
  Época   8 | train_loss=0.1728 | train_acc=0.4982 | val_acc=0.4778 | val_f1m=0.4771 | LR=0.001000
  Época   9 | train_loss=0.1687 | train_acc=0.5117 | val_acc=0.4817 | val_f1m=0.4825 | LR=0.001000
  Época  10 | train_loss=0.1670 | train_acc=0.5186 | val_acc=0.4960 | val_f1m=0.4986 | LR=0.000999
  Época  1

Cargando train fold5: 100%|██████████| 68/68 [00:08<00:00,  8.31it/s]
Cargando val fold5: 100%|██████████| 15/15 [00:01<00:00,  8.30it/s]
Cargando test fold5: 100%|██████████| 20/20 [00:02<00:00,  8.23it/s]


[Fold 5/5] Entrenando modelo global... (n_train=5712 | n_val=1260 | n_test=1680)




  Época   1 | train_loss=0.2235 | train_acc=0.2638 | val_acc=0.2960 | val_f1m=0.2855 | LR=0.000125
  Época   2 | train_loss=0.2031 | train_acc=0.3780 | val_acc=0.3984 | val_f1m=0.3959 | LR=0.000250
  Época   3 | train_loss=0.1842 | train_acc=0.4629 | val_acc=0.4413 | val_f1m=0.4405 | LR=0.000375
  Época   4 | train_loss=0.1765 | train_acc=0.4879 | val_acc=0.4397 | val_f1m=0.4409 | LR=0.000500
  Época   5 | train_loss=0.1809 | train_acc=0.4690 | val_acc=0.4254 | val_f1m=0.4172 | LR=0.000625
  Época   6 | train_loss=0.1749 | train_acc=0.4821 | val_acc=0.4373 | val_f1m=0.4356 | LR=0.000750
  Época   7 | train_loss=0.1715 | train_acc=0.5023 | val_acc=0.4000 | val_f1m=0.3734 | LR=0.000875
  Época   8 | train_loss=0.1718 | train_acc=0.5058 | val_acc=0.4571 | val_f1m=0.4603 | LR=0.001000
  Época   9 | train_loss=0.1667 | train_acc=0.5093 | val_acc=0.4524 | val_f1m=0.4541 | LR=0.001000
  Época  10 | train_loss=0.1633 | train_acc=0.5278 | val_acc=0.4611 | val_f1m=0.4629 | LR=0.000999
  Época  1

# MEJOR 54

In [4]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "7"
os.environ['PYTHONHASHSEED'] = '42'
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
import re, json, random
from pathlib import Path
from glob import glob
import numpy as np
import mne
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset, WeightedRandomSampler
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, classification_report
from sklearn.model_selection import StratifiedShuffleSplit

# =========================
# REPRODUCIBILIDAD
# =========================
RANDOM_STATE = 42


def seed_everything(seed: int = 42):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    try:
        torch.backends.cuda.matmul.allow_tf32 = False
        torch.backends.cudnn.allow_tf32 = False
    except Exception:
        pass
    try:
        torch.use_deterministic_algorithms(True)
    except Exception:
        pass


def seed_worker(worker_id: int):
    worker_seed = RANDOM_STATE + worker_id
    np.random.seed(worker_seed)
    random.seed(worker_seed)


seed_everything(RANDOM_STATE)

# =========================
# CONFIG
# =========================
PROJ = Path('..').resolve().parent
DATA_RAW = PROJ / 'data' / 'raw'
FOLDS_JSON = PROJ / 'models' / 'folds' / 'Kfold5.json'

EPOCHS = 60
BATCH_SIZE = 64        # múltiplo de 4
BASE_LR = 1e-3
WARMUP_EPOCHS = 8
PATIENCE = 12

# Split de validación por fold (por sujetos)
VAL_SUBJECT_FRAC = 0.18  # ≈ 18% de sujetos del train → val
VAL_STRAT_SUBJECT = True # estratifica por etiqueta dominante de cada sujeto

# Prepro
RESAMPLE_HZ = None
DO_NOTCH = True
DO_BANDPASS = False       # respetamos tu setup original
BP_LO, BP_HI = 4.0, 38.0
DO_CAR = False
ZSCORE_PER_EPOCH = False

# Modelo
D_MODEL = 128
N_HEADS = 4
N_LAYERS = 2
P_DROP = 0.2            # dropout en conv stack
P_DROP_ENCODER = 0.3    # dropout antes del encoder

# Ventana temporal
TMIN, TMAX = -1.0, 5.0

# TTA / SUBWINDOW en TEST
SW_MODE = 'tta'   # 'none'|'subwin'|'tta'
SW_ENABLE = True
TTA_SHIFTS_S = [-0.075, -0.05, -0.025, 0.0, 0.025, 0.05, 0.075]
SW_LEN, SW_STRIDE = 4.5, 1.5
COMBINE_TTA_AND_SUBWIN = False  # solicitado (no cambio 1)

# Sampler balanceado
USE_WEIGHTED_SAMPLER = True

# EMA
USE_EMA = True
EMA_DECAY = 0.9995  # (6) más lenta

EXCLUDE_SUBJECTS = {38, 88, 89, 92, 100, 104}
EXPECTED_8 = ['C3','C4','Cz','CP3','CP4','FC3','FC4','FCz']
CLASS_NAMES = ['left', 'right', 'both_fists', 'both_feet']

IMAGERY_RUNS_LR = {4, 8, 12}
IMAGERY_RUNS_BF = {6, 10, 14}

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🚀 Usando dispositivo: {DEVICE}")
print("🧠 INICIANDO EXPERIMENTO CON CNN+Transformer (K-Fold por sujeto como EEGNet)")
print(f"🔧 Configuración: 4c, 8 canales, 6s | EPOCHS={EPOCHS}, BATCH={BATCH_SIZE}, LR={BASE_LR} | ZSCORE_PER_EPOCH={ZSCORE_PER_EPOCH}")

# =========================
# UTILIDADES I/O
# =========================
def normalize_ch_name(name: str) -> str:
    s = re.sub(r'[^A-Za-z0-9]', '', name)
    return s.upper()


NORMALIZED_TARGETS = [normalize_ch_name(c) for c in EXPECTED_8]


def pick_8_channels(raw: mne.io.BaseRaw) -> mne.io.BaseRaw:
    chs = raw.info['ch_names']
    norm_map = {normalize_ch_name(ch): ch for ch in chs}
    picked = []
    for target_norm, target_orig in zip(NORMALIZED_TARGETS, EXPECTED_8):
        if target_norm in norm_map:
            picked.append(norm_map[target_norm])
        else:
            raise RuntimeError(f"Canal requerido '{target_orig}' no encontrado. Disponibles: {chs}")
    return raw.pick(picks=picked)


def list_subject_imagery_edfs(subject_id: str) -> list:
    subj_dir = DATA_RAW / subject_id
    edfs = []
    for r in [4, 6, 8, 10, 12, 14]:
        edfs.extend(glob(str(subj_dir / f"{subject_id}R{r:02d}.edf")))
    return sorted(edfs)


def subject_id_to_int(s: str) -> int:
    m = re.match(r'[Ss](\d+)', s)
    return int(m.group(1)) if m else -1


def load_subject_epochs(subject_id: str, resample_hz: int, do_notch: bool, do_bandpass: bool,
                        do_car: bool, bp_lo: float, bp_hi: float):
    edfs = list_subject_imagery_edfs(subject_id)
    if len(edfs) == 0:
        return np.empty((0,8,1), dtype=np.float32), np.empty((0,), dtype=int), None

    X_list, y_list, sfreq_list = [], [], []

    for edf_path in edfs:
        m = re.search(r"R(\d{2})", Path(edf_path).name)
        run = int(m.group(1)) if m else -1

        raw = mne.io.read_raw_edf(edf_path, preload=True, verbose='ERROR')
        raw = pick_8_channels(raw)

        if do_notch:
            raw.notch_filter(freqs=[60.0], picks='all', verbose='ERROR')
        if do_bandpass:
            raw.filter(l_freq=bp_lo, h_freq=bp_hi, picks='all', verbose='ERROR')
        if do_car:
            raw.set_eeg_reference('average', projection=False, verbose='ERROR')

        if resample_hz is not None and resample_hz > 0:
            raw.resample(resample_hz)
        sfreq = raw.info['sfreq']

        events, event_id = mne.events_from_annotations(raw, verbose='ERROR')
        keep = {k: v for k, v in event_id.items() if k in {'T1', 'T2'}}
        if len(keep) == 0:
            continue

        epochs = mne.Epochs(raw, events=events, event_id=keep, tmin=TMIN, tmax=TMAX,
                            baseline=None, preload=True, verbose='ERROR')
        X = epochs.get_data()

        if ZSCORE_PER_EPOCH:
            X = X.astype(np.float32)
            eps = 1e-6
            mu = X.mean(axis=2, keepdims=True)
            sd = X.std(axis=2, keepdims=True) + eps
            X = (X - mu) / sd

        ev_codes = epochs.events[:, 2]
        inv = {v: k for k, v in keep.items()}
        y_run = []
        for code in ev_codes:
            lab = inv[code]
            if run in IMAGERY_RUNS_LR:
                y_run.append(0 if lab == 'T1' else 1)
            elif run in IMAGERY_RUNS_BF:
                y_run.append(2 if lab == 'T1' else 3)
            else:
                y_run.append(-1)
        y_run = np.array(y_run, dtype=int)
        mask = y_run >= 0
        if mask.sum() == 0:
            continue

        X_list.append(X[mask])
        y_list.append(y_run[mask])
        sfreq_list.append(sfreq)

    if len(X_list) == 0:
        return np.empty((0,8,1), dtype=np.float32), np.empty((0,), dtype=int), None

    X_all = np.concatenate(X_list, axis=0).astype(np.float32)
    y_all = np.concatenate(y_list, axis=0).astype(int)

    if len(set([int(round(s)) for s in sfreq_list])) != 1:
        raise RuntimeError(f"Sampling rates inconsistentes: {sfreq_list}")

    return X_all, y_all, sfreq_list[0]


def load_fold_subjects(folds_json: Path, fold: int):
    with open(folds_json, 'r') as f:
        data = json.load(f)
    for item in data.get('folds', []):
        if int(item.get('fold', -1)) == int(fold):
            return list(item.get('train', [])), list(item.get('test', []))
    raise ValueError(f"Fold {fold} not found in {folds_json}")


def standardize_per_channel(train_X, other_X):
    C = train_X.shape[1]
    train_X = train_X.astype(np.float32)
    other_X = other_X.astype(np.float32)
    for c in range(C):
        mu = train_X[:, c, :].mean()
        sd = train_X[:, c, :].std()
        sd = sd if sd > 1e-6 else 1.0
        train_X[:, c, :] = (train_X[:, c, :] - mu) / sd
        other_X[:, c, :] = (other_X[:, c, :] - mu) / sd
    return train_X, other_X

# =========================
# MODELO (GroupNorm en conv)
# =========================
def make_gn(num_channels, num_groups=8):
    g = min(num_groups, num_channels)
    while num_channels % g != 0 and g > 1:
        g -= 1
    return nn.GroupNorm(g, num_channels)


class DepthwiseSeparableConv(nn.Module):
    def __init__(self, in_ch, out_ch, k, s=1, p=0, p_drop=0.2):
        super().__init__()
        self.dw = nn.Conv1d(in_ch, in_ch, kernel_size=k, stride=s, padding=p, groups=in_ch, bias=False)
        self.pw = nn.Conv1d(in_ch, out_ch, kernel_size=1, bias=False)
        self.norm = make_gn(out_ch)
        self.act = nn.ELU()
        self.dropout = nn.Dropout(p=p_drop)

    def forward(self, x):
        x = self.dw(x); x = self.pw(x); x = self.norm(x)
        x = self.act(x); x = self.dropout(x)
        return x


class EEGCNNTransformer(nn.Module):
    def __init__(self, n_ch=8, n_cls=4, d_model=128, n_heads=4, n_layers=2,
                 p_drop=0.2, p_drop_encoder=0.3):
        super().__init__()
        self.conv_t = nn.Sequential(
            nn.Conv1d(n_ch, 32, kernel_size=129, stride=2, padding=64, bias=False),
            make_gn(32),
            nn.ELU(),
            nn.Dropout(p=p_drop),
            DepthwiseSeparableConv(32, 64, k=31, s=2, p=15, p_drop=p_drop),
            DepthwiseSeparableConv(64, 128, k=15, s=2, p=7,  p_drop=p_drop),
        )
        self.proj = nn.Conv1d(128, d_model, kernel_size=1, bias=False)
        self.dropout = nn.Dropout(p=p_drop_encoder)
        self.pos_encoding = None
        enc = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=n_heads, dim_feedforward=2*d_model,
            batch_first=True, activation='gelu', dropout=0.1, norm_first=False  # (4)
        )
        self.encoder = nn.TransformerEncoder(enc, num_layers=n_layers)
        self.cls = nn.Parameter(torch.zeros(1, 1, d_model))
        nn.init.normal_(self.cls, std=0.02)
        self.head = nn.Sequential(nn.LayerNorm(d_model), nn.Linear(d_model, n_cls))

    def _positional_encoding(self, L, d):
        pos = torch.arange(0, L, dtype=torch.float32).unsqueeze(1)
        i   = torch.arange(0, d, dtype=torch.float32).unsqueeze(0)
        angle = pos / torch.pow(10000, (2 * (i//2)) / d)
        pe = torch.zeros(L, d, dtype=torch.float32)
        pe[:, 0::2] = torch.sin(angle[:, 0::2])
        pe[:, 1::2] = torch.cos(angle[:, 1::2])
        return pe

    def forward(self, x):
        z = self.conv_t(x)           # (B, 128, T')
        z = self.proj(z)             # (B, d_model, T')
        z = self.dropout(z)
        z = z.transpose(1, 2)        # (B, T', d_model)
        B, L, D = z.shape
        if (self.pos_encoding is None) or (self.pos_encoding.shape[0] != L) or (self.pos_encoding.shape[1] != D):
            self.pos_encoding = self._positional_encoding(L, D).to(z.device)
        z = z + self.pos_encoding[None, :, :]
        cls_tok = self.cls.expand(B, -1, -1)
        z = torch.cat([cls_tok, z], dim=1)
        z = self.encoder(z)
        cls = z[:, 0, :]
        return self.head(cls)

# =========================
# (1) FOCAL LOSS
# =========================
class FocalLoss(nn.Module):
    def __init__(self, alpha: torch.Tensor, gamma: float = 1.5, reduction: str = 'mean'):
        super().__init__()
        self.alpha = alpha / alpha.sum()
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, logits, target):
        logp = nn.functional.log_softmax(logits, dim=-1)      # (B,C)
        p = logp.exp()
        idx = torch.arange(target.shape[0], device=logits.device)
        pt = p[idx, target]
        logpt = logp[idx, target]
        at = self.alpha[target]
        loss = - at * ((1 - pt) ** self.gamma) * logpt
        if self.reduction == 'mean': return loss.mean()
        if self.reduction == 'sum':  return loss.sum()
        return loss

# =========================
# (4) AUGMENTS MÁS FUERTES (pero estables)
# =========================
def augment_batch(
    xb,
    p_jitter=0.35, p_noise=0.35, p_chdrop=0.15,
    max_jitter_frac=0.03, noise_std=0.03, max_chdrop=1
):
    B, C, T = xb.shape
    if np.random.rand() < p_jitter:
        max_shift = int(max(1, T*max_jitter_frac))
        shifts = torch.randint(low=-max_shift, high=max_shift+1, size=(B,), device=xb.device)
        for i in range(B):
            xb[i] = torch.roll(xb[i], shifts=int(shifts[i].item()), dims=-1)
    if np.random.rand() < p_noise:
        xb = xb + noise_std*torch.randn_like(xb)
    if np.random.rand() < p_chdrop and max_chdrop > 0:
        k = min(max_chdrop, C)
        for i in range(B):
            idx = torch.randperm(C, device=xb.device)[:k]
            xb[i, idx, :] = 0.0
    return xb

# =========================
# EMA de pesos
# =========================
class ModelEMA:
    def __init__(self, model: nn.Module, decay: float = 0.9995, device=None):
        self.ema = self._clone(model).to(device if device is not None else next(model.parameters()).device)
        self.decay = decay
        self._updates = 0
        self.update(model, force=True)

    def _clone(self, model):
        ema = type(model)()
        ema.load_state_dict(model.state_dict())
        for p in ema.parameters():
            p.requires_grad_(False)
        return ema

    @torch.no_grad()
    def update(self, model: nn.Module, force: bool = False):
        d = self.decay
        # (6) warmup del decay extendido a 1000 updates
        if self._updates < 1000:
            d = (self._updates / 1000.0) * self.decay
        msd = model.state_dict()
        esd = self.ema.state_dict()
        for k in esd.keys():
            if esd[k].dtype.is_floating_point:
                esd[k].mul_(d).add_(msd[k].detach(), alpha=1.0 - d)
            else:
                esd[k] = msd[k]
        self._updates += 1

# =========================
# INFERENCIA TTA / SUBWINDOW
# =========================
def subwindow_logits(model, X, sfreq, sw_len, sw_stride, device):
    model.eval()
    wl = int(round(sw_len * sfreq))
    st = int(round(sw_stride * sfreq))
    wl = max(1, min(wl, X.shape[-1])); st = max(1, st)
    out = []
    with torch.no_grad():
        for i in range(X.shape[0]):
            x = X[i]; acc = []
            for s in range(0, max(1, X.shape[-1]-wl+1), st):
                seg = x[:, s:s+wl]
                if seg.shape[-1] < wl:
                    pad = wl - seg.shape[-1]
                    seg = np.pad(seg, ((0,0),(0,pad)), mode='edge')
                xb = torch.tensor(seg[None, ...], dtype=torch.float32, device=device)
                logit = model(xb).detach().cpu().numpy()[0]
                acc.append(logit)
            acc = np.mean(np.stack(acc, axis=0), axis=0) if len(acc) else np.zeros(4, dtype=np.float32)
            out.append(acc)
    return np.stack(out, axis=0)


def time_shift_tta_logits(model, X, sfreq, shifts_s, device):
    model.eval()
    T = X.shape[-1]; out = []
    with torch.no_grad():
        for i in range(X.shape[0]):
            x0 = X[i]; acc = []
            for sh in shifts_s:
                shift = int(round(sh * sfreq))
                if shift == 0:
                    x = x0
                elif shift > 0:
                    x = np.pad(x0[:, shift:], ((0,0),(0,shift)), mode='edge')[:, :T]
                else:
                    shift = -shift
                    x = np.pad(x0[:, :-shift], ((0,0),(shift,0)), mode='edge')[:, :T]
                xb = torch.tensor(x[None, ...], dtype=torch.float32, device=device)
                logit = model(xb).detach().cpu().numpy()[0]
                acc.append(logit)
            out.append(np.mean(np.stack(acc, axis=0), axis=0))
    return np.stack(out, axis=0)

# =========================
# Utilidades splits estratificados por sujeto
# =========================
def build_subject_label_map(subject_ids):
    y_dom_list = []
    for sid in subject_ids:
        Xs, ys, _ = load_subject_epochs(sid, RESAMPLE_HZ, DO_NOTCH, DO_BANDPASS, DO_CAR, BP_LO, BP_HI)
        if len(ys) == 0:
            y_dom_list.append(-1)
            continue
        binc = np.bincount(ys, minlength=4)
        y_dom = int(np.argmax(binc))
        y_dom_list.append(y_dom)
    return np.array(y_dom_list, dtype=int)

# =========================
# TRAIN/EVAL por FOLD (con train/val/test)
# =========================
def train_one_fold(fold:int, device):
    def load_fold_subjects_local(folds_json: Path, fold: int):
        return load_fold_subjects(folds_json, fold)

    train_sub, test_sub = load_fold_subjects_local(FOLDS_JSON, fold)
    train_sub = [s for s in train_sub if subject_id_to_int(s) not in EXCLUDE_SUBJECTS]
    test_sub  = [s for s in test_sub  if subject_id_to_int(s) not in EXCLUDE_SUBJECTS]

    # Split de validación por sujetos (determinista y opcionalmente estratificado)
    rng = np.random.RandomState(RANDOM_STATE + fold)
    tr_subjects = sorted(train_sub)

    if VAL_STRAT_SUBJECT and len(tr_subjects) > 1:
        y_dom = build_subject_label_map(tr_subjects)
        if np.any(y_dom < 0):
            mask = y_dom >= 0
            moda = int(np.bincount(y_dom[mask]).argmax()) if mask.sum() > 0 else 0
            y_dom[~mask] = moda
        n_val_subj = max(1, int(round(len(tr_subjects) * VAL_SUBJECT_FRAC)))
        sss = StratifiedShuffleSplit(n_splits=1, test_size=n_val_subj, random_state=RANDOM_STATE + fold)
        # sss.split(X, y) — X no importa, pasamos índice
        idx = np.arange(len(tr_subjects))
        _, val_idx = next(sss.split(idx, y_dom))
        val_subjects = sorted([tr_subjects[i] for i in val_idx])
        train_subjects = [s for s in tr_subjects if s not in val_subjects]
    else:
        tr_subjects_shuf = tr_subjects.copy()
        rng.shuffle(tr_subjects_shuf)
        n_val_subj = max(1, int(round(len(tr_subjects_shuf) * VAL_SUBJECT_FRAC)))
        val_subjects = sorted(tr_subjects_shuf[:n_val_subj])
        train_subjects = sorted(tr_subjects_shuf[n_val_subj:])

    # Carga TRAIN/VAL/TEST
    X_tr_list, y_tr_list, sub_tr_list = [], [], []
    X_val_list, y_val_list, sub_val_list = [], [], []
    X_te_list, y_te_list, sub_te_list = [], [], []
    sfreq = None

    for sid in tqdm(train_subjects, desc=f"Cargando train fold{fold}"):
        Xs, ys, sf = load_subject_epochs(sid, RESAMPLE_HZ, DO_NOTCH, DO_BANDPASS, DO_CAR, BP_LO, BP_HI)
        if len(ys) == 0: continue
        X_tr_list.append(Xs); y_tr_list.append(ys)
        sub_tr_list.append(np.full_like(ys, fill_value=subject_id_to_int(sid)))
        sfreq = sf if sfreq is None else sfreq

    for sid in tqdm(val_subjects, desc=f"Cargando val fold{fold}"):
        Xs, ys, sf = load_subject_epochs(sid, RESAMPLE_HZ, DO_NOTCH, DO_BANDPASS, DO_CAR, BP_LO, BP_HI)
        if len(ys) == 0: continue
        X_val_list.append(Xs); y_val_list.append(ys)
        sub_val_list.append(np.full_like(ys, fill_value=subject_id_to_int(sid)))
        sfreq = sf if sfreq is None else sfreq

    for sid in tqdm(test_sub, desc=f"Cargando test fold{fold}"):
        Xs, ys, sf = load_subject_epochs(sid, RESAMPLE_HZ, DO_NOTCH, DO_BANDPASS, DO_CAR, BP_LO, BP_HI)
        if len(ys) == 0: continue
        X_te_list.append(Xs); y_te_list.append(ys)
        sub_te_list.append(np.full_like(ys, fill_value=subject_id_to_int(sid)))
        sfreq = sf if sfreq is None else sfreq

    # Concatenar
    X_tr = np.concatenate(X_tr_list, axis=0); y_tr = np.concatenate(y_tr_list, axis=0)
    sub_tr = np.concatenate(sub_tr_list, axis=0)
    X_val = np.concatenate(X_val_list, axis=0); y_val = np.concatenate(y_val_list, axis=0)
    sub_val = np.concatenate(sub_val_list, axis=0)
    X_te = np.concatenate(X_te_list, axis=0); y_te = np.concatenate(y_te_list, axis=0)
    sub_te = np.concatenate(sub_te_list, axis=0)

    print(f"[Fold {fold}/5] Entrenando modelo global... (n_train={len(y_tr)} | n_val={len(y_val)} | n_test={len(y_te)})")

    # Normalización por canal (fit en TRAIN y aplicar a VAL/TEST)
    if ZSCORE_PER_EPOCH:
        X_tr_std, X_val_std, X_te_std = X_tr, X_val, X_te
    else:
        X_tr_std, X_val_std = standardize_per_channel(X_tr, X_val)
        _,        X_te_std  = standardize_per_channel(X_tr, X_te)  # reutiliza stats de train

    # Datasets
    tr_ds  = TensorDataset(torch.tensor(X_tr_std),  torch.tensor(y_tr).long(),  torch.tensor(sub_tr).long())
    val_ds = TensorDataset(torch.tensor(X_val_std), torch.tensor(y_val).long(), torch.tensor(sub_val).long())
    te_ds  = TensorDataset(torch.tensor(X_te_std),  torch.tensor(y_te).long(),  torch.tensor(sub_te).long())

    # (2) Weighted sampler templado: w = (1/ws)^a * (1/wk)^b
    def make_weighted_sampler(dataset: TensorDataset):
        _Xb, yb, sb = dataset.tensors
        yb_np = yb.numpy()
        sb_np = sb.numpy()
        uniq_s, cnt_s = np.unique(sb_np, return_counts=True)
        map_s = {s:c for s,c in zip(uniq_s, cnt_s)}
        key = sb_np.astype(np.int64) * 10 + yb_np.astype(np.int64)
        uniq_k, cnt_k = np.unique(key, return_counts=True)
        map_k = {k:c for k,c in zip(uniq_k, cnt_k)}
        a, b = 0.8, 1.0
        w = []
        for s, y in zip(sb_np, yb_np):
            k = int(s)*10 + int(y)
            ws = float(map_s[int(s)])
            wk = float(map_k[k])
            w.append((ws ** (-a)) * (wk ** (-b)))
        w = np.array(w, dtype=np.float64)
        w = w / (w.mean() + 1e-12)
        sampler = WeightedRandomSampler(weights=torch.tensor(w, dtype=torch.double),
                                        num_samples=len(yb_np), replacement=True)
        return sampler

    if USE_WEIGHTED_SAMPLER:
        tr_sampler = make_weighted_sampler(tr_ds)
        tr_ld  = DataLoader(tr_ds, batch_size=BATCH_SIZE, sampler=tr_sampler, drop_last=False, worker_init_fn=seed_worker)
    else:
        tr_ld  = DataLoader(tr_ds, batch_size=BATCH_SIZE, shuffle=True,  drop_last=False, worker_init_fn=seed_worker)

    val_ld = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, worker_init_fn=seed_worker)
    te_ld  = DataLoader(te_ds,  batch_size=BATCH_SIZE, shuffle=False, drop_last=False, worker_init_fn=seed_worker)

    # Modelo
    model = EEGCNNTransformer(n_ch=8, n_cls=4, d_model=D_MODEL, n_heads=N_HEADS,
                              n_layers=N_LAYERS, p_drop=P_DROP, p_drop_encoder=P_DROP_ENCODER).to(device)

    # Optimizador + Focal Loss (3) alpha tweaking
    opt = torch.optim.AdamW(model.parameters(), lr=BASE_LR, weight_decay=1e-2)

    class_counts = np.bincount(y_tr, minlength=4).astype(np.float32)
    inv = class_counts.sum() / (4.0 * np.maximum(class_counts, 1.0))
    alpha = torch.tensor(inv, dtype=torch.float32, device=device)
    alpha_mean = alpha.mean().item()
    alpha[2] = 1.25 * alpha_mean  # both_fists (ya estaba)
    alpha[3] = 1.05 * alpha_mean # (3) pequeño empujón a both_feet
    crit = FocalLoss(alpha=alpha, gamma=1.5, reduction='mean')

    # (7) Warmup+Cosine (min_factor=0.1)
    from torch.optim.lr_scheduler import LambdaLR
    total_epochs = EPOCHS
    warmup_epochs = max(1, int(WARMUP_EPOCHS))
    min_factor = 0.1

    def lr_lambda(current_epoch):
        if current_epoch < warmup_epochs:
            return (current_epoch + 1) / warmup_epochs
        progress = (current_epoch - warmup_epochs) / max(1, (total_epochs - warmup_epochs))
        progress = min(1.0, max(0.0, progress))
        return min_factor + 0.5 * (1.0 - min_factor) * (1.0 + np.cos(np.pi * progress))

    scheduler = LambdaLR(opt, lr_lambda=lr_lambda)

    # EMA
    if USE_EMA:
        ema = ModelEMA(model, decay=EMA_DECAY, device=device)
    else:
        ema = None

    # Entrenamiento con early stopping por F1 macro (EMA en val/test si aplica)
    best_f1, best_state, wait = 0.0, None, 0
    hist = {"ep": [], "tr_loss": [], "tr_acc": [], "val_acc": [], "val_f1m": [], "lr": []}

    def evaluate_on(loader, use_ema=True):
        mdl = ema.ema if (ema is not None and use_ema) else model
        mdl.eval()
        preds, gts = [], []
        with torch.no_grad():
            for xb, yb, _sb in loader:
                xb = xb.to(device)
                p = mdl(xb).argmax(dim=1).cpu().numpy()
                preds.append(p); gts.append(yb.numpy())
        preds = np.concatenate(preds); gts = np.concatenate(gts)
        acc = accuracy_score(gts, preds)
        f1m = f1_score(gts, preds, average='macro')
        return acc, f1m

    for ep in range(1, EPOCHS+1):
        # ---- Train ----
        model.train()
        tr_loss, n_seen, tr_correct = 0.0, 0, 0
        for xb, yb, _sb in tr_ld:
            xb = xb.to(device); yb = yb.to(device)
            xb = augment_batch(xb)
            opt.zero_grad()
            logits = model(xb)
            loss = crit(logits, yb)
            loss.backward()
            # (5) Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            opt.step()
            if ema is not None:
                ema.update(model)

            tr_loss += loss.item() * len(yb)
            n_seen += len(yb)
            tr_correct += (logits.argmax(1) == yb).sum().item()
        tr_loss /= max(1, n_seen)
        tr_acc = tr_correct / max(1, n_seen)

        # ---- Val (EMA si aplica) ----
        acc, f1m = evaluate_on(val_ld, use_ema=True)

        hist["ep"].append(ep)
        hist["tr_loss"].append(tr_loss)
        hist["tr_acc"].append(tr_acc)
        hist["val_acc"].append(acc)
        hist["val_f1m"].append(f1m)
        hist["lr"].append(scheduler.get_last_lr()[0])

        print(f"  Época {ep:3d} | train_loss={tr_loss:.4f} | train_acc={tr_acc:.4f} | val_acc={acc:.4f} | val_f1m={f1m:.4f} | LR={scheduler.get_last_lr()[0]:.6f}")

        improved = f1m > best_f1 + 1e-4
        if improved:
            best_f1 = f1m
            ref_model = ema.ema if ema is not None else model
            best_state = {k: v.detach().cpu() for k, v in ref_model.state_dict().items()}
            wait = 0
        else:
            wait += 1

        scheduler.step()
        if wait >= PATIENCE:
            print(f"  Early stopping en época {ep} (mejor val_f1m={best_f1:.4f})")
            break

    # Cargamos mejor estado guardado (EMA si estaba activo)
    if best_state is not None:
        (ema.ema if (ema is not None) else model).load_state_dict(best_state)

    # ---- Guardar curva ----
    fig = plt.figure(figsize=(8,4.5))
    ax1 = plt.gca()
    ax1.plot(hist["ep"], hist["tr_loss"], label="train_loss")
    ax1.plot(hist["ep"], hist["tr_acc"], label="tr_acc")
    ax1.plot(hist["ep"], hist["val_acc"], label="val_acc")
    ax1.set_xlabel("Época"); ax1.set_title(f"Fold {fold} — Curva de entrenamiento")
    ax1.legend(); ax1.grid(True, alpha=0.3)
    out_png = f"training_curve_fold{fold}.png"
    plt.tight_layout()
    plt.savefig(out_png, dpi=140)
    plt.close(fig)
    print(f"↳ Curva de entrenamiento guardada: {out_png}")

    # ---- Evaluación final en TEST ----
    eval_model = ema.ema if (ema is not None) else model
    eval_model.eval()

    sfreq_used = RESAMPLE_HZ
    if sfreq_used is None:
        sfreq_used = int(round(X_te_std.shape[-1] / (TMAX - TMIN)))

    if (not SW_ENABLE) or SW_MODE == 'none':
        preds, gts = [], []
        with torch.no_grad():
            for xb, yb, _sb in te_ld:
                xb = xb.to(device)
                p = eval_model(xb).argmax(dim=1).cpu().numpy()
                preds.append(p); gts.append(yb.numpy())
        preds = np.concatenate(preds); gts = np.concatenate(gts)
    elif SW_MODE in ('subwin', 'tta'):
        logits_tta = None
        logits_sw  = None
        if SW_MODE == 'subwin':
            logits_sw = subwindow_logits(eval_model, X_te_std, sfreq_used, SW_LEN, SW_STRIDE, device)
        elif SW_MODE == 'tta':
            logits_tta = time_shift_tta_logits(eval_model, X_te_std, sfreq_used, TTA_SHIFTS_S, device)

        if COMBINE_TTA_AND_SUBWIN:
            if logits_tta is None:
                logits_tta = time_shift_tta_logits(eval_model, X_te_std, sfreq_used, TTA_SHIFTS_S, device)
            if logits_sw is None:
                logits_sw  = subwindow_logits(eval_model, X_te_std, sfreq_used, SW_LEN, SW_STRIDE, device)
            logits = 0.5 * logits_tta + 0.5 * logits_sw
        else:
            logits = logits_tta if logits_tta is not None else logits_sw

        preds = logits.argmax(axis=1); gts = y_te
    else:
        raise ValueError(f"SW_MODE desconocido: {SW_MODE}")

    acc = accuracy_score(gts, preds)
    f1m = f1_score(gts, preds, average='macro')
    print(f"[Fold {fold}/5] Global acc={acc:.4f}\n")
    print(classification_report(gts, preds, target_names=[c.replace('_',' ') for c in CLASS_NAMES], digits=4))
    print("Confusion matrix (rows=true, cols=pred):")
    print(confusion_matrix(gts, preds, labels=[0,1,2,3]))

    return acc, f1m

# =========================
# LOOP 5 FOLDS + RESUMEN
# =========================
if __name__ == "__main__":
    acc_folds, f1_folds = [], []
    for fold in range(1, 6):
        acc, f1m = train_one_fold(fold, DEVICE)
        acc_folds.append(f"{acc:.4f}")
        f1_folds.append(f"{f1m:.4f}")

    acc_mean = float(np.mean([float(a) for a in acc_folds]))
    f1_mean  = float(np.mean([float(f) for f in f1_folds]))

    print("\n============================================================")
    print("RESULTADOS FINALES")
    print("============================================================")
    print(f"Global folds (ACC): {acc_folds}")
    print(f"Global mean ACC: {acc_mean:.4f}")
    print(f"F1 folds (MACRO): {f1_folds}")
    print(f"F1 mean (MACRO): {f1_mean:.4f}")


🚀 Usando dispositivo: cuda
🧠 INICIANDO EXPERIMENTO CON CNN+Transformer (K-Fold por sujeto como EEGNet)
🔧 Configuración: 4c, 8 canales, 6s | EPOCHS=60, BATCH=64, LR=0.001 | ZSCORE_PER_EPOCH=False


Cargando train fold1: 100%|██████████| 67/67 [00:08<00:00,  8.35it/s]
Cargando val fold1: 100%|██████████| 15/15 [00:01<00:00,  8.34it/s]
Cargando test fold1: 100%|██████████| 21/21 [00:02<00:00,  8.31it/s]


[Fold 1/5] Entrenando modelo global... (n_train=5628 | n_val=1260 | n_test=1764)
  Época   1 | train_loss=0.2326 | train_acc=0.2772 | val_acc=0.2627 | val_f1m=0.2225 | LR=0.000125
  Época   2 | train_loss=0.2126 | train_acc=0.3257 | val_acc=0.3825 | val_f1m=0.3362 | LR=0.000250
  Época   3 | train_loss=0.1839 | train_acc=0.4577 | val_acc=0.4381 | val_f1m=0.4333 | LR=0.000375
  Época   4 | train_loss=0.1797 | train_acc=0.4844 | val_acc=0.4571 | val_f1m=0.4561 | LR=0.000500
  Época   5 | train_loss=0.1724 | train_acc=0.5137 | val_acc=0.4508 | val_f1m=0.4461 | LR=0.000625
  Época   6 | train_loss=0.1673 | train_acc=0.5348 | val_acc=0.4476 | val_f1m=0.4405 | LR=0.000750
  Época   7 | train_loss=0.1650 | train_acc=0.5311 | val_acc=0.4397 | val_f1m=0.4288 | LR=0.000875
  Época   8 | train_loss=0.1653 | train_acc=0.5359 | val_acc=0.4611 | val_f1m=0.4632 | LR=0.001000
  Época   9 | train_loss=0.1664 | train_acc=0.5226 | val_acc=0.4484 | val_f1m=0.4503 | LR=0.001000
  Época  10 | train_loss=0.1

Cargando train fold2: 100%|██████████| 67/67 [00:08<00:00,  8.24it/s]
Cargando val fold2: 100%|██████████| 15/15 [00:01<00:00,  8.36it/s]
Cargando test fold2: 100%|██████████| 21/21 [00:02<00:00,  8.35it/s]


[Fold 2/5] Entrenando modelo global... (n_train=5628 | n_val=1260 | n_test=1764)
  Época   1 | train_loss=0.2281 | train_acc=0.2681 | val_acc=0.2865 | val_f1m=0.2064 | LR=0.000125
  Época   2 | train_loss=0.2145 | train_acc=0.3339 | val_acc=0.4103 | val_f1m=0.3789 | LR=0.000250
  Época   3 | train_loss=0.1877 | train_acc=0.4597 | val_acc=0.4476 | val_f1m=0.4424 | LR=0.000375
  Época   4 | train_loss=0.1827 | train_acc=0.4696 | val_acc=0.4659 | val_f1m=0.4621 | LR=0.000500
  Época   5 | train_loss=0.1765 | train_acc=0.4888 | val_acc=0.4921 | val_f1m=0.4929 | LR=0.000625
  Época   6 | train_loss=0.1761 | train_acc=0.4883 | val_acc=0.4881 | val_f1m=0.4785 | LR=0.000750
  Época   7 | train_loss=0.1764 | train_acc=0.4986 | val_acc=0.4643 | val_f1m=0.4580 | LR=0.000875
  Época   8 | train_loss=0.1739 | train_acc=0.5084 | val_acc=0.4659 | val_f1m=0.4536 | LR=0.001000
  Época   9 | train_loss=0.1720 | train_acc=0.4986 | val_acc=0.4937 | val_f1m=0.4930 | LR=0.001000
  Época  10 | train_loss=0.1

Cargando train fold3: 100%|██████████| 67/67 [00:08<00:00,  8.37it/s]
Cargando val fold3: 100%|██████████| 15/15 [00:01<00:00,  8.40it/s]
Cargando test fold3: 100%|██████████| 21/21 [00:02<00:00,  8.38it/s]


[Fold 3/5] Entrenando modelo global... (n_train=5628 | n_val=1260 | n_test=1764)
  Época   1 | train_loss=0.2284 | train_acc=0.2509 | val_acc=0.2786 | val_f1m=0.2070 | LR=0.000125
  Época   2 | train_loss=0.2205 | train_acc=0.2996 | val_acc=0.3230 | val_f1m=0.2591 | LR=0.000250
  Época   3 | train_loss=0.1958 | train_acc=0.4232 | val_acc=0.4619 | val_f1m=0.4565 | LR=0.000375
  Época   4 | train_loss=0.1820 | train_acc=0.4869 | val_acc=0.4643 | val_f1m=0.4600 | LR=0.000500
  Época   5 | train_loss=0.1747 | train_acc=0.5080 | val_acc=0.4643 | val_f1m=0.4611 | LR=0.000625
  Época   6 | train_loss=0.1756 | train_acc=0.4973 | val_acc=0.4889 | val_f1m=0.4889 | LR=0.000750
  Época   7 | train_loss=0.1709 | train_acc=0.5162 | val_acc=0.4754 | val_f1m=0.4776 | LR=0.000875
  Época   8 | train_loss=0.1686 | train_acc=0.5240 | val_acc=0.4849 | val_f1m=0.4766 | LR=0.001000
  Época   9 | train_loss=0.1650 | train_acc=0.5339 | val_acc=0.4937 | val_f1m=0.4983 | LR=0.001000
  Época  10 | train_loss=0.1

Cargando train fold4: 100%|██████████| 68/68 [00:08<00:00,  8.28it/s]
Cargando val fold4: 100%|██████████| 15/15 [00:01<00:00,  8.35it/s]
Cargando test fold4: 100%|██████████| 20/20 [00:02<00:00,  8.37it/s]


[Fold 4/5] Entrenando modelo global... (n_train=5712 | n_val=1260 | n_test=1680)
  Época   1 | train_loss=0.2296 | train_acc=0.2553 | val_acc=0.2762 | val_f1m=0.1851 | LR=0.000125
  Época   2 | train_loss=0.2137 | train_acc=0.3381 | val_acc=0.4167 | val_f1m=0.4011 | LR=0.000250
  Época   3 | train_loss=0.1933 | train_acc=0.4405 | val_acc=0.4762 | val_f1m=0.4725 | LR=0.000375
  Época   4 | train_loss=0.1807 | train_acc=0.4944 | val_acc=0.4738 | val_f1m=0.4669 | LR=0.000500
  Época   5 | train_loss=0.1771 | train_acc=0.4904 | val_acc=0.4730 | val_f1m=0.4696 | LR=0.000625
  Época   6 | train_loss=0.1761 | train_acc=0.4981 | val_acc=0.4841 | val_f1m=0.4874 | LR=0.000750
  Época   7 | train_loss=0.1746 | train_acc=0.5077 | val_acc=0.4865 | val_f1m=0.4865 | LR=0.000875
  Época   8 | train_loss=0.1710 | train_acc=0.5065 | val_acc=0.4857 | val_f1m=0.4889 | LR=0.001000
  Época   9 | train_loss=0.1664 | train_acc=0.5340 | val_acc=0.4770 | val_f1m=0.4766 | LR=0.001000
  Época  10 | train_loss=0.1

Cargando train fold5: 100%|██████████| 68/68 [00:08<00:00,  8.30it/s]
Cargando val fold5: 100%|██████████| 15/15 [00:01<00:00,  8.38it/s]
Cargando test fold5: 100%|██████████| 20/20 [00:02<00:00,  8.44it/s]


[Fold 5/5] Entrenando modelo global... (n_train=5712 | n_val=1260 | n_test=1680)
  Época   1 | train_loss=0.2275 | train_acc=0.2700 | val_acc=0.2579 | val_f1m=0.1263 | LR=0.000125
  Época   2 | train_loss=0.2264 | train_acc=0.2775 | val_acc=0.3063 | val_f1m=0.2825 | LR=0.000250
  Época   3 | train_loss=0.2047 | train_acc=0.3699 | val_acc=0.4087 | val_f1m=0.4094 | LR=0.000375
  Época   4 | train_loss=0.1864 | train_acc=0.4604 | val_acc=0.4167 | val_f1m=0.3820 | LR=0.000500
  Época   5 | train_loss=0.1767 | train_acc=0.4911 | val_acc=0.4206 | val_f1m=0.4213 | LR=0.000625
  Época   6 | train_loss=0.1744 | train_acc=0.4995 | val_acc=0.4381 | val_f1m=0.4302 | LR=0.000750
  Época   7 | train_loss=0.1728 | train_acc=0.4977 | val_acc=0.4476 | val_f1m=0.4490 | LR=0.000875
  Época   8 | train_loss=0.1681 | train_acc=0.5186 | val_acc=0.4587 | val_f1m=0.4609 | LR=0.001000
  Época   9 | train_loss=0.1673 | train_acc=0.5215 | val_acc=0.4643 | val_f1m=0.4620 | LR=0.001000
  Época  10 | train_loss=0.1

## FINETUNING

In [9]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "7"
os.environ['PYTHONHASHSEED'] = '42'
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'

import re, json, random
from pathlib import Path
from glob import glob
import numpy as np
import mne
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset, WeightedRandomSampler
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, classification_report
from sklearn.model_selection import StratifiedShuffleSplit, StratifiedKFold

# =========================
# REPRODUCIBILIDAD
# =========================
RANDOM_STATE = 42

def seed_everything(seed: int = 42):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    try:
        torch.backends.cuda.matmul.allow_tf32 = False
        torch.backends.cudnn.allow_tf32 = False
    except Exception:
        pass
    try:
        torch.use_deterministic_algorithms(True)
    except Exception:
        pass

def seed_worker(worker_id: int):
    worker_seed = RANDOM_STATE + worker_id
    np.random.seed(worker_seed)
    random.seed(worker_seed)

seed_everything(RANDOM_STATE)

# =========================
# CONFIG
# =========================
PROJ = Path('..').resolve().parent
DATA_RAW = PROJ / 'data' / 'raw'
FOLDS_JSON = PROJ / 'models' / 'folds' / 'Kfold5.json'

EPOCHS = 60
BATCH_SIZE = 64
BASE_LR = 1e-3
WARMUP_EPOCHS = 8
PATIENCE = 12

# Split de validación por fold (por sujetos)
VAL_SUBJECT_FRAC = 0.18
VAL_STRAT_SUBJECT = True

# Prepro global
RESAMPLE_HZ = None
DO_NOTCH = True
DO_BANDPASS = False
BP_LO, BP_HI = 4.0, 38.0
DO_CAR = False
ZSCORE_PER_EPOCH = False

# Modelo
D_MODEL = 128
N_HEADS = 4
N_LAYERS = 2
P_DROP = 0.2
P_DROP_ENCODER = 0.3

# Ventana temporal
TMIN, TMAX = -1.0, 5.0

# TTA / SUBWINDOW en TEST
SW_MODE = 'tta'   # 'none'|'subwin'|'tta'
SW_ENABLE = True
TTA_SHIFTS_S = [-0.075, -0.05, -0.025, 0.0, 0.025, 0.05, 0.075]
SW_LEN, SW_STRIDE = 4.5, 1.5
COMBINE_TTA_AND_SUBWIN = False

# Sampler balanceado
USE_WEIGHTED_SAMPLER = True

# EMA (solo GLOBAL). En FT se desactiva.
USE_EMA = True
EMA_DECAY = 0.9995

# Fine-tuning (por sujeto)
FT_N_FOLDS = 4
FT_FREEZE_EPOCHS = 8           # etapa 1 (congelar backbone)
FT_UNFREEZE_EPOCHS = 8         # etapa 2 (descongelar)
FT_PATIENCE = 6                # early stopping en FT
FT_BATCH = 64
FT_LR_HEAD = 1e-3              # cabeza (5x aprox del backbone)
FT_LR_BACKBONE = 2e-4          # backbone
FT_WD = 1e-3                   # weight decay menor en FT
FT_AUG = dict(p_jitter=0.25, p_noise=0.25, p_chdrop=0.10, max_jitter_frac=0.02, noise_std=0.02, max_chdrop=1)
FT_ALPHA_BOOST_FISTS = 1.25
FT_ALPHA_BOOST_FEET  = 1.05

EXCLUDE_SUBJECTS = {38, 88, 89, 92, 100, 104}
EXPECTED_8 = ['C3','C4','Cz','CP3','CP4','FC3','FC4','FCz']
CLASS_NAMES = ['left', 'right', 'both_fists', 'both_feet']

IMAGERY_RUNS_LR = {4, 8, 12}
IMAGERY_RUNS_BF = {6, 10, 14}

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🚀 Usando dispositivo: {DEVICE}")
print("🧠 INICIANDO EXPERIMENTO CON CNN+Transformer (K-Fold por sujeto como EEGNet)")
print(f"🔧 Configuración: 4c, 8 canales, 6s | EPOCHS={EPOCHS}, BATCH={BATCH_SIZE}, LR={BASE_LR} | ZSCORE_PER_EPOCH={ZSCORE_PER_EPOCH}")

# =========================
# UTILIDADES I/O
# =========================
def normalize_ch_name(name: str) -> str:
    s = re.sub(r'[^A-Za-z0-9]', '', name)
    return s.upper()

NORMALIZED_TARGETS = [normalize_ch_name(c) for c in EXPECTED_8]

def pick_8_channels(raw: mne.io.BaseRaw) -> mne.io.BaseRaw:
    chs = raw.info['ch_names']
    norm_map = {normalize_ch_name(ch): ch for ch in chs}
    picked = []
    for target_norm, target_orig in zip(NORMALIZED_TARGETS, EXPECTED_8):
        if target_norm in norm_map:
            picked.append(norm_map[target_norm])
        else:
            raise RuntimeError(f"Canal requerido '{target_orig}' no encontrado. Disponibles: {chs}")
    return raw.pick(picks=picked)

def list_subject_imagery_edfs(subject_id: str) -> list:
    subj_dir = DATA_RAW / subject_id
    edfs = []
    for r in [4, 6, 8, 10, 12, 14]:
        edfs.extend(glob(str(subj_dir / f"{subject_id}R{r:02d}.edf")))
    return sorted(edfs)

def subject_id_to_int(s: str) -> int:
    m = re.match(r'[Ss](\d+)', s)
    return int(m.group(1)) if m else -1

def load_subject_epochs(subject_id: str, resample_hz: int, do_notch: bool, do_bandpass: bool,
                        do_car: bool, bp_lo: float, bp_hi: float):
    edfs = list_subject_imagery_edfs(subject_id)
    if len(edfs) == 0:
        return np.empty((0,8,1), dtype=np.float32), np.empty((0,), dtype=int), None

    X_list, y_list, sfreq_list = [], [], []
    for edf_path in edfs:
        m = re.search(r"R(\d{2})", Path(edf_path).name)
        run = int(m.group(1)) if m else -1

        raw = mne.io.read_raw_edf(edf_path, preload=True, verbose='ERROR')
        raw = pick_8_channels(raw)

        if do_notch:
            raw.notch_filter(freqs=[60.0], picks='all', verbose='ERROR')
        if do_bandpass:
            raw.filter(l_freq=bp_lo, h_freq=bp_hi, picks='all', verbose='ERROR')
        if do_car:
            raw.set_eeg_reference('average', projection=False, verbose='ERROR')
        if resample_hz is not None and resample_hz > 0:
            raw.resample(resample_hz)

        sfreq = raw.info['sfreq']
        events, event_id = mne.events_from_annotations(raw, verbose='ERROR')
        keep = {k: v for k, v in event_id.items() if k in {'T1', 'T2'}}
        if len(keep) == 0:
            continue

        epochs = mne.Epochs(raw, events=events, event_id=keep, tmin=TMIN, tmax=TMAX,
                            baseline=None, preload=True, verbose='ERROR')
        X = epochs.get_data()

        if ZSCORE_PER_EPOCH:
            X = X.astype(np.float32)
            eps = 1e-6
            mu = X.mean(axis=2, keepdims=True)
            sd = X.std(axis=2, keepdims=True) + eps
            X = (X - mu) / sd

        ev_codes = epochs.events[:, 2]
        inv = {v: k for k, v in keep.items()}
        y_run = []
        for code in ev_codes:
            lab = inv[code]
            if run in IMAGERY_RUNS_LR:
                y_run.append(0 if lab == 'T1' else 1)
            elif run in IMAGERY_RUNS_BF:
                y_run.append(2 if lab == 'T1' else 3)
            else:
                y_run.append(-1)
        y_run = np.array(y_run, dtype=int)
        mask = y_run >= 0
        if mask.sum() == 0:
            continue

        X_list.append(X[mask])
        y_list.append(y_run[mask])
        sfreq_list.append(sfreq)

    if len(X_list) == 0:
        return np.empty((0,8,1), dtype=np.float32), np.empty((0,), dtype=int), None

    X_all = np.concatenate(X_list, axis=0).astype(np.float32)
    y_all = np.concatenate(y_list, axis=0).astype(int)

    if len(set([int(round(s)) for s in sfreq_list])) != 1:
        raise RuntimeError(f"Sampling rates inconsistentes: {sfreq_list}")

    return X_all, y_all, sfreq_list[0]

def load_fold_subjects(folds_json: Path, fold: int):
    with open(folds_json, 'r') as f:
        data = json.load(f)
    for item in data.get('folds', []):
        if int(item.get('fold', -1)) == int(fold):
            return list(item.get('train', [])), list(item.get('test', []))
    raise ValueError(f"Fold {fold} not found in {folds_json}")

def standardize_per_channel(train_X, other_X):
    C = train_X.shape[1]
    train_X = train_X.astype(np.float32)
    other_X = other_X.astype(np.float32)
    for c in range(C):
        mu = train_X[:, c, :].mean()
        sd = train_X[:, c, :].std()
        sd = sd if sd > 1e-6 else 1.0
        train_X[:, c, :] = (train_X[:, c, :] - mu) / sd
        other_X[:, c, :] = (other_X[:, c, :] - mu) / sd
    return train_X, other_X

# =========================
# MODELO (GroupNorm en conv)
# =========================
def make_gn(num_channels, num_groups=8):
    g = min(num_groups, num_channels)
    while num_channels % g != 0 and g > 1:
        g -= 1
    return nn.GroupNorm(g, num_channels)

class DepthwiseSeparableConv(nn.Module):
    def __init__(self, in_ch, out_ch, k, s=1, p=0, p_drop=0.2):
        super().__init__()
        self.dw = nn.Conv1d(in_ch, in_ch, kernel_size=k, stride=s, padding=p, groups=in_ch, bias=False)
        self.pw = nn.Conv1d(in_ch, out_ch, kernel_size=1, bias=False)
        self.norm = make_gn(out_ch)
        self.act = nn.ELU()
        self.dropout = nn.Dropout(p=p_drop)

    def forward(self, x):
        x = self.dw(x); x = self.pw(x); x = self.norm(x)
        x = self.act(x); x = self.dropout(x)
        return x

class EEGCNNTransformer(nn.Module):
    def __init__(self, n_ch=8, n_cls=4, d_model=128, n_heads=4, n_layers=2,
                 p_drop=0.2, p_drop_encoder=0.3):
        super().__init__()
        self.conv_t = nn.Sequential(
            nn.Conv1d(n_ch, 32, kernel_size=129, stride=2, padding=64, bias=False),
            make_gn(32),
            nn.ELU(),
            nn.Dropout(p=p_drop),
            DepthwiseSeparableConv(32, 64, k=31, s=2, p=15, p_drop=p_drop),
            DepthwiseSeparableConv(64, 128, k=15, s=2, p=7,  p_drop=p_drop),
        )
        self.proj = nn.Conv1d(128, d_model, kernel_size=1, bias=False)
        self.dropout = nn.Dropout(p=p_drop_encoder)
        self.pos_encoding = None
        enc = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=n_heads, dim_feedforward=2*d_model,
            batch_first=True, activation='gelu', dropout=0.1, norm_first=False
        )
        self.encoder = nn.TransformerEncoder(enc, num_layers=n_layers)
        self.cls = nn.Parameter(torch.zeros(1, 1, d_model))
        nn.init.normal_(self.cls, std=0.02)
        self.head = nn.Sequential(nn.LayerNorm(d_model), nn.Linear(d_model, n_cls))

    def _positional_encoding(self, L, d):
        pos = torch.arange(0, L, dtype=torch.float32).unsqueeze(1)
        i   = torch.arange(0, d, dtype=torch.float32).unsqueeze(0)
        angle = pos / torch.pow(10000, (2 * (i//2)) / d)
        pe = torch.zeros(L, d, dtype=torch.float32)
        pe[:, 0::2] = torch.sin(angle[:, 0::2])
        pe[:, 1::2] = torch.cos(angle[:, 1::2])
        return pe

    def forward(self, x):
        z = self.conv_t(x)           # (B, 128, T')
        z = self.proj(z)             # (B, d_model, T')
        z = self.dropout(z)
        z = z.transpose(1, 2)        # (B, T', d_model)
        B, L, D = z.shape
        if (self.pos_encoding is None) or (self.pos_encoding.shape[0] != L) or (self.pos_encoding.shape[1] != D):
            self.pos_encoding = self._positional_encoding(L, D).to(z.device)
        z = z + self.pos_encoding[None, :, :]
        cls_tok = self.cls.expand(B, -1, -1)
        z = torch.cat([cls_tok, z], dim=1)
        z = self.encoder(z)
        cls = z[:, 0, :]
        return self.head(cls)

# =========================
# FOCAL LOSS
# =========================
class FocalLoss(nn.Module):
    def __init__(self, alpha: torch.Tensor, gamma: float = 1.5, reduction: str = 'mean'):
        super().__init__()
        self.alpha = alpha / alpha.sum()
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, logits, target):
        logp = nn.functional.log_softmax(logits, dim=-1)      # (B,C)
        p = logp.exp()
        idx = torch.arange(target.shape[0], device=logits.device)
        pt = p[idx, target]
        logpt = logp[idx, target]
        at = self.alpha[target]
        loss = - at * ((1 - pt) ** self.gamma) * logpt
        if self.reduction == 'mean': return loss.mean()
        if self.reduction == 'sum':  return loss.sum()
        return loss

# =========================
# AUGMENTS
# =========================
def augment_batch(
    xb,
    p_jitter=0.35, p_noise=0.35, p_chdrop=0.15,
    max_jitter_frac=0.03, noise_std=0.03, max_chdrop=1
):
    B, C, T = xb.shape
    if np.random.rand() < p_jitter:
        max_shift = int(max(1, T*max_jitter_frac))
        shifts = torch.randint(low=-max_shift, high=max_shift+1, size=(B,), device=xb.device)
        for i in range(B):
            xb[i] = torch.roll(xb[i], shifts=int(shifts[i].item()), dims=-1)
    if np.random.rand() < p_noise:
        xb = xb + noise_std*torch.randn_like(xb)
    if np.random.rand() < p_chdrop and max_chdrop > 0:
        k = min(max_chdrop, C)
        for i in range(B):
            idx = torch.randperm(C, device=xb.device)[:k]
            xb[i, idx, :] = 0.0
    return xb

# Versión suave para FT
def augment_batch_ft(xb):
    return augment_batch(xb, **FT_AUG)

# =========================
# EMA de pesos (solo global)
# =========================
class ModelEMA:
    def __init__(self, model: nn.Module, decay: float = 0.9995, device=None):
        self.ema = self._clone(model).to(device if device is not None else next(model.parameters()).device)
        self.decay = decay
        self._updates = 0
        self.update(model, force=True)

    def _clone(self, model):
        ema = type(model)()
        ema.load_state_dict(model.state_dict())
        for p in ema.parameters():
            p.requires_grad_(False)
        return ema

    @torch.no_grad()
    def update(self, model: nn.Module, force: bool = False):
        d = self.decay
        if self._updates < 1000:
            d = (self._updates / 1000.0) * self.decay
        msd = model.state_dict()
        esd = self.ema.state_dict()
        for k in esd.keys():
            if esd[k].dtype.is_floating_point:
                esd[k].mul_(d).add_(msd[k].detach(), alpha=1.0 - d)
            else:
                esd[k] = msd[k]
        self._updates += 1

# =========================
# INFERENCIA TTA / SUBWINDOW
# =========================
def subwindow_logits(model, X, sfreq, sw_len, sw_stride, device):
    model.eval()
    wl = int(round(sw_len * sfreq))
    st = int(round(sw_stride * sfreq))
    wl = max(1, min(wl, X.shape[-1])); st = max(1, st)
    out = []
    with torch.no_grad():
        for i in range(X.shape[0]):
            x = X[i]; acc = []
            for s in range(0, max(1, X.shape[-1]-wl+1), st):
                seg = x[:, s:s+wl]
                if seg.shape[-1] < wl:
                    pad = wl - seg.shape[-1]
                    seg = np.pad(seg, ((0,0),(0,pad)), mode='edge')
                xb = torch.tensor(seg[None, ...], dtype=torch.float32, device=device)
                logit = model(xb).detach().cpu().numpy()[0]
                acc.append(logit)
            acc = np.mean(np.stack(acc, axis=0), axis=0) if len(acc) else np.zeros(4, dtype=np.float32)
            out.append(acc)
    return np.stack(out, axis=0)

def time_shift_tta_logits(model, X, sfreq, shifts_s, device):
    model.eval()
    T = X.shape[-1]; out = []
    with torch.no_grad():
        for i in range(X.shape[0]):
            x0 = X[i]; acc = []
            for sh in shifts_s:
                shift = int(round(sh * sfreq))
                if shift == 0:
                    x = x0
                elif shift > 0:
                    x = np.pad(x0[:, shift:], ((0,0),(0,shift)), mode='edge')[:, :T]
                else:
                    shift = -shift
                    x = np.pad(x0[:, :-shift], ((0,0),(shift,0)), mode='edge')[:, :T]
                xb = torch.tensor(x[None, ...], dtype=torch.float32, device=device)
                logit = model(xb).detach().cpu().numpy()[0]
                acc.append(logit)
            out.append(np.mean(np.stack(acc, axis=0), axis=0))
    return np.stack(out, axis=0)

# =========================
# Utilidades splits estratificados por sujeto
# =========================
def build_subject_label_map(subject_ids):
    y_dom_list = []
    for sid in subject_ids:
        Xs, ys, _ = load_subject_epochs(sid, RESAMPLE_HZ, DO_NOTCH, DO_BANDPASS, DO_CAR, BP_LO, BP_HI)
        if len(ys) == 0:
            y_dom_list.append(-1)
            continue
        binc = np.bincount(ys, minlength=4)
        y_dom = int(np.argmax(binc))
        y_dom_list.append(y_dom)
    return np.array(y_dom_list, dtype=int)

# =========================
# TRAIN/EVAL GLOBAL POR FOLD
# =========================
def train_one_fold(fold:int, device):
    def load_fold_subjects_local(folds_json: Path, fold: int):
        return load_fold_subjects(folds_json, fold)

    train_sub, test_sub = load_fold_subjects_local(FOLDS_JSON, fold)
    train_sub = [s for s in train_sub if subject_id_to_int(s) not in EXCLUDE_SUBJECTS]
    test_sub  = [s for s in test_sub  if subject_id_to_int(s) not in EXCLUDE_SUBJECTS]

    rng = np.random.RandomState(RANDOM_STATE + fold)
    tr_subjects = sorted(train_sub)

    if VAL_STRAT_SUBJECT and len(tr_subjects) > 1:
        y_dom = build_subject_label_map(tr_subjects)
        if np.any(y_dom < 0):
            mask = y_dom >= 0
            moda = int(np.bincount(y_dom[mask]).argmax()) if mask.sum() > 0 else 0
            y_dom[~mask] = moda
        n_val_subj = max(1, int(round(len(tr_subjects) * VAL_SUBJECT_FRAC)))
        sss = StratifiedShuffleSplit(n_splits=1, test_size=n_val_subj, random_state=RANDOM_STATE + fold)
        idx = np.arange(len(tr_subjects))
        _, val_idx = next(sss.split(idx, y_dom))
        val_subjects = sorted([tr_subjects[i] for i in val_idx])
        train_subjects = [s for s in tr_subjects if s not in val_subjects]
    else:
        tr_subjects_shuf = tr_subjects.copy()
        rng.shuffle(tr_subjects_shuf)
        n_val_subj = max(1, int(round(len(tr_subjects_shuf) * VAL_SUBJECT_FRAC)))
        val_subjects = sorted(tr_subjects_shuf[:n_val_subj])
        train_subjects = sorted(tr_subjects_shuf[n_val_subj:])

    # Carga TRAIN/VAL/TEST
    X_tr_list, y_tr_list, sub_tr_list = [], [], []
    X_val_list, y_val_list, sub_val_list = [], [], []
    X_te_list, y_te_list, sub_te_list = [], [], []
    sfreq = None

    for sid in tqdm(train_subjects, desc=f"Cargando train fold{fold}"):
        Xs, ys, sf = load_subject_epochs(sid, RESAMPLE_HZ, DO_NOTCH, DO_BANDPASS, DO_CAR, BP_LO, BP_HI)
        if len(ys) == 0: continue
        X_tr_list.append(Xs); y_tr_list.append(ys)
        sub_tr_list.append(np.full_like(ys, fill_value=subject_id_to_int(sid)))
        sfreq = sf if sfreq is None else sfreq

    for sid in tqdm(val_subjects, desc=f"Cargando val fold{fold}"):
        Xs, ys, sf = load_subject_epochs(sid, RESAMPLE_HZ, DO_NOTCH, DO_BANDPASS, DO_CAR, BP_LO, BP_HI)
        if len(ys) == 0: continue
        X_val_list.append(Xs); y_val_list.append(ys)
        sub_val_list.append(np.full_like(ys, fill_value=subject_id_to_int(sid)))
        sfreq = sf if sfreq is None else sfreq

    for sid in tqdm(test_sub, desc=f"Cargando test fold{fold}"):
        Xs, ys, sf = load_subject_epochs(sid, RESAMPLE_HZ, DO_NOTCH, DO_BANDPASS, DO_CAR, BP_LO, BP_HI)
        if len(ys) == 0: continue
        X_te_list.append(Xs); y_te_list.append(ys)
        sub_te_list.append(np.full_like(ys, fill_value=subject_id_to_int(sid)))
        sfreq = sf if sfreq is None else sfreq

    # Concatenar
    X_tr = np.concatenate(X_tr_list, axis=0); y_tr = np.concatenate(y_tr_list, axis=0)
    sub_tr = np.concatenate(sub_tr_list, axis=0)
    X_val = np.concatenate(X_val_list, axis=0); y_val = np.concatenate(y_val_list, axis=0)
    sub_val = np.concatenate(sub_val_list, axis=0)
    X_te = np.concatenate(X_te_list, axis=0); y_te = np.concatenate(y_te_list, axis=0)
    sub_te = np.concatenate(sub_te_list, axis=0)

    print(f"[Fold {fold}/5] Entrenando modelo global... (n_train={len(y_tr)} | n_val={len(y_val)} | n_test={len(y_te)})")

    # Normalización por canal (fit en TRAIN y aplicar a VAL/TEST)
    if ZSCORE_PER_EPOCH:
        X_tr_std, X_val_std, X_te_std = X_tr, X_val, X_te
    else:
        X_tr_std, X_val_std = standardize_per_channel(X_tr, X_val)
        _,        X_te_std  = standardize_per_channel(X_tr, X_te)

    # Datasets
    tr_ds  = TensorDataset(torch.tensor(X_tr_std),  torch.tensor(y_tr).long(),  torch.tensor(sub_tr).long())
    val_ds = TensorDataset(torch.tensor(X_val_std), torch.tensor(y_val).long(), torch.tensor(sub_val).long())
    te_ds  = TensorDataset(torch.tensor(X_te_std),  torch.tensor(y_te).long(),  torch.tensor(sub_te).long())

    # Weighted sampler templado: w = (1/ws)^a * (1/wk)^b
    def make_weighted_sampler(dataset: TensorDataset):
        _Xb, yb, sb = dataset.tensors
        yb_np = yb.numpy()
        sb_np = sb.numpy()
        uniq_s, cnt_s = np.unique(sb_np, return_counts=True)
        map_s = {s:c for s,c in zip(uniq_s, cnt_s)}
        key = sb_np.astype(np.int64) * 10 + yb_np.astype(np.int64)
        uniq_k, cnt_k = np.unique(key, return_counts=True)
        map_k = {k:c for k,c in zip(uniq_k, cnt_k)}
        a, b = 0.8, 1.0
        w = []
        for s, y in zip(sb_np, yb_np):
            k = int(s)*10 + int(y)
            ws = float(map_s[int(s)])
            wk = float(map_k[k])
            w.append((ws ** (-a)) * (wk ** (-b)))
        w = np.array(w, dtype=np.float64)
        w = w / (w.mean() + 1e-12)
        sampler = WeightedRandomSampler(weights=torch.tensor(w, dtype=torch.double),
                                        num_samples=len(yb_np), replacement=True)
        return sampler

    if USE_WEIGHTED_SAMPLER:
        tr_sampler = make_weighted_sampler(tr_ds)
        tr_ld  = DataLoader(tr_ds, batch_size=BATCH_SIZE, sampler=tr_sampler, drop_last=False, worker_init_fn=seed_worker)
    else:
        tr_ld  = DataLoader(tr_ds, batch_size=BATCH_SIZE, shuffle=True,  drop_last=False, worker_init_fn=seed_worker)

    val_ld = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, worker_init_fn=seed_worker)
    te_ld  = DataLoader(te_ds,  batch_size=BATCH_SIZE, shuffle=False, drop_last=False, worker_init_fn=seed_worker)

    # Modelo
    model = EEGCNNTransformer(n_ch=8, n_cls=4, d_model=D_MODEL, n_heads=N_HEADS,
                              n_layers=N_LAYERS, p_drop=P_DROP, p_drop_encoder=P_DROP_ENCODER).to(device)

    # Optimizador + Focal Loss
    opt = torch.optim.AdamW(model.parameters(), lr=BASE_LR, weight_decay=1e-2)
    class_counts = np.bincount(y_tr, minlength=4).astype(np.float32)
    inv = class_counts.sum() / (4.0 * np.maximum(class_counts, 1.0))
    alpha = torch.tensor(inv, dtype=torch.float32, device=device)
    alpha_mean = alpha.mean().item()
    alpha[2] = 1.25 * alpha_mean  # both_fists
    alpha[3] = 1.05 * alpha_mean  # both_feet
    crit = FocalLoss(alpha=alpha, gamma=1.5, reduction='mean')

    # LR scheduler Warmup+Cosine
    from torch.optim.lr_scheduler import LambdaLR
    total_epochs = EPOCHS
    warmup_epochs = max(1, int(WARMUP_EPOCHS))
    min_factor = 0.1
    def lr_lambda(current_epoch):
        if current_epoch < warmup_epochs:
            return (current_epoch + 1) / warmup_epochs
        progress = (current_epoch - warmup_epochs) / max(1, (total_epochs - warmup_epochs))
        progress = min(1.0, max(0.0, progress))
        return min_factor + 0.5 * (1.0 - min_factor) * (1.0 + np.cos(np.pi * progress))
    scheduler = LambdaLR(opt, lr_lambda=lr_lambda)

    # EMA (solo global)
    if USE_EMA:
        ema = ModelEMA(model, decay=EMA_DECAY, device=device)
    else:
        ema = None

    # Entrenamiento global con early stopping por F1 macro
    best_f1, best_state, wait = 0.0, None, 0
    hist = {"ep": [], "tr_loss": [], "tr_acc": [], "val_acc": [], "val_f1m": [], "lr": []}

    def evaluate_on(loader, use_ema=True):
        mdl = ema.ema if (ema is not None and use_ema) else model
        mdl.eval()
        preds, gts = [], []
        with torch.no_grad():
            for xb, yb, _sb in loader:
                xb = xb.to(device)
                p = mdl(xb).argmax(dim=1).cpu().numpy()
                preds.append(p); gts.append(yb.numpy())
        preds = np.concatenate(preds); gts = np.concatenate(gts)
        acc = accuracy_score(gts, preds)
        f1m = f1_score(gts, preds, average='macro')
        return acc, f1m

    for ep in range(1, EPOCHS+1):
        model.train()
        tr_loss, n_seen, tr_correct = 0.0, 0, 0
        for xb, yb, _sb in tr_ld:
            xb = xb.to(device); yb = yb.to(device)
            xb = augment_batch(xb)
            opt.zero_grad()
            logits = model(xb)
            loss = crit(logits, yb)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            opt.step()
            if ema is not None:
                ema.update(model)

            tr_loss += loss.item() * len(yb)
            n_seen += len(yb)
            tr_correct += (logits.argmax(1) == yb).sum().item()
        tr_loss /= max(1, n_seen)
        tr_acc = tr_correct / max(1, n_seen)

        acc, f1m = evaluate_on(val_ld, use_ema=True)

        hist["ep"].append(ep)
        hist["tr_loss"].append(tr_loss)
        hist["tr_acc"].append(tr_acc)
        hist["val_acc"].append(acc)
        hist["val_f1m"].append(f1m)
        hist["lr"].append(scheduler.get_last_lr()[0])

        print(f"  Época {ep:3d} | train_loss={tr_loss:.4f} | train_acc={tr_acc:.4f} | val_acc={acc:.4f} | val_f1m={f1m:.4f} | LR={scheduler.get_last_lr()[0]:.6f}")

        improved = f1m > best_f1 + 1e-4
        if improved:
            best_f1 = f1m
            ref_model = ema.ema if ema is not None else model
            best_state = {k: v.detach().cpu() for k, v in ref_model.state_dict().items()}
            wait = 0
        else:
            wait += 1

        scheduler.step()
        if wait >= PATIENCE:
            print(f"  Early stopping en época {ep} (mejor val_f1m={best_f1:.4f})")
            break

    if best_state is not None:
        (ema.ema if (ema is not None) else model).load_state_dict(best_state)

    # ---- Guardar curva ----
    fig = plt.figure(figsize=(8,4.5))
    ax1 = plt.gca()
    ax1.plot(hist["ep"], hist["tr_loss"], label="train_loss")
    ax1.plot(hist["ep"], hist["tr_acc"], label="tr_acc")
    ax1.plot(hist["ep"], hist["val_acc"], label="val_acc")
    ax1.set_xlabel("Época"); ax1.set_title(f"Fold {fold} — Curva de entrenamiento")
    ax1.legend(); ax1.grid(True, alpha=0.3)
    out_png = f"training_curve_fold{fold}.png"
    plt.tight_layout()
    plt.savefig(out_png, dpi=140)
    plt.close(fig)
    print(f"↳ Curva de entrenamiento guardada: {out_png}")

    # ---- Evaluación final en TEST (global) ----
    eval_model = ema.ema if (ema is not None) else model
    eval_model.eval()

    sfreq_used = RESAMPLE_HZ
    if sfreq_used is None:
        sfreq_used = int(round(X_te_std.shape[-1] / (TMAX - TMIN)))

    if (not SW_ENABLE) or SW_MODE == 'none':
        preds, gts = [], []
        with torch.no_grad():
            for xb, yb, _sb in te_ld:
                xb = xb.to(device)
                p = eval_model(xb).argmax(dim=1).cpu().numpy()
                preds.append(p); gts.append(yb.numpy())
        preds = np.concatenate(preds); gts = np.concatenate(gts)
    elif SW_MODE in ('subwin', 'tta'):
        logits_tta = None
        logits_sw  = None
        if SW_MODE == 'subwin':
            logits_sw = subwindow_logits(eval_model, X_te_std, sfreq_used, SW_LEN, SW_STRIDE, device)
        elif SW_MODE == 'tta':
            logits_tta = time_shift_tta_logits(eval_model, X_te_std, sfreq_used, TTA_SHIFTS_S, device)

        if COMBINE_TTA_AND_SUBWIN:
            if logits_tta is None:
                logits_tta = time_shift_tta_logits(eval_model, X_te_std, sfreq_used, TTA_SHIFTS_S, device)
            if logits_sw is None:
                logits_sw  = subwindow_logits(eval_model, X_te_std, sfreq_used, SW_LEN, SW_STRIDE, device)
            logits = 0.5 * logits_tta + 0.5 * logits_sw
        else:
            logits = logits_tta if logits_tta is not None else logits_sw

        preds = logits.argmax(axis=1); gts = y_te
    else:
        raise ValueError(f"SW_MODE desconocido: {SW_MODE}")

    acc = accuracy_score(gts, preds)
    f1m = f1_score(gts, preds, average='macro')
    print(f"[Fold {fold}/5] Global acc={acc:.4f}\n")
    print(classification_report(gts, preds, target_names=[c.replace('_',' ') for c in CLASS_NAMES], digits=4))
    print("Confusion matrix (rows=true, cols=pred):")
    cm = confusion_matrix(gts, preds, labels=[0,1,2,3])
    print(cm)
    print(f"↳ Matriz de confusión guardada: confusion_global_fold{fold}.png")
    # plot cm
    fig = plt.figure(figsize=(4.8,4.2))
    plt.imshow(cm, cmap='Blues'); plt.title(f"Confusion — Fold {fold} Global")
    plt.xlabel("pred"); plt.ylabel("true")
    plt.colorbar(); plt.tight_layout()
    plt.savefig(f"confusion_global_fold{fold}.png", dpi=140); plt.close(fig)

    # =========================
    # FINE-TUNING PROGRESIVO POR SUJETO (4-fold CV interno)
    # =========================
    # agrupamos por sujeto en el TEST
    subjects = np.unique(sub_te)
    subj_to_idx = {s: np.where(sub_te == s)[0] for s in subjects}

    def make_ft_optimizer(model):
        # dos grupos de params: head (lr alto), backbone (lr bajo)
        head_params = list(model.head.parameters())
        backbone_params = [p for n,p in model.named_parameters() if not n.startswith('head.')]
        return torch.optim.AdamW([
            {'params': backbone_params, 'lr': FT_LR_BACKBONE},
            {'params': head_params,     'lr': FT_LR_HEAD}
        ], weight_decay=FT_WD)

    def freeze_backbone(model, freeze=True):
        for n,p in model.named_parameters():
            if not n.startswith('head.'):
                p.requires_grad_(not freeze)

    def ft_make_criterion_from_counts(counts):
        inv = counts.sum() / (4.0 * np.maximum(counts.astype(np.float32), 1.0))
        a = torch.tensor(inv, dtype=torch.float32, device=device)
        am = a.mean().item()
        a[2] = FT_ALPHA_BOOST_FISTS * am
        a[3] = FT_ALPHA_BOOST_FEET  * am
        return FocalLoss(alpha=a, gamma=1.5, reduction='mean')

    def evaluate_tensor(model_t, X, y):
        model_t.eval()
        with torch.no_grad():
            xb = torch.tensor(X, dtype=torch.float32, device=device)
            p = model_t(xb).argmax(1).cpu().numpy()
        return accuracy_score(y, p), f1_score(y, p, average='macro'), p

    # Para reporte agregado
    ft_correct, ft_total = 0, 0
    all_true, all_pred = [], []

    # usamos como base los pesos del mejor global (EMA si existía)
    base_state = (ema.ema if (ema is not None) else model).state_dict()

    for s in subjects:
        idx = subj_to_idx[s]
        Xs, ys = X_te_std[idx], y_te[idx]

        if len(np.unique(ys)) < 2 or len(ys) < 20:
            # sujeto con muy pocos ejemplos/clases: eval directo
            eval_model = EEGCNNTransformer(n_ch=8, n_cls=4, d_model=D_MODEL, n_heads=N_HEADS,
                                           n_layers=N_LAYERS, p_drop=P_DROP, p_drop_encoder=P_DROP_ENCODER).to(device)
            eval_model.load_state_dict(base_state)
            acc_s, f1_s, pred_s = evaluate_tensor(eval_model, Xs, ys)
            ft_correct += int(acc_s*len(ys)); ft_total += len(ys)
            all_true.append(ys); all_pred.append(pred_s)
            continue

        skf = StratifiedKFold(n_splits=FT_N_FOLDS, shuffle=True, random_state=RANDOM_STATE + fold + int(s))
        for tr_idx, te_idx in skf.split(Xs, ys):
            Xtr, ytr = Xs[tr_idx].copy(), ys[tr_idx].copy()
            Xte_s, yte_s = Xs[te_idx].copy(), ys[te_idx].copy()

            # estandarización por sujeto (fit en FT-train, aplicar a FT-test)
            Xtr_std, Xte_std = standardize_per_channel(Xtr, Xte_s)

            # modelo fresh desde el global
            ft_model = EEGCNNTransformer(n_ch=8, n_cls=4, d_model=D_MODEL, n_heads=N_HEADS,
                                         n_layers=N_LAYERS, p_drop=P_DROP, p_drop_encoder=P_DROP_ENCODER).to(device)
            ft_model.load_state_dict(base_state)

            # optim + criterio (α de este sujeto/split)
            opt_ft = make_ft_optimizer(ft_model)
            alpha_counts = np.bincount(ytr, minlength=4)
            ft_crit = ft_make_criterion_from_counts(alpha_counts)

            # loaders FT
            ds_tr = TensorDataset(torch.tensor(Xtr_std), torch.tensor(ytr).long())
            ds_te = TensorDataset(torch.tensor(Xte_std), torch.tensor(yte_s).long())
            ld_tr = DataLoader(ds_tr, batch_size=FT_BATCH, shuffle=True,  drop_last=False, worker_init_fn=seed_worker)
            ld_te = DataLoader(ds_te, batch_size=FT_BATCH, shuffle=False, drop_last=False, worker_init_fn=seed_worker)

            # ETAPA 1: freeze backbone
            freeze_backbone(ft_model, True)
            best_f1_s, best_state_s, wait_s = -1, None, 0
            for ep in range(1, FT_FREEZE_EPOCHS+1):
                ft_model.train()
                for xb, yb in ld_tr:
                    xb = xb.to(device); yb = yb.to(device)
                    xb = augment_batch_ft(xb)
                    opt_ft.zero_grad()
                    logits = ft_model(xb)
                    loss = ft_crit(logits, yb)
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(ft_model.parameters(), 1.0)
                    opt_ft.step()

                # eval fold
                acc_s, f1_s = evaluate_tensor(ft_model, Xte_std, yte_s)[:2]
                improved = f1_s > best_f1_s + 1e-4
                if improved:
                    best_f1_s = f1_s
                    best_state_s = {k: v.detach().cpu() for k,v in ft_model.state_dict().items()}
                    wait_s = 0
                else:
                    wait_s += 1
                if wait_s >= FT_PATIENCE:
                    break

            if best_state_s is not None:
                ft_model.load_state_dict(best_state_s)

            # ETAPA 2: unfreeze (todo entrenable) con early stopping
            freeze_backbone(ft_model, False)
            best_f1_s2, best_state_s2, wait_s2 = -1, None, 0
            for ep in range(1, FT_UNFREEZE_EPOCHS+1):
                ft_model.train()
                for xb, yb in ld_tr:
                    xb = xb.to(device); yb = yb.to(device)
                    xb = augment_batch_ft(xb)
                    opt_ft.zero_grad()
                    logits = ft_model(xb)
                    loss = ft_crit(logits, yb)
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(ft_model.parameters(), 1.0)
                    opt_ft.step()

                acc_s, f1_s = evaluate_tensor(ft_model, Xte_std, yte_s)[:2]
                improved = f1_s > best_f1_s2 + 1e-4
                if improved:
                    best_f1_s2 = f1_s
                    best_state_s2 = {k: v.detach().cpu() for k,v in ft_model.state_dict().items()}
                    wait_s2 = 0
                else:
                    wait_s2 += 1
                if wait_s2 >= FT_PATIENCE:
                    break

            if best_state_s2 is not None:
                ft_model.load_state_dict(best_state_s2)

            # métricas del fold del sujeto
            acc_s, f1_s, pred_s = evaluate_tensor(ft_model, Xte_std, yte_s)
            ft_correct += int(acc_s*len(yte_s)); ft_total += len(yte_s)
            all_true.append(yte_s); all_pred.append(pred_s)

    # resumen FT por sujeto
    all_true = np.concatenate(all_true)
    all_pred = np.concatenate(all_pred)
    ft_acc = accuracy_score(all_true, all_pred)
    ft_f1m = f1_score(all_true, all_pred, average='macro')
    print(f"  Fine-tuning PROGRESIVO (por sujeto, {FT_N_FOLDS}-fold CV) acc={ft_acc:.4f}")
    print(f"  Δ(FT-Global) = {ft_acc - acc:+.4f}")
    cm_ft = confusion_matrix(all_true, all_pred, labels=[0,1,2,3])
    print("Confusion matrix FT (rows=true, cols=pred):")
    print(cm_ft)
    fig = plt.figure(figsize=(4.8,4.2))
    plt.imshow(cm_ft, cmap='Greens'); plt.title(f"Confusion — Fold {fold} FT")
    plt.xlabel("pred"); plt.ylabel("true")
    plt.colorbar(); plt.tight_layout()
    plt.savefig(f"confusion_ft_fold{fold}.png", dpi=140); plt.close(fig)
    print(f"↳ Matriz de confusión FT guardada: confusion_ft_fold{fold}.png")

    return acc, f1m, ft_acc

# =========================
# LOOP 5 FOLDS + RESUMEN
# =========================
if __name__ == "__main__":
    acc_folds, f1_folds, ft_acc_folds = [], [], []
    for fold in range(1, 6):
        acc, f1m, ft_acc = train_one_fold(fold, DEVICE)
        acc_folds.append(f"{acc:.4f}")
        f1_folds.append(f"{f1m:.4f}")
        ft_acc_folds.append(f"{ft_acc:.4f}")

    acc_mean = float(np.mean([float(a) for a in acc_folds]))
    f1_mean  = float(np.mean([float(f) for f in f1_folds]))
    ft_acc_mean = float(np.mean([float(a) for a in ft_acc_folds]))
    delta_mean = ft_acc_mean - acc_mean

    print("\n============================================================")
    print("RESULTADOS FINALES")
    print("============================================================")
    print(f"Global folds (ACC): {acc_folds}")
    print(f"Global mean ACC: {acc_mean:.4f}")
    print(f"F1 folds (MACRO): {f1_folds}")
    print(f"F1 mean (MACRO): {f1_mean:.4f}")
    print(f"Fine-tune PROGRESIVO folds: {ft_acc_folds}")
    print(f"Fine-tune PROGRESIVO mean: {ft_acc_mean:.4f}")
    print(f"Δ(FT-Global) mean: {delta_mean:+.4f}")


🚀 Usando dispositivo: cuda
🧠 INICIANDO EXPERIMENTO CON CNN+Transformer (K-Fold por sujeto como EEGNet)
🔧 Configuración: 4c, 8 canales, 6s | EPOCHS=60, BATCH=64, LR=0.001 | ZSCORE_PER_EPOCH=False


Cargando train fold1: 100%|██████████| 67/67 [00:08<00:00,  8.29it/s]
Cargando val fold1: 100%|██████████| 15/15 [00:01<00:00,  8.49it/s]
Cargando test fold1: 100%|██████████| 21/21 [00:02<00:00,  7.74it/s]


[Fold 1/5] Entrenando modelo global... (n_train=5628 | n_val=1260 | n_test=1764)
  Época   1 | train_loss=0.2326 | train_acc=0.2772 | val_acc=0.2627 | val_f1m=0.2225 | LR=0.000125
  Época   2 | train_loss=0.2126 | train_acc=0.3257 | val_acc=0.3825 | val_f1m=0.3362 | LR=0.000250
  Época   3 | train_loss=0.1839 | train_acc=0.4577 | val_acc=0.4381 | val_f1m=0.4333 | LR=0.000375
  Época   4 | train_loss=0.1797 | train_acc=0.4844 | val_acc=0.4571 | val_f1m=0.4561 | LR=0.000500
  Época   5 | train_loss=0.1724 | train_acc=0.5137 | val_acc=0.4508 | val_f1m=0.4461 | LR=0.000625
  Época   6 | train_loss=0.1673 | train_acc=0.5348 | val_acc=0.4476 | val_f1m=0.4405 | LR=0.000750
  Época   7 | train_loss=0.1650 | train_acc=0.5311 | val_acc=0.4397 | val_f1m=0.4288 | LR=0.000875
  Época   8 | train_loss=0.1653 | train_acc=0.5359 | val_acc=0.4611 | val_f1m=0.4632 | LR=0.001000
  Época   9 | train_loss=0.1664 | train_acc=0.5226 | val_acc=0.4484 | val_f1m=0.4503 | LR=0.001000
  Época  10 | train_loss=0.1

Cargando train fold2: 100%|██████████| 67/67 [00:07<00:00,  8.46it/s]
Cargando val fold2: 100%|██████████| 15/15 [00:01<00:00,  8.51it/s]
Cargando test fold2: 100%|██████████| 21/21 [00:02<00:00,  8.53it/s]


[Fold 2/5] Entrenando modelo global... (n_train=5628 | n_val=1260 | n_test=1764)
  Época   1 | train_loss=0.2279 | train_acc=0.2733 | val_acc=0.3056 | val_f1m=0.2554 | LR=0.000125
  Época   2 | train_loss=0.2164 | train_acc=0.3218 | val_acc=0.3889 | val_f1m=0.3563 | LR=0.000250
  Época   3 | train_loss=0.1937 | train_acc=0.4271 | val_acc=0.4349 | val_f1m=0.4283 | LR=0.000375
  Época   4 | train_loss=0.1842 | train_acc=0.4701 | val_acc=0.4603 | val_f1m=0.4592 | LR=0.000500
  Época   5 | train_loss=0.1778 | train_acc=0.4989 | val_acc=0.4667 | val_f1m=0.4591 | LR=0.000625
  Época   6 | train_loss=0.1776 | train_acc=0.4893 | val_acc=0.4619 | val_f1m=0.4597 | LR=0.000750
  Época   7 | train_loss=0.1742 | train_acc=0.5032 | val_acc=0.4770 | val_f1m=0.4738 | LR=0.000875
  Época   8 | train_loss=0.1750 | train_acc=0.4957 | val_acc=0.4746 | val_f1m=0.4715 | LR=0.001000
  Época   9 | train_loss=0.1705 | train_acc=0.5215 | val_acc=0.4730 | val_f1m=0.4739 | LR=0.001000
  Época  10 | train_loss=0.1

Cargando train fold3: 100%|██████████| 67/67 [00:07<00:00,  8.46it/s]
Cargando val fold3: 100%|██████████| 15/15 [00:01<00:00,  8.42it/s]
Cargando test fold3: 100%|██████████| 21/21 [00:02<00:00,  8.42it/s]


[Fold 3/5] Entrenando modelo global... (n_train=5628 | n_val=1260 | n_test=1764)
  Época   1 | train_loss=0.2271 | train_acc=0.2788 | val_acc=0.3048 | val_f1m=0.2318 | LR=0.000125
  Época   2 | train_loss=0.2121 | train_acc=0.3431 | val_acc=0.4151 | val_f1m=0.4163 | LR=0.000250
  Época   3 | train_loss=0.1900 | train_acc=0.4437 | val_acc=0.4508 | val_f1m=0.4547 | LR=0.000375
  Época   4 | train_loss=0.1801 | train_acc=0.4813 | val_acc=0.4635 | val_f1m=0.4623 | LR=0.000500
  Época   5 | train_loss=0.1720 | train_acc=0.5146 | val_acc=0.4849 | val_f1m=0.4840 | LR=0.000625
  Época   6 | train_loss=0.1697 | train_acc=0.5251 | val_acc=0.4937 | val_f1m=0.4964 | LR=0.000750
  Época   7 | train_loss=0.1700 | train_acc=0.5229 | val_acc=0.4857 | val_f1m=0.4877 | LR=0.000875
  Época   8 | train_loss=0.1673 | train_acc=0.5393 | val_acc=0.4937 | val_f1m=0.4947 | LR=0.001000
  Época   9 | train_loss=0.1642 | train_acc=0.5304 | val_acc=0.5143 | val_f1m=0.5126 | LR=0.001000
  Época  10 | train_loss=0.1

Cargando train fold4: 100%|██████████| 68/68 [00:08<00:00,  8.47it/s]
Cargando val fold4: 100%|██████████| 15/15 [00:01<00:00,  8.46it/s]
Cargando test fold4: 100%|██████████| 20/20 [00:02<00:00,  8.44it/s]


[Fold 4/5] Entrenando modelo global... (n_train=5712 | n_val=1260 | n_test=1680)
  Época   1 | train_loss=0.2284 | train_acc=0.2549 | val_acc=0.2794 | val_f1m=0.1887 | LR=0.000125
  Época   2 | train_loss=0.2121 | train_acc=0.3410 | val_acc=0.3960 | val_f1m=0.3735 | LR=0.000250
  Época   3 | train_loss=0.1935 | train_acc=0.4301 | val_acc=0.4429 | val_f1m=0.4377 | LR=0.000375
  Época   4 | train_loss=0.1836 | train_acc=0.4729 | val_acc=0.4794 | val_f1m=0.4802 | LR=0.000500
  Época   5 | train_loss=0.1762 | train_acc=0.5051 | val_acc=0.4865 | val_f1m=0.4867 | LR=0.000625
  Época   6 | train_loss=0.1741 | train_acc=0.5068 | val_acc=0.4794 | val_f1m=0.4797 | LR=0.000750
  Época   7 | train_loss=0.1763 | train_acc=0.5009 | val_acc=0.4825 | val_f1m=0.4826 | LR=0.000875
  Época   8 | train_loss=0.1705 | train_acc=0.5254 | val_acc=0.4817 | val_f1m=0.4722 | LR=0.001000
  Época   9 | train_loss=0.1674 | train_acc=0.5205 | val_acc=0.4929 | val_f1m=0.4927 | LR=0.001000
  Época  10 | train_loss=0.1

Cargando train fold5: 100%|██████████| 68/68 [00:08<00:00,  8.49it/s]
Cargando val fold5: 100%|██████████| 15/15 [00:01<00:00,  8.61it/s]
Cargando test fold5: 100%|██████████| 20/20 [00:02<00:00,  8.66it/s]


[Fold 5/5] Entrenando modelo global... (n_train=5712 | n_val=1260 | n_test=1680)
  Época   1 | train_loss=0.2280 | train_acc=0.2640 | val_acc=0.2659 | val_f1m=0.2042 | LR=0.000125
  Época   2 | train_loss=0.2091 | train_acc=0.3528 | val_acc=0.4159 | val_f1m=0.4055 | LR=0.000250
  Época   3 | train_loss=0.1839 | train_acc=0.4732 | val_acc=0.4254 | val_f1m=0.4233 | LR=0.000375
  Época   4 | train_loss=0.1776 | train_acc=0.4839 | val_acc=0.4413 | val_f1m=0.4399 | LR=0.000500
  Época   5 | train_loss=0.1759 | train_acc=0.4998 | val_acc=0.4571 | val_f1m=0.4528 | LR=0.000625
  Época   6 | train_loss=0.1751 | train_acc=0.4989 | val_acc=0.4325 | val_f1m=0.4330 | LR=0.000750
  Época   7 | train_loss=0.1735 | train_acc=0.5019 | val_acc=0.4683 | val_f1m=0.4706 | LR=0.000875
  Época   8 | train_loss=0.1705 | train_acc=0.5075 | val_acc=0.4571 | val_f1m=0.4602 | LR=0.001000
  Época   9 | train_loss=0.1692 | train_acc=0.5131 | val_acc=0.4476 | val_f1m=0.4471 | LR=0.001000
  Época  10 | train_loss=0.1

## DOS CLASES

In [12]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "7"
os.environ['PYTHONHASHSEED'] = '42'
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
import re, json, random
from pathlib import Path
from glob import glob
import numpy as np
import mne
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset, WeightedRandomSampler
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, classification_report
from sklearn.model_selection import StratifiedShuffleSplit

# =========================
# REPRODUCIBILIDAD
# =========================
RANDOM_STATE = 42

def seed_everything(seed: int = 42):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    try:
        torch.backends.cuda.matmul.allow_tf32 = False
        torch.backends.cudnn.allow_tf32 = False
    except Exception:
        pass
    try:
        torch.use_deterministic_algorithms(True)
    except Exception:
        pass

def seed_worker(worker_id: int):
    worker_seed = RANDOM_STATE + worker_id
    np.random.seed(worker_seed)
    random.seed(worker_seed)

seed_everything(RANDOM_STATE)

# =========================
# CONFIG
# =========================
PROJ = Path('..').resolve().parent
DATA_RAW = PROJ / 'data' / 'raw'
FOLDS_JSON = PROJ / 'models' / 'folds' / 'Kfold5.json'

EPOCHS = 60
BATCH_SIZE = 64        # múltiplo de 4
BASE_LR = 1e-3
WARMUP_EPOCHS = 8
PATIENCE = 12

# Split de validación por fold (por sujetos)
VAL_SUBJECT_FRAC = 0.18  # ≈ 18% de sujetos del train → val
VAL_STRAT_SUBJECT = True # estratifica por etiqueta dominante de cada sujeto

# Prepro
RESAMPLE_HZ = None
DO_NOTCH = True
DO_BANDPASS = False       # respetamos tu setup original
BP_LO, BP_HI = 4.0, 38.0
DO_CAR = False
ZSCORE_PER_EPOCH = False

# Modelo
D_MODEL = 128
N_HEADS = 4
N_LAYERS = 2
P_DROP = 0.2            # dropout en conv stack
P_DROP_ENCODER = 0.3    # dropout antes del encoder

# Ventana temporal
TMIN, TMAX = -1.0, 5.0

# TTA / SUBWINDOW en TEST
SW_MODE = 'tta'   # 'none'|'subwin'|'tta'
SW_ENABLE = True
TTA_SHIFTS_S = [-0.075, -0.05, -0.025, 0.0, 0.025, 0.05, 0.075]
SW_LEN, SW_STRIDE = 4.5, 1.5
COMBINE_TTA_AND_SUBWIN = False

# Sampler balanceado
USE_WEIGHTED_SAMPLER = True

# EMA
USE_EMA = True
EMA_DECAY = 0.9995

EXCLUDE_SUBJECTS = {38, 88, 89, 92, 100, 104}
EXPECTED_8 = ['C3','C4','Cz','CP3','CP4','FC3','FC4','FCz']
# -----> SOLO 2 CLASES
CLASS_NAMES = ['left', 'right']

# Solo usaremos los runs de imaginación L/R; ignoramos los de fists/feet
IMAGERY_RUNS_LR = {4, 8, 12}

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🚀 Usando dispositivo: {DEVICE}")
print("🧠 INICIANDO EXPERIMENTO CON CNN+Transformer (K-Fold por sujeto) — 2 clases (left/right)")
print(f"🔧 Configuración: 2c (L/R), 8 canales, 6s | EPOCHS={EPOCHS}, BATCH={BATCH_SIZE}, LR={BASE_LR} | ZSCORE_PER_EPOCH={ZSCORE_PER_EPOCH}")

# =========================
# UTILIDADES I/O
# =========================
def normalize_ch_name(name: str) -> str:
    s = re.sub(r'[^A-Za-z0-9]', '', name)
    return s.upper()

NORMALIZED_TARGETS = [normalize_ch_name(c) for c in EXPECTED_8]

def pick_8_channels(raw: mne.io.BaseRaw) -> mne.io.BaseRaw:
    chs = raw.info['ch_names']
    norm_map = {normalize_ch_name(ch): ch for ch in chs}
    picked = []
    for target_norm, target_orig in zip(NORMALIZED_TARGETS, EXPECTED_8):
        if target_norm in norm_map:
            picked.append(norm_map[target_norm])
        else:
            raise RuntimeError(f"Canal requerido '{target_orig}' no encontrado. Disponibles: {chs}")
    return raw.pick(picks=picked)

def list_subject_imagery_edfs(subject_id: str) -> list:
    subj_dir = DATA_RAW / subject_id
    edfs = []
    for r in [4, 6, 8, 10, 12, 14]:
        edfs.extend(glob(str(subj_dir / f"{subject_id}R{r:02d}.edf")))
    return sorted(edfs)

def subject_id_to_int(s: str) -> int:
    m = re.match(r'[Ss](\d+)', s)
    return int(m.group(1)) if m else -1

def load_subject_epochs(subject_id: str, resample_hz: int, do_notch: bool, do_bandpass: bool,
                        do_car: bool, bp_lo: float, bp_hi: float):
    edfs = list_subject_imagery_edfs(subject_id)
    if len(edfs) == 0:
        return np.empty((0,8,1), dtype=np.float32), np.empty((0,), dtype=int), None

    X_list, y_list, sfreq_list = [], [], []

    for edf_path in edfs:
        m = re.search(r"R(\d{2})", Path(edf_path).name)
        run = int(m.group(1)) if m else -1

        # Usar SOLO runs L/R
        if run not in IMAGERY_RUNS_LR:
            continue

        raw = mne.io.read_raw_edf(edf_path, preload=True, verbose='ERROR')
        raw = pick_8_channels(raw)

        if do_notch:
            raw.notch_filter(freqs=[60.0], picks='all', verbose='ERROR')
        if do_bandpass:
            raw.filter(l_freq=bp_lo, h_freq=bp_hi, picks='all', verbose='ERROR')
        if do_car:
            raw.set_eeg_reference('average', projection=False, verbose='ERROR')

        if resample_hz is not None and resample_hz > 0:
            raw.resample(resample_hz)
        sfreq = raw.info['sfreq']

        events, event_id = mne.events_from_annotations(raw, verbose='ERROR')
        # T1 = left, T2 = right
        keep = {k: v for k, v in event_id.items() if k in {'T1', 'T2'}}
        if len(keep) == 0:
            continue

        epochs = mne.Epochs(raw, events=events, event_id=keep, tmin=TMIN, tmax=TMAX,
                            baseline=None, preload=True, verbose='ERROR')
        X = epochs.get_data()

        if ZSCORE_PER_EPOCH:
            X = X.astype(np.float32)
            eps = 1e-6
            mu = X.mean(axis=2, keepdims=True)
            sd = X.std(axis=2, keepdims=True) + eps
            X = (X - mu) / sd

        ev_codes = epochs.events[:, 2]
        inv = {v: k for k, v in keep.items()}
        y_run = []
        for code in ev_codes:
            lab = inv[code]
            y_run.append(0 if lab == 'T1' else 1)  # 0=left, 1=right
        y_run = np.array(y_run, dtype=int)

        X_list.append(X)
        y_list.append(y_run)
        sfreq_list.append(sfreq)

    if len(X_list) == 0:
        return np.empty((0,8,1), dtype=np.float32), np.empty((0,), dtype=int), None

    X_all = np.concatenate(X_list, axis=0).astype(np.float32)
    y_all = np.concatenate(y_list, axis=0).astype(int)

    if len(set([int(round(s)) for s in sfreq_list])) != 1:
        raise RuntimeError(f"Sampling rates inconsistentes: {sfreq_list}")

    return X_all, y_all, sfreq_list[0]

def load_fold_subjects(folds_json: Path, fold: int):
    with open(folds_json, 'r') as f:
        data = json.load(f)
    for item in data.get('folds', []):
        if int(item.get('fold', -1)) == int(fold):
            return list(item.get('train', [])), list(item.get('test', []))
    raise ValueError(f"Fold {fold} not found in {folds_json}")

def standardize_per_channel(train_X, other_X):
    C = train_X.shape[1]
    train_X = train_X.astype(np.float32)
    other_X = other_X.astype(np.float32)
    for c in range(C):
        mu = train_X[:, c, :].mean()
        sd = train_X[:, c, :].std()
        sd = sd if sd > 1e-6 else 1.0
        train_X[:, c, :] = (train_X[:, c, :] - mu) / sd
        other_X[:, c, :] = (other_X[:, c, :] - mu) / sd
    return train_X, other_X

# =========================
# MODELO (GroupNorm en conv)
# =========================
def make_gn(num_channels, num_groups=8):
    g = min(num_groups, num_channels)
    while num_channels % g != 0 and g > 1:
        g -= 1
    return nn.GroupNorm(g, num_channels)

class DepthwiseSeparableConv(nn.Module):
    def __init__(self, in_ch, out_ch, k, s=1, p=0, p_drop=0.2):
        super().__init__()
        self.dw = nn.Conv1d(in_ch, in_ch, kernel_size=k, stride=s, padding=p, groups=in_ch, bias=False)
        self.pw = nn.Conv1d(in_ch, out_ch, kernel_size=1, bias=False)
        self.norm = make_gn(out_ch)
        self.act = nn.ELU()
        self.dropout = nn.Dropout(p=p_drop)

    def forward(self, x):
        x = self.dw(x); x = self.pw(x); x = self.norm(x)
        x = self.act(x); x = self.dropout(x)
        return x

class EEGCNNTransformer(nn.Module):
    def __init__(self, n_ch=8, n_cls=2, d_model=128, n_heads=4, n_layers=2,
                 p_drop=0.2, p_drop_encoder=0.3):
        super().__init__()
        self.conv_t = nn.Sequential(
            nn.Conv1d(n_ch, 32, kernel_size=129, stride=2, padding=64, bias=False),
            make_gn(32),
            nn.ELU(),
            nn.Dropout(p=p_drop),
            DepthwiseSeparableConv(32, 64, k=31, s=2, p=15, p_drop=p_drop),
            DepthwiseSeparableConv(64, 128, k=15, s=2, p=7,  p_drop=p_drop),
        )
        self.proj = nn.Conv1d(128, d_model, kernel_size=1, bias=False)
        self.dropout = nn.Dropout(p=p_drop_encoder)
        self.pos_encoding = None
        enc = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=n_heads, dim_feedforward=2*d_model,
            batch_first=True, activation='gelu', dropout=0.1, norm_first=False
        )
        self.encoder = nn.TransformerEncoder(enc, num_layers=n_layers)
        self.cls = nn.Parameter(torch.zeros(1, 1, d_model))
        nn.init.normal_(self.cls, std=0.02)
        self.head = nn.Sequential(nn.LayerNorm(d_model), nn.Linear(d_model, n_cls))

    def _positional_encoding(self, L, d):
        pos = torch.arange(0, L, dtype=torch.float32).unsqueeze(1)
        i   = torch.arange(0, d, dtype=torch.float32).unsqueeze(0)
        angle = pos / torch.pow(10000, (2 * (i//2)) / d)
        pe = torch.zeros(L, d, dtype=torch.float32)
        pe[:, 0::2] = torch.sin(angle[:, 0::2])
        pe[:, 1::2] = torch.cos(angle[:, 1::2])
        return pe

    def forward(self, x):
        z = self.conv_t(x)           # (B, 128, T')
        z = self.proj(z)             # (B, d_model, T')
        z = self.dropout(z)
        z = z.transpose(1, 2)        # (B, T', d_model)
        B, L, D = z.shape
        if (self.pos_encoding is None) or (self.pos_encoding.shape[0] != L) or (self.pos_encoding.shape[1] != D):
            self.pos_encoding = self._positional_encoding(L, D).to(z.device)
        z = z + self.pos_encoding[None, :, :]
        cls_tok = self.cls.expand(B, -1, -1)
        z = torch.cat([cls_tok, z], dim=1)
        z = self.encoder(z)
        cls = z[:, 0, :]
        return self.head(cls)

# =========================
# FOCAL LOSS (ajustada a 2 clases)
# =========================
class FocalLoss(nn.Module):
    def __init__(self, alpha: torch.Tensor, gamma: float = 1.5, reduction: str = 'mean'):
        super().__init__()
        self.alpha = alpha / alpha.sum()
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, logits, target):
        logp = nn.functional.log_softmax(logits, dim=-1)      # (B,C)
        p = logp.exp()
        idx = torch.arange(target.shape[0], device=logits.device)
        pt = p[idx, target]
        logpt = logp[idx, target]
        at = self.alpha[target]
        loss = - at * ((1 - pt) ** self.gamma) * logpt
        if self.reduction == 'mean': return loss.mean()
        if self.reduction == 'sum':  return loss.sum()
        return loss

# =========================
# AUGMENTS
# =========================
def augment_batch(
    xb,
    p_jitter=0.35, p_noise=0.35, p_chdrop=0.15,
    max_jitter_frac=0.03, noise_std=0.03, max_chdrop=1
):
    B, C, T = xb.shape
    if np.random.rand() < p_jitter:
        max_shift = int(max(1, T*max_jitter_frac))
        shifts = torch.randint(low=-max_shift, high=max_shift+1, size=(B,), device=xb.device)
        for i in range(B):
            xb[i] = torch.roll(xb[i], shifts=int(shifts[i].item()), dims=-1)
    if np.random.rand() < p_noise:
        xb = xb + noise_std*torch.randn_like(xb)
    if np.random.rand() < p_chdrop and max_chdrop > 0:
        k = min(max_chdrop, C)
        for i in range(B):
            idx = torch.randperm(C, device=xb.device)[:k]
            xb[i, idx, :] = 0.0
    return xb

# =========================
# EMA de pesos
# =========================
class ModelEMA:
    def __init__(self, model: nn.Module, decay: float = 0.9995, device=None):
        self.ema = self._clone(model).to(device if device is not None else next(model.parameters()).device)
        self.decay = decay
        self._updates = 0
        self.update(model, force=True)

    def _clone(self, model):
        ema = type(model)()
        ema.load_state_dict(model.state_dict())
        for p in ema.parameters():
            p.requires_grad_(False)
        return ema

    @torch.no_grad()
    def update(self, model: nn.Module, force: bool = False):
        d = self.decay
        if self._updates < 1000:
            d = (self._updates / 1000.0) * self.decay
        msd = model.state_dict()
        esd = self.ema.state_dict()
        for k in esd.keys():
            if esd[k].dtype.is_floating_point:
                esd[k].mul_(d).add_(msd[k].detach(), alpha=1.0 - d)
            else:
                esd[k] = msd[k]
        self._updates += 1

# =========================
# INFERENCIA TTA / SUBWINDOW
# =========================
def subwindow_logits(model, X, sfreq, sw_len, sw_stride, device):
    model.eval()
    wl = int(round(sw_len * sfreq))
    st = int(round(sw_stride * sfreq))
    wl = max(1, min(wl, X.shape[-1])); st = max(1, st)
    out = []
    with torch.no_grad():
        for i in range(X.shape[0]):
            x = X[i]; acc = []
            for s in range(0, max(1, X.shape[-1]-wl+1), st):
                seg = x[:, s:s+wl]
                if seg.shape[-1] < wl:
                    pad = wl - seg.shape[-1]
                    seg = np.pad(seg, ((0,0),(0,pad)), mode='edge')
                xb = torch.tensor(seg[None, ...], dtype=torch.float32, device=device)
                logit = model(xb).detach().cpu().numpy()[0]
                acc.append(logit)
            acc = np.mean(np.stack(acc, axis=0), axis=0) if len(acc) else np.zeros(2, dtype=np.float32)
            out.append(acc)
    return np.stack(out, axis=0)

def time_shift_tta_logits(model, X, sfreq, shifts_s, device):
    model.eval()
    T = X.shape[-1]; out = []
    with torch.no_grad():
        for i in range(X.shape[0]):
            x0 = X[i]; acc = []
            for sh in shifts_s:
                shift = int(round(sh * sfreq))
                if shift == 0:
                    x = x0
                elif shift > 0:
                    x = np.pad(x0[:, shift:], ((0,0),(0,shift)), mode='edge')[:, :T]
                else:
                    shift = -shift
                    x = np.pad(x0[:, :-shift], ((0,0),(shift,0)), mode='edge')[:, :T]
                xb = torch.tensor(x[None, ...], dtype=torch.float32, device=device)
                logit = model(xb).detach().cpu().numpy()[0]
                acc.append(logit)
            out.append(np.mean(np.stack(acc, axis=0), axis=0))
    return np.stack(out, axis=0)

# =========================
# Utilidades splits estratificados por sujeto
# =========================
def build_subject_label_map(subject_ids):
    y_dom_list = []
    for sid in subject_ids:
        Xs, ys, _ = load_subject_epochs(sid, RESAMPLE_HZ, DO_NOTCH, DO_BANDPASS, DO_CAR, BP_LO, BP_HI)
        if len(ys) == 0:
            y_dom_list.append(-1)
            continue
        binc = np.bincount(ys, minlength=2)
        y_dom = int(np.argmax(binc))
        y_dom_list.append(y_dom)
    return np.array(y_dom_list, dtype=int)

# =========================
# TRAIN/EVAL por FOLD (con train/val/test)
# =========================
def train_one_fold(fold:int, device):
    def load_fold_subjects_local(folds_json: Path, fold: int):
        return load_fold_subjects(folds_json, fold)

    train_sub, test_sub = load_fold_subjects_local(FOLDS_JSON, fold)
    train_sub = [s for s in train_sub if subject_id_to_int(s) not in EXCLUDE_SUBJECTS]
    test_sub  = [s for s in test_sub  if subject_id_to_int(s) not in EXCLUDE_SUBJECTS]

    # Split de validación por sujetos (determinista y opcionalmente estratificado)
    rng = np.random.RandomState(RANDOM_STATE + fold)
    tr_subjects = sorted(train_sub)

    if VAL_STRAT_SUBJECT and len(tr_subjects) > 1:
        y_dom = build_subject_label_map(tr_subjects)
        if np.any(y_dom < 0):
            mask = y_dom >= 0
            moda = int(np.bincount(y_dom[mask]).argmax()) if mask.sum() > 0 else 0
            y_dom[~mask] = moda
        n_val_subj = max(1, int(round(len(tr_subjects) * VAL_SUBJECT_FRAC)))
        sss = StratifiedShuffleSplit(n_splits=1, test_size=n_val_subj, random_state=RANDOM_STATE + fold)
        idx = np.arange(len(tr_subjects))
        _, val_idx = next(sss.split(idx, y_dom))
        val_subjects = sorted([tr_subjects[i] for i in val_idx])
        train_subjects = [s for s in tr_subjects if s not in val_subjects]
    else:
        tr_subjects_shuf = tr_subjects.copy()
        rng.shuffle(tr_subjects_shuf)
        n_val_subj = max(1, int(round(len(tr_subjects_shuf) * VAL_SUBJECT_FRAC)))
        val_subjects = sorted(tr_subjects_shuf[:n_val_subj])
        train_subjects = sorted(tr_subjects_shuf[n_val_subj:])

    # Carga TRAIN/VAL/TEST
    X_tr_list, y_tr_list, sub_tr_list = [], [], []
    X_val_list, y_val_list, sub_val_list = [], [], []
    X_te_list, y_te_list, sub_te_list = [], [], []
    sfreq = None

    for sid in tqdm(train_subjects, desc=f"Cargando train fold{fold}"):
        Xs, ys, sf = load_subject_epochs(sid, RESAMPLE_HZ, DO_NOTCH, DO_BANDPASS, DO_CAR, BP_LO, BP_HI)
        if len(ys) == 0: continue
        X_tr_list.append(Xs); y_tr_list.append(ys)
        sub_tr_list.append(np.full_like(ys, fill_value=subject_id_to_int(sid)))
        sfreq = sf if sfreq is None else sfreq

    for sid in tqdm(val_subjects, desc=f"Cargando val fold{fold}"):
        Xs, ys, sf = load_subject_epochs(sid, RESAMPLE_HZ, DO_NOTCH, DO_BANDPASS, DO_CAR, BP_LO, BP_HI)
        if len(ys) == 0: continue
        X_val_list.append(Xs); y_val_list.append(ys)
        sub_val_list.append(np.full_like(ys, fill_value=subject_id_to_int(sid)))
        sfreq = sf if sfreq is None else sfreq

    for sid in tqdm(test_sub, desc=f"Cargando test fold{fold}"):
        Xs, ys, sf = load_subject_epochs(sid, RESAMPLE_HZ, DO_NOTCH, DO_BANDPASS, DO_CAR, BP_LO, BP_HI)
        if len(ys) == 0: continue
        X_te_list.append(Xs); y_te_list.append(ys)
        sub_te_list.append(np.full_like(ys, fill_value=subject_id_to_int(sid)))
        sfreq = sf if sfreq is None else sfreq

    # Concatenar
    X_tr = np.concatenate(X_tr_list, axis=0); y_tr = np.concatenate(y_tr_list, axis=0)
    sub_tr = np.concatenate(sub_tr_list, axis=0)
    X_val = np.concatenate(X_val_list, axis=0); y_val = np.concatenate(y_val_list, axis=0)
    sub_val = np.concatenate(sub_val_list, axis=0)
    X_te = np.concatenate(X_te_list, axis=0); y_te = np.concatenate(y_te_list, axis=0)
    sub_te = np.concatenate(sub_te_list, axis=0)

    print(f"[Fold {fold}/5] Entrenando modelo global... (n_train={len(y_tr)} | n_val={len(y_val)} | n_test={len(y_te)})")

    # Normalización por canal (fit en TRAIN y aplicar a VAL/TEST)
    if ZSCORE_PER_EPOCH:
        X_tr_std, X_val_std, X_te_std = X_tr, X_val, X_te
    else:
        X_tr_std, X_val_std = standardize_per_channel(X_tr, X_val)
        _,        X_te_std  = standardize_per_channel(X_tr, X_te)

    # Datasets
    tr_ds  = TensorDataset(torch.tensor(X_tr_std),  torch.tensor(y_tr).long(),  torch.tensor(sub_tr).long())
    val_ds = TensorDataset(torch.tensor(X_val_std), torch.tensor(y_val).long(), torch.tensor(sub_val).long())
    te_ds  = TensorDataset(torch.tensor(X_te_std),  torch.tensor(y_te).long(),  torch.tensor(sub_te).long())

    # Weighted sampler templado (por sujeto/label)
    def make_weighted_sampler(dataset: TensorDataset):
        _Xb, yb, sb = dataset.tensors
        yb_np = yb.numpy()
        sb_np = sb.numpy()
        uniq_s, cnt_s = np.unique(sb_np, return_counts=True)
        map_s = {s:c for s,c in zip(uniq_s, cnt_s)}
        key = sb_np.astype(np.int64) * 10 + yb_np.astype(np.int64)
        uniq_k, cnt_k = np.unique(key, return_counts=True)
        map_k = {k:c for k,c in zip(uniq_k, cnt_k)}
        a, b = 0.8, 1.0
        w = []
        for s, y in zip(sb_np, yb_np):
            k = int(s)*10 + int(y)
            ws = float(map_s[int(s)])
            wk = float(map_k[k])
            w.append((ws ** (-a)) * (wk ** (-b)))
        w = np.array(w, dtype=np.float64)
        w = w / (w.mean() + 1e-12)
        sampler = WeightedRandomSampler(weights=torch.tensor(w, dtype=torch.double),
                                        num_samples=len(yb_np), replacement=True)
        return sampler

    if USE_WEIGHTED_SAMPLER:
        tr_sampler = make_weighted_sampler(tr_ds)
        tr_ld  = DataLoader(tr_ds, batch_size=BATCH_SIZE, sampler=tr_sampler, drop_last=False, worker_init_fn=seed_worker)
    else:
        tr_ld  = DataLoader(tr_ds, batch_size=BATCH_SIZE, shuffle=True,  drop_last=False, worker_init_fn=seed_worker)

    val_ld = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, worker_init_fn=seed_worker)
    te_ld  = DataLoader(te_ds,  batch_size=BATCH_SIZE, shuffle=False, drop_last=False, worker_init_fn=seed_worker)

    # Modelo (n_cls=2)
    model = EEGCNNTransformer(n_ch=8, n_cls=2, d_model=D_MODEL, n_heads=N_HEADS,
                              n_layers=N_LAYERS, p_drop=P_DROP, p_drop_encoder=P_DROP_ENCODER).to(device)

    # Optimizador + Focal Loss (alpha para 2 clases)
    opt = torch.optim.AdamW(model.parameters(), lr=BASE_LR, weight_decay=1e-2)

    class_counts = np.bincount(y_tr, minlength=2).astype(np.float32)
    inv = class_counts.sum() / (2.0 * np.maximum(class_counts, 1.0))
    alpha = torch.tensor(inv, dtype=torch.float32, device=device)
    crit = FocalLoss(alpha=alpha, gamma=1.5, reduction='mean')

    # Warmup+Cosine (min_factor=0.1)
    from torch.optim.lr_scheduler import LambdaLR
    total_epochs = EPOCHS
    warmup_epochs = max(1, int(WARMUP_EPOCHS))
    min_factor = 0.1

    def lr_lambda(current_epoch):
        if current_epoch < warmup_epochs:
            return (current_epoch + 1) / warmup_epochs
        progress = (current_epoch - warmup_epochs) / max(1, (total_epochs - warmup_epochs))
        progress = min(1.0, max(0.0, progress))
        return min_factor + 0.5 * (1.0 - min_factor) * (1.0 + np.cos(np.pi * progress))

    scheduler = LambdaLR(opt, lr_lambda=lr_lambda)

    # EMA
    if USE_EMA:
        ema = ModelEMA(model, decay=EMA_DECAY, device=device)
    else:
        ema = None

    # Entrenamiento con early stopping por F1 macro (EMA en val/test si aplica)
    best_f1, best_state, wait = 0.0, None, 0
    hist = {"ep": [], "tr_loss": [], "tr_acc": [], "val_acc": [], "val_f1m": [], "lr": []}

    def evaluate_on(loader, use_ema=True):
        mdl = ema.ema if (ema is not None and use_ema) else model
        mdl.eval()
        preds, gts = [], []
        with torch.no_grad():
            for xb, yb, _sb in loader:
                xb = xb.to(device)
                p = mdl(xb).argmax(dim=1).cpu().numpy()
                preds.append(p); gts.append(yb.numpy())
        preds = np.concatenate(preds); gts = np.concatenate(gts)
        acc = accuracy_score(gts, preds)
        f1m = f1_score(gts, preds, average='macro')
        return acc, f1m

    for ep in range(1, EPOCHS+1):
        # ---- Train ----
        model.train()
        tr_loss, n_seen, tr_correct = 0.0, 0, 0
        for xb, yb, _sb in tr_ld:
            xb = xb.to(device); yb = yb.to(device)
            xb = augment_batch(xb)
            opt.zero_grad()
            logits = model(xb)
            loss = crit(logits, yb)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            opt.step()
            if ema is not None:
                ema.update(model)

            tr_loss += loss.item() * len(yb)
            n_seen += len(yb)
            tr_correct += (logits.argmax(1) == yb).sum().item()
        tr_loss /= max(1, n_seen)
        tr_acc = tr_correct / max(1, n_seen)

        # ---- Val (EMA si aplica) ----
        acc, f1m = evaluate_on(val_ld, use_ema=True)

        hist["ep"].append(ep)
        hist["tr_loss"].append(tr_loss)
        hist["tr_acc"].append(tr_acc)
        hist["val_acc"].append(acc)
        hist["val_f1m"].append(f1m)
        hist["lr"].append(scheduler.get_last_lr()[0])

        print(f"  Época {ep:3d} | train_loss={tr_loss:.4f} | train_acc={tr_acc:.4f} | val_acc={acc:.4f} | val_f1m={f1m:.4f} | LR={scheduler.get_last_lr()[0]:.6f}")

        improved = f1m > best_f1 + 1e-4
        if improved:
            best_f1 = f1m
            ref_model = ema.ema if ema is not None else model
            best_state = {k: v.detach().cpu() for k, v in ref_model.state_dict().items()}
            wait = 0
        else:
            wait += 1

        scheduler.step()
        if wait >= PATIENCE:
            print(f"  Early stopping en época {ep} (mejor val_f1m={best_f1:.4f})")
            break

    # Cargamos mejor estado guardado (EMA si estaba activo)
    if best_state is not None:
        (ema.ema if (ema is not None) else model).load_state_dict(best_state)

    # ---- Guardar curva ----
    fig = plt.figure(figsize=(8,4.5))
    ax1 = plt.gca()
    ax1.plot(hist["ep"], hist["tr_loss"], label="train_loss")
    ax1.plot(hist["ep"], hist["tr_acc"], label="tr_acc")
    ax1.plot(hist["ep"], hist["val_acc"], label="val_acc")
    ax1.set_xlabel("Época"); ax1.set_title(f"Fold {fold} — Curva de entrenamiento (2 clases)")
    ax1.legend(); ax1.grid(True, alpha=0.3)
    out_png = f"training_curve_fold{fold}.png"
    plt.tight_layout()
    plt.savefig(out_png, dpi=140)
    plt.close(fig)
    print(f"↳ Curva de entrenamiento guardada: {out_png}")

    # ---- Evaluación final en TEST ----
    eval_model = ema.ema if (ema is not None) else model
    eval_model.eval()

    sfreq_used = RESAMPLE_HZ
    if sfreq_used is None:
        sfreq_used = int(round(X_te_std.shape[-1] / (TMAX - TMIN)))

    if (not SW_ENABLE) or SW_MODE == 'none':
        preds, gts = [], []
        with torch.no_grad():
            for xb, yb, _sb in te_ld:
                xb = xb.to(device)
                p = eval_model(xb).argmax(dim=1).cpu().numpy()
                preds.append(p); gts.append(yb.numpy())
        preds = np.concatenate(preds); gts = np.concatenate(gts)
    elif SW_MODE in ('subwin', 'tta'):
        logits_tta = None
        logits_sw  = None
        if SW_MODE == 'subwin':
            logits_sw = subwindow_logits(eval_model, X_te_std, sfreq_used, SW_LEN, SW_STRIDE, device)
        elif SW_MODE == 'tta':
            logits_tta = time_shift_tta_logits(eval_model, X_te_std, sfreq_used, TTA_SHIFTS_S, device)

        if COMBINE_TTA_AND_SUBWIN:
            if logits_tta is None:
                logits_tta = time_shift_tta_logits(eval_model, X_te_std, sfreq_used, TTA_SHIFTS_S, device)
            if logits_sw is None:
                logits_sw  = subwindow_logits(eval_model, X_te_std, sfreq_used, SW_LEN, SW_STRIDE, device)
            logits = 0.5 * logits_tta + 0.5 * logits_sw
        else:
            logits = logits_tta if logits_tta is not None else logits_sw

        preds = logits.argmax(axis=1); gts = y_te
    else:
        raise ValueError(f"SW_MODE desconocido: {SW_MODE}")

    acc = accuracy_score(gts, preds)
    f1m = f1_score(gts, preds, average='macro')
    print(f"[Fold {fold}/5] Global acc={acc:.4f}\n")
    print(classification_report(gts, preds, target_names=[c.replace('_',' ') for c in CLASS_NAMES], digits=4))
    print("Confusion matrix (rows=true, cols=pred):")
    print(confusion_matrix(gts, preds, labels=[0,1]))

    return acc, f1m

# =========================
# LOOP 5 FOLDS + RESUMEN
# =========================
if __name__ == "__main__":
    acc_folds, f1_folds = [], []
    for fold in range(1, 6):
        acc, f1m = train_one_fold(fold, DEVICE)
        acc_folds.append(f"{acc:.4f}")
        f1_folds.append(f"{f1m:.4f}")

    acc_mean = float(np.mean([float(a) for a in acc_folds]))
    f1_mean  = float(np.mean([float(f) for f in f1_folds]))

    print("\n============================================================")
    print("RESULTADOS FINALES (2 clases: left/right)")
    print("============================================================")
    print(f"Global folds (ACC): {acc_folds}")
    print(f"Global mean ACC: {acc_mean:.4f}")
    print(f"F1 folds (MACRO): {f1_folds}")
    print(f"F1 mean (MACRO): {f1_mean:.4f}")


🚀 Usando dispositivo: cuda
🧠 INICIANDO EXPERIMENTO CON CNN+Transformer (K-Fold por sujeto) — 2 clases (left/right)
🔧 Configuración: 2c (L/R), 8 canales, 6s | EPOCHS=60, BATCH=64, LR=0.001 | ZSCORE_PER_EPOCH=False


Cargando train fold1: 100%|██████████| 67/67 [00:03<00:00, 17.11it/s]
Cargando val fold1: 100%|██████████| 15/15 [00:00<00:00, 17.16it/s]
Cargando test fold1: 100%|██████████| 21/21 [00:01<00:00, 17.18it/s]


[Fold 1/5] Entrenando modelo global... (n_train=2814 | n_val=630 | n_test=882)
  Época   1 | train_loss=0.1543 | train_acc=0.5064 | val_acc=0.5825 | val_f1m=0.5682 | LR=0.000125
  Época   2 | train_loss=0.1046 | train_acc=0.6834 | val_acc=0.7063 | val_f1m=0.7060 | LR=0.000250
  Época   3 | train_loss=0.0808 | train_acc=0.7893 | val_acc=0.7429 | val_f1m=0.7428 | LR=0.000375
  Época   4 | train_loss=0.0776 | train_acc=0.8149 | val_acc=0.7587 | val_f1m=0.7587 | LR=0.000500
  Época   5 | train_loss=0.0701 | train_acc=0.8262 | val_acc=0.7333 | val_f1m=0.7309 | LR=0.000625
  Época   6 | train_loss=0.0698 | train_acc=0.8344 | val_acc=0.7730 | val_f1m=0.7729 | LR=0.000750
  Época   7 | train_loss=0.0678 | train_acc=0.8312 | val_acc=0.7619 | val_f1m=0.7619 | LR=0.000875
  Época   8 | train_loss=0.0652 | train_acc=0.8426 | val_acc=0.7190 | val_f1m=0.7131 | LR=0.001000
  Época   9 | train_loss=0.0688 | train_acc=0.8323 | val_acc=0.7651 | val_f1m=0.7650 | LR=0.001000
  Época  10 | train_loss=0.073

Cargando train fold2: 100%|██████████| 67/67 [00:04<00:00, 15.84it/s]
Cargando val fold2: 100%|██████████| 15/15 [00:00<00:00, 15.76it/s]
Cargando test fold2: 100%|██████████| 21/21 [00:01<00:00, 16.22it/s]


[Fold 2/5] Entrenando modelo global... (n_train=2814 | n_val=630 | n_test=882)
  Época   1 | train_loss=0.1289 | train_acc=0.5235 | val_acc=0.5762 | val_f1m=0.5722 | LR=0.000125
  Época   2 | train_loss=0.1076 | train_acc=0.6731 | val_acc=0.7111 | val_f1m=0.7111 | LR=0.000250
  Época   3 | train_loss=0.0869 | train_acc=0.7793 | val_acc=0.7286 | val_f1m=0.7242 | LR=0.000375
  Época   4 | train_loss=0.0806 | train_acc=0.7864 | val_acc=0.7540 | val_f1m=0.7494 | LR=0.000500
  Época   5 | train_loss=0.0825 | train_acc=0.7896 | val_acc=0.7952 | val_f1m=0.7952 | LR=0.000625
  Época   6 | train_loss=0.0732 | train_acc=0.8145 | val_acc=0.7603 | val_f1m=0.7603 | LR=0.000750
  Época   7 | train_loss=0.0748 | train_acc=0.8099 | val_acc=0.7778 | val_f1m=0.7747 | LR=0.000875
  Época   8 | train_loss=0.0749 | train_acc=0.8106 | val_acc=0.7508 | val_f1m=0.7428 | LR=0.001000
  Época   9 | train_loss=0.0761 | train_acc=0.8106 | val_acc=0.7794 | val_f1m=0.7792 | LR=0.001000
  Época  10 | train_loss=0.069

Cargando train fold3: 100%|██████████| 67/67 [00:03<00:00, 16.77it/s]
Cargando val fold3: 100%|██████████| 15/15 [00:00<00:00, 16.84it/s]
Cargando test fold3: 100%|██████████| 21/21 [00:01<00:00, 16.38it/s]


[Fold 3/5] Entrenando modelo global... (n_train=2814 | n_val=630 | n_test=882)
  Época   1 | train_loss=0.1278 | train_acc=0.5181 | val_acc=0.5270 | val_f1m=0.4369 | LR=0.000125
  Época   2 | train_loss=0.1172 | train_acc=0.6027 | val_acc=0.7524 | val_f1m=0.7523 | LR=0.000250
  Época   3 | train_loss=0.0903 | train_acc=0.7672 | val_acc=0.7984 | val_f1m=0.7973 | LR=0.000375
  Época   4 | train_loss=0.0877 | train_acc=0.7832 | val_acc=0.7921 | val_f1m=0.7921 | LR=0.000500
  Época   5 | train_loss=0.0747 | train_acc=0.8198 | val_acc=0.7937 | val_f1m=0.7935 | LR=0.000625
  Época   6 | train_loss=0.0789 | train_acc=0.8021 | val_acc=0.7746 | val_f1m=0.7739 | LR=0.000750
  Época   7 | train_loss=0.0700 | train_acc=0.8298 | val_acc=0.7921 | val_f1m=0.7910 | LR=0.000875
  Época   8 | train_loss=0.0758 | train_acc=0.8117 | val_acc=0.7746 | val_f1m=0.7738 | LR=0.001000
  Época   9 | train_loss=0.0723 | train_acc=0.8244 | val_acc=0.7889 | val_f1m=0.7876 | LR=0.001000
  Época  10 | train_loss=0.069

Cargando train fold4: 100%|██████████| 68/68 [00:04<00:00, 16.79it/s]
Cargando val fold4: 100%|██████████| 15/15 [00:00<00:00, 16.74it/s]
Cargando test fold4: 100%|██████████| 20/20 [00:01<00:00, 16.73it/s]


[Fold 4/5] Entrenando modelo global... (n_train=2856 | n_val=630 | n_test=840)
  Época   1 | train_loss=0.1285 | train_acc=0.5081 | val_acc=0.5524 | val_f1m=0.4999 | LR=0.000125
  Época   2 | train_loss=0.1099 | train_acc=0.6516 | val_acc=0.7238 | val_f1m=0.7233 | LR=0.000250
  Época   3 | train_loss=0.0932 | train_acc=0.7528 | val_acc=0.7603 | val_f1m=0.7599 | LR=0.000375
  Época   4 | train_loss=0.0815 | train_acc=0.7899 | val_acc=0.7524 | val_f1m=0.7523 | LR=0.000500
  Época   5 | train_loss=0.0778 | train_acc=0.8057 | val_acc=0.7651 | val_f1m=0.7649 | LR=0.000625
  Época   6 | train_loss=0.0837 | train_acc=0.7847 | val_acc=0.7714 | val_f1m=0.7684 | LR=0.000750
  Época   7 | train_loss=0.0745 | train_acc=0.8144 | val_acc=0.7857 | val_f1m=0.7842 | LR=0.000875
  Época   8 | train_loss=0.0732 | train_acc=0.8113 | val_acc=0.7794 | val_f1m=0.7784 | LR=0.001000
  Época   9 | train_loss=0.0708 | train_acc=0.8176 | val_acc=0.7698 | val_f1m=0.7693 | LR=0.001000
  Época  10 | train_loss=0.069

Cargando train fold5: 100%|██████████| 68/68 [00:03<00:00, 17.03it/s]
Cargando val fold5: 100%|██████████| 15/15 [00:00<00:00, 16.61it/s]
Cargando test fold5: 100%|██████████| 20/20 [00:01<00:00, 16.28it/s]


[Fold 5/5] Entrenando modelo global... (n_train=2856 | n_val=630 | n_test=840)
  Época   1 | train_loss=0.1265 | train_acc=0.5224 | val_acc=0.5524 | val_f1m=0.5190 | LR=0.000125
  Época   2 | train_loss=0.1102 | train_acc=0.6579 | val_acc=0.6968 | val_f1m=0.6950 | LR=0.000250
  Época   3 | train_loss=0.0859 | train_acc=0.7763 | val_acc=0.7206 | val_f1m=0.7205 | LR=0.000375
  Época   4 | train_loss=0.0811 | train_acc=0.7994 | val_acc=0.7317 | val_f1m=0.7316 | LR=0.000500
  Época   5 | train_loss=0.0791 | train_acc=0.8029 | val_acc=0.7286 | val_f1m=0.7286 | LR=0.000625
  Época   6 | train_loss=0.0750 | train_acc=0.8193 | val_acc=0.6952 | val_f1m=0.6830 | LR=0.000750
  Época   7 | train_loss=0.0791 | train_acc=0.7980 | val_acc=0.7587 | val_f1m=0.7586 | LR=0.000875
  Época   8 | train_loss=0.0720 | train_acc=0.8186 | val_acc=0.7524 | val_f1m=0.7512 | LR=0.001000
  Época   9 | train_loss=0.0726 | train_acc=0.8165 | val_acc=0.7651 | val_f1m=0.7650 | LR=0.001000
  Época  10 | train_loss=0.069

## DOS CLASES FINETUNING

In [14]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "7"
os.environ['PYTHONHASHSEED'] = '42'
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'

import re, json, random
from pathlib import Path
from glob import glob
import numpy as np
import mne
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset, WeightedRandomSampler
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, classification_report
from sklearn.model_selection import StratifiedShuffleSplit, StratifiedKFold

# =========================
# REPRODUCIBILIDAD
# =========================
RANDOM_STATE = 42

def seed_everything(seed: int = 42):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    try:
        torch.backends.cuda.matmul.allow_tf32 = False
        torch.backends.cudnn.allow_tf32 = False
    except Exception:
        pass
    try:
        torch.use_deterministic_algorithms(True)
    except Exception:
        pass

def seed_worker(worker_id: int):
    worker_seed = RANDOM_STATE + worker_id
    np.random.seed(worker_seed)
    random.seed(worker_seed)

seed_everything(RANDOM_STATE)

# =========================
# CONFIG
# =========================
PROJ = Path('..').resolve().parent
DATA_RAW = PROJ / 'data' / 'raw'
FOLDS_JSON = PROJ / 'models' / 'folds' / 'Kfold5.json'

EPOCHS = 60
BATCH_SIZE = 64
BASE_LR = 1e-3
WARMUP_EPOCHS = 8
PATIENCE = 12

# Split de validación por fold (por sujetos)
VAL_SUBJECT_FRAC = 0.18
VAL_STRAT_SUBJECT = True

# Prepro global
RESAMPLE_HZ = None
DO_NOTCH = True
DO_BANDPASS = False
BP_LO, BP_HI = 4.0, 38.0
DO_CAR = False
ZSCORE_PER_EPOCH = False

# Modelo
D_MODEL = 128
N_HEADS = 4
N_LAYERS = 2
P_DROP = 0.2
P_DROP_ENCODER = 0.3

# Ventana temporal
TMIN, TMAX = -1.0, 5.0

# TTA / SUBWINDOW en TEST
SW_MODE = 'tta'   # 'none'|'subwin'|'tta'
SW_ENABLE = True
TTA_SHIFTS_S = [-0.075, -0.05, -0.025, 0.0, 0.025, 0.05, 0.075]
SW_LEN, SW_STRIDE = 4.5, 1.5
COMBINE_TTA_AND_SUBWIN = False

# Sampler balanceado
USE_WEIGHTED_SAMPLER = True

# EMA (solo GLOBAL). En FT se desactiva.
USE_EMA = True
EMA_DECAY = 0.9995

# Fine-tuning (por sujeto) — sigue activo, pero ahora con n_cls=2
FT_N_FOLDS = 4
FT_FREEZE_EPOCHS = 8
FT_UNFREEZE_EPOCHS = 8
FT_PATIENCE = 6
FT_BATCH = 64
FT_LR_HEAD = 1e-3
FT_LR_BACKBONE = 2e-4
FT_WD = 1e-3
FT_AUG = dict(p_jitter=0.25, p_noise=0.25, p_chdrop=0.10, max_jitter_frac=0.02, noise_std=0.02, max_chdrop=1)

EXCLUDE_SUBJECTS = {38, 88, 89, 92, 100, 104}
EXPECTED_8 = ['C3','C4','Cz','CP3','CP4','FC3','FC4','FCz']
CLASS_NAMES = ['left', 'right']  # <-- 2 clases

# Solo runs de imaginación L/R
IMAGERY_RUNS_LR = {4, 8, 12}

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🚀 Usando dispositivo: {DEVICE}")
print("🧠 INICIANDO EXPERIMENTO CON CNN+Transformer (K-Fold por sujeto)")
print(f"🔧 Configuración: 2c (L/R), 8 canales, 6s | EPOCHS={EPOCHS}, BATCH={BATCH_SIZE}, LR={BASE_LR} | ZSCORE_PER_EPOCH={ZSCORE_PER_EPOCH}")

# =========================
# UTILIDADES I/O
# =========================
def normalize_ch_name(name: str) -> str:
    s = re.sub(r'[^A-Za-z0-9]', '', name)
    return s.upper()

NORMALIZED_TARGETS = [normalize_ch_name(c) for c in EXPECTED_8]

def pick_8_channels(raw: mne.io.BaseRaw) -> mne.io.BaseRaw:
    chs = raw.info['ch_names']
    norm_map = {normalize_ch_name(ch): ch for ch in chs}
    picked = []
    for target_norm, target_orig in zip(NORMALIZED_TARGETS, EXPECTED_8):
        if target_norm in norm_map:
            picked.append(norm_map[target_norm])
        else:
            raise RuntimeError(f"Canal requerido '{target_orig}' no encontrado. Disponibles: {chs}")
    return raw.pick(picks=picked)

def list_subject_imagery_edfs(subject_id: str) -> list:
    subj_dir = DATA_RAW / subject_id
    edfs = []
    for r in [4, 6, 8, 10, 12, 14]:
        edfs.extend(glob(str(subj_dir / f"{subject_id}R{r:02d}.edf")))
    return sorted(edfs)

def subject_id_to_int(s: str) -> int:
    m = re.match(r'[Ss](\d+)', s)
    return int(m.group(1)) if m else -1

def load_subject_epochs(subject_id: str, resample_hz: int, do_notch: bool, do_bandpass: bool,
                        do_car: bool, bp_lo: float, bp_hi: float):
    edfs = list_subject_imagery_edfs(subject_id)
    if len(edfs) == 0:
        return np.empty((0,8,1), dtype=np.float32), np.empty((0,), dtype=int), None

    X_list, y_list, sfreq_list = [], [], []
    for edf_path in edfs:
        m = re.search(r"R(\d{2})", Path(edf_path).name)
        run = int(m.group(1)) if m else -1

        # Usar SOLO runs L/R
        if run not in IMAGERY_RUNS_LR:
            continue

        raw = mne.io.read_raw_edf(edf_path, preload=True, verbose='ERROR')
        raw = pick_8_channels(raw)

        if do_notch:
            raw.notch_filter(freqs=[60.0], picks='all', verbose='ERROR')
        if do_bandpass:
            raw.filter(l_freq=bp_lo, h_freq=bp_hi, picks='all', verbose='ERROR')
        if do_car:
            raw.set_eeg_reference('average', projection=False, verbose='ERROR')
        if resample_hz is not None and resample_hz > 0:
            raw.resample(resample_hz)

        sfreq = raw.info['sfreq']
        events, event_id = mne.events_from_annotations(raw, verbose='ERROR')

        # Mantener solo T1/T2 -> left/right
        keep = {k: v for k, v in event_id.items() if k in {'T1', 'T2'}}
        if len(keep) == 0:
            continue

        epochs = mne.Epochs(raw, events=events, event_id=keep, tmin=TMIN, tmax=TMAX,
                            baseline=None, preload=True, verbose='ERROR')
        X = epochs.get_data()

        if ZSCORE_PER_EPOCH:
            X = X.astype(np.float32)
            eps = 1e-6
            mu = X.mean(axis=2, keepdims=True)
            sd = X.std(axis=2, keepdims=True) + eps
            X = (X - mu) / sd

        ev_codes = epochs.events[:, 2]
        inv = {v: k for k, v in keep.items()}
        # 0=left(T1), 1=right(T2)
        y_run = np.array([0 if inv[c] == 'T1' else 1 for c in ev_codes], dtype=int)

        X_list.append(X)
        y_list.append(y_run)
        sfreq_list.append(sfreq)

    if len(X_list) == 0:
        return np.empty((0,8,1), dtype=np.float32), np.empty((0,), dtype=int), None

    X_all = np.concatenate(X_list, axis=0).astype(np.float32)
    y_all = np.concatenate(y_list, axis=0).astype(int)

    if len(set([int(round(s)) for s in sfreq_list])) != 1:
        raise RuntimeError(f"Sampling rates inconsistentes: {sfreq_list}")

    return X_all, y_all, sfreq_list[0]

def load_fold_subjects(folds_json: Path, fold: int):
    with open(folds_json, 'r') as f:
        data = json.load(f)
    for item in data.get('folds', []):
        if int(item.get('fold', -1)) == int(fold):
            return list(item.get('train', [])), list(item.get('test', []))
    raise ValueError(f"Fold {fold} not found in {folds_json}")

def standardize_per_channel(train_X, other_X):
    C = train_X.shape[1]
    train_X = train_X.astype(np.float32)
    other_X = other_X.astype(np.float32)
    for c in range(C):
        mu = train_X[:, c, :].mean()
        sd = train_X[:, c, :].std()
        sd = sd if sd > 1e-6 else 1.0
        train_X[:, c, :] = (train_X[:, c, :] - mu) / sd
        other_X[:, c, :] = (other_X[:, c, :] - mu) / sd
    return train_X, other_X

# =========================
# MODELO (GroupNorm en conv)
# =========================
def make_gn(num_channels, num_groups=8):
    g = min(num_groups, num_channels)
    while num_channels % g != 0 and g > 1:
        g -= 1
    return nn.GroupNorm(g, num_channels)

class DepthwiseSeparableConv(nn.Module):
    def __init__(self, in_ch, out_ch, k, s=1, p=0, p_drop=0.2):
        super().__init__()
        self.dw = nn.Conv1d(in_ch, in_ch, kernel_size=k, stride=s, padding=p, groups=in_ch, bias=False)
        self.pw = nn.Conv1d(in_ch, out_ch, kernel_size=1, bias=False)
        self.norm = make_gn(out_ch)
        self.act = nn.ELU()
        self.dropout = nn.Dropout(p=p_drop)

    def forward(self, x):
        x = self.dw(x); x = self.pw(x); x = self.norm(x)
        x = self.act(x); x = self.dropout(x)
        return x

class EEGCNNTransformer(nn.Module):
    def __init__(self, n_ch=8, n_cls=2, d_model=128, n_heads=4, n_layers=2,
                 p_drop=0.2, p_drop_encoder=0.3):
        super().__init__()
        self.conv_t = nn.Sequential(
            nn.Conv1d(n_ch, 32, kernel_size=129, stride=2, padding=64, bias=False),
            make_gn(32),
            nn.ELU(),
            nn.Dropout(p=p_drop),
            DepthwiseSeparableConv(32, 64, k=31, s=2, p=15, p_drop=p_drop),
            DepthwiseSeparableConv(64, 128, k=15, s=2, p=7,  p_drop=p_drop),
        )
        self.proj = nn.Conv1d(128, d_model, kernel_size=1, bias=False)
        self.dropout = nn.Dropout(p=p_drop_encoder)
        self.pos_encoding = None
        enc = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=n_heads, dim_feedforward=2*d_model,
            batch_first=True, activation='gelu', dropout=0.1, norm_first=False
        )
        self.encoder = nn.TransformerEncoder(enc, num_layers=n_layers)
        self.cls = nn.Parameter(torch.zeros(1, 1, d_model))
        nn.init.normal_(self.cls, std=0.02)
        self.head = nn.Sequential(nn.LayerNorm(d_model), nn.Linear(d_model, n_cls))

    def _positional_encoding(self, L, d):
        pos = torch.arange(0, L, dtype=torch.float32).unsqueeze(1)
        i   = torch.arange(0, d, dtype=torch.float32).unsqueeze(0)
        angle = pos / torch.pow(10000, (2 * (i//2)) / d)
        pe = torch.zeros(L, d, dtype=torch.float32)
        pe[:, 0::2] = torch.sin(angle[:, 0::2])
        pe[:, 1::2] = torch.cos(angle[:, 1::2])
        return pe

    def forward(self, x):
        z = self.conv_t(x)           # (B, 128, T')
        z = self.proj(z)             # (B, d_model, T')
        z = self.dropout(z)
        z = z.transpose(1, 2)        # (B, T', d_model)
        B, L, D = z.shape
        if (self.pos_encoding is None) or (self.pos_encoding.shape[0] != L) or (self.pos_encoding.shape[1] != D):
            self.pos_encoding = self._positional_encoding(L, D).to(z.device)
        z = z + self.pos_encoding[None, :, :]
        cls_tok = self.cls.expand(B, -1, -1)
        z = torch.cat([cls_tok, z], dim=1)
        z = self.encoder(z)
        cls = z[:, 0, :]
        return self.head(cls)

# =========================
# FOCAL LOSS (2 clases)
# =========================
class FocalLoss(nn.Module):
    def __init__(self, alpha: torch.Tensor, gamma: float = 1.5, reduction: str = 'mean'):
        super().__init__()
        self.alpha = alpha / alpha.sum()
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, logits, target):
        logp = nn.functional.log_softmax(logits, dim=-1)      # (B,C)
        p = logp.exp()
        idx = torch.arange(target.shape[0], device=logits.device)
        pt = p[idx, target]
        logpt = logp[idx, target]
        at = self.alpha[target]
        loss = - at * ((1 - pt) ** self.gamma) * logpt
        if self.reduction == 'mean': return loss.mean()
        if self.reduction == 'sum':  return loss.sum()
        return loss

# =========================
# AUGMENTS
# =========================
def augment_batch(
    xb,
    p_jitter=0.35, p_noise=0.35, p_chdrop=0.15,
    max_jitter_frac=0.03, noise_std=0.03, max_chdrop=1
):
    B, C, T = xb.shape
    if np.random.rand() < p_jitter:
        max_shift = int(max(1, T*max_jitter_frac))
        shifts = torch.randint(low=-max_shift, high=max_shift+1, size=(B,), device=xb.device)
        for i in range(B):
            xb[i] = torch.roll(xb[i], shifts=int(shifts[i].item()), dims=-1)
    if np.random.rand() < p_noise:
        xb = xb + noise_std*torch.randn_like(xb)
    if np.random.rand() < p_chdrop and max_chdrop > 0:
        k = min(max_chdrop, C)
        for i in range(B):
            idx = torch.randperm(C, device=xb.device)[:k]
            xb[i, idx, :] = 0.0
    return xb

# Versión suave para FT
def augment_batch_ft(xb):
    return augment_batch(xb, **FT_AUG)

# =========================
# EMA de pesos (solo global)
# =========================
class ModelEMA:
    def __init__(self, model: nn.Module, decay: float = 0.9995, device=None):
        self.ema = self._clone(model).to(device if device is not None else next(model.parameters()).device)
        self.decay = decay
        self._updates = 0
        self.update(model, force=True)

    def _clone(self, model):
        ema = type(model)()
        ema.load_state_dict(model.state_dict())
        for p in ema.parameters():
            p.requires_grad_(False)
        return ema

    @torch.no_grad()
    def update(self, model: nn.Module, force: bool = False):
        d = self.decay
        if self._updates < 1000:
            d = (self._updates / 1000.0) * self.decay
        msd = model.state_dict()
        esd = self.ema.state_dict()
        for k in esd.keys():
            if esd[k].dtype.is_floating_point:
                esd[k].mul_(d).add_(msd[k].detach(), alpha=1.0 - d)
            else:
                esd[k] = msd[k]
        self._updates += 1

# =========================
# INFERENCIA TTA / SUBWINDOW
# =========================
def subwindow_logits(model, X, sfreq, sw_len, sw_stride, device):
    model.eval()
    wl = int(round(sw_len * sfreq))
    st = int(round(sw_stride * sfreq))
    wl = max(1, min(wl, X.shape[-1])); st = max(1, st)
    out = []
    with torch.no_grad():
        for i in range(X.shape[0]):
            x = X[i]; acc = []
            for s in range(0, max(1, X.shape[-1]-wl+1), st):
                seg = x[:, s:s+wl]
                if seg.shape[-1] < wl:
                    pad = wl - seg.shape[-1]
                    seg = np.pad(seg, ((0,0),(0,pad)), mode='edge')
                xb = torch.tensor(seg[None, ...], dtype=torch.float32, device=device)
                logit = model(xb).detach().cpu().numpy()[0]
                acc.append(logit)
            acc = np.mean(np.stack(acc, axis=0), axis=0) if len(acc) else np.zeros(2, dtype=np.float32)
            out.append(acc)
    return np.stack(out, axis=0)

def time_shift_tta_logits(model, X, sfreq, shifts_s, device):
    model.eval()
    T = X.shape[-1]; out = []
    with torch.no_grad():
        for i in range(X.shape[0]):
            x0 = X[i]; acc = []
            for sh in shifts_s:
                shift = int(round(sh * sfreq))
                if shift == 0:
                    x = x0
                elif shift > 0:
                    x = np.pad(x0[:, shift:], ((0,0),(0,shift)), mode='edge')[:, :T]
                else:
                    shift = -shift
                    x = np.pad(x0[:, :-shift], ((0,0),(shift,0)), mode='edge')[:, :T]
                xb = torch.tensor(x[None, ...], dtype=torch.float32, device=device)
                logit = model(xb).detach().cpu().numpy()[0]
                acc.append(logit)
            out.append(np.mean(np.stack(acc, axis=0), axis=0))
    return np.stack(out, axis=0)

# =========================
# Utilidades splits estratificados por sujeto
# =========================
def build_subject_label_map(subject_ids):
    y_dom_list = []
    for sid in subject_ids:
        Xs, ys, _ = load_subject_epochs(sid, RESAMPLE_HZ, DO_NOTCH, DO_BANDPASS, DO_CAR, BP_LO, BP_HI)
        if len(ys) == 0:
            y_dom_list.append(-1)
            continue
        binc = np.bincount(ys, minlength=2)
        y_dom = int(np.argmax(binc))
        y_dom_list.append(y_dom)
    return np.array(y_dom_list, dtype=int)

# =========================
# TRAIN/EVAL GLOBAL POR FOLD
# =========================
def train_one_fold(fold:int, device):
    def load_fold_subjects_local(folds_json: Path, fold: int):
        return load_fold_subjects(folds_json, fold)

    train_sub, test_sub = load_fold_subjects_local(FOLDS_JSON, fold)
    train_sub = [s for s in train_sub if subject_id_to_int(s) not in EXCLUDE_SUBJECTS]
    test_sub  = [s for s in test_sub  if subject_id_to_int(s) not in EXCLUDE_SUBJECTS]

    rng = np.random.RandomState(RANDOM_STATE + fold)
    tr_subjects = sorted(train_sub)

    if VAL_STRAT_SUBJECT and len(tr_subjects) > 1:
        y_dom = build_subject_label_map(tr_subjects)
        if np.any(y_dom < 0):
            mask = y_dom >= 0
            moda = int(np.bincount(y_dom[mask]).argmax()) if mask.sum() > 0 else 0
            y_dom[~mask] = moda
        n_val_subj = max(1, int(round(len(tr_subjects) * VAL_SUBJECT_FRAC)))
        sss = StratifiedShuffleSplit(n_splits=1, test_size=n_val_subj, random_state=RANDOM_STATE + fold)
        idx = np.arange(len(tr_subjects))
        _, val_idx = next(sss.split(idx, y_dom))
        val_subjects = sorted([tr_subjects[i] for i in val_idx])
        train_subjects = [s for s in tr_subjects if s not in val_subjects]
    else:
        tr_subjects_shuf = tr_subjects.copy()
        rng.shuffle(tr_subjects_shuf)
        n_val_subj = max(1, int(round(len(tr_subjects_shuf) * VAL_SUBJECT_FRAC)))
        val_subjects = sorted(tr_subjects_shuf[:n_val_subj])
        train_subjects = sorted(tr_subjects_shuf[n_val_subj:])

    # Carga TRAIN/VAL/TEST
    X_tr_list, y_tr_list, sub_tr_list = [], [], []
    X_val_list, y_val_list, sub_val_list = [], [], []
    X_te_list, y_te_list, sub_te_list = [], [], []
    sfreq = None

    for sid in tqdm(train_subjects, desc=f"Cargando train fold{fold}"):
        Xs, ys, sf = load_subject_epochs(sid, RESAMPLE_HZ, DO_NOTCH, DO_BANDPASS, DO_CAR, BP_LO, BP_HI)
        if len(ys) == 0: continue
        X_tr_list.append(Xs); y_tr_list.append(ys)
        sub_tr_list.append(np.full_like(ys, fill_value=subject_id_to_int(sid)))
        sfreq = sf if sfreq is None else sfreq

    for sid in tqdm(val_subjects, desc=f"Cargando val fold{fold}"):
        Xs, ys, sf = load_subject_epochs(sid, RESAMPLE_HZ, DO_NOTCH, DO_BANDPASS, DO_CAR, BP_LO, BP_HI)
        if len(ys) == 0: continue
        X_val_list.append(Xs); y_val_list.append(ys)
        sub_val_list.append(np.full_like(ys, fill_value=subject_id_to_int(sid)))
        sfreq = sf if sfreq is None else sfreq

    for sid in tqdm(test_sub, desc=f"Cargando test fold{fold}"):
        Xs, ys, sf = load_subject_epochs(sid, RESAMPLE_HZ, DO_NOTCH, DO_BANDPASS, DO_CAR, BP_LO, BP_HI)
        if len(ys) == 0: continue
        X_te_list.append(Xs); y_te_list.append(ys)
        sub_te_list.append(np.full_like(ys, fill_value=subject_id_to_int(sid)))
        sfreq = sf if sfreq is None else sfreq

    # Concatenar
    X_tr = np.concatenate(X_tr_list, axis=0); y_tr = np.concatenate(y_tr_list, axis=0)
    sub_tr = np.concatenate(sub_tr_list, axis=0)
    X_val = np.concatenate(X_val_list, axis=0); y_val = np.concatenate(y_val_list, axis=0)
    sub_val = np.concatenate(sub_val_list, axis=0)
    X_te = np.concatenate(X_te_list, axis=0); y_te = np.concatenate(y_te_list, axis=0)
    sub_te = np.concatenate(sub_te_list, axis=0)

    print(f"[Fold {fold}/5] Entrenando modelo global... (n_train={len(y_tr)} | n_val={len(y_val)} | n_test={len(y_te)})")

    # Normalización por canal (fit en TRAIN y aplicar a VAL/TEST)
    if ZSCORE_PER_EPOCH:
        X_tr_std, X_val_std, X_te_std = X_tr, X_val, X_te
    else:
        X_tr_std, X_val_std = standardize_per_channel(X_tr, X_val)
        _,        X_te_std  = standardize_per_channel(X_tr, X_te)

    # Datasets
    tr_ds  = TensorDataset(torch.tensor(X_tr_std),  torch.tensor(y_tr).long(),  torch.tensor(sub_tr).long())
    val_ds = TensorDataset(torch.tensor(X_val_std), torch.tensor(y_val).long(), torch.tensor(sub_val).long())
    te_ds  = TensorDataset(torch.tensor(X_te_std),  torch.tensor(y_te).long(),  torch.tensor(sub_te).long())

    # Weighted sampler templado: w = (1/ws)^a * (1/wk)^b
    def make_weighted_sampler(dataset: TensorDataset):
        _Xb, yb, sb = dataset.tensors
        yb_np = yb.numpy()
        sb_np = sb.numpy()
        uniq_s, cnt_s = np.unique(sb_np, return_counts=True)
        map_s = {s:c for s,c in zip(uniq_s, cnt_s)}
        key = sb_np.astype(np.int64) * 10 + yb_np.astype(np.int64)
        uniq_k, cnt_k = np.unique(key, return_counts=True)
        map_k = {k:c for k,c in zip(uniq_k, cnt_k)}
        a, b = 0.8, 1.0
        w = []
        for s, y in zip(sb_np, yb_np):
            k = int(s)*10 + int(y)
            ws = float(map_s[int(s)])
            wk = float(map_k[k])
            w.append((ws ** (-a)) * (wk ** (-b)))
        w = np.array(w, dtype=np.float64)
        w = w / (w.mean() + 1e-12)
        sampler = WeightedRandomSampler(weights=torch.tensor(w, dtype=torch.double),
                                        num_samples=len(yb_np), replacement=True)
        return sampler

    if USE_WEIGHTED_SAMPLER:
        tr_sampler = make_weighted_sampler(tr_ds)
        tr_ld  = DataLoader(tr_ds, batch_size=BATCH_SIZE, sampler=tr_sampler, drop_last=False, worker_init_fn=seed_worker)
    else:
        tr_ld  = DataLoader(tr_ds, batch_size=BATCH_SIZE, shuffle=True,  drop_last=False, worker_init_fn=seed_worker)

    val_ld = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, worker_init_fn=seed_worker)
    te_ld  = DataLoader(te_ds,  batch_size=BATCH_SIZE, shuffle=False, drop_last=False, worker_init_fn=seed_worker)

    # Modelo (2 clases)
    model = EEGCNNTransformer(n_ch=8, n_cls=2, d_model=D_MODEL, n_heads=N_HEADS,
                              n_layers=N_LAYERS, p_drop=P_DROP, p_drop_encoder=P_DROP_ENCODER).to(device)

    # Optimizador + Focal Loss (2 clases)
    opt = torch.optim.AdamW(model.parameters(), lr=BASE_LR, weight_decay=1e-2)
    class_counts = np.bincount(y_tr, minlength=2).astype(np.float32)
    inv = class_counts.sum() / (2.0 * np.maximum(class_counts, 1.0))
    alpha = torch.tensor(inv, dtype=torch.float32, device=device)
    crit = FocalLoss(alpha=alpha, gamma=1.5, reduction='mean')

    # LR scheduler Warmup+Cosine
    from torch.optim.lr_scheduler import LambdaLR
    total_epochs = EPOCHS
    warmup_epochs = max(1, int(WARMUP_EPOCHS))
    min_factor = 0.1
    def lr_lambda(current_epoch):
        if current_epoch < warmup_epochs:
            return (current_epoch + 1) / warmup_epochs
        progress = (current_epoch - warmup_epochs) / max(1, (total_epochs - warmup_epochs))
        progress = min(1.0, max(0.0, progress))
        return min_factor + 0.5 * (1.0 - min_factor) * (1.0 + np.cos(np.pi * progress))
    scheduler = LambdaLR(opt, lr_lambda=lr_lambda)

    # EMA (solo global)
    if USE_EMA:
        ema = ModelEMA(model, decay=EMA_DECAY, device=device)
    else:
        ema = None

    # Entrenamiento global con early stopping por F1 macro
    best_f1, best_state, wait = 0.0, None, 0
    hist = {"ep": [], "tr_loss": [], "tr_acc": [], "val_acc": [], "val_f1m": [], "lr": []}

    def evaluate_on(loader, use_ema=True):
        mdl = ema.ema if (ema is not None and use_ema) else model
        mdl.eval()
        preds, gts = [], []
        with torch.no_grad():
            for xb, yb, _sb in loader:
                xb = xb.to(device)
                p = mdl(xb).argmax(dim=1).cpu().numpy()
                preds.append(p); gts.append(yb.numpy())
        preds = np.concatenate(preds); gts = np.concatenate(gts)
        acc = accuracy_score(gts, preds)
        f1m = f1_score(gts, preds, average='macro')
        return acc, f1m

    for ep in range(1, EPOCHS+1):
        model.train()
        tr_loss, n_seen, tr_correct = 0.0, 0, 0
        for xb, yb, _sb in tr_ld:
            xb = xb.to(device); yb = yb.to(device)
            xb = augment_batch(xb)
            opt.zero_grad()
            logits = model(xb)
            loss = crit(logits, yb)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            opt.step()
            if ema is not None:
                ema.update(model)

            tr_loss += loss.item() * len(yb)
            n_seen += len(yb)
            tr_correct += (logits.argmax(1) == yb).sum().item()
        tr_loss /= max(1, n_seen)
        tr_acc = tr_correct / max(1, n_seen)

        acc, f1m = evaluate_on(val_ld, use_ema=True)

        hist["ep"].append(ep)
        hist["tr_loss"].append(tr_loss)
        hist["tr_acc"].append(tr_acc)
        hist["val_acc"].append(acc)
        hist["val_f1m"].append(f1m)
        hist["lr"].append(scheduler.get_last_lr()[0])

        print(f"  Época {ep:3d} | train_loss={tr_loss:.4f} | train_acc={tr_acc:.4f} | val_acc={acc:.4f} | val_f1m={f1m:.4f} | LR={scheduler.get_last_lr()[0]:.6f}")

        improved = f1m > best_f1 + 1e-4
        if improved:
            best_f1 = f1m
            ref_model = ema.ema if ema is not None else model
            best_state = {k: v.detach().cpu() for k, v in ref_model.state_dict().items()}
            wait = 0
        else:
            wait += 1

        scheduler.step()
        if wait >= PATIENCE:
            print(f"  Early stopping en época {ep} (mejor val_f1m={best_f1:.4f})")
            break

    if best_state is not None:
        (ema.ema if (ema is not None) else model).load_state_dict(best_state)

    # ---- Guardar curva ----
    fig = plt.figure(figsize=(8,4.5))
    ax1 = plt.gca()
    ax1.plot(hist["ep"], hist["tr_loss"], label="train_loss")
    ax1.plot(hist["ep"], hist["tr_acc"], label="tr_acc")
    ax1.plot(hist["ep"], hist["val_acc"], label="val_acc")
    ax1.set_xlabel("Época"); ax1.set_title(f"Fold {fold} — Curva de entrenamiento (2 clases)")
    ax1.legend(); ax1.grid(True, alpha=0.3)
    out_png = f"training_curve_fold{fold}.png"
    plt.tight_layout()
    plt.savefig(out_png, dpi=140)
    plt.close(fig)
    print(f"↳ Curva de entrenamiento guardada: {out_png}")

    # ---- Evaluación final en TEST (global) ----
    eval_model = ema.ema if (ema is not None) else model
    eval_model.eval()

    sfreq_used = RESAMPLE_HZ
    if sfreq_used is None:
        sfreq_used = int(round(X_te_std.shape[-1] / (TMAX - TMIN)))

    if (not SW_ENABLE) or SW_MODE == 'none':
        preds, gts = [], []
        with torch.no_grad():
            for xb, yb, _sb in te_ld:
                xb = xb.to(device)
                p = eval_model(xb).argmax(dim=1).cpu().numpy()
                preds.append(p); gts.append(yb.numpy())
        preds = np.concatenate(preds); gts = np.concatenate(gts)
    elif SW_MODE in ('subwin', 'tta'):
        logits_tta = None
        logits_sw  = None
        if SW_MODE == 'subwin':
            logits_sw = subwindow_logits(eval_model, X_te_std, sfreq_used, SW_LEN, SW_STRIDE, device)
        elif SW_MODE == 'tta':
            logits_tta = time_shift_tta_logits(eval_model, X_te_std, sfreq_used, TTA_SHIFTS_S, device)

        if COMBINE_TTA_AND_SUBWIN:
            if logits_tta is None:
                logits_tta = time_shift_tta_logits(eval_model, X_te_std, sfreq_used, TTA_SHIFTS_S, device)
            if logits_sw is None:
                logits_sw  = subwindow_logits(eval_model, X_te_std, sfreq_used, SW_LEN, SW_STRIDE, device)
            logits = 0.5 * logits_tta + 0.5 * logits_sw
        else:
            logits = logits_tta if logits_tta is not None else logits_sw

        preds = logits.argmax(axis=1); gts = y_te
    else:
        raise ValueError(f"SW_MODE desconocido: {SW_MODE}")

    acc = accuracy_score(gts, preds)
    f1m = f1_score(gts, preds, average='macro')
    print(f"[Fold {fold}/5] Global acc={acc:.4f}\n")
    print(classification_report(gts, preds, target_names=[c.replace('_',' ') for c in CLASS_NAMES], digits=4))
    print("Confusion matrix (rows=true, cols=pred):")
    cm = confusion_matrix(gts, preds, labels=[0,1])
    print(cm)
    fig = plt.figure(figsize=(4.8,4.2))
    plt.imshow(cm, cmap='Blues'); plt.title(f"Confusion — Fold {fold} Global (2 clases)")
    plt.xlabel("pred"); plt.ylabel("true")
    plt.colorbar(); plt.tight_layout()
    plt.savefig(f"confusion_global_fold{fold}.png", dpi=140); plt.close(fig)
    print(f"↳ Matriz de confusión guardada: confusion_global_fold{fold}.png")

    # =========================
    # FINE-TUNING PROGRESIVO POR SUJETO (4-fold CV interno) — 2 clases
    # =========================
    subjects = np.unique(sub_te)
    subj_to_idx = {s: np.where(sub_te == s)[0] for s in subjects}

    def make_ft_optimizer(model):
        head_params = list(model.head.parameters())
        backbone_params = [p for n,p in model.named_parameters() if not n.startswith('head.')]
        return torch.optim.AdamW([
            {'params': backbone_params, 'lr': FT_LR_BACKBONE},
            {'params': head_params,     'lr': FT_LR_HEAD}
        ], weight_decay=FT_WD)

    def freeze_backbone(model, freeze=True):
        for n,p in model.named_parameters():
            if not n.startswith('head.'):
                p.requires_grad_(not freeze)

    def ft_make_criterion_from_counts(counts):
        inv = counts.sum() / (2.0 * np.maximum(counts.astype(np.float32), 1.0))
        a = torch.tensor(inv, dtype=torch.float32, device=device)
        return FocalLoss(alpha=a, gamma=1.5, reduction='mean')

    def evaluate_tensor(model_t, X, y):
        model_t.eval()
        with torch.no_grad():
            xb = torch.tensor(X, dtype=torch.float32, device=device)
            p = model_t(xb).argmax(1).cpu().numpy()
        return accuracy_score(y, p), f1_score(y, p, average='macro'), p

    all_true, all_pred = [], []

    # Base: mejores pesos globales (EMA si existe)
    base_state = (ema.ema if (ema is not None) else model).state_dict()

    for s in subjects:
        idx = subj_to_idx[s]
        Xs, ys = X_te_std[idx], y_te[idx]

        # Si no hay ambas clases o muy pocos ejemplos, solo eval directa
        if len(np.unique(ys)) < 2 or len(ys) < 20:
            eval_model = EEGCNNTransformer(n_ch=8, n_cls=2, d_model=D_MODEL, n_heads=N_HEADS,
                                           n_layers=N_LAYERS, p_drop=P_DROP, p_drop_encoder=P_DROP_ENCODER).to(device)
            eval_model.load_state_dict(base_state)
            _, _, pred_s = evaluate_tensor(eval_model, Xs, ys)
            all_true.append(ys); all_pred.append(pred_s)
            continue

        skf = StratifiedKFold(n_splits=FT_N_FOLDS, shuffle=True, random_state=RANDOM_STATE + fold + int(s))
        for tr_idx, te_idx in skf.split(Xs, ys):
            Xtr, ytr = Xs[tr_idx].copy(), ys[tr_idx].copy()
            Xte_s, yte_s = Xs[te_idx].copy(), ys[te_idx].copy()

            # Estandarización por sujeto (fit en FT-train)
            Xtr_std, Xte_std = standardize_per_channel(Xtr, Xte_s)

            # Modelo fresh desde el global
            ft_model = EEGCNNTransformer(n_ch=8, n_cls=2, d_model=D_MODEL, n_heads=N_HEADS,
                                         n_layers=N_LAYERS, p_drop=P_DROP, p_drop_encoder=P_DROP_ENCODER).to(device)
            ft_model.load_state_dict(base_state)

            # optim + criterio (α de este sujeto/split)
            opt_ft = make_ft_optimizer(ft_model)
            alpha_counts = np.bincount(ytr, minlength=2)
            ft_crit = ft_make_criterion_from_counts(alpha_counts)

            # loaders FT
            ds_tr = TensorDataset(torch.tensor(Xtr_std), torch.tensor(ytr).long())
            ds_te = TensorDataset(torch.tensor(Xte_std), torch.tensor(yte_s).long())
            ld_tr = DataLoader(ds_tr, batch_size=FT_BATCH, shuffle=True,  drop_last=False, worker_init_fn=seed_worker)
            ld_te = DataLoader(ds_te, batch_size=FT_BATCH, shuffle=False, drop_last=False, worker_init_fn=seed_worker)

            # ETAPA 1: freeze backbone
            freeze_backbone(ft_model, True)
            best_f1_s, best_state_s, wait_s = -1, None, 0
            for ep in range(1, FT_FREEZE_EPOCHS+1):
                ft_model.train()
                for xb, yb in ld_tr:
                    xb = xb.to(device); yb = yb.to(device)
                    xb = augment_batch_ft(xb)
                    opt_ft.zero_grad()
                    logits = ft_model(xb)
                    loss = ft_crit(logits, yb)
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(ft_model.parameters(), 1.0)
                    opt_ft.step()

                acc_s, f1_s = evaluate_tensor(ft_model, Xte_std, yte_s)[:2]
                improved = f1_s > best_f1_s + 1e-4
                if improved:
                    best_f1_s = f1_s
                    best_state_s = {k: v.detach().cpu() for k,v in ft_model.state_dict().items()}
                    wait_s = 0
                else:
                    wait_s += 1
                if wait_s >= FT_PATIENCE:
                    break

            if best_state_s is not None:
                ft_model.load_state_dict(best_state_s)

            # ETAPA 2: unfreeze (todo entrenable)
            freeze_backbone(ft_model, False)
            best_f1_s2, best_state_s2, wait_s2 = -1, None, 0
            for ep in range(1, FT_UNFREEZE_EPOCHS+1):
                ft_model.train()
                for xb, yb in ld_tr:
                    xb = xb.to(device); yb = yb.to(device)
                    xb = augment_batch_ft(xb)
                    opt_ft.zero_grad()
                    logits = ft_model(xb)
                    loss = ft_crit(logits, yb)
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(ft_model.parameters(), 1.0)
                    opt_ft.step()

                acc_s, f1_s = evaluate_tensor(ft_model, Xte_std, yte_s)[:2]
                improved = f1_s > best_f1_s2 + 1e-4
                if improved:
                    best_f1_s2 = f1_s
                    best_state_s2 = {k: v.detach().cpu() for k,v in ft_model.state_dict().items()}
                    wait_s2 = 0
                else:
                    wait_s2 += 1
                if wait_s2 >= FT_PATIENCE:
                    break

            if best_state_s2 is not None:
                ft_model.load_state_dict(best_state_s2)

            # métricas del fold del sujeto
            _, _, pred_s = evaluate_tensor(ft_model, Xte_std, yte_s)
            all_true.append(yte_s); all_pred.append(pred_s)

    # resumen FT por sujeto
    all_true = np.concatenate(all_true) if len(all_true) else np.array([], dtype=int)
    all_pred = np.concatenate(all_pred) if len(all_pred) else np.array([], dtype=int)
    if len(all_true) > 0:
        ft_acc = accuracy_score(all_true, all_pred)
        ft_f1m = f1_score(all_true, all_pred, average='macro')
        print(f"  Fine-tuning PROGRESIVO (por sujeto, {FT_N_FOLDS}-fold CV) acc={ft_acc:.4f}")
        print(f"  Δ(FT-Global) = {ft_acc - acc:+.4f}")
        cm_ft = confusion_matrix(all_true, all_pred, labels=[0,1])
        print("Confusion matrix FT (rows=true, cols=pred):")
        print(cm_ft)
        fig = plt.figure(figsize=(4.8,4.2))
        plt.imshow(cm_ft, cmap='Greens'); plt.title(f"Confusion — Fold {fold} FT (2 clases)")
        plt.xlabel("pred"); plt.ylabel("true")
        plt.colorbar(); plt.tight_layout()
        plt.savefig(f"confusion_ft_fold{fold}.png", dpi=140); plt.close(fig)
        print(f"↳ Matriz de confusión FT guardada: confusion_ft_fold{fold}.png")
    else:
        ft_acc = float('nan')
        print("  Fine-tuning PROGRESIVO: no se generaron splits válidos (sujetos con una sola clase o muy pocos ensayos).")

    return acc, f1m, ft_acc

# =========================
# LOOP 5 FOLDS + RESUMEN
# =========================
if __name__ == "__main__":
    acc_folds, f1_folds, ft_acc_folds = [], [], []
    for fold in range(1, 6):
        acc, f1m, ft_acc = train_one_fold(fold, DEVICE)
        acc_folds.append(f"{acc:.4f}")
        f1_folds.append(f"{f1m:.4f}")
        ft_acc_folds.append("nan" if (ft_acc != ft_acc) else f"{ft_acc:.4f}")  # maneja NaN

    acc_mean = float(np.mean([float(a) for a in acc_folds]))
    f1_mean  = float(np.mean([float(f) for f in f1_folds]))
    valid_ft = [float(a) for a in ft_acc_folds if a != "nan"]
    ft_acc_mean = float(np.mean(valid_ft)) if len(valid_ft) else float('nan')
    delta_mean = (ft_acc_mean - acc_mean) if (ft_acc_mean == ft_acc_mean) else float('nan')

    print("\n============================================================")
    print("RESULTADOS FINALES (2 clases: left/right)")
    print("============================================================")
    print(f"Global folds (ACC): {acc_folds}")
    print(f"Global mean ACC: {acc_mean:.4f}")
    print(f"F1 folds (MACRO): {f1_folds}")
    print(f"F1 mean (MACRO): {f1_mean:.4f}")
    print(f"Fine-tune PROGRESIVO folds: {ft_acc_folds}")
    if ft_acc_mean == ft_acc_mean:
        print(f"Fine-tune PROGRESIVO mean: {ft_acc_mean:.4f}")
        print(f"Δ(FT-Global) mean: {delta_mean:+.4f}")
    else:
        print("Fine-tune PROGRESIVO mean: N/A (insuficientes sujetos/conjuntos válidos)")


🚀 Usando dispositivo: cuda
🧠 INICIANDO EXPERIMENTO CON CNN+Transformer (K-Fold por sujeto)
🔧 Configuración: 2c (L/R), 8 canales, 6s | EPOCHS=60, BATCH=64, LR=0.001 | ZSCORE_PER_EPOCH=False


Cargando train fold1: 100%|██████████| 67/67 [00:03<00:00, 16.94it/s]
Cargando val fold1: 100%|██████████| 15/15 [00:00<00:00, 17.04it/s]
Cargando test fold1: 100%|██████████| 21/21 [00:01<00:00, 16.95it/s]


[Fold 1/5] Entrenando modelo global... (n_train=2814 | n_val=630 | n_test=882)
  Época   1 | train_loss=0.1543 | train_acc=0.5064 | val_acc=0.5825 | val_f1m=0.5682 | LR=0.000125
  Época   2 | train_loss=0.1046 | train_acc=0.6834 | val_acc=0.7063 | val_f1m=0.7060 | LR=0.000250
  Época   3 | train_loss=0.0808 | train_acc=0.7893 | val_acc=0.7429 | val_f1m=0.7428 | LR=0.000375
  Época   4 | train_loss=0.0776 | train_acc=0.8149 | val_acc=0.7587 | val_f1m=0.7587 | LR=0.000500
  Época   5 | train_loss=0.0701 | train_acc=0.8262 | val_acc=0.7333 | val_f1m=0.7309 | LR=0.000625
  Época   6 | train_loss=0.0698 | train_acc=0.8344 | val_acc=0.7730 | val_f1m=0.7729 | LR=0.000750
  Época   7 | train_loss=0.0678 | train_acc=0.8312 | val_acc=0.7619 | val_f1m=0.7619 | LR=0.000875
  Época   8 | train_loss=0.0652 | train_acc=0.8426 | val_acc=0.7190 | val_f1m=0.7131 | LR=0.001000
  Época   9 | train_loss=0.0688 | train_acc=0.8323 | val_acc=0.7651 | val_f1m=0.7650 | LR=0.001000
  Época  10 | train_loss=0.073

Cargando train fold2: 100%|██████████| 67/67 [00:03<00:00, 17.01it/s]
Cargando val fold2: 100%|██████████| 15/15 [00:00<00:00, 17.27it/s]
Cargando test fold2: 100%|██████████| 21/21 [00:01<00:00, 17.21it/s]


[Fold 2/5] Entrenando modelo global... (n_train=2814 | n_val=630 | n_test=882)
  Época   1 | train_loss=0.1262 | train_acc=0.5448 | val_acc=0.6270 | val_f1m=0.6247 | LR=0.000125
  Época   2 | train_loss=0.1055 | train_acc=0.6795 | val_acc=0.7429 | val_f1m=0.7423 | LR=0.000250
  Época   3 | train_loss=0.0851 | train_acc=0.7758 | val_acc=0.7857 | val_f1m=0.7857 | LR=0.000375
  Época   4 | train_loss=0.0827 | train_acc=0.7772 | val_acc=0.7810 | val_f1m=0.7806 | LR=0.000500
  Época   5 | train_loss=0.0752 | train_acc=0.8063 | val_acc=0.7698 | val_f1m=0.7698 | LR=0.000625
  Época   6 | train_loss=0.0748 | train_acc=0.8109 | val_acc=0.8000 | val_f1m=0.7999 | LR=0.000750
  Época   7 | train_loss=0.0756 | train_acc=0.7950 | val_acc=0.7778 | val_f1m=0.7776 | LR=0.000875
  Época   8 | train_loss=0.0740 | train_acc=0.8042 | val_acc=0.7714 | val_f1m=0.7713 | LR=0.001000
  Época   9 | train_loss=0.0709 | train_acc=0.8227 | val_acc=0.7508 | val_f1m=0.7506 | LR=0.001000
  Época  10 | train_loss=0.071

Cargando train fold3: 100%|██████████| 67/67 [00:04<00:00, 16.46it/s]
Cargando val fold3: 100%|██████████| 15/15 [00:00<00:00, 16.79it/s]
Cargando test fold3: 100%|██████████| 21/21 [00:01<00:00, 16.80it/s]


[Fold 3/5] Entrenando modelo global... (n_train=2814 | n_val=630 | n_test=882)
  Época   1 | train_loss=0.1298 | train_acc=0.5156 | val_acc=0.6000 | val_f1m=0.5860 | LR=0.000125
  Época   2 | train_loss=0.1162 | train_acc=0.6006 | val_acc=0.7365 | val_f1m=0.7361 | LR=0.000250
  Época   3 | train_loss=0.0930 | train_acc=0.7559 | val_acc=0.7619 | val_f1m=0.7614 | LR=0.000375
  Época   4 | train_loss=0.0809 | train_acc=0.7999 | val_acc=0.7730 | val_f1m=0.7721 | LR=0.000500
  Época   5 | train_loss=0.0843 | train_acc=0.7846 | val_acc=0.7921 | val_f1m=0.7920 | LR=0.000625
  Época   6 | train_loss=0.0759 | train_acc=0.8113 | val_acc=0.8063 | val_f1m=0.8062 | LR=0.000750
  Época   7 | train_loss=0.0734 | train_acc=0.8120 | val_acc=0.7825 | val_f1m=0.7819 | LR=0.000875
  Época   8 | train_loss=0.0696 | train_acc=0.8333 | val_acc=0.8016 | val_f1m=0.8009 | LR=0.001000
  Época   9 | train_loss=0.0709 | train_acc=0.8262 | val_acc=0.7905 | val_f1m=0.7903 | LR=0.001000
  Época  10 | train_loss=0.066

Cargando train fold4: 100%|██████████| 68/68 [00:04<00:00, 16.57it/s]
Cargando val fold4: 100%|██████████| 15/15 [00:00<00:00, 16.47it/s]
Cargando test fold4: 100%|██████████| 20/20 [00:01<00:00, 16.35it/s]


[Fold 4/5] Entrenando modelo global... (n_train=2856 | n_val=630 | n_test=840)
  Época   1 | train_loss=0.1266 | train_acc=0.5284 | val_acc=0.5857 | val_f1m=0.5583 | LR=0.000125
  Época   2 | train_loss=0.1082 | train_acc=0.6569 | val_acc=0.7063 | val_f1m=0.7043 | LR=0.000250
  Época   3 | train_loss=0.0931 | train_acc=0.7486 | val_acc=0.7603 | val_f1m=0.7597 | LR=0.000375
  Época   4 | train_loss=0.0826 | train_acc=0.7927 | val_acc=0.7524 | val_f1m=0.7512 | LR=0.000500
  Época   5 | train_loss=0.0843 | train_acc=0.7777 | val_acc=0.7825 | val_f1m=0.7825 | LR=0.000625
  Época   6 | train_loss=0.0792 | train_acc=0.8029 | val_acc=0.7476 | val_f1m=0.7422 | LR=0.000750
  Época   7 | train_loss=0.0748 | train_acc=0.8141 | val_acc=0.7698 | val_f1m=0.7691 | LR=0.000875
  Época   8 | train_loss=0.0791 | train_acc=0.7969 | val_acc=0.7857 | val_f1m=0.7857 | LR=0.001000
  Época   9 | train_loss=0.0668 | train_acc=0.8309 | val_acc=0.8143 | val_f1m=0.8142 | LR=0.001000
  Época  10 | train_loss=0.072

Cargando train fold5: 100%|██████████| 68/68 [00:04<00:00, 16.52it/s]
Cargando val fold5: 100%|██████████| 15/15 [00:00<00:00, 16.80it/s]
Cargando test fold5: 100%|██████████| 20/20 [00:01<00:00, 16.85it/s]


[Fold 5/5] Entrenando modelo global... (n_train=2856 | n_val=630 | n_test=840)
  Época   1 | train_loss=0.1309 | train_acc=0.5224 | val_acc=0.5635 | val_f1m=0.5252 | LR=0.000125
  Época   2 | train_loss=0.1096 | train_acc=0.6551 | val_acc=0.6571 | val_f1m=0.6536 | LR=0.000250
  Época   3 | train_loss=0.0879 | train_acc=0.7749 | val_acc=0.7413 | val_f1m=0.7397 | LR=0.000375
  Época   4 | train_loss=0.0841 | train_acc=0.7903 | val_acc=0.7444 | val_f1m=0.7422 | LR=0.000500
  Época   5 | train_loss=0.0771 | train_acc=0.8064 | val_acc=0.7381 | val_f1m=0.7377 | LR=0.000625
  Época   6 | train_loss=0.0722 | train_acc=0.8204 | val_acc=0.7286 | val_f1m=0.7214 | LR=0.000750
  Época   7 | train_loss=0.0783 | train_acc=0.8074 | val_acc=0.7413 | val_f1m=0.7406 | LR=0.000875
  Época   8 | train_loss=0.0726 | train_acc=0.8204 | val_acc=0.7698 | val_f1m=0.7692 | LR=0.001000
  Época   9 | train_loss=0.0682 | train_acc=0.8368 | val_acc=0.7444 | val_f1m=0.7441 | LR=0.001000
  Época  10 | train_loss=0.070