In [7]:
# -*- coding: utf-8 -*-
# CNN + Transformer (PyTorch) para PhysioNet/BCI2000 MI — IMAGINERÍA (4 clases) con 8 canales.
# Incluye reproducibilidad (seed=42) y determinismo CUDA para cuBLAS/cuDNN en notebook.

# =========================
# Reproducibilidad (poner ANTES de importar torch)
# =========================
import os
os.environ['PYTHONHASHSEED'] = '42'
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'  # determinismo cuBLAS

import re, json, random
from pathlib import Path
from glob import glob
import numpy as np
import mne
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix

# =========================
# REPRODUCIBILIDAD (seed + determinismo)
# =========================
RANDOM_STATE = 42

def seed_everything(seed: int = 42):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    # determinismo cuDNN
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Evitar TF32 (puede romper determinismo en algunas GPU)
    try:
        torch.backends.cuda.matmul.allow_tf32 = False
        torch.backends.cudnn.allow_tf32 = False
    except Exception:
        pass
    # reforzar determinismo cuando sea posible
    try:
        torch.use_deterministic_algorithms(True)
    except Exception:
        pass

def seed_worker(worker_id: int):
    worker_seed = seed_everything.__defaults__[0] + worker_id if seed_everything.__defaults__ else 42 + worker_id
    np.random.seed(worker_seed)
    random.seed(worker_seed)

seed_everything(RANDOM_STATE)

# =========================
# CONFIG (edita a tu gusto)
# =========================
PROJ = Path('..').resolve().parent
DATA_RAW = PROJ / 'data' / 'raw'                     # .../S###/S###R##.edf
FOLDS_JSON = PROJ / 'models' / 'folds' / 'Kfold5.json'

FOLD = 1
EPOCHS = 40
BATCH_SIZE = 64
LR = 1e-3
RESAMPLE_HZ = None              # None para mantener original
DO_NOTCH = True                 # 60 Hz
DO_BANDPASS = False             # band-pass (4–38) desactivado por defecto
BP_LO, BP_HI = 4.0, 38.0
DO_CAR = False                   # re-referencia promedio sobre 8 canales

# Ventana temporal (seg)
TMIN, TMAX = 0.5, 4.5

EXCLUDE_SUBJECTS = {38, 88, 89, 92, 100, 104}
EXPECTED_8 = ['C3','C4','Cz','CP3','CP4','FC3','FC4','FCz']
CLASS_NAMES = ['left', 'right', 'both_fists', 'both_feet']

# Runs IMAGINERÍA
IMAGERY_RUNS_LR = {4, 8, 12}    # T1=left, T2=right
IMAGERY_RUNS_BF = {6, 10, 14}   # T1=both_fists, T2=both_feet

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🚀 Device: {DEVICE}")

# =========================
# UTILIDADES
# =========================
def normalize_ch_name(name: str) -> str:
    """Normaliza: quita no-alfanum, pasa a upper. 'Fc3.' -> 'FC3', 'Cz..' -> 'CZ'."""
    s = re.sub(r'[^A-Za-z0-9]', '', name)
    return s.upper()

NORMALIZED_TARGETS = [normalize_ch_name(c) for c in EXPECTED_8]

def pick_8_channels(raw: mne.io.BaseRaw) -> mne.io.BaseRaw:
    """Mapea canales del EDF a EXPECTED_8 aunque tengan puntos/minúsculas."""
    chs = raw.info['ch_names']
    norm_map = {normalize_ch_name(ch): ch for ch in chs}
    picked = []
    for target_norm, target_orig in zip(NORMALIZED_TARGETS, EXPECTED_8):
        if target_norm in norm_map:
            picked.append(norm_map[target_norm])
        else:
            raise RuntimeError(f"Canal requerido '{target_orig}' no encontrado. Disponibles: {chs}")
    return raw.pick(picks=picked)

def list_subject_imagery_edfs(subject_id: str) -> list:
    """Devuelve lista de EDFs IMAGERY para un sujeto S###: R04,R06,R08,R10,R12,R14."""
    subj_dir = DATA_RAW / subject_id
    edfs = []
    for r in [4, 6, 8, 10, 12, 14]:
        pattern = str(subj_dir / f"{subject_id}R{r:02d}.edf")
        edfs.extend(glob(pattern))
    return sorted(edfs)

def load_subject_epochs(subject_id: str, resample_hz: int, do_notch: bool, do_bandpass: bool,
                        do_car: bool, bp_lo: float, bp_hi: float):
    """Carga EDFs de imaginería, aplica preproc y devuelve (X[N,8,T], y[N], sfreq)."""
    edfs = list_subject_imagery_edfs(subject_id)
    if len(edfs) == 0:
        raise FileNotFoundError(f"No imagery EDF files for {subject_id} under {DATA_RAW}")

    X_list, y_list, sfreq_list = [], [], []

    for edf_path in edfs:
        m = re.search(r"R(\d{2})", Path(edf_path).name)
        run = int(m.group(1)) if m else -1

        raw = mne.io.read_raw_edf(edf_path, preload=True, verbose='ERROR')

        # Selección de 8 canales
        raw = pick_8_channels(raw)

        # --- PREPROC OPCIONAL ---
        if do_notch:
            raw.notch_filter(freqs=[60.0], picks='all', verbose='ERROR')
        if do_bandpass:
            raw.filter(l_freq=bp_lo, h_freq=bp_hi, picks='all', verbose='ERROR')
        if do_car:
            raw.set_eeg_reference('average', projection=False, verbose='ERROR')

        # Resample
        if resample_hz is not None and resample_hz > 0:
            raw.resample(resample_hz)
        sfreq = raw.info['sfreq']

        # Eventos (T0/T1/T2)
        events, event_id = mne.events_from_annotations(raw, verbose='ERROR')

        # Mantener solo T1/T2 (T0 ignorado)
        keep = {k: v for k, v in event_id.items() if k in {'T1', 'T2'}}
        if len(keep) == 0:
            continue

        epochs = mne.Epochs(raw, events=events, event_id=keep, tmin=TMIN, tmax=TMAX,
                            baseline=None, preload=True, verbose='ERROR')
        X = epochs.get_data()  # (n_epochs, 8, T)

        # Construir y según RUN
        ev_codes = epochs.events[:, 2]
        inv = {v: k for k, v in keep.items()}  # id -> 'T1'/'T2'
        y_run = []
        for code in ev_codes:
            lab = inv[code]
            if run in IMAGERY_RUNS_LR:
                y_run.append(0 if lab == 'T1' else 1)  # left/right
            elif run in IMAGERY_RUNS_BF:
                y_run.append(2 if lab == 'T1' else 3)  # both_fists/feet
            else:
                y_run.append(-1)
        y_run = np.array(y_run, dtype=int)
        keep_mask = y_run >= 0
        X = X[keep_mask]
        y = y_run[keep_mask]

        if len(y) == 0:
            continue

        X_list.append(X)
        y_list.append(y)
        sfreq_list.append(sfreq)

    if len(X_list) == 0:
        return np.empty((0, 8, 1)), np.empty((0,), dtype=int), None

    X_all = np.concatenate(X_list, axis=0)
    y_all = np.concatenate(y_list, axis=0)

    if len(set([round(s) for s in sfreq_list])) != 1:
        raise RuntimeError(f"Inconsistent sampling rates: {sfreq_list}")

    return X_all, y_all, sfreq_list[0]

def load_fold_subjects(folds_json: Path, fold: int):
    with open(folds_json, 'r') as f:
        data = json.load(f)
    for item in data.get('folds', []):
        if int(item.get('fold', -1)) == int(fold):
            train_sub = list(item.get('train', []))
            test_sub  = list(item.get('test', []))
            return train_sub, test_sub
    raise ValueError(f"Fold {fold} not found in {folds_json}")

def subject_id_to_int(s: str) -> int:
    m = re.match(r'[Ss](\d+)', s)
    return int(m.group(1)) if m else -1

def class_count_summary(y, name):
    bc = np.bincount(y, minlength=4)
    print(f"{name} counts:", dict(zip(CLASS_NAMES, bc)), "| total =", bc.sum())

# =========================
# MODELO: CNN + Transformer (con dropout)
# =========================
class DepthwiseSeparableConv(nn.Module):
    def __init__(self, in_ch, out_ch, k, s=1, p=0):
        super().__init__()
        self.dw = nn.Conv1d(in_ch, in_ch, kernel_size=k, stride=s, padding=p, groups=in_ch, bias=False)
        self.pw = nn.Conv1d(in_ch, out_ch, kernel_size=1, bias=False)
        self.bn = nn.BatchNorm1d(out_ch)
        self.act = nn.ELU()
    def forward(self, x):
        x = self.dw(x); x = self.pw(x); x = self.bn(x)
        return self.act(x)

class EEGCNNTransformer(nn.Module):
    def __init__(self, n_ch=8, n_cls=4, d_model=128, n_heads=4, n_layers=2, p_drop=0.2):
        super().__init__()
        self.conv_t = nn.Sequential(
            nn.Conv1d(n_ch, 32, kernel_size=129, stride=2, padding=64, bias=False),
            nn.BatchNorm1d(32), nn.ELU(),
            DepthwiseSeparableConv(32, 64, k=31, s=2, p=15),
            DepthwiseSeparableConv(64, 128, k=15, s=2, p=7),
        )
        self.proj = nn.Conv1d(128, d_model, kernel_size=1, bias=False)
        self.dropout = nn.Dropout(p=p_drop)
        self.pos_encoding = None
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=n_heads, dim_feedforward=2*d_model,
            batch_first=True, activation='gelu', dropout=0.1, norm_first=True
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
        self.cls = nn.Parameter(torch.zeros(1, 1, d_model))
        nn.init.normal_(self.cls, std=0.02)
        self.head = nn.Sequential(nn.LayerNorm(d_model), nn.Linear(d_model, n_cls))

    def _positional_encoding(self, L, d):
        pos = torch.arange(0, L, dtype=torch.float32).unsqueeze(1)
        i   = torch.arange(0, d, dtype=torch.float32).unsqueeze(0)
        angle = pos / torch.pow(10000, (2 * (i//2)) / d)
        pe = torch.zeros(L, d, dtype=torch.float32)
        pe[:, 0::2] = torch.sin(angle[:, 0::2])
        pe[:, 1::2] = torch.cos(angle[:, 1::2])
        return pe

    def forward(self, x):
        # x: (B, C, T)
        z = self.conv_t(x)           # (B, 128, T')
        z = self.proj(z)             # (B, d_model, T')
        z = self.dropout(z)          # regularización extra
        z = z.transpose(1, 2)        # (B, T', d_model)
        B, L, D = z.shape
        if (self.pos_encoding is None) or (self.pos_encoding.shape[0] != L) or (self.pos_encoding.shape[1] != D):
            self.pos_encoding = self._positional_encoding(L, D).to(z.device)
        z = z + self.pos_encoding[None, :, :]
        cls_tok = self.cls.expand(B, -1, -1)  # (B,1,D)
        z = torch.cat([cls_tok, z], dim=1)    # (B, 1+L, D)
        z = self.encoder(z)                   # (B, 1+L, D)
        cls = z[:, 0, :]
        return self.head(cls)

# =========================
# ENTRENAMIENTO / EVAL
# =========================
def standardize_per_channel(train_X, test_X):
    """Estandariza por canal usando SOLO estadísticas del train. Entradas (N,C,T)."""
    C = train_X.shape[1]
    train_X = train_X.astype(np.float32)
    test_X  = test_X.astype(np.float32)
    for c in range(C):
        mu = train_X[:, c, :].mean()
        sd = train_X[:, c, :].std()
        sd = sd if sd > 1e-6 else 1.0
        train_X[:, c, :] = (train_X[:, c, :] - mu) / sd
        test_X[:, c, :]  = (test_X[:, c, :] - mu) / sd
    return train_X, test_X

def train_one_fold(fold:int, resample_hz:int, do_notch:bool, do_bandpass:bool,
                   bp_lo:float, bp_hi:float, epochs:int, batch_size:int, lr:float,
                   device:str=torch.device('cuda' if torch.cuda.is_available() else 'cpu')):

    # --- sujetos fold ---
    train_sub, test_sub = load_fold_subjects(FOLDS_JSON, fold)
    train_sub = [s for s in train_sub if subject_id_to_int(s) not in EXCLUDE_SUBJECTS]
    test_sub  = [s for s in test_sub  if subject_id_to_int(s) not in EXCLUDE_SUBJECTS]

    # --- carga datos ---
    X_tr_list, y_tr_list, X_te_list, y_te_list = [], [], [], []
    sfreq = None

    for sid in tqdm(train_sub, desc=f"Cargando train fold{fold}"):
        Xs, ys, sf = load_subject_epochs(sid, resample_hz, do_notch, do_bandpass, DO_CAR, bp_lo, bp_hi)
        if len(ys) == 0: continue
        X_tr_list.append(Xs); y_tr_list.append(ys); sfreq = sf if sfreq is None else sfreq

    for sid in tqdm(test_sub, desc=f"Cargando test fold{fold}"):
        Xs, ys, sf = load_subject_epochs(sid, resample_hz, do_notch, do_bandpass, DO_CAR, bp_lo, bp_hi)
        if len(ys) == 0: continue
        X_te_list.append(Xs); y_te_list.append(ys); sfreq = sf if sfreq is None else sfreq

    if len(X_tr_list) == 0 or len(X_te_list) == 0:
        raise RuntimeError("Datos insuficientes tras carga de sujetos.")

    X_tr = np.concatenate(X_tr_list, axis=0); y_tr = np.concatenate(y_tr_list, axis=0)
    X_te = np.concatenate(X_te_list, axis=0); y_te = np.concatenate(y_te_list, axis=0)

    # --- diagnósticos de desbalance ---
    print("train counts:", dict(zip(CLASS_NAMES, np.bincount(y_tr, minlength=4))), "| total =", len(y_tr))
    print("test  counts:", dict(zip(CLASS_NAMES, np.bincount(y_te, minlength=4))), "| total =", len(y_te))

    # --- estandarización sin leakage ---
    X_tr, X_te = standardize_per_channel(X_tr, X_te)

    # --- DataLoaders (reproducibles) ---
    g = torch.Generator(device="cpu"); g.manual_seed(RANDOM_STATE)
    tr_ld = DataLoader(
        TensorDataset(torch.tensor(X_tr), torch.tensor(y_tr).long()),
        batch_size=batch_size, shuffle=True, drop_last=False,
        generator=g, worker_init_fn=seed_worker
    )
    te_ld = DataLoader(
        TensorDataset(torch.tensor(X_te), torch.tensor(y_te).long()),
        batch_size=batch_size, shuffle=False, drop_last=False,
        generator=g, worker_init_fn=seed_worker
    )

    # --- Modelo, optimizador, pérdida ---
    model = EEGCNNTransformer(n_ch=8, n_cls=4, d_model=128, n_heads=4, n_layers=2, p_drop=0.2).to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-2)
    crit = torch.nn.CrossEntropyLoss()

    best_acc, best_state = 0.0, None

    for ep in range(1, epochs+1):
        model.train()
        tr_loss, n_seen = 0.0, 0
        for xb, yb in tr_ld:
            xb = xb.to(device); yb = yb.to(device)
            opt.zero_grad()
            logits = model(xb)
            loss = crit(logits, yb)
            loss.backward()
            opt.step()
            tr_loss += loss.item() * len(yb); n_seen += len(yb)
        tr_loss /= max(1, n_seen)

        # --- Eval ---
        model.eval()
        preds, gts = [], []
        with torch.no_grad():
            for xb, yb in te_ld:
                xb = xb.to(device)
                p = model(xb).argmax(dim=1).cpu().numpy()
                preds.append(p); gts.append(yb.numpy())
        preds = np.concatenate(preds); gts = np.concatenate(gts)
        acc = accuracy_score(gts, preds)
        f1m = f1_score(gts, preds, average='macro')

        if acc > best_acc:
            best_acc = acc
            best_state = {k: v.cpu() for k, v in model.state_dict().items()}

        print(f"Epoch {ep:03d} | train_loss={tr_loss:.4f} | val_acc={acc:.4f} | val_f1m={f1m:.4f}")

    # cargar el mejor
    if best_state is not None:
        model.load_state_dict(best_state)

    # --- reporte final ---
    model.eval()
    preds, gts = [], []
    with torch.no_grad():
        for xb, yb in te_ld:
            xb = xb.to(device)
            p = model(xb).argmax(dim=1).cpu().numpy()
            preds.append(p); gts.append(yb.numpy())
    preds = np.concatenate(preds); gts = np.concatenate(gts)

    acc = accuracy_score(gts, preds)
    f1m = f1_score(gts, preds, average='macro')
    cm = confusion_matrix(gts, preds, labels=[0,1,2,3])

    print(f"\n=== RESULTADOS (fold {fold}) ===")
    print(f"Acc: {acc:.4f} | Macro-F1: {f1m:.4f}")
    print("Confusion matrix (rows=true, cols=pred):")
    print(cm)

    return acc, f1m, cm, model

# =========================
# EJECUCIÓN (Notebook)
# =========================
acc, f1m, cm, model = train_one_fold(
    fold=FOLD,
    resample_hz=RESAMPLE_HZ,
    do_notch=DO_NOTCH,
    do_bandpass=DO_BANDPASS,
    bp_lo=BP_LO,
    bp_hi=BP_HI,
    epochs=EPOCHS,
    batch_size=BATCH_SIZE,
    lr=LR,
)


🚀 Device: cuda


Cargando train fold1: 100%|██████████| 82/82 [00:12<00:00,  6.56it/s]
Cargando test fold1: 100%|██████████| 21/21 [00:03<00:00,  6.55it/s]


train counts: {'left': np.int64(1767), 'right': np.int64(1748), 'both_fists': np.int64(1752), 'both_feet': np.int64(1761)} | total = 7028
test  counts: {'left': np.int64(448), 'right': np.int64(446), 'both_fists': np.int64(442), 'both_feet': np.int64(450)} | total = 1786




Epoch 001 | train_loss=1.3710 | val_acc=0.3533 | val_f1m=0.3243
Epoch 002 | train_loss=1.2879 | val_acc=0.3897 | val_f1m=0.3910
Epoch 003 | train_loss=1.2595 | val_acc=0.3763 | val_f1m=0.3440
Epoch 004 | train_loss=1.2336 | val_acc=0.3757 | val_f1m=0.3722
Epoch 005 | train_loss=1.2297 | val_acc=0.4149 | val_f1m=0.4140
Epoch 006 | train_loss=1.2192 | val_acc=0.4009 | val_f1m=0.3996
Epoch 007 | train_loss=1.2103 | val_acc=0.3992 | val_f1m=0.3996
Epoch 008 | train_loss=1.2072 | val_acc=0.3852 | val_f1m=0.3780
Epoch 009 | train_loss=1.1913 | val_acc=0.4065 | val_f1m=0.4018
Epoch 010 | train_loss=1.1832 | val_acc=0.3852 | val_f1m=0.3849
Epoch 011 | train_loss=1.1808 | val_acc=0.4071 | val_f1m=0.4083
Epoch 012 | train_loss=1.1699 | val_acc=0.3852 | val_f1m=0.3858
Epoch 013 | train_loss=1.1525 | val_acc=0.3824 | val_f1m=0.3767
Epoch 014 | train_loss=1.1547 | val_acc=0.3886 | val_f1m=0.3849
Epoch 015 | train_loss=1.1468 | val_acc=0.3757 | val_f1m=0.3682
Epoch 016 | train_loss=1.1294 | val_acc=

In [21]:
# -*- coding: utf-8 -*-
# CNN + Transformer (PyTorch) para PhysioNet/BCI2000 MI — IMAGINERÍA (4 clases) con 8 canales.
# Incluye: class-weights + label smoothing, warmup+cosine+early stopping, sampler balanceado,
# augmentations ligeras, TTA estilo EEGNet y reproducibilidad fuerte (seed=42).
# Ventana: 6 s (TMIN=-1.0, TMAX=5.0). Z-score por época activable con ZSCORE_PER_EPOCH.
# Imprime métricas por fold (classification_report) y un runner de 5 folds tipo EEGNet.

# =========================
# Reproducibilidad (poner ANTES de importar torch)
# =========================
import os
os.environ['PYTHONHASHSEED'] = '42'
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'  # determinismo cuBLAS

import re, json, random
from pathlib import Path
from glob import glob
import numpy as np
import mne
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset, WeightedRandomSampler
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, classification_report

# =========================
# REPRODUCIBILIDAD (seed + determinismo)
# =========================
RANDOM_STATE = 42

def seed_everything(seed: int = 42):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    # determinismo cuDNN
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # evitar TF32 (puede romper determinismo)
    try:
        torch.backends.cuda.matmul.allow_tf32 = False
        torch.backends.cudnn.allow_tf32 = False
    except Exception:
        pass
    # reforzar determinismo
    try:
        torch.use_deterministic_algorithms(True)
    except Exception:
        pass

def seed_worker(worker_id: int):
    worker_seed = RANDOM_STATE + worker_id
    np.random.seed(worker_seed)
    random.seed(worker_seed)

seed_everything(RANDOM_STATE)

# =========================
# CONFIG (edita a tu gusto)
# =========================
PROJ = Path('..').resolve().parent
DATA_RAW = PROJ / 'data' / 'raw'                     # .../S###/S###R##.edf
FOLDS_JSON = PROJ / 'models' / 'folds' / 'Kfold5.json'

FOLD = 5
EPOCHS = 60
BATCH_SIZE = 64
LR = 1e-3
RESAMPLE_HZ = None          # None: mantiene 160 Hz original
DO_NOTCH = True             # ganador
DO_BANDPASS = False         # ganador
BP_LO, BP_HI = 4.0, 38.0
DO_CAR = False              # ganador

# Normalización
ZSCORE_PER_EPOCH = False     # <- activa/desactiva z-score por época (como EEGNet)

# Hiperparámetros del modelo
D_MODEL = 128              # (si quieres probar más capacidad: 160/192)
N_HEADS = 4
N_LAYERS = 2
P_DROP = 0.2               # sube a 0.3 si ves sobreajuste

# Ventana temporal (seg) — igual a EEGNet (6 s)
TMIN, TMAX = -1.0, 5.0

# Sliding windows / TTA en evaluación
SW_MODE = 'tta'   # 'none' | 'subwin' | 'tta'
SW_ENABLE = True  # si quieres apagar todo, pon False

# Para TTA de desplazamientos cortos ±25 ms (EEGNet)
TTA_SHIFTS_S = [-0.025, -0.0125, 0.0, 0.0125, 0.025]

# (solo si usas 'subwin'; no se usa en 'tta' EEGNet)
SW_LEN   = 2.0
SW_STRIDE = 0.5

EXCLUDE_SUBJECTS = {38, 88, 89, 92, 100, 104}
EXPECTED_8 = ['C3','C4','Cz','CP3','CP4','FC3','FC4','FCz']
CLASS_NAMES = ['left', 'right', 'both_fists', 'both_feet']

# Runs IMAGINERÍA
IMAGERY_RUNS_LR = {4, 8, 12}    # T1=left, T2=right
IMAGERY_RUNS_BF = {6, 10, 14}   # T1=both_fists, T2=both_feet

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🚀 Device: {DEVICE}")

# =========================
# UTILIDADES
# =========================
def normalize_ch_name(name: str) -> str:
    s = re.sub(r'[^A-Za-z0-9]', '', name)
    return s.upper()

NORMALIZED_TARGETS = [normalize_ch_name(c) for c in EXPECTED_8]

def pick_8_channels(raw: mne.io.BaseRaw) -> mne.io.BaseRaw:
    chs = raw.info['ch_names']
    norm_map = {normalize_ch_name(ch): ch for ch in chs}
    picked = []
    for target_norm, target_orig in zip(NORMALIZED_TARGETS, EXPECTED_8):
        if target_norm in norm_map:
            picked.append(norm_map[target_norm])
        else:
            raise RuntimeError(f"Canal requerido '{target_orig}' no encontrado. Disponibles: {chs}")
    return raw.pick(picks=picked)

def list_subject_imagery_edfs(subject_id: str) -> list:
    subj_dir = DATA_RAW / subject_id
    edfs = []
    for r in [4, 6, 8, 10, 12, 14]:
        pattern = str(subj_dir / f"{subject_id}R{r:02d}.edf")
        edfs.extend(glob(pattern))
    return sorted(edfs)

def load_subject_epochs(subject_id: str, resample_hz: int, do_notch: bool, do_bandpass: bool,
                        do_car: bool, bp_lo: float, bp_hi: float):
    edfs = list_subject_imagery_edfs(subject_id)
    if len(edfs) == 0:
        raise FileNotFoundError(f"No imagery EDF files for {subject_id} under {DATA_RAW}")

    X_list, y_list, sfreq_list = [], [], []

    for edf_path in edfs:
        m = re.search(r"R(\d{2})", Path(edf_path).name)
        run = int(m.group(1)) if m else -1

        raw = mne.io.read_raw_edf(edf_path, preload=True, verbose='ERROR')

        # Selección de 8 canales
        raw = pick_8_channels(raw)

        # --- PREPROC OPCIONAL ---
        if do_notch:
            raw.notch_filter(freqs=[60.0], picks='all', verbose='ERROR')
        if do_bandpass:
            raw.filter(l_freq=bp_lo, h_freq=bp_hi, picks='all', verbose='ERROR')
        if do_car:
            raw.set_eeg_reference('average', projection=False, verbose='ERROR')

        # Resample
        if resample_hz is not None and resample_hz > 0:
            raw.resample(resample_hz)
        sfreq = raw.info['sfreq']

        # Eventos (T0/T1/T2)
        events, event_id = mne.events_from_annotations(raw, verbose='ERROR')

        # Mantener solo T1/T2
        keep = {k: v for k, v in event_id.items() if k in {'T1', 'T2'}}
        if len(keep) == 0:
            continue

        epochs = mne.Epochs(raw, events=events, event_id=keep, tmin=TMIN, tmax=TMAX,
                            baseline=None, preload=True, verbose='ERROR')
        X = epochs.get_data()  # (n_epochs, 8, T)

        # ---- Z-SCORE POR ÉPOCA (opcional, como EEGNet) ----
        if ZSCORE_PER_EPOCH:
            X = X.astype(np.float32)
            eps = 1e-6
            mu = X.mean(axis=2, keepdims=True)              # (N,8,1)
            sd = X.std(axis=2, keepdims=True) + eps         # (N,8,1)
            X = (X - mu) / sd

        # Construir y según RUN
        ev_codes = epochs.events[:, 2]
        inv = {v: k for k, v in keep.items()}  # id -> 'T1'/'T2'
        y_run = []
        for code in ev_codes:
            lab = inv[code]
            if run in IMAGERY_RUNS_LR:
                y_run.append(0 if lab == 'T1' else 1)  # left/right
            elif run in IMAGERY_RUNS_BF:
                y_run.append(2 if lab == 'T1' else 3)  # both_fists/feet
            else:
                y_run.append(-1)
        y_run = np.array(y_run, dtype=int)
        keep_mask = y_run >= 0
        X = X[keep_mask]
        y = y_run[keep_mask]

        if len(y) == 0:
            continue

        X_list.append(X)
        y_list.append(y)
        sfreq_list.append(sfreq)

    if len(X_list) == 0:
        return np.empty((0, 8, 1)), np.empty((0,), dtype=int), None

    X_all = np.concatenate(X_list, axis=0)
    y_all = np.concatenate(y_list, axis=0)

    if len(set([round(s) for s in sfreq_list])) != 1:
        raise RuntimeError(f"Inconsistent sampling rates: {sfreq_list}")

    return X_all, y_all, sfreq_list[0]

def load_fold_subjects(folds_json: Path, fold: int):
    with open(folds_json, 'r') as f:
        data = json.load(f)
    for item in data.get('folds', []):
        if int(item.get('fold', -1)) == int(fold):
            train_sub = list(item.get('train', []))
            test_sub  = list(item.get('test', []))
            return train_sub, test_sub
    raise ValueError(f"Fold {fold} not found in {folds_json}")

def subject_id_to_int(s: str) -> int:
    m = re.match(r'[Ss](\d+)', s)
    return int(m.group(1)) if m else -1

def class_count_summary(y, name):
    bc = np.bincount(y, minlength=4)
    print(f"{name} counts:", dict(zip(CLASS_NAMES, bc)), "| total =", bc.sum())

# =========================
# MODELO: CNN + Transformer
# =========================
class DepthwiseSeparableConv(nn.Module):
    def __init__(self, in_ch, out_ch, k, s=1, p=0):
        super().__init__()
        self.dw = nn.Conv1d(in_ch, in_ch, kernel_size=k, stride=s, padding=p, groups=in_ch, bias=False)
        self.pw = nn.Conv1d(in_ch, out_ch, kernel_size=1, bias=False)
        self.bn = nn.BatchNorm1d(out_ch)
        self.act = nn.ELU()
    def forward(self, x):
        x = self.dw(x); x = self.pw(x); x = self.bn(x)
        return self.act(x)

class EEGCNNTransformer(nn.Module):
    def __init__(self, n_ch=8, n_cls=4, d_model=128, n_heads=4, n_layers=2, p_drop=0.2):
        super().__init__()
        self.conv_t = nn.Sequential(
            nn.Conv1d(n_ch, 32, kernel_size=129, stride=2, padding=64, bias=False),
            nn.BatchNorm1d(32), nn.ELU(),
            DepthwiseSeparableConv(32, 64, k=31, s=2, p=15),
            DepthwiseSeparableConv(64, 128, k=15, s=2, p=7),
        )
        self.proj = nn.Conv1d(128, d_model, kernel_size=1, bias=False)
        self.dropout = nn.Dropout(p=p_drop)
        self.pos_encoding = None
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=n_heads, dim_feedforward=2*d_model,
            batch_first=True, activation='gelu', dropout=0.1, norm_first=True
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
        self.cls = nn.Parameter(torch.zeros(1, 1, d_model))
        nn.init.normal_(self.cls, std=0.02)
        self.head = nn.Sequential(nn.LayerNorm(d_model), nn.Linear(d_model, n_cls))

    def _positional_encoding(self, L, d):
        pos = torch.arange(0, L, dtype=torch.float32).unsqueeze(1)
        i   = torch.arange(0, d, dtype=torch.float32).unsqueeze(0)
        angle = pos / torch.pow(10000, (2 * (i//2)) / d)
        pe = torch.zeros(L, d, dtype=torch.float32)
        pe[:, 0::2] = torch.sin(angle[:, 0::2])
        pe[:, 1::2] = torch.cos(angle[:, 1::2])
        return pe

    def forward(self, x):
        z = self.conv_t(x)           # (B, 128, T')
        z = self.proj(z)             # (B, d_model, T')
        z = self.dropout(z)
        z = z.transpose(1, 2)        # (B, T', d_model)
        B, L, D = z.shape
        if (self.pos_encoding is None) or (self.pos_encoding.shape[0] != L) or (self.pos_encoding.shape[1] != D):
            self.pos_encoding = self._positional_encoding(L, D).to(z.device)
        z = z + self.pos_encoding[None, :, :]
        cls_tok = self.cls.expand(B, -1, -1)  # (B,1,D)
        z = torch.cat([cls_tok, z], dim=1)    # (B, 1+L, D)
        z = self.encoder(z)                   # (B, 1+L, D)
        cls = z[:, 0, :]
        return self.head(cls)

# =========================
# ENTRENAMIENTO / EVAL
# =========================
def standardize_per_channel(train_X, test_X):
    C = train_X.shape[1]
    train_X = train_X.astype(np.float32)
    test_X  = test_X.astype(np.float32)
    for c in range(C):
        mu = train_X[:, c, :].mean()
        sd = train_X[:, c, :].std()
        sd = sd if sd > 1e-6 else 1.0
        train_X[:, c, :] = (train_X[:, c, :] - mu) / sd
        test_X[:, c, :]  = (test_X[:, c, :] - mu) / sd
    return train_X, test_X

# --- Augmentations ligeras ---
def augment_batch(xb, p_jitter=0.25, p_noise=0.25, p_chdrop=0.15,
                  max_jitter_frac=0.02, noise_std=0.02):
    """
    xb: torch.Tensor (B,C,T) (se asume ya normalizado)
    - jitter temporal (roll)
    - ruido gaussiano
    - channel dropout (apagar 1 canal aleatorio)
    """
    B, C, T = xb.shape
    if np.random.rand() < p_jitter:
        max_shift = int(max(1, T*max_jitter_frac))
        shifts = torch.randint(low=-max_shift, high=max_shift+1, size=(B,), device=xb.device)
        for i in range(B):
            xb[i] = torch.roll(xb[i], shifts=int(shifts[i].item()), dims=-1)
    if np.random.rand() < p_noise:
        xb = xb + noise_std*torch.randn_like(xb)
    if np.random.rand() < p_chdrop:
        k = 1
        for i in range(B):
            idx = torch.randperm(C, device=xb.device)[:k]
            xb[i, idx, :] = 0.0
    return xb

def subwindow_logits(model, X, sfreq, sw_len, sw_stride, device):
    """
    Promedia logits sobre sub-ventanas de longitud sw_len (seg).
    Entrada:  X (N,C,T) estandarizado
    Salida:   (N,4)
    """
    model.eval()
    wl = int(round(sw_len * sfreq))
    st = int(round(sw_stride * sfreq))
    wl = max(1, min(wl, X.shape[-1]))
    st = max(1, st)

    out = []
    with torch.no_grad():
        for i in range(X.shape[0]):
            x = X[i]  # (C,T)
            acc = []
            for s in range(0, max(1, X.shape[-1]-wl+1), st):
                seg = x[:, s:s+wl]
                if seg.shape[-1] < wl:
                    pad = wl - seg.shape[-1]
                    seg = np.pad(seg, ((0,0),(0,pad)), mode='edge')
                xb = torch.tensor(seg[None, ...], dtype=torch.float32, device=device)
                logit = model(xb).detach().cpu().numpy()[0]  # (4,)
                acc.append(logit)
            acc = np.mean(np.stack(acc, axis=0), axis=0) if len(acc) else np.zeros(4, dtype=np.float32)
            out.append(acc)
    return np.stack(out, axis=0)  # (N,4)

def time_shift_tta_logits(model, X, sfreq, shifts_s, device):
    """
    Test-Time Augmentation (EEGNet-like): 5 desplazamientos cortos ±25ms manteniendo longitud T.
    Entrada:  X (N,C,T) estandarizado
    Salida:   (N,4)
    """
    model.eval()
    T = X.shape[-1]
    out = []
    with torch.no_grad():
        for i in range(X.shape[0]):
            x0 = X[i]  # (C,T)
            acc = []
            for sh in shifts_s:
                shift = int(round(sh * sfreq))
                if shift == 0:
                    x = x0
                elif shift > 0:
                    x = np.pad(x0[:, shift:], ((0,0),(0,shift)), mode='edge')[:, :T]
                else:
                    shift = -shift
                    x = np.pad(x0[:, :-shift], ((0,0),(shift,0)), mode='edge')[:, :T]
                xb = torch.tensor(x[None, ...], dtype=torch.float32, device=device)
                logit = model(xb).detach().cpu().numpy()[0]  # (4,)
                acc.append(logit)
            acc = np.mean(np.stack(acc, axis=0), axis=0)
            out.append(acc)
    return np.stack(out, axis=0)  # (N,4)

def load_fold_subject_ids_and_counts(fold:int):
    train_sub, test_sub = load_fold_subjects(FOLDS_JSON, fold)
    train_sub = [s for s in train_sub if subject_id_to_int(s) not in EXCLUDE_SUBJECTS]
    test_sub  = [s for s in test_sub  if subject_id_to_int(s) not in EXCLUDE_SUBJECTS]
    return train_sub, test_sub

def train_one_fold(fold:int, resample_hz:int, do_notch:bool, do_bandpass:bool,
                   bp_lo:float, bp_hi:float, epochs:int, batch_size:int, lr:float,
                   device:str=torch.device('cuda' if torch.cuda.is_available() else 'cpu')):

    # --- sujetos fold ---
    train_sub, test_sub = load_fold_subject_ids_and_counts(fold)

    # --- carga datos ---
    X_tr_list, y_tr_list, X_te_list, y_te_list = [], [], [], []
    sfreq = None

    for sid in tqdm(train_sub, desc=f"Cargando train fold{fold}"):
        Xs, ys, sf = load_subject_epochs(sid, resample_hz, do_notch, do_bandpass, DO_CAR, bp_lo, bp_hi)
        if len(ys) == 0: continue
        X_tr_list.append(Xs); y_tr_list.append(ys); sfreq = sf if sfreq is None else sfreq

    for sid in tqdm(test_sub, desc=f"Cargando test fold{fold}"):
        Xs, ys, sf = load_subject_epochs(sid, resample_hz, do_notch, do_bandpass, DO_CAR, bp_lo, bp_hi)
        if len(ys) == 0: continue
        X_te_list.append(Xs); y_te_list.append(ys); sfreq = sf if sfreq is None else sfreq

    if len(X_tr_list) == 0 or len(X_te_list) == 0:
        raise RuntimeError("Datos insuficientes tras carga de sujetos.")

    X_tr = np.concatenate(X_tr_list, axis=0); y_tr = np.concatenate(y_tr_list, axis=0)
    X_te = np.concatenate(X_te_list, axis=0); y_te = np.concatenate(y_te_list, axis=0)

    # --- diagnósticos de desbalance ---
    # (y cabecera al estilo EEGNet con n_train / n_test)
    print(f"[Fold {fold}/5] Entrenando modelo global... (n_train={len(y_tr)} | n_test={len(y_te)})")
    print("train counts:", dict(zip(CLASS_NAMES, np.bincount(y_tr, minlength=4))), "| total =", len(y_tr))
    print("test  counts:", dict(zip(CLASS_NAMES, np.bincount(y_te, minlength=4))), "| total =", len(y_te))

    # --- Normalización ---
    if ZSCORE_PER_EPOCH:
        # Ya se normalizó por época dentro de load_subject_epochs → NO hacer z-score global
        X_tr_std, X_te_std = X_tr, X_te
    else:
        # Sin z-score por época → aplicar z-score global por canal usando SOLO train
        X_tr_std, X_te_std = standardize_per_channel(X_tr, X_te)

    # --- DataLoaders (sampler balanceado + reproducible) ---
    train_ds = TensorDataset(torch.tensor(X_tr_std), torch.tensor(y_tr).long())
    class_counts = np.bincount(y_tr, minlength=4)
    class_weights_vec = class_counts.sum() / (4.0 * np.maximum(class_counts, 1))
    sample_weights = class_weights_vec[train_ds.tensors[1].numpy()]
    sampler = WeightedRandomSampler(
        weights=torch.tensor(sample_weights, dtype=torch.float32),
        num_samples=len(train_ds), replacement=True,
        generator=torch.Generator().manual_seed(RANDOM_STATE)
    )
    tr_ld = DataLoader(train_ds, batch_size=batch_size, sampler=sampler,
                       drop_last=False, worker_init_fn=seed_worker)

    te_ld = DataLoader(
        TensorDataset(torch.tensor(X_te_std), torch.tensor(y_te).long()),
        batch_size=batch_size, shuffle=False, drop_last=False,
        worker_init_fn=seed_worker
    )

    # --- Modelo ---
    model = EEGCNNTransformer(
        n_ch=8, n_cls=4, d_model=D_MODEL, n_heads=N_HEADS, n_layers=N_LAYERS, p_drop=P_DROP
    ).to(device)

    # --- Optimizador ---
    opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-2)

    # (1) Class-weights + label smoothing
    class_weights = torch.tensor(class_weights_vec, dtype=torch.float32, device=device)
    crit = torch.nn.CrossEntropyLoss(weight=class_weights, label_smoothing=0.05)

    # (2) Warmup + Cosine LR + Early stopping por macro-F1
    from torch.optim.lr_scheduler import LambdaLR
    total_epochs = epochs
    warmup_epochs = max(1, min(5, epochs//10))
    def lr_lambda(current_epoch):
        if current_epoch < warmup_epochs:
            return (current_epoch + 1) / warmup_epochs
        progress = (current_epoch - warmup_epochs) / max(1, (total_epochs - warmup_epochs))
        progress = min(1.0, max(0.0, progress))
        return 0.5 * (1.0 + np.cos(np.pi * progress))
    scheduler = LambdaLR(opt, lr_lambda=lr_lambda)

    best_f1, best_state, patience, wait = 0.0, None, 10, 0

    for ep in range(1, epochs+1):
        model.train()
        tr_loss, n_seen = 0.0, 0
        for xb, yb in tr_ld:
            # augs en GPU (se asume input ya normalizado)
            xb = xb.to(device)
            xb = augment_batch(xb)
            yb = yb.to(device)

            opt.zero_grad()
            logits = model(xb)
            loss = crit(logits, yb)
            loss.backward()
            opt.step()
            tr_loss += loss.item() * len(yb); n_seen += len(yb)
        tr_loss /= max(1, n_seen)

        # --- Eval en test set (mismo patrón que tu script actual) ---
        model.eval()
        preds, gts = [], []
        with torch.no_grad():
            for xb, yb in te_ld:
                xb = xb.to(device)
                p = model(xb).argmax(dim=1).cpu().numpy()
                preds.append(p); gts.append(yb.numpy())
        preds = np.concatenate(preds); gts = np.concatenate(gts)
        acc = accuracy_score(gts, preds)
        f1m = f1_score(gts, preds, average='macro')

        improved = f1m > best_f1 + 1e-4
        if improved:
            best_f1 = f1m
            best_state = {k: v.cpu() for k, v in model.state_dict().items()}
            wait = 0
        else:
            wait += 1

        scheduler.step()
        # Log estilo EEGNet (usa acc/f1m del split de test como "val_*" porque no tenemos val separado)
        print(f"  Época {ep:3d} | train_loss={tr_loss:.4f} | val_acc={acc:.4f} | val_f1m={f1m:.4f} | LR={scheduler.get_last_lr()[0]:.6f}")

        if wait >= patience:
            print(f"  Early stopping en época {ep} (mejor val_f1m={best_f1:.4f})")
            break

    # cargar el mejor
    if best_state is not None:
        model.load_state_dict(best_state)

    # --- reporte final con TTA/subwin/none ---
    model.eval()
    sfreq_used = RESAMPLE_HZ if RESAMPLE_HZ is not None else int(round(X_te_std.shape[-1] / (TMAX - TMIN)))

    if (not SW_ENABLE) or SW_MODE == 'none':
        preds, gts = [], []
        with torch.no_grad():
            for xb, yb in te_ld:
                xb = xb.to(device)
                p = model(xb).argmax(dim=1).cpu().numpy()
                preds.append(p); gts.append(yb.numpy())
        preds = np.concatenate(preds); gts = np.concatenate(gts)

    elif SW_MODE == 'subwin':
        logits = subwindow_logits(model, X_te_std, sfreq_used, SW_LEN, SW_STRIDE, device)  # (N,4)
        preds = logits.argmax(axis=1)
        gts = y_te

    elif SW_MODE == 'tta':
        logits = time_shift_tta_logits(model, X_te_std, sfreq_used, TTA_SHIFTS_S, device)  # (N,4)
        preds = logits.argmax(axis=1)
        gts = y_te

    else:
        raise ValueError(f"SW_MODE desconocido: {SW_MODE}")

    acc = accuracy_score(gts, preds)
    f1m = f1_score(gts, preds, average='macro')
    cm = confusion_matrix(gts, preds, labels=[0,1,2,3])

    # ---- Salida estilo EEGNet ----
    print(f"[Fold {fold}/5] Global acc={acc:.4f}")
    print("\n" + classification_report(gts, preds, target_names=[
        'Left','Right','Both Fists','Both Feet'
    ], digits=4))

    print(f"=== RESULTADOS (fold {fold}) ===")
    print(f"Acc: {acc:.4f} | Macro-F1: {f1m:.4f}")
    print("Confusion matrix (rows=true, cols=pred):")
    print(cm)

    return acc, f1m, cm, model

# =========================
# Runner multi-fold con resumen (igual estilo EEGNet)
# =========================
def run_all_folds(n_folds=5):
    accs, f1s = [], []
    for f in range(1, n_folds+1):
        print("\n" + "="*12 + f"  FOLD {f}  " + "="*12)
        acc, f1m, cm, _ = train_one_fold(
            fold=f,
            resample_hz=RESAMPLE_HZ,
            do_notch=DO_NOTCH,
            do_bandpass=DO_BANDPASS,
            bp_lo=BP_LO,
            bp_hi=BP_HI,
            epochs=EPOCHS,
            batch_size=BATCH_SIZE,
            lr=LR,
        )
        accs.append(acc); f1s.append(f1m)
    if accs:
        print("\n" + "="*60)
        print("RESULTADOS FINALES (Transformer)")
        print("="*60)
        print("Acc folds:", [f"{a:.4f}" for a in accs])
        print("F1m folds:", [f"{a:.4f}" for a in f1s])
        print(f"Acc mean: {np.mean(accs):.4f} | std: {np.std(accs):.4f}")
        print(f"F1m mean: {np.mean(f1s):.4f} | std: {np.std(f1s):.4f}")

# =========================
# EJECUCIÓN
# =========================
if __name__ == "__main__":
    # Opción A: correr todos los folds 1..5 con reporte por fold y resumen
    run_all_folds(n_folds=5)

    # Opción B: correr un solo fold
    # acc, f1m, cm, model = train_one_fold(
    #     fold=FOLD,
    #     resample_hz=RESAMPLE_HZ,
    #     do_notch=DO_NOTCH,
    #     do_bandpass=DO_BANDPASS,
    #     bp_lo=BP_LO,
    #     bp_hi=BP_HI,
    #     epochs=EPOCHS,
    #     batch_size=BATCH_SIZE,
    #     lr=LR,
    # )


🚀 Device: cuda



Cargando train fold1: 100%|██████████| 82/82 [00:09<00:00,  8.35it/s]
Cargando test fold1: 100%|██████████| 21/21 [00:02<00:00,  8.35it/s]


[Fold 1/5] Entrenando modelo global... (n_train=6888 | n_test=1764)
train counts: {'left': np.int64(1722), 'right': np.int64(1722), 'both_fists': np.int64(1722), 'both_feet': np.int64(1722)} | total = 6888
test  counts: {'left': np.int64(441), 'right': np.int64(441), 'both_fists': np.int64(441), 'both_feet': np.int64(441)} | total = 1764




  Época   1 | train_loss=1.3483 | val_acc=0.3656 | val_f1m=0.3498 | LR=0.000400
  Época   2 | train_loss=1.2134 | val_acc=0.4427 | val_f1m=0.4359 | LR=0.000600
  Época   3 | train_loss=1.1752 | val_acc=0.4598 | val_f1m=0.4615 | LR=0.000800
  Época   4 | train_loss=1.1564 | val_acc=0.4495 | val_f1m=0.4522 | LR=0.001000
  Época   5 | train_loss=1.1524 | val_acc=0.4507 | val_f1m=0.4365 | LR=0.001000
  Época   6 | train_loss=1.1156 | val_acc=0.4717 | val_f1m=0.4685 | LR=0.000999
  Época   7 | train_loss=1.1040 | val_acc=0.4745 | val_f1m=0.4572 | LR=0.000997
  Época   8 | train_loss=1.1060 | val_acc=0.4705 | val_f1m=0.4688 | LR=0.000993
  Época   9 | train_loss=1.1008 | val_acc=0.4422 | val_f1m=0.4248 | LR=0.000987
  Época  10 | train_loss=1.0739 | val_acc=0.4586 | val_f1m=0.4385 | LR=0.000980
  Época  11 | train_loss=1.0749 | val_acc=0.4830 | val_f1m=0.4818 | LR=0.000971
  Época  12 | train_loss=1.0540 | val_acc=0.4615 | val_f1m=0.4579 | LR=0.000961
  Época  13 | train_loss=1.0441 | val_ac

Cargando train fold2: 100%|██████████| 82/82 [00:09<00:00,  8.31it/s]
Cargando test fold2: 100%|██████████| 21/21 [00:02<00:00,  8.19it/s]


[Fold 2/5] Entrenando modelo global... (n_train=6888 | n_test=1764)
train counts: {'left': np.int64(1722), 'right': np.int64(1722), 'both_fists': np.int64(1722), 'both_feet': np.int64(1722)} | total = 6888
test  counts: {'left': np.int64(441), 'right': np.int64(441), 'both_fists': np.int64(441), 'both_feet': np.int64(441)} | total = 1764




  Época   1 | train_loss=1.3581 | val_acc=0.4286 | val_f1m=0.4178 | LR=0.000400
  Época   2 | train_loss=1.2361 | val_acc=0.4864 | val_f1m=0.4735 | LR=0.000600
  Época   3 | train_loss=1.2014 | val_acc=0.5249 | val_f1m=0.5194 | LR=0.000800
  Época   4 | train_loss=1.1905 | val_acc=0.5283 | val_f1m=0.5318 | LR=0.001000
  Época   5 | train_loss=1.1873 | val_acc=0.5340 | val_f1m=0.5350 | LR=0.001000
  Época   6 | train_loss=1.1505 | val_acc=0.5368 | val_f1m=0.5404 | LR=0.000999
  Época   7 | train_loss=1.1486 | val_acc=0.5459 | val_f1m=0.5367 | LR=0.000997
  Época   8 | train_loss=1.1518 | val_acc=0.5193 | val_f1m=0.4985 | LR=0.000993
  Época   9 | train_loss=1.1339 | val_acc=0.5329 | val_f1m=0.5064 | LR=0.000987
  Época  10 | train_loss=1.1026 | val_acc=0.5482 | val_f1m=0.5315 | LR=0.000980
  Época  11 | train_loss=1.1050 | val_acc=0.5465 | val_f1m=0.5389 | LR=0.000971
  Época  12 | train_loss=1.0913 | val_acc=0.5527 | val_f1m=0.5528 | LR=0.000961
  Época  13 | train_loss=1.0917 | val_ac

Cargando train fold3: 100%|██████████| 82/82 [00:09<00:00,  8.30it/s]
Cargando test fold3: 100%|██████████| 21/21 [00:02<00:00,  8.33it/s]


[Fold 3/5] Entrenando modelo global... (n_train=6888 | n_test=1764)
train counts: {'left': np.int64(1722), 'right': np.int64(1722), 'both_fists': np.int64(1722), 'both_feet': np.int64(1722)} | total = 6888
test  counts: {'left': np.int64(441), 'right': np.int64(441), 'both_fists': np.int64(441), 'both_feet': np.int64(441)} | total = 1764




  Época   1 | train_loss=1.3863 | val_acc=0.3254 | val_f1m=0.2574 | LR=0.000400
  Época   2 | train_loss=1.2550 | val_acc=0.4677 | val_f1m=0.4658 | LR=0.000600
  Época   3 | train_loss=1.1853 | val_acc=0.4773 | val_f1m=0.4771 | LR=0.000800
  Época   4 | train_loss=1.1739 | val_acc=0.4575 | val_f1m=0.4612 | LR=0.001000
  Época   5 | train_loss=1.1491 | val_acc=0.4864 | val_f1m=0.4779 | LR=0.001000
  Época   6 | train_loss=1.1223 | val_acc=0.4881 | val_f1m=0.4881 | LR=0.000999
  Época   7 | train_loss=1.1163 | val_acc=0.4683 | val_f1m=0.4501 | LR=0.000997
  Época   8 | train_loss=1.1067 | val_acc=0.4688 | val_f1m=0.4621 | LR=0.000993
  Época   9 | train_loss=1.0928 | val_acc=0.4813 | val_f1m=0.4665 | LR=0.000987
  Época  10 | train_loss=1.0775 | val_acc=0.4813 | val_f1m=0.4804 | LR=0.000980
  Época  11 | train_loss=1.0863 | val_acc=0.4717 | val_f1m=0.4595 | LR=0.000971
  Época  12 | train_loss=1.0623 | val_acc=0.4586 | val_f1m=0.4511 | LR=0.000961
  Época  13 | train_loss=1.0616 | val_ac

Cargando train fold4: 100%|██████████| 83/83 [00:10<00:00,  8.27it/s]
Cargando test fold4: 100%|██████████| 20/20 [00:02<00:00,  8.26it/s]


[Fold 4/5] Entrenando modelo global... (n_train=6972 | n_test=1680)
train counts: {'left': np.int64(1743), 'right': np.int64(1743), 'both_fists': np.int64(1743), 'both_feet': np.int64(1743)} | total = 6972
test  counts: {'left': np.int64(420), 'right': np.int64(420), 'both_fists': np.int64(420), 'both_feet': np.int64(420)} | total = 1680




  Época   1 | train_loss=1.3840 | val_acc=0.4030 | val_f1m=0.4010 | LR=0.000400
  Época   2 | train_loss=1.2280 | val_acc=0.4929 | val_f1m=0.4864 | LR=0.000600
  Época   3 | train_loss=1.1767 | val_acc=0.4726 | val_f1m=0.4500 | LR=0.000800
  Época   4 | train_loss=1.1629 | val_acc=0.4994 | val_f1m=0.4880 | LR=0.001000
  Época   5 | train_loss=1.1573 | val_acc=0.5018 | val_f1m=0.4922 | LR=0.001000
  Época   6 | train_loss=1.1326 | val_acc=0.5018 | val_f1m=0.4944 | LR=0.000999
  Época   7 | train_loss=1.1316 | val_acc=0.5071 | val_f1m=0.5104 | LR=0.000997
  Época   8 | train_loss=1.1127 | val_acc=0.5131 | val_f1m=0.5082 | LR=0.000993
  Época   9 | train_loss=1.1066 | val_acc=0.5119 | val_f1m=0.5065 | LR=0.000987
  Época  10 | train_loss=1.0880 | val_acc=0.4940 | val_f1m=0.4961 | LR=0.000980
  Época  11 | train_loss=1.0813 | val_acc=0.5095 | val_f1m=0.5083 | LR=0.000971
  Época  12 | train_loss=1.0809 | val_acc=0.5244 | val_f1m=0.5223 | LR=0.000961
  Época  13 | train_loss=1.0706 | val_ac

Cargando train fold5: 100%|██████████| 83/83 [00:09<00:00,  8.40it/s]
Cargando test fold5: 100%|██████████| 20/20 [00:02<00:00,  8.30it/s]


[Fold 5/5] Entrenando modelo global... (n_train=6972 | n_test=1680)
train counts: {'left': np.int64(1743), 'right': np.int64(1743), 'both_fists': np.int64(1743), 'both_feet': np.int64(1743)} | total = 6972
test  counts: {'left': np.int64(420), 'right': np.int64(420), 'both_fists': np.int64(420), 'both_feet': np.int64(420)} | total = 1680




  Época   1 | train_loss=1.3911 | val_acc=0.3518 | val_f1m=0.3082 | LR=0.000400
  Época   2 | train_loss=1.2738 | val_acc=0.5071 | val_f1m=0.5038 | LR=0.000600
  Época   3 | train_loss=1.1954 | val_acc=0.5208 | val_f1m=0.5012 | LR=0.000800
  Época   4 | train_loss=1.1863 | val_acc=0.5393 | val_f1m=0.5232 | LR=0.001000
  Época   5 | train_loss=1.1657 | val_acc=0.5232 | val_f1m=0.5089 | LR=0.001000
  Época   6 | train_loss=1.1465 | val_acc=0.5399 | val_f1m=0.5326 | LR=0.000999
  Época   7 | train_loss=1.1450 | val_acc=0.5054 | val_f1m=0.4860 | LR=0.000997
  Época   8 | train_loss=1.1193 | val_acc=0.5214 | val_f1m=0.5132 | LR=0.000993
  Época   9 | train_loss=1.1186 | val_acc=0.5488 | val_f1m=0.5431 | LR=0.000987
  Época  10 | train_loss=1.1034 | val_acc=0.5333 | val_f1m=0.5356 | LR=0.000980
  Época  11 | train_loss=1.0964 | val_acc=0.5339 | val_f1m=0.5147 | LR=0.000971
  Época  12 | train_loss=1.0925 | val_acc=0.5423 | val_f1m=0.5327 | LR=0.000961
  Época  13 | train_loss=1.0821 | val_ac

In [23]:
# -*- coding: utf-8 -*-
# CNN + Transformer (PyTorch) para PhysioNet/BCI2000 MI — IMAGINERÍA (4 clases) con 8 canales.
# Eval inter-sujeto con Kfold5.json (igual que EEGNet):
# - Split por sujetos (train/test) desde JSON.
# - Dentro de train: GroupShuffleSplit 85/15 por sujetos para validación.
# - Logs por época: train_acc, val_acc, val_f1m, LR + Early stopping (val_f1m).
# - Reporte por fold: classification_report + confusion_matrix + curva de entrenamiento.
# - Al final: resumen con accuracies por fold y media.
# Ventana: 6 s (TMIN=-1.0, TMAX=5.0). Z-score por época activable con ZSCORE_PER_EPOCH.

# =========================
# Reproducibilidad (poner ANTES de importar torch)
# =========================
import os
os.environ['PYTHONHASHSEED'] = '42'
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'  # determinismo cuBLAS

import re, json, random
from pathlib import Path
from glob import glob
import numpy as np
import mne
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset, WeightedRandomSampler
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, classification_report
from sklearn.model_selection import GroupShuffleSplit

import matplotlib.pyplot as plt

# =========================
# REPRODUCIBILIDAD (seed + determinismo)
# =========================
RANDOM_STATE = 42

def seed_everything(seed: int = 42):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    # determinismo cuDNN
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # evitar TF32
    try:
        torch.backends.cuda.matmul.allow_tf32 = False
        torch.backends.cudnn.allow_tf32 = False
    except Exception:
        pass
    try:
        torch.use_deterministic_algorithms(True)
    except Exception:
        pass

def seed_worker(worker_id: int):
    worker_seed = RANDOM_STATE + worker_id
    np.random.seed(worker_seed)
    random.seed(worker_seed)

seed_everything(RANDOM_STATE)

# =========================
# CONFIG
# =========================
PROJ = Path('..').resolve().parent
DATA_RAW = PROJ / 'data' / 'raw'                     # .../S###/S###R##.edf
FOLDS_JSON = PROJ / 'models' / 'folds' / 'Kfold5.json'

# Entrenamiento
EPOCHS = 60
BATCH_SIZE = 64
LR = 1e-3

# Preproc
RESAMPLE_HZ = None          # None: 160 Hz original
DO_NOTCH = True
DO_BANDPASS = False
BP_LO, BP_HI = 4.0, 38.0
DO_CAR = False

# Normalización
ZSCORE_PER_EPOCH = False     # <- activa/desactiva z-score por época (como EEGNet)

# Modelo
D_MODEL = 128
N_HEADS = 4
N_LAYERS = 2
P_DROP = 0.2

# Ventana temporal (seg) — como EEGNet (6 s)
TMIN, TMAX = -1.0, 5.0

# TTA evaluación (estilo EEGNet: 5 shifts ±25ms)
SW_MODE = 'tta'   # 'none' | 'subwin' | 'tta'
SW_ENABLE = True
TTA_SHIFTS_S = [-0.025, -0.0125, 0.0, 0.0125, 0.025]

# (solo si usas 'subwin'; no se usa en 'tta')
SW_LEN   = 2.0
SW_STRIDE = 0.5

# Constantes dataset
EXCLUDE_SUBJECTS = {38, 88, 89, 92, 100, 104}
EXPECTED_8 = ['C3','C4','Cz','CP3','CP4','FC3','FC4','FCz']
CLASS_NAMES = ['left', 'right', 'both_fists', 'both_feet']

# Runs IMAGINERÍA
IMAGERY_RUNS_LR = {4, 8, 12}    # T1=left, T2=right
IMAGERY_RUNS_BF = {6, 10, 14}   # T1=both_fists, T2=both_feet

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🚀 Usando dispositivo: {DEVICE}")

# =========================
# UTILIDADES
# =========================
def normalize_ch_name(name: str) -> str:
    s = re.sub(r'[^A-Za-z0-9]', '', name)
    return s.upper()

NORMALIZED_TARGETS = [normalize_ch_name(c) for c in EXPECTED_8]

def pick_8_channels(raw: mne.io.BaseRaw) -> mne.io.BaseRaw:
    chs = raw.info['ch_names']
    norm_map = {normalize_ch_name(ch): ch for ch in chs}
    picked = []
    for target_norm, target_orig in zip(NORMALIZED_TARGETS, EXPECTED_8):
        if target_norm in norm_map:
            picked.append(norm_map[target_norm])
        else:
            raise RuntimeError(f"Canal requerido '{target_orig}' no encontrado. Disponibles: {chs}")
    return raw.pick(picks=picked)

def list_subject_imagery_edfs(subject_id: str) -> list:
    subj_dir = DATA_RAW / subject_id
    edfs = []
    for r in [4, 6, 8, 10, 12, 14]:
        pattern = str(subj_dir / f"{subject_id}R{r:02d}.edf")
        edfs.extend(glob(pattern))
    return sorted(edfs)

def load_subject_epochs(subject_id: str, resample_hz: int, do_notch: bool, do_bandpass: bool,
                        do_car: bool, bp_lo: float, bp_hi: float):
    edfs = list_subject_imagery_edfs(subject_id)
    if len(edfs) == 0:
        raise FileNotFoundError(f"No imagery EDF files for {subject_id} under {DATA_RAW}")

    X_list, y_list, sfreq_list = [], [], []

    for edf_path in edfs:
        m = re.search(r"R(\d{2})", Path(edf_path).name)
        run = int(m.group(1)) if m else -1

        raw = mne.io.read_raw_edf(edf_path, preload=True, verbose='ERROR')

        # Selección de 8 canales
        raw = pick_8_channels(raw)

        # --- PREPROC OPCIONAL ---
        if do_notch:
            raw.notch_filter(freqs=[60.0], picks='all', verbose='ERROR')
        if do_bandpass:
            raw.filter(l_freq=bp_lo, h_freq=bp_hi, picks='all', verbose='ERROR')
        if do_car:
            raw.set_eeg_reference('average', projection=False, verbose='ERROR')

        # Resample
        if resample_hz is not None and resample_hz > 0:
            raw.resample(resample_hz)
        sfreq = raw.info['sfreq']

        # Eventos (T0/T1/T2)
        events, event_id = mne.events_from_annotations(raw, verbose='ERROR')

        # Mantener solo T1/T2
        keep = {k: v for k, v in event_id.items() if k in {'T1', 'T2'}}
        if len(keep) == 0:
            continue

        epochs = mne.Epochs(raw, events=events, event_id=keep, tmin=TMIN, tmax=TMAX,
                            baseline=None, preload=True, verbose='ERROR')
        X = epochs.get_data()  # (n_epochs, 8, T)

        # ---- Z-SCORE POR ÉPOCA (opcional, como EEGNet) ----
        if ZSCORE_PER_EPOCH:
            X = X.astype(np.float32)
            eps = 1e-6
            mu = X.mean(axis=2, keepdims=True)              # (N,8,1)
            sd = X.std(axis=2, keepdims=True) + eps         # (N,8,1)
            X = (X - mu) / sd

        # Construir y según RUN
        ev_codes = epochs.events[:, 2]
        inv = {v: k for k, v in keep.items()}  # id -> 'T1'/'T2'
        y_run = []
        for code in ev_codes:
            lab = inv[code]
            if run in IMAGERY_RUNS_LR:
                y_run.append(0 if lab == 'T1' else 1)  # left/right
            elif run in IMAGERY_RUNS_BF:
                y_run.append(2 if lab == 'T1' else 3)  # both_fists/feet
            else:
                y_run.append(-1)
        y_run = np.array(y_run, dtype=int)
        keep_mask = y_run >= 0
        X = X[keep_mask]
        y = y_run[keep_mask]

        if len(y) == 0:
            continue

        X_list.append(X)
        y_list.append(y)
        sfreq_list.append(sfreq)

    if len(X_list) == 0:
        return np.empty((0, 8, 1)), np.empty((0,), dtype=int), None

    X_all = np.concatenate(X_list, axis=0)
    y_all = np.concatenate(y_list, axis=0)

    if len(set([round(s) for s in sfreq_list])) != 1:
        raise RuntimeError(f"Inconsistent sampling rates: {sfreq_list}")

    return X_all, y_all, sfreq_list[0]

def load_fold_subjects(folds_json: Path, fold: int):
    with open(folds_json, 'r') as f:
        data = json.load(f)
    for item in data.get('folds', []):
        if int(item.get('fold', -1)) == int(fold):
            train_sub = list(item.get('train', []))
            test_sub  = list(item.get('test', []))
            return train_sub, test_sub
    raise ValueError(f"Fold {fold} not found in {folds_json}")

def subject_id_to_int(s: str) -> int:
    m = re.match(r'[Ss](\d+)', s)
    return int(m.group(1)) if m else -1

def standardize_fit(X_train):
    """Devuelve μ/σ por canal (C,) calculados sobre (N,C,T) de train."""
    C = X_train.shape[1]
    mu = np.zeros(C, dtype=np.float32)
    sd = np.ones(C, dtype=np.float32)
    for c in range(C):
        mu[c] = X_train[:, c, :].mean()
        sd_c = X_train[:, c, :].std()
        sd[c] = sd_c if sd_c > 1e-6 else 1.0
    return mu, sd

def standardize_apply(X, mu, sd):
    X = X.astype(np.float32)
    for c in range(X.shape[1]):
        X[:, c, :] = (X[:, c, :] - mu[c]) / sd[c]
    return X

def plot_training_curves(history, fname):
    try:
        plt.figure(figsize=(6,4))
        plt.plot(history['train_acc'], label='train_acc')
        plt.plot(history['val_acc'], label='val_acc')
        plt.xlabel('Epoch'); plt.ylabel('Accuracy'); plt.title('Training curve'); plt.legend()
        plt.tight_layout(); plt.savefig(fname, dpi=150); plt.close()
        print(f"↳ Curva de entrenamiento guardada: {fname}")
    except Exception as e:
        print(f"[plot] No se pudo guardar {fname}: {e}")

# =========================
# MODELO: CNN + Transformer
# =========================
class DepthwiseSeparableConv(nn.Module):
    def __init__(self, in_ch, out_ch, k, s=1, p=0):
        super().__init__()
        self.dw = nn.Conv1d(in_ch, in_ch, kernel_size=k, stride=s, padding=p, groups=in_ch, bias=False)
        self.pw = nn.Conv1d(in_ch, out_ch, kernel_size=1, bias=False)
        self.bn = nn.BatchNorm1d(out_ch)
        self.act = nn.ELU()
    def forward(self, x):
        x = self.dw(x); x = self.pw(x); x = self.bn(x)
        return self.act(x)

class EEGCNNTransformer(nn.Module):
    def __init__(self, n_ch=8, n_cls=4, d_model=128, n_heads=4, n_layers=2, p_drop=0.2):
        super().__init__()
        self.conv_t = nn.Sequential(
            nn.Conv1d(n_ch, 32, kernel_size=129, stride=2, padding=64, bias=False),
            nn.BatchNorm1d(32), nn.ELU(),
            DepthwiseSeparableConv(32, 64, k=31, s=2, p=15),
            DepthwiseSeparableConv(64, 128, k=15, s=2, p=7),
        )
        self.proj = nn.Conv1d(128, d_model, kernel_size=1, bias=False)
        self.dropout = nn.Dropout(p=p_drop)
        self.pos_encoding = None
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=n_heads, dim_feedforward=2*d_model,
            batch_first=True, activation='gelu', dropout=0.1, norm_first=True
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
        self.cls = nn.Parameter(torch.zeros(1, 1, d_model))
        nn.init.normal_(self.cls, std=0.02)
        self.head = nn.Sequential(nn.LayerNorm(d_model), nn.Linear(d_model, n_cls))

    def _positional_encoding(self, L, d):
        pos = torch.arange(0, L, dtype=torch.float32).unsqueeze(1)
        i   = torch.arange(0, d, dtype=torch.float32).unsqueeze(0)
        angle = pos / torch.pow(10000, (2 * (i//2)) / d)
        pe = torch.zeros(L, d, dtype=torch.float32)
        pe[:, 0::2] = torch.sin(angle[:, 0::2])
        pe[:, 1::2] = torch.cos(angle[:, 1::2])
        return pe

    def forward(self, x):
        z = self.conv_t(x)           # (B, 128, T')
        z = self.proj(z)             # (B, d_model, T')
        z = self.dropout(z)
        z = z.transpose(1, 2)        # (B, T', d_model)
        B, L, D = z.shape
        if (self.pos_encoding is None) or (self.pos_encoding.shape[0] != L) or (self.pos_encoding.shape[1] != D):
            self.pos_encoding = self._positional_encoding(L, D).to(z.device)
        z = z + self.pos_encoding[None, :, :]
        cls_tok = self.cls.expand(B, -1, -1)  # (B,1,D)
        z = torch.cat([cls_tok, z], dim=1)    # (B, 1+L, D)
        z = self.encoder(z)                   # (B, 1+L, D)
        cls = z[:, 0, :]
        return self.head(cls)

# =========================
# AUGS + EVAL HELPERS
# =========================
def augment_batch(xb, p_jitter=0.25, p_noise=0.25, p_chdrop=0.15,
                  max_jitter_frac=0.02, noise_std=0.02):
    """
    xb: torch.Tensor (B,C,T) (se asume ya normalizado)
    - jitter temporal (roll)
    - ruido gaussiano
    - channel dropout (apagar 1 canal aleatorio)
    """
    B, C, T = xb.shape
    if np.random.rand() < p_jitter:
        max_shift = int(max(1, T*max_jitter_frac))
        shifts = torch.randint(low=-max_shift, high=max_shift+1, size=(B,), device=xb.device)
        for i in range(B):
            xb[i] = torch.roll(xb[i], shifts=int(shifts[i].item()), dims=-1)
    if np.random.rand() < p_noise:
        xb = xb + noise_std*torch.randn_like(xb)
    if np.random.rand() < p_chdrop:
        k = 1
        for i in range(B):
            idx = torch.randperm(C, device=xb.device)[:k]
            xb[i, idx, :] = 0.0
    return xb

def subwindow_logits(model, X, sfreq, sw_len, sw_stride, device):
    model.eval()
    wl = int(round(sw_len * sfreq))
    st = int(round(sw_stride * sfreq))
    wl = max(1, min(wl, X.shape[-1]))
    st = max(1, st)

    out = []
    with torch.no_grad():
        for i in range(X.shape[0]):
            x = X[i]  # (C,T)
            acc = []
            for s in range(0, max(1, X.shape[-1]-wl+1), st):
                seg = x[:, s:s+wl]
                if seg.shape[-1] < wl:
                    pad = wl - seg.shape[-1]
                    seg = np.pad(seg, ((0,0),(0,pad)), mode='edge')
                xb = torch.tensor(seg[None, ...], dtype=torch.float32, device=device)
                logit = model(xb).detach().cpu().numpy()[0]
                acc.append(logit)
            acc = np.mean(np.stack(acc, axis=0), axis=0) if len(acc) else np.zeros(4, dtype=np.float32)
            out.append(acc)
    return np.stack(out, axis=0)

def time_shift_tta_logits(model, X, sfreq, shifts_s, device):
    """EEGNet-like TTA: 5 desplazamientos cortos ±25ms manteniendo longitud T."""
    model.eval()
    T = X.shape[-1]
    out = []
    with torch.no_grad():
        for i in range(X.shape[0]):
            x0 = X[i]  # (C,T)
            acc = []
            for sh in shifts_s:
                shift = int(round(sh * sfreq))
                if shift == 0:
                    x = x0
                elif shift > 0:
                    x = np.pad(x0[:, shift:], ((0,0),(0,shift)), mode='edge')[:, :T]
                else:
                    shift = -shift
                    x = np.pad(x0[:, :-shift], ((0,0),(shift,0)), mode='edge')[:, :T]
                xb = torch.tensor(x[None, ...], dtype=torch.float32, device=device)
                logit = model(xb).detach().cpu().numpy()[0]
                acc.append(logit)
            acc = np.mean(np.stack(acc, axis=0), axis=0)
            out.append(acc)
    return np.stack(out, axis=0)

# =========================
# ENTRENAMIENTO / EVAL (un fold)
# =========================
def train_one_fold(fold:int, resample_hz:int, do_notch:bool, do_bandpass:bool,
                   bp_lo:float, bp_hi:float, epochs:int, batch_size:int, lr:float,
                   device:str=torch.device('cuda' if torch.cuda.is_available() else 'cpu')):

    # --- sujetos fold (del JSON) ---
    train_sub, test_sub = load_fold_subjects(FOLDS_JSON, fold)
    train_sub = [s for s in train_sub if subject_id_to_int(s) not in EXCLUDE_SUBJECTS]
    test_sub  = [s for s in test_sub  if subject_id_to_int(s) not in EXCLUDE_SUBJECTS]

    # --- carga datos por sujeto (sin limitar muestras) ---
    X_tr_list, y_tr_list, X_te_list, y_te_list = [], [], [], []
    sfreq = None

    # también guardamos el tamaño por sujeto para construir groups_tr
    train_groups_counts = []  # lista de (sid_int, count)

    for sid in tqdm(train_sub, desc=f"Cargando train fold{fold}"):
        Xs, ys, sf = load_subject_epochs(sid, resample_hz, do_notch, do_bandpass, DO_CAR, bp_lo, bp_hi)
        if len(ys) == 0:
            continue
        X_tr_list.append(Xs); y_tr_list.append(ys); sfreq = sf if sfreq is None else sfreq
        train_groups_counts.append((subject_id_to_int(sid), len(ys)))

    for sid in tqdm(test_sub, desc=f"Cargando test fold{fold}"):
        Xs, ys, sf = load_subject_epochs(sid, resample_hz, do_notch, do_bandpass, DO_CAR, bp_lo, bp_hi)
        if len(ys) == 0:
            continue
        X_te_list.append(Xs); y_te_list.append(ys); sfreq = sf if sfreq is None else sfreq

    if len(X_tr_list) == 0 or len(X_te_list) == 0:
        raise RuntimeError("Datos insuficientes tras carga de sujetos.")

    X_tr = np.concatenate(X_tr_list, axis=0); y_tr = np.concatenate(y_tr_list, axis=0)
    X_te = np.concatenate(X_te_list, axis=0); y_te = np.concatenate(y_te_list, axis=0)

    # --- construir vector groups_tr (id de sujeto por época de train) ---
    groups_tr = []
    for sid_int, cnt in train_groups_counts:
        groups_tr.extend([sid_int]*cnt)
    groups_tr = np.array(groups_tr[:len(y_tr)], dtype=int)

    # --- split de validación por sujetos (85/15) ---
    gss = GroupShuffleSplit(n_splits=1, test_size=0.15, random_state=RANDOM_STATE)
    (tr_idx, va_idx), = gss.split(X_tr, y_tr, groups=groups_tr)
    X_tr_tr, y_tr_tr = X_tr[tr_idx], y_tr[tr_idx]
    X_tr_va, y_tr_va = X_tr[va_idx], y_tr[va_idx]

    n_train, n_val, n_test = len(y_tr_tr), len(y_tr_va), len(y_te)
    print(f"[Fold {fold}/5] Entrenando modelo global... (n_train={n_train} | n_val={n_val} | n_test={n_test})")

    # --- Normalización sin leakage ---
    if ZSCORE_PER_EPOCH:
        Xtr_std, Xva_std, Xte_std = X_tr_tr, X_tr_va, X_te
    else:
        mu, sd = standardize_fit(X_tr_tr)                 # μ/σ SOLO con train_tr
        Xtr_std = standardize_apply(X_tr_tr, mu, sd)
        Xva_std = standardize_apply(X_tr_va, mu, sd)
        Xte_std = standardize_apply(X_te,    mu, sd)

    # --- DataLoaders ---
    # sampler balanceado sobre y_tr_tr
    class_counts = np.bincount(y_tr_tr, minlength=4)
    class_weights_vec = class_counts.sum() / (4.0 * np.maximum(class_counts, 1))
    sample_weights = class_weights_vec[y_tr_tr]

    tr_ds = TensorDataset(torch.tensor(Xtr_std), torch.tensor(y_tr_tr).long())
    va_ds = TensorDataset(torch.tensor(Xva_std), torch.tensor(y_tr_va).long())
    te_ds = TensorDataset(torch.tensor(Xte_std), torch.tensor(y_te).long())

    sampler = WeightedRandomSampler(
        weights=torch.tensor(sample_weights, dtype=torch.float32),
        num_samples=len(tr_ds), replacement=True,
        generator=torch.Generator().manual_seed(RANDOM_STATE)
    )
    tr_ld = DataLoader(tr_ds, batch_size=batch_size, sampler=sampler,
                       drop_last=False, worker_init_fn=seed_worker)

    # loaders sin augs para evaluar
    tr_eval_ld = DataLoader(tr_ds, batch_size=batch_size, shuffle=False, drop_last=False, worker_init_fn=seed_worker)
    va_ld      = DataLoader(va_ds, batch_size=batch_size, shuffle=False, drop_last=False, worker_init_fn=seed_worker)
    te_ld      = DataLoader(te_ds, batch_size=batch_size, shuffle=False, drop_last=False, worker_init_fn=seed_worker)

    # --- Modelo ---
    model = EEGCNNTransformer(
        n_ch=8, n_cls=4, d_model=D_MODEL, n_heads=N_HEADS, n_layers=N_LAYERS, p_drop=P_DROP
    ).to(device)

    # --- Optimizador ---
    opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-2)

    # (1) Class-weights + label smoothing (basados en y_tr_tr)
    class_weights = torch.tensor(class_weights_vec, dtype=torch.float32, device=device)
    crit = torch.nn.CrossEntropyLoss(weight=class_weights, label_smoothing=0.05)

    # (2) Warmup + Cosine LR + Early stopping por val_f1m
    from torch.optim.lr_scheduler import LambdaLR
    total_epochs = epochs
    warmup_epochs = max(1, min(5, epochs//10))
    def lr_lambda(current_epoch):
        if current_epoch < warmup_epochs:
            return (current_epoch + 1) / warmup_epochs
        progress = (current_epoch - warmup_epochs) / max(1, (total_epochs - warmup_epochs))
        progress = min(1.0, max(0.0, progress))
        return 0.5 * (1.0 + np.cos(np.pi * progress))
    scheduler = LambdaLR(opt, lr_lambda=lr_lambda)

    best_f1, best_state, patience, wait = 0.0, None, 10, 0

    # ---- helpers de evaluación rápida (sin TTA) ----
    def _eval(dl):
        model.eval(); preds, gts = [], []
        with torch.no_grad():
            for xb, yb in dl:
                xb = xb.to(device)
                p = model(xb).argmax(dim=1).cpu().numpy()
                preds.append(p); gts.append(yb.numpy())
        preds = np.concatenate(preds); gts = np.concatenate(gts)
        acc = accuracy_score(gts, preds)
        f1m = f1_score(gts, preds, average='macro')
        return acc, f1m

    # historia para curva
    history = {'train_acc': [], 'val_acc': []}

    # =========================
    # LOOP DE ENTRENAMIENTO
    # =========================
    for ep in range(1, epochs+1):
        model.train()
        tr_loss, n_seen = 0.0, 0
        for xb, yb in tr_ld:
            xb = xb.to(device)
            xb = augment_batch(xb)  # augs SOLO en train
            yb = yb.to(device)

            opt.zero_grad()
            logits = model(xb)
            loss = crit(logits, yb)
            loss.backward()
            opt.step()
            tr_loss += loss.item() * len(yb); n_seen += len(yb)
        tr_loss /= max(1, n_seen)

        # --- métricas en train (sin augs) y val ---
        train_acc, train_f1 = _eval(tr_eval_ld)
        val_acc,   val_f1   = _eval(va_ld)
        history['train_acc'].append(train_acc)
        history['val_acc'].append(val_acc)

        improved = val_f1 > best_f1 + 1e-4
        if improved:
            best_f1 = val_f1
            best_state = {k: v.cpu() for k, v in model.state_dict().items()}
            wait = 0
        else:
            wait += 1

        scheduler.step()
        print(f"  Época {ep:3d} | train_loss={tr_loss:.4f} | train_acc={train_acc:.4f} | "
              f"val_acc={val_acc:.4f} | val_f1m={val_f1:.4f} | LR={scheduler.get_last_lr()[0]:.6f}")

        if wait >= patience:
            print(f"  Early stopping en época {ep} (mejor val_f1m={best_f1:.4f})")
            break

    # cargar el mejor
    if best_state is not None:
        model.load_state_dict(best_state)

    # guardar curva
    plot_training_curves(history, f"training_curve_fold{fold}.png")

    # =========================
    # EVALUACIÓN FINAL EN TEST (con TTA como EEGNet)
    # =========================
    model.eval()
    sfreq_used = RESAMPLE_HZ if RESAMPLE_HZ is not None else int(round(Xte_std.shape[-1] / (TMAX - TMIN)))

    if (not SW_ENABLE) or SW_MODE == 'none':
        preds, gts = [], []
        with torch.no_grad():
            for xb, yb in te_ld:
                xb = xb.to(device)
                p = model(xb).argmax(dim=1).cpu().numpy()
                preds.append(p); gts.append(yb.numpy())
        preds = np.concatenate(preds); gts = np.concatenate(gts)

    elif SW_MODE == 'subwin':
        logits = subwindow_logits(model, Xte_std, sfreq_used, SW_LEN, SW_STRIDE, device)
        preds = logits.argmax(axis=1)
        gts = y_te

    elif SW_MODE == 'tta':
        logits = time_shift_tta_logits(model, Xte_std, sfreq_used, TTA_SHIFTS_S, device)
        preds = logits.argmax(axis=1)
        gts = y_te

    else:
        raise ValueError(f"SW_MODE desconocido: {SW_MODE}")

    acc = accuracy_score(gts, preds)
    f1m = f1_score(gts, preds, average='macro')
    cm = confusion_matrix(gts, preds, labels=[0,1,2,3])

    print(f"[Fold {fold}/5] Global acc={acc:.4f}\n")
    print(classification_report(gts, preds, target_names=CLASS_NAMES, digits=4))
    print("Confusion matrix (rows=true, cols=pred):")
    print(cm)

    return acc, f1m, cm, model

# =========================
# RUNNER: TODOS LOS FOLDS del JSON
# =========================
def run_all_folds():
    # leer orden de folds desde el JSON
    with open(FOLDS_JSON, 'r') as f:
        payload = json.load(f)
    folds_info = payload.get('folds', [])
    fold_ids = [int(x.get('fold', -1)) for x in folds_info]
    fold_ids = [f for f in fold_ids if f > 0]
    if not fold_ids:
        raise ValueError("No se hallaron folds válidos en el JSON.")

    global_accs = []
    for i, fold in enumerate(sorted(fold_ids), start=1):
        acc, f1m, cm, model = train_one_fold(
            fold=fold,
            resample_hz=RESAMPLE_HZ,
            do_notch=DO_NOTCH,
            do_bandpass=DO_BANDPASS,
            bp_lo=BP_LO,
            bp_hi=BP_HI,
            epochs=EPOCHS,
            batch_size=BATCH_SIZE,
            lr=LR,
        )
        global_accs.append(acc)

    # resumen final
    print("\n" + "="*60)
    print("RESULTADOS FINALES")
    print("="*60)
    print("Global folds:", [f"{a:.4f}" for a in global_accs])
    if len(global_accs) > 0:
        print(f"Global mean: {np.mean(global_accs):.4f}")

# =========================
# MAIN
# =========================
if __name__ == "__main__":
    print("🧠 INICIANDO EXPERIMENTO CON CNN+Transformer (K-Fold por sujeto como EEGNet)")
    print(f"🔧 Configuración: 4c, 8 canales, 6s | EPOCHS={EPOCHS}, BATCH={BATCH_SIZE}, LR={LR} | ZSCORE_PER_EPOCH={ZSCORE_PER_EPOCH}")
    run_all_folds()


🚀 Usando dispositivo: cuda
🧠 INICIANDO EXPERIMENTO CON CNN+Transformer (K-Fold por sujeto como EEGNet)
🔧 Configuración: 4c, 8 canales, 6s | EPOCHS=60, BATCH=64, LR=0.001 | ZSCORE_PER_EPOCH=False


Cargando train fold1: 100%|██████████| 82/82 [00:09<00:00,  8.46it/s]
Cargando test fold1: 100%|██████████| 21/21 [00:02<00:00,  8.51it/s]


[Fold 1/5] Entrenando modelo global... (n_train=5796 | n_val=1092 | n_test=1764)




  Época   1 | train_loss=1.3768 | train_acc=0.4282 | val_acc=0.4020 | val_f1m=0.3849 | LR=0.000400
  Época   2 | train_loss=1.2056 | train_acc=0.5064 | val_acc=0.4478 | val_f1m=0.4141 | LR=0.000600
  Época   3 | train_loss=1.1660 | train_acc=0.5342 | val_acc=0.4744 | val_f1m=0.4706 | LR=0.000800
  Época   4 | train_loss=1.1217 | train_acc=0.5012 | val_acc=0.4597 | val_f1m=0.4452 | LR=0.001000
  Época   5 | train_loss=1.1386 | train_acc=0.5192 | val_acc=0.4689 | val_f1m=0.4490 | LR=0.001000
  Época   6 | train_loss=1.1240 | train_acc=0.5730 | val_acc=0.4890 | val_f1m=0.4799 | LR=0.000999
  Época   7 | train_loss=1.1121 | train_acc=0.5818 | val_acc=0.4844 | val_f1m=0.4819 | LR=0.000997
  Época   8 | train_loss=1.0740 | train_acc=0.5945 | val_acc=0.5064 | val_f1m=0.5076 | LR=0.000993
  Época   9 | train_loss=1.0799 | train_acc=0.5849 | val_acc=0.4808 | val_f1m=0.4749 | LR=0.000987
  Época  10 | train_loss=1.0823 | train_acc=0.5952 | val_acc=0.4918 | val_f1m=0.4852 | LR=0.000980
  Época  1

Cargando train fold2: 100%|██████████| 82/82 [00:09<00:00,  8.56it/s]
Cargando test fold2: 100%|██████████| 21/21 [00:02<00:00,  8.55it/s]


[Fold 2/5] Entrenando modelo global... (n_train=5796 | n_val=1092 | n_test=1764)




  Época   1 | train_loss=1.3762 | train_acc=0.3913 | val_acc=0.3956 | val_f1m=0.3899 | LR=0.000400
  Época   2 | train_loss=1.2558 | train_acc=0.4898 | val_acc=0.4579 | val_f1m=0.4376 | LR=0.000600
  Época   3 | train_loss=1.2206 | train_acc=0.4891 | val_acc=0.4652 | val_f1m=0.4489 | LR=0.000800
  Época   4 | train_loss=1.1632 | train_acc=0.5219 | val_acc=0.4872 | val_f1m=0.4793 | LR=0.001000
  Época   5 | train_loss=1.1626 | train_acc=0.5252 | val_acc=0.4753 | val_f1m=0.4661 | LR=0.001000
  Época   6 | train_loss=1.1770 | train_acc=0.5090 | val_acc=0.4615 | val_f1m=0.4608 | LR=0.000999
  Época   7 | train_loss=1.1624 | train_acc=0.5235 | val_acc=0.4661 | val_f1m=0.4519 | LR=0.000997
  Época   8 | train_loss=1.1202 | train_acc=0.5361 | val_acc=0.5064 | val_f1m=0.5107 | LR=0.000993
  Época   9 | train_loss=1.1177 | train_acc=0.5286 | val_acc=0.4679 | val_f1m=0.4461 | LR=0.000987
  Época  10 | train_loss=1.1140 | train_acc=0.5599 | val_acc=0.4771 | val_f1m=0.4579 | LR=0.000980
  Época  1

Cargando train fold3: 100%|██████████| 82/82 [00:09<00:00,  8.56it/s]
Cargando test fold3: 100%|██████████| 21/21 [00:02<00:00,  8.54it/s]


[Fold 3/5] Entrenando modelo global... (n_train=5796 | n_val=1092 | n_test=1764)




  Época   1 | train_loss=1.3924 | train_acc=0.3357 | val_acc=0.2921 | val_f1m=0.2554 | LR=0.000400
  Época   2 | train_loss=1.2548 | train_acc=0.4947 | val_acc=0.4826 | val_f1m=0.4500 | LR=0.000600
  Época   3 | train_loss=1.1909 | train_acc=0.5188 | val_acc=0.5092 | val_f1m=0.4922 | LR=0.000800
  Época   4 | train_loss=1.1453 | train_acc=0.5504 | val_acc=0.5302 | val_f1m=0.5215 | LR=0.001000
  Época   5 | train_loss=1.1365 | train_acc=0.5488 | val_acc=0.5385 | val_f1m=0.5284 | LR=0.001000
  Época   6 | train_loss=1.1531 | train_acc=0.5324 | val_acc=0.5174 | val_f1m=0.5181 | LR=0.000999
  Época   7 | train_loss=1.1202 | train_acc=0.5450 | val_acc=0.5128 | val_f1m=0.4961 | LR=0.000997
  Época   8 | train_loss=1.0959 | train_acc=0.5666 | val_acc=0.5421 | val_f1m=0.5465 | LR=0.000993
  Época   9 | train_loss=1.0916 | train_acc=0.5663 | val_acc=0.5266 | val_f1m=0.5108 | LR=0.000987
  Época  10 | train_loss=1.0746 | train_acc=0.5393 | val_acc=0.5238 | val_f1m=0.4958 | LR=0.000980
  Época  1

Cargando train fold4: 100%|██████████| 83/83 [00:10<00:00,  8.19it/s]
Cargando test fold4: 100%|██████████| 20/20 [00:02<00:00,  8.48it/s]


[Fold 4/5] Entrenando modelo global... (n_train=5880 | n_val=1092 | n_test=1680)




  Época   1 | train_loss=1.3715 | train_acc=0.4145 | val_acc=0.3810 | val_f1m=0.3803 | LR=0.000400
  Época   2 | train_loss=1.2499 | train_acc=0.4844 | val_acc=0.4725 | val_f1m=0.4507 | LR=0.000600
  Época   3 | train_loss=1.1861 | train_acc=0.5214 | val_acc=0.4963 | val_f1m=0.4885 | LR=0.000800
  Época   4 | train_loss=1.1565 | train_acc=0.5367 | val_acc=0.4927 | val_f1m=0.4787 | LR=0.001000
  Época   5 | train_loss=1.1603 | train_acc=0.5335 | val_acc=0.4789 | val_f1m=0.4707 | LR=0.001000
  Época   6 | train_loss=1.1359 | train_acc=0.5440 | val_acc=0.4835 | val_f1m=0.4717 | LR=0.000999
  Época   7 | train_loss=1.1281 | train_acc=0.5253 | val_acc=0.4753 | val_f1m=0.4447 | LR=0.000997
  Época   8 | train_loss=1.1045 | train_acc=0.5529 | val_acc=0.5082 | val_f1m=0.5061 | LR=0.000993
  Época   9 | train_loss=1.1021 | train_acc=0.5536 | val_acc=0.4927 | val_f1m=0.4917 | LR=0.000987
  Época  10 | train_loss=1.1016 | train_acc=0.5883 | val_acc=0.5055 | val_f1m=0.4993 | LR=0.000980
  Época  1

Cargando train fold5: 100%|██████████| 83/83 [00:09<00:00,  8.34it/s]
Cargando test fold5: 100%|██████████| 20/20 [00:02<00:00,  8.44it/s]


[Fold 5/5] Entrenando modelo global... (n_train=5880 | n_val=1092 | n_test=1680)




  Época   1 | train_loss=1.3815 | train_acc=0.3781 | val_acc=0.3608 | val_f1m=0.3231 | LR=0.000400
  Época   2 | train_loss=1.2467 | train_acc=0.4827 | val_acc=0.4505 | val_f1m=0.4213 | LR=0.000600
  Época   3 | train_loss=1.1956 | train_acc=0.5114 | val_acc=0.4789 | val_f1m=0.4674 | LR=0.000800
  Época   4 | train_loss=1.1715 | train_acc=0.4974 | val_acc=0.4496 | val_f1m=0.4342 | LR=0.001000
  Época   5 | train_loss=1.1565 | train_acc=0.5310 | val_acc=0.4844 | val_f1m=0.4836 | LR=0.001000
  Época   6 | train_loss=1.1611 | train_acc=0.5333 | val_acc=0.4606 | val_f1m=0.4499 | LR=0.000999
  Época   7 | train_loss=1.1409 | train_acc=0.5384 | val_acc=0.4881 | val_f1m=0.4735 | LR=0.000997
  Época   8 | train_loss=1.1086 | train_acc=0.5600 | val_acc=0.4826 | val_f1m=0.4806 | LR=0.000993
  Época   9 | train_loss=1.1179 | train_acc=0.5463 | val_acc=0.4835 | val_f1m=0.4690 | LR=0.000987
  Época  10 | train_loss=1.1123 | train_acc=0.5679 | val_acc=0.4982 | val_f1m=0.4865 | LR=0.000980
  Época  1

In [26]:
# -*- coding: utf-8 -*-
# CNN + Transformer (PyTorch) para PhysioNet/BCI2000 MI — IMAGINERÍA (4 clases) con 8 canales.
# Cambios implementados:
# (2) Optimización: LR=5e-4, EPOCHS=80, BATCH=48, patience=12 (warmup+cosine se mantiene).
# (3) Peso extra a 'both_fists' (clase 2): *1.45 en class-weights.
# (4) Augment MI específico (solo clases 2/3): temporal cutout (70–150 ms, p=0.6) + time-warp ±4%.
# (6) Arquitectura: d_model=160, n_heads=5, dropout=0.20, WD=1e-3.
# Reporte final incluye F1 mean en CV 5-fold.

import os
os.environ['PYTHONHASHSEED'] = '42'
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'

import re, json, random
from pathlib import Path
from glob import glob
import numpy as np
import mne
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset, WeightedRandomSampler
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, classification_report

# =========================
# REPRODUCIBILIDAD
# =========================
RANDOM_STATE = 42

def seed_everything(seed: int = 42):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    try:
        torch.backends.cuda.matmul.allow_tf32 = False
        torch.backends.cudnn.allow_tf32 = False
    except Exception:
        pass
    try:
        torch.use_deterministic_algorithms(True)
    except Exception:
        pass

def seed_worker(worker_id: int):
    worker_seed = RANDOM_STATE + worker_id
    np.random.seed(worker_seed)
    random.seed(worker_seed)

seed_everything(RANDOM_STATE)

# =========================
# CONFIG
# =========================
PROJ = Path('..').resolve().parent
DATA_RAW = PROJ / 'data' / 'raw'                     # .../S###/S###R##.edf
FOLDS_JSON = PROJ / 'models' / 'folds' / 'Kfold5.json'

# Entrenamiento / K-Fold
FOLDS = [1,2,3,4,5]
EPOCHS = 80
BATCH_SIZE = 48
LR = 5e-4
PATIENCE = 12

# Prepro
RESAMPLE_HZ = None
DO_NOTCH = True
DO_BANDPASS = False          # (dejamos como estaba si tu baseline lo tenía así)
BP_LO, BP_HI = 4.0, 38.0
DO_CAR = False
ZSCORE_PER_EPOCH = False     # si lo activas, recuerda que ya normaliza en load_subject_epochs

# Modelo
D_MODEL = 160
N_HEADS = 5
N_LAYERS = 2
P_DROP = 0.20

# Ventana temporal
TMIN, TMAX = -1.0, 5.0

# Sliding / TTA (igual que antes)
SW_MODE = 'tta'   # 'none' | 'subwin' | 'tta'
SW_ENABLE = True
TTA_SHIFTS_S = [-0.025, -0.0125, 0.0, 0.0125, 0.025]

# Para subventanas (si se usa)
SW_LEN   = 2.0
SW_STRIDE = 0.5

# Dataset
EXCLUDE_SUBJECTS = {38, 88, 89, 92, 100, 104}
EXPECTED_8 = ['C3','C4','Cz','CP3','CP4','FC3','FC4','FCz']
CLASS_NAMES = ['left', 'right', 'both_fists', 'both_feet']

IMAGERY_RUNS_LR = {4, 8, 12}    # T1=left, T2=right
IMAGERY_RUNS_BF = {6, 10, 14}   # T1=both_fists, T2=both_feet

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🚀 Usando dispositivo: {DEVICE}")
print("🧠 INICIANDO EXPERIMENTO CON CNN+Transformer (K-Fold por sujeto como EEGNet)")
print(f"🔧 Configuración: 4c, 8 canales, 6s | EPOCHS={EPOCHS}, BATCH={BATCH_SIZE}, LR={LR} | ZSCORE_PER_EPOCH={ZSCORE_PER_EPOCH}")

# =========================
# UTILIDADES
# =========================
def normalize_ch_name(name: str) -> str:
    s = re.sub(r'[^A-Za-z0-9]', '', name)
    return s.upper()

NORMALIZED_TARGETS = [normalize_ch_name(c) for c in EXPECTED_8]

def pick_8_channels(raw: mne.io.BaseRaw) -> mne.io.BaseRaw:
    chs = raw.info['ch_names']
    norm_map = {normalize_ch_name(ch): ch for ch in chs}
    picked = []
    for target_norm, target_orig in zip(NORMALIZED_TARGETS, EXPECTED_8):
        if target_norm in norm_map:
            picked.append(norm_map[target_norm])
        else:
            raise RuntimeError(f"Canal requerido '{target_orig}' no encontrado. Disponibles: {chs}")
    return raw.pick(picks=picked)

def list_subject_imagery_edfs(subject_id: str) -> list:
    subj_dir = DATA_RAW / subject_id
    edfs = []
    for r in [4, 6, 8, 10, 12, 14]:
        pattern = str(subj_dir / f"{subject_id}R{r:02d}.edf")
        edfs.extend(glob(pattern))
    return sorted(edfs)

def load_subject_epochs(subject_id: str, resample_hz: int, do_notch: bool, do_bandpass: bool,
                        do_car: bool, bp_lo: float, bp_hi: float):
    edfs = list_subject_imagery_edfs(subject_id)
    if len(edfs) == 0:
        raise FileNotFoundError(f"No imagery EDF files for {subject_id} under {DATA_RAW}")

    X_list, y_list, sfreq_list = [], [], []

    for edf_path in edfs:
        m = re.search(r"R(\d{2})", Path(edf_path).name)
        run = int(m.group(1)) if m else -1

        raw = mne.io.read_raw_edf(edf_path, preload=True, verbose='ERROR')

        raw = pick_8_channels(raw)

        if do_notch:
            raw.notch_filter(freqs=[60.0], picks='all', verbose='ERROR')
        if do_bandpass:
            raw.filter(l_freq=bp_lo, h_freq=bp_hi, picks='all', verbose='ERROR')
        if do_car:
            raw.set_eeg_reference('average', projection=False, verbose='ERROR')

        if resample_hz is not None and resample_hz > 0:
            raw.resample(resample_hz)
        sfreq = raw.info['sfreq']

        events, event_id = mne.events_from_annotations(raw, verbose='ERROR')

        keep = {k: v for k, v in event_id.items() if k in {'T1', 'T2'}}
        if len(keep) == 0:
            continue

        epochs = mne.Epochs(raw, events=events, event_id=keep, tmin=TMIN, tmax=TMAX,
                            baseline=None, preload=True, verbose='ERROR')
        X = epochs.get_data()

        if ZSCORE_PER_EPOCH:
            X = X.astype(np.float32)
            eps = 1e-6
            mu = X.mean(axis=2, keepdims=True)
            sd = X.std(axis=2, keepdims=True) + eps
            X = (X - mu) / sd

        ev_codes = epochs.events[:, 2]
        inv = {v: k for k, v in keep.items()}  # id -> 'T1'/'T2'
        y_run = []
        for code in ev_codes:
            lab = inv[code]
            if run in IMAGERY_RUNS_LR:
                y_run.append(0 if lab == 'T1' else 1)
            elif run in IMAGERY_RUNS_BF:
                y_run.append(2 if lab == 'T1' else 3)
            else:
                y_run.append(-1)
        y_run = np.array(y_run, dtype=int)
        keep_mask = y_run >= 0
        X = X[keep_mask]
        y = y_run[keep_mask]

        if len(y) == 0:
            continue

        X_list.append(X)
        y_list.append(y)
        sfreq_list.append(sfreq)

    if len(X_list) == 0:
        return np.empty((0, 8, 1)), np.empty((0,), dtype=int), None

    X_all = np.concatenate(X_list, axis=0)
    y_all = np.concatenate(y_list, axis=0)

    if len(set([round(s) for s in sfreq_list])) != 1:
        raise RuntimeError(f"Inconsistent sampling rates: {sfreq_list}")

    return X_all, y_all, sfreq_list[0]

def load_fold_subjects(folds_json: Path, fold: int):
    with open(folds_json, 'r') as f:
        data = json.load(f)
    for item in data.get('folds', []):
        if int(item.get('fold', -1)) == int(fold):
            train_sub = list(item.get('train', []))
            test_sub  = list(item.get('test', []))
            return train_sub, test_sub
    raise ValueError(f"Fold {fold} not found in {folds_json}")

def subject_id_to_int(s: str) -> int:
    m = re.match(r'[Ss](\d+)', s)
    return int(m.group(1)) if m else -1

def standardize_per_channel(train_X, test_X):
    C = train_X.shape[1]
    train_X = train_X.astype(np.float32)
    test_X  = test_X.astype(np.float32)
    for c in range(C):
        mu = train_X[:, c, :].mean()
        sd = train_X[:, c, :].std()
        sd = sd if sd > 1e-6 else 1.0
        train_X[:, c, :] = (train_X[:, c, :] - mu) / sd
        test_X[:, c, :]  = (test_X[:, c, :] - mu) / sd
    return train_X, test_X

# =========================
# MODELO
# =========================
class DepthwiseSeparableConv(nn.Module):
    def __init__(self, in_ch, out_ch, k, s=1, p=0):
        super().__init__()
        self.dw = nn.Conv1d(in_ch, in_ch, kernel_size=k, stride=s, padding=p, groups=in_ch, bias=False)
        self.pw = nn.Conv1d(in_ch, out_ch, kernel_size=1, bias=False)
        self.bn = nn.BatchNorm1d(out_ch)
        self.act = nn.ELU()
    def forward(self, x):
        x = self.dw(x); x = self.pw(x); x = self.bn(x)
        return self.act(x)

class EEGCNNTransformer(nn.Module):
    def __init__(self, n_ch=8, n_cls=4, d_model=160, n_heads=5, n_layers=2, p_drop=0.20):
        super().__init__()
        self.conv_t = nn.Sequential(
            nn.Conv1d(n_ch, 32, kernel_size=129, stride=2, padding=64, bias=False),
            nn.BatchNorm1d(32), nn.ELU(),
            DepthwiseSeparableConv(32, 64, k=31, s=2, p=15),
            DepthwiseSeparableConv(64, 128, k=15, s=2, p=7),
        )
        self.proj = nn.Conv1d(128, d_model, kernel_size=1, bias=False)
        self.dropout = nn.Dropout(p=p_drop)
        self.pos_encoding = None
        enc_ly = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=n_heads, dim_feedforward=2*d_model,
            batch_first=True, activation='gelu', dropout=0.10, norm_first=True
        )
        self.encoder = nn.TransformerEncoder(enc_ly, num_layers=n_layers)
        self.cls = nn.Parameter(torch.zeros(1, 1, d_model))
        nn.init.normal_(self.cls, std=0.02)
        self.head = nn.Sequential(nn.LayerNorm(d_model), nn.Linear(d_model, n_cls))

    def _positional_encoding(self, L, d):
        pos = torch.arange(0, L, dtype=torch.float32).unsqueeze(1)
        i   = torch.arange(0, d, dtype=torch.float32).unsqueeze(0)
        angle = pos / torch.pow(10000, (2 * (i//2)) / d)
        pe = torch.zeros(L, d, dtype=torch.float32)
        pe[:, 0::2] = torch.sin(angle[:, 0::2])
        pe[:, 1::2] = torch.cos(angle[:, 1::2])
        return pe

    def forward(self, x):
        z = self.conv_t(x)           # (B, 128, T')
        z = self.proj(z)             # (B, d_model, T')
        z = self.dropout(z)
        z = z.transpose(1, 2)        # (B, T', d_model)
        B, L, D = z.shape
        if (self.pos_encoding is None) or (self.pos_encoding.shape[0] != L) or (self.pos_encoding.shape[1] != D):
            self.pos_encoding = self._positional_encoding(L, D).to(z.device)
        z = z + self.pos_encoding[None, :, :]
        cls_tok = self.cls.expand(B, -1, -1)  # (B,1,D)
        z = torch.cat([cls_tok, z], dim=1)    # (B, 1+L, D)
        z = self.encoder(z)                   # (B, 1+L, D)
        cls = z[:, 0, :]
        return self.head(cls)

# =========================
# AUGMENTATIONS
# =========================
def augment_generic(xb, p_jitter=0.25, p_noise=0.25, p_chdrop=0.15,
                    max_jitter_frac=0.02, noise_std=0.02):
    """
    xb: torch.Tensor (B,C,T), ya normalizado. Genérico (todas las clases)
    """
    B, C, T = xb.shape
    if np.random.rand() < p_jitter:
        max_shift = int(max(1, T*max_jitter_frac))
        shifts = torch.randint(low=-max_shift, high=max_shift+1, size=(B,), device=xb.device)
        for i in range(B):
            xb[i] = torch.roll(xb[i], shifts=int(shifts[i].item()), dims=-1)
    if np.random.rand() < p_noise:
        xb = xb + noise_std*torch.randn_like(xb)
    if np.random.rand() < p_chdrop:
        k = 1
        for i in range(B):
            idx = torch.randperm(C, device=xb.device)[:k]
            xb[i, idx, :] = 0.0
    return xb

def _temporal_cutout_np(x_np, sfreq, min_ms=70, max_ms=150):
    """
    x_np: (C,T) numpy. Corta a cero un segmento temporal continuo (edge padding no necesario).
    """
    C, T = x_np.shape
    min_len = max(1, int(round((min_ms/1000.0)*sfreq)))
    max_len = max(1, int(round((max_ms/1000.0)*sfreq)))
    w = np.random.randint(min_len, max_len+1)
    if w >= T:
        start = 0
        end = T
    else:
        start = np.random.randint(0, T - w + 1)
        end = start + w
    x_np[:, start:end] = 0.0
    return x_np

def _time_warp_np(x_np, max_warp_frac=0.04):
    """
    Simple time-warp: estira/encoge la señal en el eje temporal con factor ~ U[1-α, 1+α]
    y re-muestrea a la longitud original.
    """
    C, T = x_np.shape
    factor = 1.0 + np.random.uniform(-max_warp_frac, max_warp_frac)
    new_T = max(2, int(round(T * factor)))
    # Interpolación lineal por canal
    orig_t = np.linspace(0.0, 1.0, T, dtype=np.float32)
    new_t  = np.linspace(0.0, 1.0, new_T, dtype=np.float32)
    warped = np.zeros((C, new_T), dtype=np.float32)
    for c in range(C):
        warped[c] = np.interp(new_t, orig_t, x_np[c])
    # Reajuste a T original
    resampled = np.zeros((C, T), dtype=np.float32)
    if new_T == T:
        resampled = warped
    elif new_T > T:
        # recorte centrado
        start = (new_T - T) // 2
        resampled = warped[:, start:start+T]
    else:
        # pad centrado con borde
        pad = T - new_T
        left = pad // 2
        right = pad - left
        resampled[:, :left] = warped[:, 0:1]
        resampled[:, left:left+new_T] = warped
        resampled[:, left+new_T:] = warped[:, -1:]
    return resampled

def augment_mi_specific(xb, yb, sfreq, p_cutout=0.6, cutout_ms=(70,150), p_warp=0.5, warp_frac=0.04):
    """
    Augment específico para MI en clases 2/3 (both_fists / both_feet):
    - temporal cutout (70–150 ms) con p=0.6
    - time-warp ±4% con p=0.5
    xb: torch.Tensor (B,C,T) en GPU
    yb: torch.LongTensor (B,)
    """
    xb_np = xb.detach().cpu().numpy()
    y_np  = yb.detach().cpu().numpy()
    B, C, T = xb_np.shape
    for i in range(B):
        if y_np[i] in (2, 3):
            if np.random.rand() < p_cutout:
                xb_np[i] = _temporal_cutout_np(xb_np[i], sfreq, min_ms=cutout_ms[0], max_ms=cutout_ms[1])
            if np.random.rand() < p_warp:
                xb_np[i] = _time_warp_np(xb_np[i], max_warp_frac=warp_frac)
    return torch.tensor(xb_np, dtype=xb.dtype, device=xb.device)

# =========================
# EVAL helpers
# =========================
def subwindow_logits(model, X, sfreq, sw_len, sw_stride, device):
    model.eval()
    wl = int(round(sw_len * sfreq))
    st = int(round(sw_stride * sfreq))
    wl = max(1, min(wl, X.shape[-1]))
    st = max(1, st)

    out = []
    with torch.no_grad():
        for i in range(X.shape[0]):
            x = X[i]
            acc = []
            for s in range(0, max(1, X.shape[-1]-wl+1), st):
                seg = x[:, s:s+wl]
                if seg.shape[-1] < wl:
                    pad = wl - seg.shape[-1]
                    seg = np.pad(seg, ((0,0),(0,pad)), mode='edge')
                xb = torch.tensor(seg[None, ...], dtype=torch.float32, device=device)
                logit = model(xb).detach().cpu().numpy()[0]
                acc.append(logit)
            acc = np.mean(np.stack(acc, axis=0), axis=0) if len(acc) else np.zeros(4, dtype=np.float32)
            out.append(acc)
    return np.stack(out, axis=0)

def time_shift_tta_logits(model, X, sfreq, shifts_s, device):
    model.eval()
    T = X.shape[-1]
    out = []
    with torch.no_grad():
        for i in range(X.shape[0]):
            x0 = X[i]
            acc = []
            for sh in shifts_s:
                shift = int(round(sh * sfreq))
                if shift == 0:
                    x = x0
                elif shift > 0:
                    x = np.pad(x0[:, shift:], ((0,0),(0,shift)), mode='edge')[:, :T]
                else:
                    shift = -shift
                    x = np.pad(x0[:, :-shift], ((0,0),(shift,0)), mode='edge')[:, :T]
                xb = torch.tensor(x[None, ...], dtype=torch.float32, device=device)
                logit = model(xb).detach().cpu().numpy()[0]
                acc.append(logit)
            acc = np.mean(np.stack(acc, axis=0), axis=0)
            out.append(acc)
    return np.stack(out, axis=0)

# =========================
# TRAIN LOOP (por fold)
# =========================
def train_one_fold(fold:int, resample_hz:int, do_notch:bool, do_bandpass:bool,
                   bp_lo:float, bp_hi:float, epochs:int, batch_size:int, lr:float,
                   patience:int, device:str=torch.device('cuda' if torch.cuda.is_available() else 'cpu')):

    # sujetos fold
    train_sub, test_sub = load_fold_subjects(FOLDS_JSON, fold)
    train_sub = [s for s in train_sub if subject_id_to_int(s) not in EXCLUDE_SUBJECTS]
    test_sub  = [s for s in test_sub  if subject_id_to_int(s) not in EXCLUDE_SUBJECTS]

    # carga datos
    X_tr_list, y_tr_list, X_te_list, y_te_list = [], [], [], []
    sfreq = None

    for sid in tqdm(train_sub, desc=f"Cargando train fold{fold}"):
        Xs, ys, sf = load_subject_epochs(sid, resample_hz, do_notch, do_bandpass, DO_CAR, bp_lo, bp_hi)
        if len(ys) == 0: continue
        X_tr_list.append(Xs); y_tr_list.append(ys); sfreq = sf if sfreq is None else sfreq

    for sid in tqdm(test_sub, desc=f"Cargando test fold{fold}"):
        Xs, ys, sf = load_subject_epochs(sid, resample_hz, do_notch, do_bandpass, DO_CAR, bp_lo, bp_hi)
        if len(ys) == 0: continue
        X_te_list.append(Xs); y_te_list.append(ys); sfreq = sf if sfreq is None else sfreq

    if len(X_tr_list) == 0 or len(X_te_list) == 0:
        raise RuntimeError("Datos insuficientes tras carga de sujetos.")

    X_tr = np.concatenate(X_tr_list, axis=0); y_tr = np.concatenate(y_tr_list, axis=0)
    X_te = np.concatenate(X_te_list, axis=0); y_te = np.concatenate(y_te_list, axis=0)

    # Normalización (global por canal, si no usas zscore por época)
    if ZSCORE_PER_EPOCH:
        X_tr_std, X_te_std = X_tr, X_te
    else:
        X_tr_std, X_te_std = standardize_per_channel(X_tr, X_te)

    # DataLoaders + sampler balanceado por clase (reproducible)
    train_ds = TensorDataset(torch.tensor(X_tr_std), torch.tensor(y_tr).long())
    class_counts = np.bincount(y_tr, minlength=4)
    class_weights_vec = class_counts.sum() / (4.0 * np.maximum(class_counts, 1))
    # (3) empuje a both_fists (clase 2)
    class_weights_vec[2] = class_weights_vec[2] * 1.45

    sample_weights = class_weights_vec[train_ds.tensors[1].numpy()]
    sampler = WeightedRandomSampler(
        weights=torch.tensor(sample_weights, dtype=torch.float32),
        num_samples=len(train_ds), replacement=True,
        generator=torch.Generator().manual_seed(RANDOM_STATE)
    )
    tr_ld = DataLoader(train_ds, batch_size=batch_size, sampler=sampler,
                       drop_last=False, worker_init_fn=seed_worker)

    te_ld = DataLoader(
        TensorDataset(torch.tensor(X_te_std), torch.tensor(y_te).long()),
        batch_size=batch_size, shuffle=False, drop_last=False,
        worker_init_fn=seed_worker
    )

    # Modelo
    model = EEGCNNTransformer(
        n_ch=8, n_cls=4, d_model=D_MODEL, n_heads=N_HEADS, n_layers=N_LAYERS, p_drop=P_DROP
    ).to(device)

    # Optimizador / scheduler
    opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-3)  # (6) WD=1e-3

    # Loss con class-weights + label smoothing
    class_weights = torch.tensor(class_weights_vec, dtype=torch.float32, device=device)
    crit = torch.nn.CrossEntropyLoss(weight=class_weights, label_smoothing=0.05)

    # Warmup + Cosine LR + Early stopping por macro-F1
    from torch.optim.lr_scheduler import LambdaLR
    total_epochs = epochs
    warmup_epochs = max(1, min(8, epochs//10))  # un poco más largo que antes
    def lr_lambda(current_epoch):
        if current_epoch < warmup_epochs:
            return (current_epoch + 1) / warmup_epochs
        progress = (current_epoch - warmup_epochs) / max(1, (total_epochs - warmup_epochs))
        progress = min(1.0, max(0.0, progress))
        return 0.5 * (1.0 + np.cos(np.pi * progress))
    scheduler = LambdaLR(opt, lr_lambda=lr_lambda)

    # Entrenamiento (validación en test set durante training para monitoreo, como venías usando)
    print(f"[Fold {fold}/5] Entrenando modelo global... (n_train={len(X_tr_std)} | n_test={len(X_te_std)})")
    best_f1, best_state, wait = 0.0, None, 0

    for ep in range(1, epochs+1):
        model.train()
        tr_loss, n_seen, tr_correct = 0.0, 0, 0
        for xb, yb in tr_ld:
            xb = xb.to(device); yb = yb.to(device)

            # (4) Augments: primero genéricos, luego MI-específico (clases 2/3)
            xb = augment_generic(xb)
            xb = augment_mi_specific(xb, yb, sfreq, p_cutout=0.6, cutout_ms=(70,150), p_warp=0.5, warp_frac=0.04)

            opt.zero_grad()
            logits = model(xb)
            loss = crit(logits, yb)
            loss.backward()
            opt.step()

            tr_loss += loss.item() * len(yb)
            n_seen  += len(yb)
            tr_correct += (logits.argmax(dim=1) == yb).sum().item()

        tr_loss /= max(1, n_seen)
        tr_acc = tr_correct / max(1, n_seen)

        # "Validación" (monitoreo) en test (como en tus logs)
        model.eval()
        preds, gts = [], []
        with torch.no_grad():
            for xb, yb in te_ld:
                xb = xb.to(device)
                p = model(xb).argmax(dim=1).cpu().numpy()
                preds.append(p); gts.append(yb.numpy())
        preds = np.concatenate(preds); gts = np.concatenate(gts)
        acc = accuracy_score(gts, preds)
        f1m = f1_score(gts, preds, average='macro')

        improved = f1m > best_f1 + 1e-4
        if improved:
            best_f1 = f1m
            best_state = {k: v.cpu() for k, v in model.state_dict().items()}
            wait = 0
        else:
            wait += 1

        scheduler.step()
        print(f"  Época {ep:3d} | train_loss={tr_loss:.4f} | train_acc={tr_acc:.4f} | val_acc={acc:.4f} | val_f1m={f1m:.4f} | LR={scheduler.get_last_lr()[0]:.6f}")

        if wait >= patience:
            print(f"  Early stopping en época {ep} (mejor val_f1m={best_f1:.4f})")
            break

    if best_state is not None:
        model.load_state_dict(best_state)

    # --- Evaluación final (con TTA si corresponde) ---
    model.eval()
    sfreq_used = RESAMPLE_HZ if RESAMPLE_HZ is not None else int(round(X_te_std.shape[-1] / (TMAX - TMIN)))

    if (not SW_ENABLE) or SW_MODE == 'none':
        preds, gts = [], []
        with torch.no_grad():
            for xb, yb in te_ld:
                xb = xb.to(device)
                p = model(xb).argmax(dim=1).cpu().numpy()
                preds.append(p); gts.append(yb.numpy())
        preds = np.concatenate(preds); gts = np.concatenate(gts)

    elif SW_MODE == 'subwin':
        logits = subwindow_logits(model, X_te_std, sfreq_used, SW_LEN, SW_STRIDE, device)
        preds = logits.argmax(axis=1)
        gts = y_te

    elif SW_MODE == 'tta':
        logits = time_shift_tta_logits(model, X_te_std, sfreq_used, TTA_SHIFTS_S, device)
        preds = logits.argmax(axis=1)
        gts = y_te

    else:
        raise ValueError(f"SW_MODE desconocido: {SW_MODE}")

    acc = accuracy_score(gts, preds)
    f1m = f1_score(gts, preds, average='macro')
    cm = confusion_matrix(gts, preds, labels=[0,1,2,3])

    print(f"\n[Fold {fold}/5] Global acc={acc:.4f}\n")
    print(classification_report(gts, preds, target_names=['left','right','both_fists','both_feet'], digits=4))
    print("Confusion matrix (rows=true, cols=pred):")
    print(cm)

    return acc, f1m, cm, model

# =========================
# EJECUCIÓN K-FOLD (1..5) + resumen final con F1 mean
# =========================
all_acc, all_f1 = [], []
for FOLD in FOLDS:
    acc, f1m, cm, model = train_one_fold(
        fold=FOLD,
        resample_hz=RESAMPLE_HZ,
        do_notch=DO_NOTCH,
        do_bandpass=DO_BANDPASS,
        bp_lo=BP_LO,
        bp_hi=BP_HI,
        epochs=EPOCHS,
        batch_size=BATCH_SIZE,
        lr=LR,
        patience=PATIENCE,
    )
    all_acc.append(acc)
    all_f1.append(f1m)

print("\n" + "="*60)
print("RESULTADOS FINALES")
print("="*60)
acc_str = [f"{a:.4f}" for a in all_acc]
f1_str  = [f"{f:.4f}" for f in all_f1]
print(f"Global folds (ACC): {acc_str}")
print(f"Global mean ACC: {np.mean(all_acc):.4f}")
print(f"F1 folds (MACRO): {f1_str}")
print(f"F1 mean (MACRO): {np.mean(all_f1):.4f}")


🚀 Usando dispositivo: cuda
🧠 INICIANDO EXPERIMENTO CON CNN+Transformer (K-Fold por sujeto como EEGNet)
🔧 Configuración: 4c, 8 canales, 6s | EPOCHS=80, BATCH=48, LR=0.0005 | ZSCORE_PER_EPOCH=False


Cargando train fold1: 100%|██████████| 82/82 [00:09<00:00,  8.26it/s]
Cargando test fold1: 100%|██████████| 21/21 [00:02<00:00,  8.31it/s]


[Fold 1/5] Entrenando modelo global... (n_train=6888 | n_test=1764)
  Época   1 | train_loss=1.3280 | train_acc=0.3368 | val_acc=0.2483 | val_f1m=0.1005 | LR=0.000125
  Época   2 | train_loss=1.2326 | train_acc=0.3979 | val_acc=0.4014 | val_f1m=0.3829 | LR=0.000188
  Época   3 | train_loss=1.1439 | train_acc=0.4859 | val_acc=0.4059 | val_f1m=0.3950 | LR=0.000250
  Época   4 | train_loss=1.1294 | train_acc=0.4974 | val_acc=0.4365 | val_f1m=0.4316 | LR=0.000313
  Época   5 | train_loss=1.1158 | train_acc=0.5090 | val_acc=0.4297 | val_f1m=0.4138 | LR=0.000375
  Época   6 | train_loss=1.1106 | train_acc=0.5148 | val_acc=0.4546 | val_f1m=0.4546 | LR=0.000438
  Época   7 | train_loss=1.0905 | train_acc=0.5189 | val_acc=0.4620 | val_f1m=0.4632 | LR=0.000500
  Época   8 | train_loss=1.0943 | train_acc=0.5215 | val_acc=0.4524 | val_f1m=0.4430 | LR=0.000500
  Época   9 | train_loss=1.0813 | train_acc=0.5315 | val_acc=0.4320 | val_f1m=0.4149 | LR=0.000500
  Época  10 | train_loss=1.0707 | train_a

Cargando train fold2: 100%|██████████| 82/82 [00:09<00:00,  8.24it/s]
Cargando test fold2: 100%|██████████| 21/21 [00:02<00:00,  8.27it/s]


[Fold 2/5] Entrenando modelo global... (n_train=6888 | n_test=1764)
  Época   1 | train_loss=1.3313 | train_acc=0.3310 | val_acc=0.2500 | val_f1m=0.1000 | LR=0.000125
  Época   2 | train_loss=1.2677 | train_acc=0.3743 | val_acc=0.4399 | val_f1m=0.4312 | LR=0.000188
  Época   3 | train_loss=1.1965 | train_acc=0.4480 | val_acc=0.4144 | val_f1m=0.3824 | LR=0.000250
  Época   4 | train_loss=1.1410 | train_acc=0.4871 | val_acc=0.4824 | val_f1m=0.4712 | LR=0.000313
  Época   5 | train_loss=1.1403 | train_acc=0.4801 | val_acc=0.4841 | val_f1m=0.4710 | LR=0.000375
  Época   6 | train_loss=1.1223 | train_acc=0.5010 | val_acc=0.5079 | val_f1m=0.5000 | LR=0.000438
  Época   7 | train_loss=1.0854 | train_acc=0.5271 | val_acc=0.5442 | val_f1m=0.5454 | LR=0.000500
  Época   8 | train_loss=1.1148 | train_acc=0.5128 | val_acc=0.5221 | val_f1m=0.5183 | LR=0.000500
  Época   9 | train_loss=1.0924 | train_acc=0.5228 | val_acc=0.5227 | val_f1m=0.5070 | LR=0.000500
  Época  10 | train_loss=1.0793 | train_a

Cargando train fold3: 100%|██████████| 82/82 [00:10<00:00,  8.17it/s]
Cargando test fold3: 100%|██████████| 21/21 [00:02<00:00,  8.08it/s]


[Fold 3/5] Entrenando modelo global... (n_train=6888 | n_test=1764)
  Época   1 | train_loss=1.3268 | train_acc=0.3323 | val_acc=0.2755 | val_f1m=0.1552 | LR=0.000125
  Época   2 | train_loss=1.2266 | train_acc=0.4265 | val_acc=0.4694 | val_f1m=0.4716 | LR=0.000188
  Época   3 | train_loss=1.1599 | train_acc=0.4778 | val_acc=0.4280 | val_f1m=0.4189 | LR=0.000250
  Época   4 | train_loss=1.1196 | train_acc=0.5042 | val_acc=0.4399 | val_f1m=0.4349 | LR=0.000313
  Época   5 | train_loss=1.1143 | train_acc=0.5038 | val_acc=0.4240 | val_f1m=0.4074 | LR=0.000375
  Época   6 | train_loss=1.1130 | train_acc=0.5122 | val_acc=0.4325 | val_f1m=0.4224 | LR=0.000438
  Época   7 | train_loss=1.0708 | train_acc=0.5392 | val_acc=0.4626 | val_f1m=0.4670 | LR=0.000500
  Época   8 | train_loss=1.0849 | train_acc=0.5273 | val_acc=0.4490 | val_f1m=0.4418 | LR=0.000500
  Época   9 | train_loss=1.0700 | train_acc=0.5438 | val_acc=0.4609 | val_f1m=0.4566 | LR=0.000500
  Época  10 | train_loss=1.0446 | train_a

Cargando train fold4: 100%|██████████| 83/83 [00:09<00:00,  8.33it/s]
Cargando test fold4: 100%|██████████| 20/20 [00:02<00:00,  8.24it/s]


[Fold 4/5] Entrenando modelo global... (n_train=6972 | n_test=1680)
  Época   1 | train_loss=1.3366 | train_acc=0.3196 | val_acc=0.2524 | val_f1m=0.1092 | LR=0.000125
  Época   2 | train_loss=1.2454 | train_acc=0.3987 | val_acc=0.3304 | val_f1m=0.2582 | LR=0.000188
  Época   3 | train_loss=1.1598 | train_acc=0.4768 | val_acc=0.4631 | val_f1m=0.4557 | LR=0.000250
  Época   4 | train_loss=1.1347 | train_acc=0.4974 | val_acc=0.4851 | val_f1m=0.4800 | LR=0.000313
  Época   5 | train_loss=1.1064 | train_acc=0.5149 | val_acc=0.4429 | val_f1m=0.4177 | LR=0.000375
  Época   6 | train_loss=1.1061 | train_acc=0.5158 | val_acc=0.4673 | val_f1m=0.4515 | LR=0.000438
  Época   7 | train_loss=1.0961 | train_acc=0.5275 | val_acc=0.4482 | val_f1m=0.4389 | LR=0.000500
  Época   8 | train_loss=1.1113 | train_acc=0.5146 | val_acc=0.4827 | val_f1m=0.4712 | LR=0.000500
  Época   9 | train_loss=1.0856 | train_acc=0.5273 | val_acc=0.4994 | val_f1m=0.4867 | LR=0.000500
  Época  10 | train_loss=1.0705 | train_a

Cargando train fold5: 100%|██████████| 83/83 [00:10<00:00,  7.57it/s]
Cargando test fold5: 100%|██████████| 20/20 [00:02<00:00,  7.46it/s]


[Fold 5/5] Entrenando modelo global... (n_train=6972 | n_test=1680)
  Época   1 | train_loss=1.3340 | train_acc=0.3210 | val_acc=0.2518 | val_f1m=0.1040 | LR=0.000125
  Época   2 | train_loss=1.2390 | train_acc=0.3985 | val_acc=0.3833 | val_f1m=0.3330 | LR=0.000188
  Época   3 | train_loss=1.1695 | train_acc=0.4740 | val_acc=0.5405 | val_f1m=0.5431 | LR=0.000250
  Época   4 | train_loss=1.1480 | train_acc=0.4859 | val_acc=0.5119 | val_f1m=0.4974 | LR=0.000313
  Época   5 | train_loss=1.1248 | train_acc=0.4944 | val_acc=0.5012 | val_f1m=0.4931 | LR=0.000375
  Época   6 | train_loss=1.1140 | train_acc=0.5040 | val_acc=0.5268 | val_f1m=0.5248 | LR=0.000438
  Época   7 | train_loss=1.1024 | train_acc=0.5215 | val_acc=0.5268 | val_f1m=0.5281 | LR=0.000500
  Época   8 | train_loss=1.1145 | train_acc=0.5050 | val_acc=0.5256 | val_f1m=0.5237 | LR=0.000500
  Época   9 | train_loss=1.0938 | train_acc=0.5234 | val_acc=0.5185 | val_f1m=0.5191 | LR=0.000500
  Época  10 | train_loss=1.0753 | train_a

In [None]:
# -*- coding: utf-8 -*-
# CNN+Transformer para MI (4 clases, 8 canales) con:
# (1) Focal Loss, (3) BatchSampler balanceado clase+sujeto (solo train),
# (6) LR warmup+cosine, early stopping por F1 macro,
# y split train/val sujeto-aware por fold. Guarda curvas por fold y resumen final.

import os
os.environ['PYTHONHASHSEED'] = '42'
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'

import re, json, random
from pathlib import Path
from glob import glob
import numpy as np
import mne
from tqdm import tqdm
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data.sampler import BatchSampler
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, classification_report

# =========================
# REPRODUCIBILIDAD
# =========================
RANDOM_STATE = 42
def seed_everything(seed: int = 42):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    try:
        torch.backends.cuda.matmul.allow_tf32 = False
        torch.backends.cudnn.allow_tf32 = False
    except Exception:
        pass
    try:
        torch.use_deterministic_algorithms(True)
    except Exception:
        pass

def seed_worker(worker_id: int):
    worker_seed = RANDOM_STATE + worker_id
    np.random.seed(worker_seed)
    random.seed(worker_seed)

seed_everything(RANDOM_STATE)

# =========================
# CONFIG
# =========================
PROJ = Path('..').resolve().parent
DATA_RAW = PROJ / 'data' / 'raw'
FOLDS_JSON = PROJ / 'models' / 'folds' / 'Kfold5.json'

EPOCHS = 60
BATCH_SIZE = 64        # múltiplo de 4
BASE_LR = 1e-3
WARMUP_EPOCHS = 8
PATIENCE = 12

# Split de validación por fold (por sujetos)
VAL_SUBJECT_FRAC = 0.18  # ≈ 18% de sujetos del train → val
VAL_STRAT_SUBJECT = True # split por sujetos (no mezcla muestras de un mismo sujeto)

# Prepro
RESAMPLE_HZ = None
DO_NOTCH = True
DO_BANDPASS = False
BP_LO, BP_HI = 4.0, 38.0
DO_CAR = False
ZSCORE_PER_EPOCH = False

# Modelo
D_MODEL = 128
N_HEADS = 4
N_LAYERS = 2
P_DROP = 0.2

# Ventana temporal
TMIN, TMAX = -1.0, 5.0

# TTA (sólo en TEST)
SW_MODE = 'tta'   # 'none'|'subwin'|'tta'
SW_ENABLE = True
TTA_SHIFTS_S = [-0.025, -0.0125, 0.0, 0.0125, 0.025]
SW_LEN, SW_STRIDE = 2.0, 0.5

EXCLUDE_SUBJECTS = {38, 88, 89, 92, 100, 104}
EXPECTED_8 = ['C3','C4','Cz','CP3','CP4','FC3','FC4','FCz']
CLASS_NAMES = ['left', 'right', 'both_fists', 'both_feet']

IMAGERY_RUNS_LR = {4, 8, 12}
IMAGERY_RUNS_BF = {6, 10, 14}

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🚀 Usando dispositivo: {DEVICE}")
print("🧠 INICIANDO EXPERIMENTO CON CNN+Transformer (K-Fold por sujeto como EEGNet)")
print(f"🔧 Configuración: 4c, 8 canales, 6s | EPOCHS={EPOCHS}, BATCH={BATCH_SIZE}, LR={BASE_LR} | ZSCORE_PER_EPOCH={ZSCORE_PER_EPOCH}")

# =========================
# UTILIDADES I/O
# =========================
def normalize_ch_name(name: str) -> str:
    s = re.sub(r'[^A-Za-z0-9]', '', name)
    return s.upper()

NORMALIZED_TARGETS = [normalize_ch_name(c) for c in EXPECTED_8]

def pick_8_channels(raw: mne.io.BaseRaw) -> mne.io.BaseRaw:
    chs = raw.info['ch_names']
    norm_map = {normalize_ch_name(ch): ch for ch in chs}
    picked = []
    for target_norm, target_orig in zip(NORMALIZED_TARGETS, EXPECTED_8):
        if target_norm in norm_map:
            picked.append(norm_map[target_norm])
        else:
            raise RuntimeError(f"Canal requerido '{target_orig}' no encontrado. Disponibles: {chs}")
    return raw.pick(picks=picked)

def list_subject_imagery_edfs(subject_id: str) -> list:
    subj_dir = DATA_RAW / subject_id
    edfs = []
    for r in [4, 6, 8, 10, 12, 14]:
        edfs.extend(glob(str(subj_dir / f"{subject_id}R{r:02d}.edf")))
    return sorted(edfs)

def subject_id_to_int(s: str) -> int:
    m = re.match(r'[Ss](\d+)', s)
    return int(m.group(1)) if m else -1

def load_subject_epochs(subject_id: str, resample_hz: int, do_notch: bool, do_bandpass: bool,
                        do_car: bool, bp_lo: float, bp_hi: float):
    edfs = list_subject_imagery_edfs(subject_id)
    if len(edfs) == 0:
        return np.empty((0,8,1), dtype=np.float32), np.empty((0,), dtype=int), None

    X_list, y_list, sfreq_list = [], [], []

    for edf_path in edfs:
        m = re.search(r"R(\d{2})", Path(edf_path).name)
        run = int(m.group(1)) if m else -1

        raw = mne.io.read_raw_edf(edf_path, preload=True, verbose='ERROR')
        raw = pick_8_channels(raw)

        if do_notch:
            raw.notch_filter(freqs=[60.0], picks='all', verbose='ERROR')
        if do_bandpass:
            raw.filter(l_freq=bp_lo, h_freq=bp_hi, picks='all', verbose='ERROR')
        if do_car:
            raw.set_eeg_reference('average', projection=False, verbose='ERROR')

        if resample_hz is not None and resample_hz > 0:
            raw.resample(resample_hz)
        sfreq = raw.info['sfreq']

        events, event_id = mne.events_from_annotations(raw, verbose='ERROR')
        keep = {k: v for k, v in event_id.items() if k in {'T1', 'T2'}}
        if len(keep) == 0:
            continue

        epochs = mne.Epochs(raw, events=events, event_id=keep, tmin=TMIN, tmax=TMAX,
                            baseline=None, preload=True, verbose='ERROR')
        X = epochs.get_data()

        if ZSCORE_PER_EPOCH:
            X = X.astype(np.float32)
            eps = 1e-6
            mu = X.mean(axis=2, keepdims=True)
            sd = X.std(axis=2, keepdims=True) + eps
            X = (X - mu) / sd

        ev_codes = epochs.events[:, 2]
        inv = {v: k for k, v in keep.items()}
        y_run = []
        for code in ev_codes:
            lab = inv[code]
            if run in IMAGERY_RUNS_LR:
                y_run.append(0 if lab == 'T1' else 1)
            elif run in IMAGERY_RUNS_BF:
                y_run.append(2 if lab == 'T1' else 3)
            else:
                y_run.append(-1)
        y_run = np.array(y_run, dtype=int)
        mask = y_run >= 0
        if mask.sum() == 0:
            continue

        X_list.append(X[mask])
        y_list.append(y_run[mask])
        sfreq_list.append(sfreq)

    if len(X_list) == 0:
        return np.empty((0,8,1), dtype=np.float32), np.empty((0,), dtype=int), None

    X_all = np.concatenate(X_list, axis=0).astype(np.float32)
    y_all = np.concatenate(y_list, axis=0).astype(int)

    if len(set([int(round(s)) for s in sfreq_list])) != 1:
        raise RuntimeError(f"Sampling rates inconsistentes: {sfreq_list}")

    return X_all, y_all, sfreq_list[0]

def load_fold_subjects(folds_json: Path, fold: int):
    with open(folds_json, 'r') as f:
        data = json.load(f)
    for item in data.get('folds', []):
        if int(item.get('fold', -1)) == int(fold):
            return list(item.get('train', [])), list(item.get('test', []))
    raise ValueError(f"Fold {fold} not found in {folds_json}")

def standardize_per_channel(train_X, other_X):
    C = train_X.shape[1]
    train_X = train_X.astype(np.float32)
    other_X = other_X.astype(np.float32)
    for c in range(C):
        mu = train_X[:, c, :].mean()
        sd = train_X[:, c, :].std()
        sd = sd if sd > 1e-6 else 1.0
        train_X[:, c, :] = (train_X[:, c, :] - mu) / sd
        other_X[:, c, :]  = (other_X[:, c, :] - mu) / sd
    return train_X, other_X

# =========================
# MODELO
# =========================
class DepthwiseSeparableConv(nn.Module):
    def __init__(self, in_ch, out_ch, k, s=1, p=0):
        super().__init__()
        self.dw = nn.Conv1d(in_ch, in_ch, kernel_size=k, stride=s, padding=p, groups=in_ch, bias=False)
        self.pw = nn.Conv1d(in_ch, out_ch, kernel_size=1, bias=False)
        self.bn = nn.BatchNorm1d(out_ch)
        self.act = nn.ELU()
    def forward(self, x):
        x = self.dw(x); x = self.pw(x); x = self.bn(x)
        return self.act(x)

class EEGCNNTransformer(nn.Module):
    def __init__(self, n_ch=8, n_cls=4, d_model=128, n_heads=4, n_layers=2, p_drop=0.2):
        super().__init__()
        self.conv_t = nn.Sequential(
            nn.Conv1d(n_ch, 32, kernel_size=129, stride=2, padding=64, bias=False),
            nn.BatchNorm1d(32), nn.ELU(),
            DepthwiseSeparableConv(32, 64, k=31, s=2, p=15),
            DepthwiseSeparableConv(64, 128, k=15, s=2, p=7),
        )
        self.proj = nn.Conv1d(128, d_model, kernel_size=1, bias=False)
        self.dropout = nn.Dropout(p=p_drop)
        self.pos_encoding = None
        enc = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=n_heads, dim_feedforward=2*d_model,
            batch_first=True, activation='gelu', dropout=0.1, norm_first=True
        )
        self.encoder = nn.TransformerEncoder(enc, num_layers=n_layers)
        self.cls = nn.Parameter(torch.zeros(1, 1, d_model))
        nn.init.normal_(self.cls, std=0.02)
        self.head = nn.Sequential(nn.LayerNorm(d_model), nn.Linear(d_model, n_cls))

    def _positional_encoding(self, L, d):
        pos = torch.arange(0, L, dtype=torch.float32).unsqueeze(1)
        i   = torch.arange(0, d, dtype=torch.float32).unsqueeze(0)
        angle = pos / torch.pow(10000, (2 * (i//2)) / d)
        pe = torch.zeros(L, d, dtype=torch.float32)
        pe[:, 0::2] = torch.sin(angle[:, 0::2])
        pe[:, 1::2] = torch.cos(angle[:, 1::2])
        return pe

    def forward(self, x):
        z = self.conv_t(x)           # (B, 128, T')
        z = self.proj(z)             # (B, d_model, T')
        z = self.dropout(z)
        z = z.transpose(1, 2)        # (B, T', d_model)
        B, L, D = z.shape
        if (self.pos_encoding is None) or (self.pos_encoding.shape[0] != L) or (self.pos_encoding.shape[1] != D):
            self.pos_encoding = self._positional_encoding(L, D).to(z.device)
        z = z + self.pos_encoding[None, :, :]
        cls_tok = self.cls.expand(B, -1, -1)
        z = torch.cat([cls_tok, z], dim=1)
        z = self.encoder(z)
        cls = z[:, 0, :]
        return self.head(cls)

# =========================
# (1) FOCAL LOSS
# =========================
class FocalLoss(nn.Module):
    def __init__(self, alpha: torch.Tensor, gamma: float = 1.5, reduction: str = 'mean'):
        super().__init__()
        self.alpha = alpha / alpha.sum()
        self.gamma = gamma
        self.reduction = reduction
    def forward(self, logits, target):
        logp = nn.functional.log_softmax(logits, dim=-1)      # (B,C)
        p = logp.exp()
        idx = torch.arange(target.shape[0], device=logits.device)
        pt = p[idx, target]
        logpt = logp[idx, target]
        at = self.alpha[target]
        loss = - at * ((1 - pt) ** self.gamma) * logpt
        if self.reduction == 'mean': return loss.mean()
        if self.reduction == 'sum':  return loss.sum()
        return loss

# =========================
# (3) SAMPLER BALANCEADO CLASE+SUJETO
# =========================
class ClassSubjectBalancedBatchSampler(BatchSampler):
    def __init__(self, labels: np.ndarray, subjects: np.ndarray, batch_size: int,
                 n_classes: int = 4, min_subjects_per_class: int = 2, generator=None):
        assert batch_size % n_classes == 0, "BATCH_SIZE debe ser múltiplo de n_clases."
        self.labels = labels.astype(int)
        self.subjects = subjects.astype(int)
        self.batch_size = batch_size
        self.n_classes = n_classes
        self.per_class = batch_size // n_classes
        self.min_subj = min_subjects_per_class
        self.generator = generator if generator is not None else random.Random(RANDOM_STATE)
        self.by_class = {c: {} for c in range(n_classes)}
        for idx, (y, s) in enumerate(zip(self.labels, self.subjects)):
            self.by_class[y].setdefault(s, []).append(idx)
        for c in range(n_classes):
            for s in self.by_class[c]:
                self.generator.shuffle(self.by_class[c][s])
        self.class_subjects = {c: list(self.by_class[c].keys()) for c in range(n_classes)}
        for c in range(n_classes):
            self.generator.shuffle(self.class_subjects[c])
        self.num_batches = int(np.ceil(len(labels) / batch_size))

    def __iter__(self):
        cursors = {c: {s: 0 for s in self.class_subjects[c]} for c in range(self.n_classes)}
        subj_ptr = {c: 0 for c in range(self.n_classes)}
        for _ in range(self.num_batches):
            batch = []
            for c in range(self.n_classes):
                subj_list = self.class_subjects[c]
                if len(subj_list) == 0: continue
                take_subj = min(max(self.min_subj, 1), len(subj_list), self.per_class)
                base = self.per_class // take_subj
                rem = self.per_class - base * take_subj
                chosen = []
                for k in range(take_subj):
                    s_idx = (subj_ptr[c] + k) % len(subj_list)
                    chosen.append(subj_list[s_idx])
                subj_ptr[c] = (subj_ptr[c] + take_subj) % len(subj_list)
                per_subj = {s: base for s in chosen}
                for k in range(rem): per_subj[chosen[k % take_subj]] += 1
                for s in chosen:
                    want = per_subj[s]
                    pool = self.by_class[c].get(s, [])
                    cur = cursors[c].get(s, 0)
                    taken = []
                    for _t in range(want):
                        if cur >= len(pool):
                            if len(pool) == 0: break
                            self.generator.shuffle(pool); cur = 0
                        taken.append(pool[cur]); cur += 1
                    cursors[c][s] = cur
                    batch.extend(taken)
            if len(batch) < self.batch_size:
                all_idx = np.arange(len(self.labels))
                need = self.batch_size - len(batch)
                batch.extend(self.generator.choices(all_idx, k=need))
            if len(batch) > self.batch_size:
                batch = batch[:self.batch_size]
            yield batch

    def __len__(self):
        return self.num_batches

# =========================
# AUGMENTS (ligeros)
# =========================
def augment_batch(xb, p_jitter=0.25, p_noise=0.25, p_chdrop=0.15,
                  max_jitter_frac=0.02, noise_std=0.02):
    B, C, T = xb.shape
    if np.random.rand() < p_jitter:
        max_shift = int(max(1, T*max_jitter_frac))
        shifts = torch.randint(low=-max_shift, high=max_shift+1, size=(B,), device=xb.device)
        for i in range(B):
            xb[i] = torch.roll(xb[i], shifts=int(shifts[i].item()), dims=-1)
    if np.random.rand() < p_noise:
        xb = xb + noise_std*torch.randn_like(xb)
    if np.random.rand() < p_chdrop:
        k = 1
        for i in range(B):
            idx = torch.randperm(C, device=xb.device)[:k]
            xb[i, idx, :] = 0.0
    return xb

# =========================
# INFERENCIA TTA / SUBWINDOW
# =========================
def subwindow_logits(model, X, sfreq, sw_len, sw_stride, device):
    model.eval()
    wl = int(round(sw_len * sfreq))
    st = int(round(sw_stride * sfreq))
    wl = max(1, min(wl, X.shape[-1])); st = max(1, st)
    out = []
    with torch.no_grad():
        for i in range(X.shape[0]):
            x = X[i]; acc = []
            for s in range(0, max(1, X.shape[-1]-wl+1), st):
                seg = x[:, s:s+wl]
                if seg.shape[-1] < wl:
                    pad = wl - seg.shape[-1]
                    seg = np.pad(seg, ((0,0),(0,pad)), mode='edge')
                xb = torch.tensor(seg[None, ...], dtype=torch.float32, device=device)
                logit = model(xb).detach().cpu().numpy()[0]
                acc.append(logit)
            acc = np.mean(np.stack(acc, axis=0), axis=0) if len(acc) else np.zeros(4, dtype=np.float32)
            out.append(acc)
    return np.stack(out, axis=0)

def time_shift_tta_logits(model, X, sfreq, shifts_s, device):
    model.eval()
    T = X.shape[-1]; out = []
    with torch.no_grad():
        for i in range(X.shape[0]):
            x0 = X[i]; acc = []
            for sh in shifts_s:
                shift = int(round(sh * sfreq))
                if shift == 0:
                    x = x0
                elif shift > 0:
                    x = np.pad(x0[:, shift:], ((0,0),(0,shift)), mode='edge')[:, :T]
                else:
                    shift = -shift
                    x = np.pad(x0[:, :-shift], ((0,0),(shift,0)), mode='edge')[:, :T]
                xb = torch.tensor(x[None, ...], dtype=torch.float32, device=device)
                logit = model(xb).detach().cpu().numpy()[0]
                acc.append(logit)
            out.append(np.mean(np.stack(acc, axis=0), axis=0))
    return np.stack(out, axis=0)

# =========================
# TRAIN/EVAL por FOLD (con train/val/test)
# =========================
def train_one_fold(fold:int, device):
    train_sub, test_sub = load_fold_subjects(FOLDS_JSON, fold)
    train_sub = [s for s in train_sub if subject_id_to_int(s) not in EXCLUDE_SUBJECTS]
    test_sub  = [s for s in test_sub  if subject_id_to_int(s) not in EXCLUDE_SUBJECTS]

    # Split de validación por sujetos (determinista)
    rng = random.Random(RANDOM_STATE + fold)
    tr_subjects = train_sub.copy()
    rng.shuffle(tr_subjects)
    n_val_subj = max(1, int(round(len(tr_subjects) * VAL_SUBJECT_FRAC)))
    val_subjects = sorted(tr_subjects[:n_val_subj])
    train_subjects = sorted(tr_subjects[n_val_subj:])

    # Carga TRAIN
    X_tr_list, y_tr_list, sub_tr_list = [], [], []
    # Carga VAL
    X_val_list, y_val_list, sub_val_list = [], [], []
    # Carga TEST
    X_te_list, y_te_list, sub_te_list = [], [], []
    sfreq = None

    for sid in tqdm(train_subjects, desc=f"Cargando train fold{fold}"):
        Xs, ys, sf = load_subject_epochs(sid, RESAMPLE_HZ, DO_NOTCH, DO_BANDPASS, DO_CAR, BP_LO, BP_HI)
        if len(ys) == 0: continue
        X_tr_list.append(Xs); y_tr_list.append(ys)
        sub_tr_list.append(np.full_like(ys, fill_value=subject_id_to_int(sid)))
        sfreq = sf if sfreq is None else sfreq

    for sid in tqdm(val_subjects, desc=f"Cargando val fold{fold}"):
        Xs, ys, sf = load_subject_epochs(sid, RESAMPLE_HZ, DO_NOTCH, DO_BANDPASS, DO_CAR, BP_LO, BP_HI)
        if len(ys) == 0: continue
        X_val_list.append(Xs); y_val_list.append(ys)
        sub_val_list.append(np.full_like(ys, fill_value=subject_id_to_int(sid)))
        sfreq = sf if sfreq is None else sfreq

    for sid in tqdm(test_sub, desc=f"Cargando test fold{fold}"):
        Xs, ys, sf = load_subject_epochs(sid, RESAMPLE_HZ, DO_NOTCH, DO_BANDPASS, DO_CAR, BP_LO, BP_HI)
        if len(ys) == 0: continue
        X_te_list.append(Xs); y_te_list.append(ys)
        sub_te_list.append(np.full_like(ys, fill_value=subject_id_to_int(sid)))
        sfreq = sf if sfreq is None else sfreq

    # Concatenar
    X_tr = np.concatenate(X_tr_list, axis=0); y_tr = np.concatenate(y_tr_list, axis=0)
    sub_tr = np.concatenate(sub_tr_list, axis=0)
    X_val = np.concatenate(X_val_list, axis=0); y_val = np.concatenate(y_val_list, axis=0)
    sub_val = np.concatenate(sub_val_list, axis=0)
    X_te = np.concatenate(X_te_list, axis=0); y_te = np.concatenate(y_te_list, axis=0)
    sub_te = np.concatenate(sub_te_list, axis=0)

    print(f"[Fold {fold}/5] Entrenando modelo global... (n_train={len(y_tr)} | n_val={len(y_val)} | n_test={len(y_te)})")

    # Normalización por canal (fit en TRAIN y aplicar a VAL/TEST)
    if ZSCORE_PER_EPOCH:
        X_tr_std, X_val_std, X_te_std = X_tr, X_val, X_te
    else:
        X_tr_std, X_val_std = standardize_per_channel(X_tr, X_val)
        _,        X_te_std  = standardize_per_channel(X_tr, X_te)  # reutiliza stats de train

    # Datasets
    tr_ds  = TensorDataset(torch.tensor(X_tr_std),  torch.tensor(y_tr).long(),  torch.tensor(sub_tr).long())
    val_ds = TensorDataset(torch.tensor(X_val_std), torch.tensor(y_val).long(), torch.tensor(sub_val).long())
    te_ds  = TensorDataset(torch.tensor(X_te_std),  torch.tensor(y_te).long(),  torch.tensor(sub_te).long())

    # (3) Sampler balanceado en TRAIN
    labels_np = y_tr
    subjects_np = sub_tr
    gen = random.Random(RANDOM_STATE + 100*fold)
    batch_sampler = ClassSubjectBalancedBatchSampler(labels_np, subjects_np, batch_size=BATCH_SIZE,
                                                     n_classes=4, min_subjects_per_class=2, generator=gen)
    tr_ld  = DataLoader(tr_ds, batch_sampler=batch_sampler, worker_init_fn=seed_worker)
    val_ld = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, worker_init_fn=seed_worker)
    te_ld  = DataLoader(te_ds,  batch_size=BATCH_SIZE, shuffle=False, drop_last=False, worker_init_fn=seed_worker)

    # Modelo
    model = EEGCNNTransformer(n_ch=8, n_cls=4, d_model=D_MODEL, n_heads=N_HEADS,
                              n_layers=N_LAYERS, p_drop=P_DROP).to(device)

    # Optimizador + (1) Focal Loss
    opt = torch.optim.AdamW(model.parameters(), lr=BASE_LR, weight_decay=1e-2)

    class_counts = np.bincount(y_tr, minlength=4).astype(np.float32)
    inv = class_counts.sum() / (4.0 * np.maximum(class_counts, 1.0))
    alpha = torch.tensor(inv, dtype=torch.float32, device=device)
    alpha_mean = alpha.mean().item()
    alpha[2] = 1.22 * alpha_mean   # downweight leve para both_fists
    crit = FocalLoss(alpha=alpha, gamma=1.5, reduction='mean')

    # (6) Warmup+Cosine
    from torch.optim.lr_scheduler import LambdaLR
    total_epochs = EPOCHS
    warmup_epochs = max(1, int(WARMUP_EPOCHS))
    min_factor = 0.1
    def lr_lambda(current_epoch):
        if current_epoch < warmup_epochs:
            return (current_epoch + 1) / warmup_epochs
        progress = (current_epoch - warmup_epochs) / max(1, (total_epochs - warmup_epochs))
        progress = min(1.0, max(0.0, progress))
        return min_factor + 0.5 * (1.0 - min_factor) * (1.0 + np.cos(np.pi * progress))
    scheduler = LambdaLR(opt, lr_lambda=lr_lambda)

    # Entrenamiento con early stopping por F1 macro (en VAL)
    best_f1, best_state, wait = 0.0, None, 0
    hist = {"ep": [], "tr_loss": [], "tr_acc": [], "val_acc": [], "val_f1m": [], "lr": []}

    for ep in range(1, EPOCHS+1):
        # ---- Train ----
        model.train()
        tr_loss, n_seen, tr_correct = 0.0, 0, 0
        for xb, yb, _sb in tr_ld:
            xb = xb.to(device); yb = yb.to(device)
            xb = augment_batch(xb)
            opt.zero_grad()
            logits = model(xb)
            loss = crit(logits, yb)
            loss.backward()
            opt.step()
            tr_loss += loss.item() * len(yb)
            n_seen += len(yb)
            tr_correct += (logits.argmax(1) == yb).sum().item()
        tr_loss /= max(1, n_seen)
        tr_acc = tr_correct / max(1, n_seen)

        # ---- Val ----
        model.eval()
        preds, gts = [], []
        with torch.no_grad():
            for xb, yb, _sb in val_ld:
                xb = xb.to(device)
                p = model(xb).argmax(dim=1).cpu().numpy()
                preds.append(p); gts.append(yb.numpy())
        preds = np.concatenate(preds); gts = np.concatenate(gts)
        acc = accuracy_score(gts, preds)
        f1m = f1_score(gts, preds, average='macro')

        hist["ep"].append(ep)
        hist["tr_loss"].append(tr_loss)
        hist["tr_acc"].append(tr_acc)
        hist["val_acc"].append(acc)
        hist["val_f1m"].append(f1m)
        hist["lr"].append(scheduler.get_last_lr()[0])

        print(f"  Época {ep:3d} | train_loss={tr_loss:.4f} | train_acc={tr_acc:.4f} | val_acc={acc:.4f} | val_f1m={f1m:.4f} | LR={scheduler.get_last_lr()[0]:.6f}")

        improved = f1m > best_f1 + 1e-4
        if improved:
            best_f1 = f1m
            best_state = {k: v.detach().cpu() for k, v in model.state_dict().items()}
            wait = 0
        else:
            wait += 1

        scheduler.step()
        if wait >= PATIENCE:
            print(f"  Early stopping en época {ep} (mejor val_f1m={best_f1:.4f})")
            break

    # Guarda mejor estado
    if best_state is not None:
        model.load_state_dict(best_state)

    # ---- Guardar curva ----
    fig = plt.figure(figsize=(8,4.5))
    ax1 = plt.gca()
    ax1.plot(hist["ep"], hist["tr_loss"], label="train_loss")
    ax1.plot(hist["ep"], hist["val_f1m"], label="val_f1m")
    ax1.plot(hist["ep"], hist["val_acc"], label="val_acc")
    ax1.set_xlabel("Época"); ax1.set_title(f"Fold {fold} — Curva de entrenamiento")
    ax1.legend(); ax1.grid(True, alpha=0.3)
    out_png = f"training_curve_fold{fold}.png"
    plt.tight_layout()
    plt.savefig(out_png, dpi=140)
    plt.close(fig)
    print(f"↳ Curva de entrenamiento guardada: {out_png}")

    # ---- Evaluación final en TEST (con TTA/subwindows si aplica) ----
    model.eval()
    # Estimar sfreq para TTA/subwindow
    sfreq_used = RESAMPLE_HZ
    if sfreq_used is None:
        # longitud en puntos / duración ventana
        sfreq_used = int(round(X_te_std.shape[-1] / (TMAX - TMIN)))

    if (not SW_ENABLE) or SW_MODE == 'none':
        preds, gts = [], []
        with torch.no_grad():
            for xb, yb, _sb in te_ld:
                xb = xb.to(device)
                p = model(xb).argmax(dim=1).cpu().numpy()
                preds.append(p); gts.append(yb.numpy())
        preds = np.concatenate(preds); gts = np.concatenate(gts)
    elif SW_MODE == 'subwin':
        logits = subwindow_logits(model, X_te_std, sfreq_used, SW_LEN, SW_STRIDE, device)
        preds = logits.argmax(axis=1); gts = y_te
    elif SW_MODE == 'tta':
        logits = time_shift_tta_logits(model, X_te_std, sfreq_used, TTA_SHIFTS_S, device)
        preds = logits.argmax(axis=1); gts = y_te
    else:
        raise ValueError(f"SW_MODE desconocido: {SW_MODE}")

    acc = accuracy_score(gts, preds)
    f1m = f1_score(gts, preds, average='macro')
    print(f"[Fold {fold}/5] Global acc={acc:.4f}\n")
    print(classification_report(gts, preds, target_names=[c.replace('_',' ') for c in CLASS_NAMES], digits=4))
    print("Confusion matrix (rows=true, cols=pred):")
    print(confusion_matrix(gts, preds, labels=[0,1,2,3]))

    return acc, f1m

# =========================
# LOOP 5 FOLDS + RESUMEN
# =========================
if __name__ == "__main__":
    acc_folds, f1_folds = [], []
    for fold in range(1, 6):
        acc, f1m = train_one_fold(fold, DEVICE)
        acc_folds.append(f"{acc:.4f}")
        f1_folds.append(f"{f1m:.4f}")

    acc_mean = float(np.mean([float(a) for a in acc_folds]))
    f1_mean  = float(np.mean([float(f) for f in f1_folds]))

    print("\n============================================================")
    print("RESULTADOS FINALES")
    print("============================================================")
    print(f"Global folds (ACC): {acc_folds}")
    print(f"Global mean ACC: {acc_mean:.4f}")
    print(f"F1 folds (MACRO): {f1_folds}")
    print(f"F1 mean (MACRO): {f1_mean:.4f}")


🚀 Usando dispositivo: cuda
🧠 INICIANDO EXPERIMENTO CON CNN+Transformer (K-Fold por sujeto como EEGNet)
🔧 Configuración: 4c, 8 canales, 6s | EPOCHS=60, BATCH=64, LR=0.001 | ZSCORE_PER_EPOCH=False


Cargando train fold1:   0%|          | 0/67 [00:00<?, ?it/s]

Cargando train fold1: 100%|██████████| 67/67 [00:08<00:00,  8.33it/s]
Cargando val fold1: 100%|██████████| 15/15 [00:01<00:00,  8.33it/s]
Cargando test fold1: 100%|██████████| 21/21 [00:02<00:00,  8.36it/s]


[Fold 1/5] Entrenando modelo global... (n_train=5628 | n_val=1260 | n_test=1764)




  Época   1 | train_loss=0.2274 | train_acc=0.2674 | val_acc=0.3087 | val_f1m=0.2489 | LR=0.000125
  Época   2 | train_loss=0.2048 | train_acc=0.3810 | val_acc=0.4913 | val_f1m=0.4946 | LR=0.000250
  Época   3 | train_loss=0.1868 | train_acc=0.4636 | val_acc=0.5008 | val_f1m=0.5057 | LR=0.000375
  Época   4 | train_loss=0.1786 | train_acc=0.4966 | val_acc=0.4913 | val_f1m=0.4942 | LR=0.000500
  Época   5 | train_loss=0.1761 | train_acc=0.5016 | val_acc=0.5317 | val_f1m=0.5321 | LR=0.000625
  Época   6 | train_loss=0.1736 | train_acc=0.5071 | val_acc=0.5262 | val_f1m=0.5316 | LR=0.000750
  Época   7 | train_loss=0.1715 | train_acc=0.5211 | val_acc=0.5222 | val_f1m=0.5212 | LR=0.000875
  Época   8 | train_loss=0.1686 | train_acc=0.5202 | val_acc=0.5397 | val_f1m=0.5403 | LR=0.001000
  Época   9 | train_loss=0.1674 | train_acc=0.5337 | val_acc=0.5341 | val_f1m=0.5378 | LR=0.001000
  Época  10 | train_loss=0.1622 | train_acc=0.5398 | val_acc=0.5294 | val_f1m=0.5349 | LR=0.000999
  Época  1

Cargando train fold2: 100%|██████████| 67/67 [00:08<00:00,  8.36it/s]
Cargando val fold2: 100%|██████████| 15/15 [00:01<00:00,  8.40it/s]
Cargando test fold2: 100%|██████████| 21/21 [00:02<00:00,  8.32it/s]


[Fold 2/5] Entrenando modelo global... (n_train=5628 | n_val=1260 | n_test=1764)




  Época   1 | train_loss=0.2292 | train_acc=0.2470 | val_acc=0.2952 | val_f1m=0.2315 | LR=0.000125
  Época   2 | train_loss=0.2078 | train_acc=0.3665 | val_acc=0.4476 | val_f1m=0.4491 | LR=0.000250
  Época   3 | train_loss=0.1913 | train_acc=0.4425 | val_acc=0.4778 | val_f1m=0.4805 | LR=0.000375
  Época   4 | train_loss=0.1841 | train_acc=0.4647 | val_acc=0.4881 | val_f1m=0.4892 | LR=0.000500
  Época   5 | train_loss=0.1803 | train_acc=0.4860 | val_acc=0.4889 | val_f1m=0.4831 | LR=0.000625
  Época   6 | train_loss=0.1791 | train_acc=0.4906 | val_acc=0.5103 | val_f1m=0.5115 | LR=0.000750
  Época   7 | train_loss=0.1791 | train_acc=0.4918 | val_acc=0.5071 | val_f1m=0.5075 | LR=0.000875
  Época   8 | train_loss=0.1751 | train_acc=0.5044 | val_acc=0.4929 | val_f1m=0.4920 | LR=0.001000
  Época   9 | train_loss=0.1714 | train_acc=0.5110 | val_acc=0.4960 | val_f1m=0.5003 | LR=0.001000
  Época  10 | train_loss=0.1692 | train_acc=0.5147 | val_acc=0.5032 | val_f1m=0.5016 | LR=0.000999
  Época  1

Cargando train fold3: 100%|██████████| 67/67 [00:08<00:00,  8.34it/s]
Cargando val fold3: 100%|██████████| 15/15 [00:01<00:00,  8.37it/s]
Cargando test fold3: 100%|██████████| 21/21 [00:02<00:00,  8.36it/s]


[Fold 3/5] Entrenando modelo global... (n_train=5628 | n_val=1260 | n_test=1764)




  Época   1 | train_loss=0.2278 | train_acc=0.2562 | val_acc=0.2889 | val_f1m=0.2374 | LR=0.000125
  Época   2 | train_loss=0.2066 | train_acc=0.3674 | val_acc=0.3810 | val_f1m=0.3703 | LR=0.000250
  Época   3 | train_loss=0.1837 | train_acc=0.4721 | val_acc=0.4365 | val_f1m=0.4376 | LR=0.000375
  Época   4 | train_loss=0.1764 | train_acc=0.4989 | val_acc=0.4460 | val_f1m=0.4447 | LR=0.000500
  Época   5 | train_loss=0.1708 | train_acc=0.5101 | val_acc=0.4611 | val_f1m=0.4610 | LR=0.000625
  Época   6 | train_loss=0.1694 | train_acc=0.5165 | val_acc=0.4365 | val_f1m=0.4282 | LR=0.000750
  Época   7 | train_loss=0.1677 | train_acc=0.5218 | val_acc=0.4675 | val_f1m=0.4650 | LR=0.000875
  Época   8 | train_loss=0.1640 | train_acc=0.5350 | val_acc=0.4500 | val_f1m=0.4475 | LR=0.001000
  Época   9 | train_loss=0.1634 | train_acc=0.5309 | val_acc=0.4754 | val_f1m=0.4711 | LR=0.001000
  Época  10 | train_loss=0.1566 | train_acc=0.5550 | val_acc=0.4722 | val_f1m=0.4691 | LR=0.000999
  Época  1

Cargando train fold4: 100%|██████████| 68/68 [00:08<00:00,  8.35it/s]
Cargando val fold4: 100%|██████████| 15/15 [00:01<00:00,  8.32it/s]
Cargando test fold4: 100%|██████████| 20/20 [00:02<00:00,  8.30it/s]


[Fold 4/5] Entrenando modelo global... (n_train=5712 | n_val=1260 | n_test=1680)




  Época   1 | train_loss=0.2276 | train_acc=0.2585 | val_acc=0.3087 | val_f1m=0.2695 | LR=0.000125
  Época   2 | train_loss=0.2024 | train_acc=0.3858 | val_acc=0.4190 | val_f1m=0.4108 | LR=0.000250
  Época   3 | train_loss=0.1844 | train_acc=0.4747 | val_acc=0.4278 | val_f1m=0.4153 | LR=0.000375
  Época   4 | train_loss=0.1797 | train_acc=0.4865 | val_acc=0.4532 | val_f1m=0.4504 | LR=0.000500
  Época   5 | train_loss=0.1768 | train_acc=0.4977 | val_acc=0.4429 | val_f1m=0.4408 | LR=0.000625
  Época   6 | train_loss=0.1736 | train_acc=0.5106 | val_acc=0.4881 | val_f1m=0.4896 | LR=0.000750
  Época   7 | train_loss=0.1730 | train_acc=0.5115 | val_acc=0.4683 | val_f1m=0.4640 | LR=0.000875
  Época   8 | train_loss=0.1681 | train_acc=0.5243 | val_acc=0.4611 | val_f1m=0.4504 | LR=0.001000
  Época   9 | train_loss=0.1666 | train_acc=0.5198 | val_acc=0.4659 | val_f1m=0.4626 | LR=0.001000
  Época  10 | train_loss=0.1624 | train_acc=0.5415 | val_acc=0.4571 | val_f1m=0.4561 | LR=0.000999
  Época  1

Cargando train fold5: 100%|██████████| 68/68 [00:08<00:00,  8.30it/s]
Cargando val fold5: 100%|██████████| 15/15 [00:01<00:00,  8.34it/s]
Cargando test fold5: 100%|██████████| 20/20 [00:02<00:00,  8.32it/s]


[Fold 5/5] Entrenando modelo global... (n_train=5712 | n_val=1260 | n_test=1680)




  Época   1 | train_loss=0.2276 | train_acc=0.2649 | val_acc=0.3000 | val_f1m=0.2215 | LR=0.000125
  Época   2 | train_loss=0.2103 | train_acc=0.3601 | val_acc=0.4802 | val_f1m=0.4749 | LR=0.000250
  Época   3 | train_loss=0.1945 | train_acc=0.4285 | val_acc=0.4944 | val_f1m=0.4835 | LR=0.000375
  Época   4 | train_loss=0.1885 | train_acc=0.4562 | val_acc=0.5206 | val_f1m=0.5210 | LR=0.000500
  Época   5 | train_loss=0.1837 | train_acc=0.4667 | val_acc=0.5079 | val_f1m=0.4916 | LR=0.000625
  Época   6 | train_loss=0.1812 | train_acc=0.4771 | val_acc=0.5413 | val_f1m=0.5343 | LR=0.000750
  Época   7 | train_loss=0.1780 | train_acc=0.4842 | val_acc=0.5365 | val_f1m=0.5381 | LR=0.000875
  Época   8 | train_loss=0.1809 | train_acc=0.4712 | val_acc=0.5206 | val_f1m=0.5238 | LR=0.001000
  Época   9 | train_loss=0.1747 | train_acc=0.4977 | val_acc=0.5484 | val_f1m=0.5462 | LR=0.001000
  Época  10 | train_loss=0.1704 | train_acc=0.5062 | val_acc=0.5468 | val_f1m=0.5491 | LR=0.000999
  Época  1

In [2]:
# -*- coding: utf-8 -*-
# CNN+Transformer para MI (4 clases, 8 canales) con:
# (1) Focal Loss
# (4) Aumentos más fuertes y estables
# (6) LR warmup+cosine, early stopping por F1 macro
# Split train/val sujeto-aware por fold. Guarda curvas por fold y resumen final.

import os
os.environ['PYTHONHASHSEED'] = '42'
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'

import re, json, random
from pathlib import Path
from glob import glob
import numpy as np
import mne
from tqdm import tqdm
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, classification_report

# =========================
# REPRODUCIBILIDAD
# =========================
RANDOM_STATE = 42
def seed_everything(seed: int = 42):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    try:
        torch.backends.cuda.matmul.allow_tf32 = False
        torch.backends.cudnn.allow_tf32 = False
    except Exception:
        pass
    try:
        torch.use_deterministic_algorithms(True)
    except Exception:
        pass

def seed_worker(worker_id: int):
    worker_seed = RANDOM_STATE + worker_id
    np.random.seed(worker_seed)
    random.seed(worker_seed)

seed_everything(RANDOM_STATE)

# =========================
# CONFIG
# =========================
PROJ = Path('..').resolve().parent
DATA_RAW = PROJ / 'data' / 'raw'
FOLDS_JSON = PROJ / 'models' / 'folds' / 'Kfold5.json'

EPOCHS = 60
BATCH_SIZE = 64        # múltiplo de 4
BASE_LR = 1e-3
WARMUP_EPOCHS = 8
PATIENCE = 12

# Split de validación por fold (por sujetos)
VAL_SUBJECT_FRAC = 0.18  # ≈ 18% de sujetos del train → val
VAL_STRAT_SUBJECT = True # split por sujetos (no mezcla muestras de un mismo sujeto)

# Prepro
RESAMPLE_HZ = None
DO_NOTCH = True
DO_BANDPASS = False       # <- dejamos apagado para ceñirnos a 1/4/6
BP_LO, BP_HI = 4.0, 38.0
DO_CAR = False
ZSCORE_PER_EPOCH = False

# Modelo
D_MODEL = 128
N_HEADS = 4
N_LAYERS = 2
P_DROP = 0.2

# Ventana temporal
TMIN, TMAX = -1.0, 5.0

# TTA (sólo en TEST)
SW_MODE = 'tta'   # 'none'|'subwin'|'tta'
SW_ENABLE = True
TTA_SHIFTS_S = [-0.05, -0.025, 0.0, 0.025, 0.05]
SW_LEN, SW_STRIDE = 4.5, 2.0

EXCLUDE_SUBJECTS = {38, 88, 89, 92, 100, 104}
EXPECTED_8 = ['C3','C4','Cz','CP3','CP4','FC3','FC4','FCz']
CLASS_NAMES = ['left', 'right', 'both_fists', 'both_feet']

IMAGERY_RUNS_LR = {4, 8, 12}
IMAGERY_RUNS_BF = {6, 10, 14}

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🚀 Usando dispositivo: {DEVICE}")
print("🧠 INICIANDO EXPERIMENTO CON CNN+Transformer (K-Fold por sujeto como EEGNet)")
print(f"🔧 Configuración: 4c, 8 canales, 6s | EPOCHS={EPOCHS}, BATCH={BATCH_SIZE}, LR={BASE_LR} | ZSCORE_PER_EPOCH={ZSCORE_PER_EPOCH}")

# =========================
# UTILIDADES I/O
# =========================
def normalize_ch_name(name: str) -> str:
    s = re.sub(r'[^A-Za-z0-9]', '', name)
    return s.upper()

NORMALIZED_TARGETS = [normalize_ch_name(c) for c in EXPECTED_8]

def pick_8_channels(raw: mne.io.BaseRaw) -> mne.io.BaseRaw:
    chs = raw.info['ch_names']
    norm_map = {normalize_ch_name(ch): ch for ch in chs}
    picked = []
    for target_norm, target_orig in zip(NORMALIZED_TARGETS, EXPECTED_8):
        if target_norm in norm_map:
            picked.append(norm_map[target_norm])
        else:
            raise RuntimeError(f"Canal requerido '{target_orig}' no encontrado. Disponibles: {chs}")
    return raw.pick(picks=picked)

def list_subject_imagery_edfs(subject_id: str) -> list:
    subj_dir = DATA_RAW / subject_id
    edfs = []
    for r in [4, 6, 8, 10, 12, 14]:
        edfs.extend(glob(str(subj_dir / f"{subject_id}R{r:02d}.edf")))
    return sorted(edfs)

def subject_id_to_int(s: str) -> int:
    m = re.match(r'[Ss](\d+)', s)
    return int(m.group(1)) if m else -1

def load_subject_epochs(subject_id: str, resample_hz: int, do_notch: bool, do_bandpass: bool,
                        do_car: bool, bp_lo: float, bp_hi: float):
    edfs = list_subject_imagery_edfs(subject_id)
    if len(edfs) == 0:
        return np.empty((0,8,1), dtype=np.float32), np.empty((0,), dtype=int), None

    X_list, y_list, sfreq_list = [], [], []

    for edf_path in edfs:
        m = re.search(r"R(\d{2})", Path(edf_path).name)
        run = int(m.group(1)) if m else -1

        raw = mne.io.read_raw_edf(edf_path, preload=True, verbose='ERROR')
        raw = pick_8_channels(raw)

        if do_notch:
            raw.notch_filter(freqs=[60.0], picks='all', verbose='ERROR')
        if do_bandpass:
            raw.filter(l_freq=bp_lo, h_freq=bp_hi, picks='all', verbose='ERROR')
        if do_car:
            raw.set_eeg_reference('average', projection=False, verbose='ERROR')

        if resample_hz is not None and resample_hz > 0:
            raw.resample(resample_hz)
        sfreq = raw.info['sfreq']

        events, event_id = mne.events_from_annotations(raw, verbose='ERROR')
        keep = {k: v for k, v in event_id.items() if k in {'T1', 'T2'}}
        if len(keep) == 0:
            continue

        epochs = mne.Epochs(raw, events=events, event_id=keep, tmin=TMIN, tmax=TMAX,
                            baseline=None, preload=True, verbose='ERROR')
        X = epochs.get_data()

        if ZSCORE_PER_EPOCH:
            X = X.astype(np.float32)
            eps = 1e-6
            mu = X.mean(axis=2, keepdims=True)
            sd = X.std(axis=2, keepdims=True) + eps
            X = (X - mu) / sd

        ev_codes = epochs.events[:, 2]
        inv = {v: k for k, v in keep.items()}
        y_run = []
        for code in ev_codes:
            lab = inv[code]
            if run in IMAGERY_RUNS_LR:
                y_run.append(0 if lab == 'T1' else 1)
            elif run in IMAGERY_RUNS_BF:
                y_run.append(2 if lab == 'T1' else 3)
            else:
                y_run.append(-1)
        y_run = np.array(y_run, dtype=int)
        mask = y_run >= 0
        if mask.sum() == 0:
            continue

        X_list.append(X[mask])
        y_list.append(y_run[mask])
        sfreq_list.append(sfreq)

    if len(X_list) == 0:
        return np.empty((0,8,1), dtype=np.float32), np.empty((0,), dtype=int), None

    X_all = np.concatenate(X_list, axis=0).astype(np.float32)
    y_all = np.concatenate(y_list, axis=0).astype(int)

    if len(set([int(round(s)) for s in sfreq_list])) != 1:
        raise RuntimeError(f"Sampling rates inconsistentes: {sfreq_list}")

    return X_all, y_all, sfreq_list[0]

def load_fold_subjects(folds_json: Path, fold: int):
    with open(folds_json, 'r') as f:
        data = json.load(f)
    for item in data.get('folds', []):
        if int(item.get('fold', -1)) == int(fold):
            return list(item.get('train', [])), list(item.get('test', []))
    raise ValueError(f"Fold {fold} not found in {folds_json}")

def standardize_per_channel(train_X, other_X):
    C = train_X.shape[1]
    train_X = train_X.astype(np.float32)
    other_X = other_X.astype(np.float32)
    for c in range(C):
        mu = train_X[:, c, :].mean()
        sd = train_X[:, c, :].std()
        sd = sd if sd > 1e-6 else 1.0
        train_X[:, c, :] = (train_X[:, c, :] - mu) / sd
        other_X[:, c, :]  = (other_X[:, c, :] - mu) / sd
    return train_X, other_X

# =========================
# MODELO
# =========================
class DepthwiseSeparableConv(nn.Module):
    def __init__(self, in_ch, out_ch, k, s=1, p=0):
        super().__init__()
        self.dw = nn.Conv1d(in_ch, in_ch, kernel_size=k, stride=s, padding=p, groups=in_ch, bias=False)
        self.pw = nn.Conv1d(in_ch, out_ch, kernel_size=1, bias=False)
        self.bn = nn.BatchNorm1d(out_ch)
        self.act = nn.ELU()
    def forward(self, x):
        x = self.dw(x); x = self.pw(x); x = self.bn(x)
        return self.act(x)

class EEGCNNTransformer(nn.Module):
    def __init__(self, n_ch=8, n_cls=4, d_model=128, n_heads=4, n_layers=2, p_drop=0.2):
        super().__init__()
        self.conv_t = nn.Sequential(
            nn.Conv1d(n_ch, 32, kernel_size=129, stride=2, padding=64, bias=False),
            nn.BatchNorm1d(32), nn.ELU(),
            DepthwiseSeparableConv(32, 64, k=31, s=2, p=15),
            DepthwiseSeparableConv(64, 128, k=15, s=2, p=7),
        )
        self.proj = nn.Conv1d(128, d_model, kernel_size=1, bias=False)
        self.dropout = nn.Dropout(p=p_drop)
        self.pos_encoding = None
        enc = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=n_heads, dim_feedforward=2*d_model,
            batch_first=True, activation='gelu', dropout=0.1, norm_first=True
        )
        self.encoder = nn.TransformerEncoder(enc, num_layers=n_layers)
        self.cls = nn.Parameter(torch.zeros(1, 1, d_model))
        nn.init.normal_(self.cls, std=0.02)
        self.head = nn.Sequential(nn.LayerNorm(d_model), nn.Linear(d_model, n_cls))

    def _positional_encoding(self, L, d):
        pos = torch.arange(0, L, dtype=torch.float32).unsqueeze(1)
        i   = torch.arange(0, d, dtype=torch.float32).unsqueeze(0)
        angle = pos / torch.pow(10000, (2 * (i//2)) / d)
        pe = torch.zeros(L, d, dtype=torch.float32)
        pe[:, 0::2] = torch.sin(angle[:, 0::2])
        pe[:, 1::2] = torch.cos(angle[:, 1::2])
        return pe

    def forward(self, x):
        z = self.conv_t(x)           # (B, 128, T')
        z = self.proj(z)             # (B, d_model, T')
        z = self.dropout(z)
        z = z.transpose(1, 2)        # (B, T', d_model)
        B, L, D = z.shape
        if (self.pos_encoding is None) or (self.pos_encoding.shape[0] != L) or (self.pos_encoding.shape[1] != D):
            self.pos_encoding = self._positional_encoding(L, D).to(z.device)
        z = z + self.pos_encoding[None, :, :]
        cls_tok = self.cls.expand(B, -1, -1)
        z = torch.cat([cls_tok, z], dim=1)
        z = self.encoder(z)
        cls = z[:, 0, :]
        return self.head(cls)

# =========================
# (1) FOCAL LOSS
# =========================
class FocalLoss(nn.Module):
    def __init__(self, alpha: torch.Tensor, gamma: float = 1.5, reduction: str = 'mean'):
        super().__init__()
        self.alpha = alpha / alpha.sum()
        self.gamma = gamma
        self.reduction = reduction
    def forward(self, logits, target):
        logp = nn.functional.log_softmax(logits, dim=-1)      # (B,C)
        p = logp.exp()
        idx = torch.arange(target.shape[0], device=logits.device)
        pt = p[idx, target]
        logpt = logp[idx, target]
        at = self.alpha[target]
        loss = - at * ((1 - pt) ** self.gamma) * logpt
        if self.reduction == 'mean': return loss.mean()
        if self.reduction == 'sum':  return loss.sum()
        return loss

# =========================
# (4) AUGMENTS MÁS FUERTES (pero estables)
# =========================
def augment_batch(
    xb,
    p_jitter=0.35, p_noise=0.35, p_chdrop=0.15,
    max_jitter_frac=0.03, noise_std=0.03, max_chdrop=1
):
    """
    - jitter: shift temporal pequeño (±3% de la longitud)
    - noise: ruido gaussiano leve
    - chdrop: apagado de 1 canal aleatorio
    """
    B, C, T = xb.shape
    if np.random.rand() < p_jitter:
        max_shift = int(max(1, T*max_jitter_frac))
        shifts = torch.randint(low=-max_shift, high=max_shift+1, size=(B,), device=xb.device)
        for i in range(B):
            xb[i] = torch.roll(xb[i], shifts=int(shifts[i].item()), dims=-1)
    if np.random.rand() < p_noise:
        xb = xb + noise_std*torch.randn_like(xb)
    if np.random.rand() < p_chdrop and max_chdrop > 0:
        k = min(max_chdrop, C)
        for i in range(B):
            idx = torch.randperm(C, device=xb.device)[:k]
            xb[i, idx, :] = 0.0
    return xb

# =========================
# INFERENCIA TTA / SUBWINDOW
# =========================
def subwindow_logits(model, X, sfreq, sw_len, sw_stride, device):
    model.eval()
    wl = int(round(sw_len * sfreq))
    st = int(round(sw_stride * sfreq))
    wl = max(1, min(wl, X.shape[-1])); st = max(1, st)
    out = []
    with torch.no_grad():
        for i in range(X.shape[0]):
            x = X[i]; acc = []
            for s in range(0, max(1, X.shape[-1]-wl+1), st):
                seg = x[:, s:s+wl]
                if seg.shape[-1] < wl:
                    pad = wl - seg.shape[-1]
                    seg = np.pad(seg, ((0,0),(0,pad)), mode='edge')
                xb = torch.tensor(seg[None, ...], dtype=torch.float32, device=device)
                logit = model(xb).detach().cpu().numpy()[0]
                acc.append(logit)
            acc = np.mean(np.stack(acc, axis=0), axis=0) if len(acc) else np.zeros(4, dtype=np.float32)
            out.append(acc)
    return np.stack(out, axis=0)

def time_shift_tta_logits(model, X, sfreq, shifts_s, device):
    model.eval()
    T = X.shape[-1]; out = []
    with torch.no_grad():
        for i in range(X.shape[0]):
            x0 = X[i]; acc = []
            for sh in shifts_s:
                shift = int(round(sh * sfreq))
                if shift == 0:
                    x = x0
                elif shift > 0:
                    x = np.pad(x0[:, shift:], ((0,0),(0,shift)), mode='edge')[:, :T]
                else:
                    shift = -shift
                    x = np.pad(x0[:, :-shift], ((0,0),(shift,0)), mode='edge')[:, :T]
                xb = torch.tensor(x[None, ...], dtype=torch.float32, device=device)
                logit = model(xb).detach().cpu().numpy()[0]
                acc.append(logit)
            out.append(np.mean(np.stack(acc, axis=0), axis=0))
    return np.stack(out, axis=0)

# =========================
# TRAIN/EVAL por FOLD (con train/val/test)
# =========================
def train_one_fold(fold:int, device):
    # --- sujetos por fold ---
    def load_fold_subjects_local(folds_json: Path, fold: int):
        return load_fold_subjects(folds_json, fold)

    train_sub, test_sub = load_fold_subjects_local(FOLDS_JSON, fold)
    train_sub = [s for s in train_sub if subject_id_to_int(s) not in EXCLUDE_SUBJECTS]
    test_sub  = [s for s in test_sub  if subject_id_to_int(s) not in EXCLUDE_SUBJECTS]

    # Split de validación por sujetos (determinista)
    rng = random.Random(RANDOM_STATE + fold)
    tr_subjects = train_sub.copy()
    rng.shuffle(tr_subjects)
    n_val_subj = max(1, int(round(len(tr_subjects) * VAL_SUBJECT_FRAC)))
    val_subjects = sorted(tr_subjects[:n_val_subj])
    train_subjects = sorted(tr_subjects[n_val_subj:])

    # Carga TRAIN/VAL/TEST
    X_tr_list, y_tr_list, sub_tr_list = [], [], []
    X_val_list, y_val_list, sub_val_list = [], [], []
    X_te_list, y_te_list, sub_te_list = [], [], []
    sfreq = None

    for sid in tqdm(train_subjects, desc=f"Cargando train fold{fold}"):
        Xs, ys, sf = load_subject_epochs(sid, RESAMPLE_HZ, DO_NOTCH, DO_BANDPASS, DO_CAR, BP_LO, BP_HI)
        if len(ys) == 0: continue
        X_tr_list.append(Xs); y_tr_list.append(ys)
        sub_tr_list.append(np.full_like(ys, fill_value=subject_id_to_int(sid)))
        sfreq = sf if sfreq is None else sfreq

    for sid in tqdm(val_subjects, desc=f"Cargando val fold{fold}"):
        Xs, ys, sf = load_subject_epochs(sid, RESAMPLE_HZ, DO_NOTCH, DO_BANDPASS, DO_CAR, BP_LO, BP_HI)
        if len(ys) == 0: continue
        X_val_list.append(Xs); y_val_list.append(ys)
        sub_val_list.append(np.full_like(ys, fill_value=subject_id_to_int(sid)))
        sfreq = sf if sfreq is None else sfreq

    for sid in tqdm(test_sub, desc=f"Cargando test fold{fold}"):
        Xs, ys, sf = load_subject_epochs(sid, RESAMPLE_HZ, DO_NOTCH, DO_BANDPASS, DO_CAR, BP_LO, BP_HI)
        if len(ys) == 0: continue
        X_te_list.append(Xs); y_te_list.append(ys)
        sub_te_list.append(np.full_like(ys, fill_value=subject_id_to_int(sid)))
        sfreq = sf if sfreq is None else sfreq

    # Concatenar
    X_tr = np.concatenate(X_tr_list, axis=0); y_tr = np.concatenate(y_tr_list, axis=0)
    sub_tr = np.concatenate(sub_tr_list, axis=0)
    X_val = np.concatenate(X_val_list, axis=0); y_val = np.concatenate(y_val_list, axis=0)
    sub_val = np.concatenate(sub_val_list, axis=0)
    X_te = np.concatenate(X_te_list, axis=0); y_te = np.concatenate(y_te_list, axis=0)
    sub_te = np.concatenate(sub_te_list, axis=0)

    print(f"[Fold {fold}/5] Entrenando modelo global... (n_train={len(y_tr)} | n_val={len(y_val)} | n_test={len(y_te)})")

    # Normalización por canal (fit en TRAIN y aplicar a VAL/TEST)
    if ZSCORE_PER_EPOCH:
        X_tr_std, X_val_std, X_te_std = X_tr, X_val, X_te
    else:
        X_tr_std, X_val_std = standardize_per_channel(X_tr, X_val)
        _,        X_te_std  = standardize_per_channel(X_tr, X_te)  # reutiliza stats de train

    # Datasets + DataLoaders (sin sampler especial, solo shuffle en TRAIN)
    tr_ds  = TensorDataset(torch.tensor(X_tr_std),  torch.tensor(y_tr).long(),  torch.tensor(sub_tr).long())
    val_ds = TensorDataset(torch.tensor(X_val_std), torch.tensor(y_val).long(), torch.tensor(sub_val).long())
    te_ds  = TensorDataset(torch.tensor(X_te_std),  torch.tensor(y_te).long(),  torch.tensor(sub_te).long())

    tr_ld  = DataLoader(tr_ds, batch_size=BATCH_SIZE, shuffle=True,  drop_last=False, worker_init_fn=seed_worker)
    val_ld = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, worker_init_fn=seed_worker)
    te_ld  = DataLoader(te_ds,  batch_size=BATCH_SIZE, shuffle=False, drop_last=False, worker_init_fn=seed_worker)

    # Modelo
    model = EEGCNNTransformer(n_ch=8, n_cls=4, d_model=D_MODEL, n_heads=N_HEADS,
                              n_layers=N_LAYERS, p_drop=P_DROP).to(device)

    # Optimizador + (1) Focal Loss (ligero upweight a both_fists)
    opt = torch.optim.AdamW(model.parameters(), lr=BASE_LR, weight_decay=1e-2)

    class_counts = np.bincount(y_tr, minlength=4).astype(np.float32)
    inv = class_counts.sum() / (4.0 * np.maximum(class_counts, 1.0))
    alpha = torch.tensor(inv, dtype=torch.float32, device=device)
    # Aumentamos un poco el peso de both_fists para mejorar su recall
    alpha_mean = alpha.mean().item()
    alpha[2] = 1.5 * alpha_mean
    crit = FocalLoss(alpha=alpha, gamma=1.5, reduction='mean')

    # (6) Warmup+Cosine
    from torch.optim.lr_scheduler import LambdaLR
    total_epochs = EPOCHS
    warmup_epochs = max(1, int(WARMUP_EPOCHS))
    min_factor = 0.1
    def lr_lambda(current_epoch):
        if current_epoch < warmup_epochs:
            return (current_epoch + 1) / warmup_epochs
        progress = (current_epoch - warmup_epochs) / max(1, (total_epochs - warmup_epochs))
        progress = min(1.0, max(0.0, progress))
        return min_factor + 0.5 * (1.0 - min_factor) * (1.0 + np.cos(np.pi * progress))
    scheduler = LambdaLR(opt, lr_lambda=lr_lambda)

    # Entrenamiento con early stopping por F1 macro (en VAL)
    best_f1, best_state, wait = 0.0, None, 0
    hist = {"ep": [], "tr_loss": [], "tr_acc": [], "val_acc": [], "val_f1m": [], "lr": []}

    for ep in range(1, EPOCHS+1):
        # ---- Train ----
        model.train()
        tr_loss, n_seen, tr_correct = 0.0, 0, 0
        for xb, yb, _sb in tr_ld:
            xb = xb.to(device); yb = yb.to(device)
            xb = augment_batch(xb)  # (4) aumentos más fuertes
            opt.zero_grad()
            logits = model(xb)
            loss = crit(logits, yb)
            loss.backward()
            opt.step()
            tr_loss += loss.item() * len(yb)
            n_seen += len(yb)
            tr_correct += (logits.argmax(1) == yb).sum().item()
        tr_loss /= max(1, n_seen)
        tr_acc = tr_correct / max(1, n_seen)

        # ---- Val ----
        model.eval()
        preds, gts = [], []
        with torch.no_grad():
            for xb, yb, _sb in val_ld:
                xb = xb.to(device)
                p = model(xb).argmax(dim=1).cpu().numpy()
                preds.append(p); gts.append(yb.numpy())
        preds = np.concatenate(preds); gts = np.concatenate(gts)
        acc = accuracy_score(gts, preds)
        f1m = f1_score(gts, preds, average='macro')

        hist["ep"].append(ep)
        hist["tr_loss"].append(tr_loss)
        hist["tr_acc"].append(tr_acc)
        hist["val_acc"].append(acc)
        hist["val_f1m"].append(f1m)
        hist["lr"].append(scheduler.get_last_lr()[0])

        print(f"  Época {ep:3d} | train_loss={tr_loss:.4f} | train_acc={tr_acc:.4f} | val_acc={acc:.4f} | val_f1m={f1m:.4f} | LR={scheduler.get_last_lr()[0]:.6f}")

        improved = f1m > best_f1 + 1e-4
        if improved:
            best_f1 = f1m
            best_state = {k: v.detach().cpu() for k, v in model.state_dict().items()}
            wait = 0
        else:
            wait += 1

        scheduler.step()
        if wait >= PATIENCE:
            print(f"  Early stopping en época {ep} (mejor val_f1m={best_f1:.4f})")
            break

    # Guarda mejor estado
    if best_state is not None:
        model.load_state_dict(best_state)

    # ---- Guardar curva ----
    fig = plt.figure(figsize=(8,4.5))
    ax1 = plt.gca()
    ax1.plot(hist["ep"], hist["tr_loss"], label="train_loss")
    ax1.plot(hist["ep"], hist["val_f1m"], label="val_f1m")
    ax1.plot(hist["ep"], hist["val_acc"], label="val_acc")
    ax1.set_xlabel("Época"); ax1.set_title(f"Fold {fold} — Curva de entrenamiento")
    ax1.legend(); ax1.grid(True, alpha=0.3)
    out_png = f"training_curve_fold{fold}.png"
    plt.tight_layout()
    plt.savefig(out_png, dpi=140)
    plt.close(fig)
    print(f"↳ Curva de entrenamiento guardada: {out_png}")

    # ---- Evaluación final en TEST (con TTA/subwindows si aplica) ----
    model.eval()
    sfreq_used = RESAMPLE_HZ
    if sfreq_used is None:
        sfreq_used = int(round(X_te_std.shape[-1] / (TMAX - TMIN)))

    if (not SW_ENABLE) or SW_MODE == 'none':
        preds, gts = [], []
        with torch.no_grad():
            for xb, yb, _sb in te_ld:
                xb = xb.to(device)
                p = model(xb).argmax(dim=1).cpu().numpy()
                preds.append(p); gts.append(yb.numpy())
        preds = np.concatenate(preds); gts = np.concatenate(gts)
    elif SW_MODE == 'subwin':
        logits = subwindow_logits(model, X_te_std, sfreq_used, SW_LEN, SW_STRIDE, device)
        preds = logits.argmax(axis=1); gts = y_te
    elif SW_MODE == 'tta':
        logits = time_shift_tta_logits(model, X_te_std, sfreq_used, TTA_SHIFTS_S, device)
        preds = logits.argmax(axis=1); gts = y_te
    else:
        raise ValueError(f"SW_MODE desconocido: {SW_MODE}")

    acc = accuracy_score(gts, preds)
    f1m = f1_score(gts, preds, average='macro')
    print(f"[Fold {fold}/5] Global acc={acc:.4f}\n")
    print(classification_report(gts, preds, target_names=[c.replace('_',' ') for c in CLASS_NAMES], digits=4))
    print("Confusion matrix (rows=true, cols=pred):")
    print(confusion_matrix(gts, preds, labels=[0,1,2,3]))

    return acc, f1m

# =========================
# LOOP 5 FOLDS + RESUMEN
# =========================
if __name__ == "__main__":
    acc_folds, f1_folds = [], []
    for fold in range(1, 6):
        acc, f1m = train_one_fold(fold, DEVICE)
        acc_folds.append(f"{acc:.4f}")
        f1_folds.append(f"{f1m:.4f}")

    acc_mean = float(np.mean([float(a) for a in acc_folds]))
    f1_mean  = float(np.mean([float(f) for f in f1_folds]))

    print("\n============================================================")
    print("RESULTADOS FINALES")
    print("============================================================")
    print(f"Global folds (ACC): {acc_folds}")
    print(f"Global mean ACC: {acc_mean:.4f}")
    print(f"F1 folds (MACRO): {f1_folds}")
    print(f"F1 mean (MACRO): {f1_mean:.4f}")


🚀 Usando dispositivo: cuda
🧠 INICIANDO EXPERIMENTO CON CNN+Transformer (K-Fold por sujeto como EEGNet)
🔧 Configuración: 4c, 8 canales, 6s | EPOCHS=60, BATCH=64, LR=0.001 | ZSCORE_PER_EPOCH=False


Cargando train fold1:   0%|          | 0/67 [00:00<?, ?it/s]

Cargando train fold1: 100%|██████████| 67/67 [00:08<00:00,  8.31it/s]
Cargando val fold1: 100%|██████████| 15/15 [00:01<00:00,  8.37it/s]
Cargando test fold1: 100%|██████████| 21/21 [00:02<00:00,  8.34it/s]


[Fold 1/5] Entrenando modelo global... (n_train=5628 | n_val=1260 | n_test=1764)




  Época   1 | train_loss=0.2252 | train_acc=0.2623 | val_acc=0.3333 | val_f1m=0.2658 | LR=0.000125
  Época   2 | train_loss=0.2027 | train_acc=0.3680 | val_acc=0.4857 | val_f1m=0.4747 | LR=0.000250
  Época   3 | train_loss=0.1827 | train_acc=0.4572 | val_acc=0.5286 | val_f1m=0.5237 | LR=0.000375
  Época   4 | train_loss=0.1775 | train_acc=0.4893 | val_acc=0.4968 | val_f1m=0.5010 | LR=0.000500
  Época   5 | train_loss=0.1710 | train_acc=0.5060 | val_acc=0.5476 | val_f1m=0.5431 | LR=0.000625
  Época   6 | train_loss=0.1712 | train_acc=0.4998 | val_acc=0.5373 | val_f1m=0.5388 | LR=0.000750
  Época   7 | train_loss=0.1694 | train_acc=0.5089 | val_acc=0.5222 | val_f1m=0.5286 | LR=0.000875
  Época   8 | train_loss=0.1691 | train_acc=0.5052 | val_acc=0.5016 | val_f1m=0.4929 | LR=0.001000
  Época   9 | train_loss=0.1673 | train_acc=0.5171 | val_acc=0.5135 | val_f1m=0.5114 | LR=0.001000
  Época  10 | train_loss=0.1623 | train_acc=0.5362 | val_acc=0.5159 | val_f1m=0.5154 | LR=0.000999
  Época  1

Cargando train fold2: 100%|██████████| 67/67 [00:07<00:00,  8.39it/s]
Cargando val fold2: 100%|██████████| 15/15 [00:01<00:00,  8.38it/s]
Cargando test fold2: 100%|██████████| 21/21 [00:02<00:00,  8.41it/s]


[Fold 2/5] Entrenando modelo global... (n_train=5628 | n_val=1260 | n_test=1764)




  Época   1 | train_loss=0.2269 | train_acc=0.2653 | val_acc=0.3000 | val_f1m=0.1935 | LR=0.000125
  Época   2 | train_loss=0.2049 | train_acc=0.3605 | val_acc=0.4667 | val_f1m=0.4441 | LR=0.000250
  Época   3 | train_loss=0.1860 | train_acc=0.4449 | val_acc=0.5190 | val_f1m=0.5179 | LR=0.000375
  Época   4 | train_loss=0.1823 | train_acc=0.4582 | val_acc=0.4889 | val_f1m=0.4842 | LR=0.000500
  Época   5 | train_loss=0.1785 | train_acc=0.4712 | val_acc=0.4992 | val_f1m=0.4978 | LR=0.000625
  Época   6 | train_loss=0.1795 | train_acc=0.4670 | val_acc=0.4762 | val_f1m=0.4758 | LR=0.000750
  Época   7 | train_loss=0.1777 | train_acc=0.4739 | val_acc=0.5008 | val_f1m=0.5046 | LR=0.000875
  Época   8 | train_loss=0.1750 | train_acc=0.4790 | val_acc=0.4952 | val_f1m=0.5000 | LR=0.001000
  Época   9 | train_loss=0.1731 | train_acc=0.4956 | val_acc=0.5024 | val_f1m=0.5077 | LR=0.001000
  Época  10 | train_loss=0.1683 | train_acc=0.5020 | val_acc=0.4810 | val_f1m=0.4832 | LR=0.000999
  Época  1

Cargando train fold3: 100%|██████████| 67/67 [00:08<00:00,  8.37it/s]
Cargando val fold3: 100%|██████████| 15/15 [00:01<00:00,  8.37it/s]
Cargando test fold3: 100%|██████████| 21/21 [00:02<00:00,  8.26it/s]


[Fold 3/5] Entrenando modelo global... (n_train=5628 | n_val=1260 | n_test=1764)




  Época   1 | train_loss=0.2267 | train_acc=0.2573 | val_acc=0.2508 | val_f1m=0.1168 | LR=0.000125
  Época   2 | train_loss=0.2075 | train_acc=0.3527 | val_acc=0.3770 | val_f1m=0.3656 | LR=0.000250
  Época   3 | train_loss=0.1819 | train_acc=0.4677 | val_acc=0.4127 | val_f1m=0.4029 | LR=0.000375
  Época   4 | train_loss=0.1715 | train_acc=0.5009 | val_acc=0.4413 | val_f1m=0.4353 | LR=0.000500
  Época   5 | train_loss=0.1680 | train_acc=0.5092 | val_acc=0.4405 | val_f1m=0.4346 | LR=0.000625
  Época   6 | train_loss=0.1712 | train_acc=0.4941 | val_acc=0.4603 | val_f1m=0.4574 | LR=0.000750
  Época   7 | train_loss=0.1654 | train_acc=0.5185 | val_acc=0.4286 | val_f1m=0.4211 | LR=0.000875
  Época   8 | train_loss=0.1628 | train_acc=0.5270 | val_acc=0.4651 | val_f1m=0.4652 | LR=0.001000
  Época   9 | train_loss=0.1575 | train_acc=0.5371 | val_acc=0.4492 | val_f1m=0.4514 | LR=0.001000
  Época  10 | train_loss=0.1560 | train_acc=0.5464 | val_acc=0.4738 | val_f1m=0.4755 | LR=0.000999
  Época  1

Cargando train fold4: 100%|██████████| 68/68 [00:08<00:00,  8.45it/s]
Cargando val fold4: 100%|██████████| 15/15 [00:01<00:00,  8.47it/s]
Cargando test fold4: 100%|██████████| 20/20 [00:02<00:00,  8.39it/s]


[Fold 4/5] Entrenando modelo global... (n_train=5712 | n_val=1260 | n_test=1680)




  Época   1 | train_loss=0.2283 | train_acc=0.2651 | val_acc=0.2762 | val_f1m=0.1801 | LR=0.000125
  Época   2 | train_loss=0.2025 | train_acc=0.3741 | val_acc=0.3690 | val_f1m=0.3387 | LR=0.000250
  Época   3 | train_loss=0.1819 | train_acc=0.4641 | val_acc=0.4468 | val_f1m=0.4394 | LR=0.000375
  Época   4 | train_loss=0.1769 | train_acc=0.4842 | val_acc=0.4476 | val_f1m=0.4466 | LR=0.000500
  Época   5 | train_loss=0.1728 | train_acc=0.4902 | val_acc=0.4492 | val_f1m=0.4452 | LR=0.000625
  Época   6 | train_loss=0.1741 | train_acc=0.4951 | val_acc=0.4579 | val_f1m=0.4571 | LR=0.000750
  Época   7 | train_loss=0.1696 | train_acc=0.4988 | val_acc=0.4317 | val_f1m=0.4049 | LR=0.000875
  Época   8 | train_loss=0.1722 | train_acc=0.5056 | val_acc=0.4103 | val_f1m=0.3774 | LR=0.001000
  Época   9 | train_loss=0.1647 | train_acc=0.5242 | val_acc=0.4627 | val_f1m=0.4630 | LR=0.001000
  Época  10 | train_loss=0.1637 | train_acc=0.5264 | val_acc=0.4810 | val_f1m=0.4755 | LR=0.000999
  Época  1

Cargando train fold5: 100%|██████████| 68/68 [00:08<00:00,  8.35it/s]
Cargando val fold5: 100%|██████████| 15/15 [00:01<00:00,  8.42it/s]
Cargando test fold5: 100%|██████████| 20/20 [00:02<00:00,  8.39it/s]


[Fold 5/5] Entrenando modelo global... (n_train=5712 | n_val=1260 | n_test=1680)




  Época   1 | train_loss=0.2259 | train_acc=0.2584 | val_acc=0.2905 | val_f1m=0.1841 | LR=0.000125
  Época   2 | train_loss=0.2056 | train_acc=0.3554 | val_acc=0.4302 | val_f1m=0.4172 | LR=0.000250
  Época   3 | train_loss=0.1902 | train_acc=0.4361 | val_acc=0.4786 | val_f1m=0.4278 | LR=0.000375
  Época   4 | train_loss=0.1858 | train_acc=0.4531 | val_acc=0.5008 | val_f1m=0.4907 | LR=0.000500
  Época   5 | train_loss=0.1816 | train_acc=0.4589 | val_acc=0.5302 | val_f1m=0.5374 | LR=0.000625
  Época   6 | train_loss=0.1802 | train_acc=0.4697 | val_acc=0.5349 | val_f1m=0.5353 | LR=0.000750
  Época   7 | train_loss=0.1802 | train_acc=0.4596 | val_acc=0.5341 | val_f1m=0.5223 | LR=0.000875
  Época   8 | train_loss=0.1764 | train_acc=0.4767 | val_acc=0.5524 | val_f1m=0.5572 | LR=0.001000
  Época   9 | train_loss=0.1773 | train_acc=0.4690 | val_acc=0.5302 | val_f1m=0.5305 | LR=0.001000
  Época  10 | train_loss=0.1756 | train_acc=0.4723 | val_acc=0.5103 | val_f1m=0.5103 | LR=0.000999
  Época  1

In [1]:
# -*- coding: utf-8 -*-
# CNN+Transformer para MI (4 clases, 8 canales) con:
# (1) Focal Loss
# (4) Aumentos más fuertes y estables
# (6) LR warmup+cosine, early stopping por F1 macro
# Cambios propuestos:
# - GroupNorm en conv (en lugar de BN)
# - EMA de pesos para val/test
# - Balanced sampler por sujeto/clase (WeightedRandomSampler)
# - Dropout antes del encoder = 0.3
# - Cosine min_factor=0.2
# - Ensemble TTA + Subwindows en test (opcional)
# - Val estratificada por sujeto (opcional)

import os
os.environ['PYTHONHASHSEED'] = '42'
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'

import re, json, random
from pathlib import Path
from glob import glob
import numpy as np
import mne
from tqdm import tqdm
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset, WeightedRandomSampler
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, classification_report
from sklearn.model_selection import StratifiedShuffleSplit

# =========================
# REPRODUCIBILIDAD
# =========================
RANDOM_STATE = 42
def seed_everything(seed: int = 42):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    try:
        torch.backends.cuda.matmul.allow_tf32 = False
        torch.backends.cudnn.allow_tf32 = False
    except Exception:
        pass
    try:
        torch.use_deterministic_algorithms(True)
    except Exception:
        pass

def seed_worker(worker_id: int):
    worker_seed = RANDOM_STATE + worker_id
    np.random.seed(worker_seed)
    random.seed(worker_seed)

seed_everything(RANDOM_STATE)

# =========================
# CONFIG
# =========================
PROJ = Path('..').resolve().parent
DATA_RAW = PROJ / 'data' / 'raw'
FOLDS_JSON = PROJ / 'models' / 'folds' / 'Kfold5.json'

EPOCHS = 60
BATCH_SIZE = 64        # múltiplo de 4
BASE_LR = 1e-3
WARMUP_EPOCHS = 8
PATIENCE = 12

# Split de validación por fold (por sujetos)
VAL_SUBJECT_FRAC = 0.18  # ≈ 18% de sujetos del train → val
VAL_STRAT_SUBJECT = True # <- ACTIVADO: estratifica por etiqueta dominante de cada sujeto

# Prepro
RESAMPLE_HZ = None
DO_NOTCH = True
DO_BANDPASS = False       # dejamos apagado para respetar tu setup original
BP_LO, BP_HI = 4.0, 38.0
DO_CAR = False
ZSCORE_PER_EPOCH = False

# Modelo
D_MODEL = 128
N_HEADS = 4
N_LAYERS = 2
P_DROP = 0.2            # dropout en conv stack
P_DROP_ENCODER = 0.3    # dropout antes del encoder (↑ respecto a 0.2)

# Ventana temporal
TMIN, TMAX = -1.0, 5.0

# TTA / SUBWINDOW en TEST
SW_MODE = 'tta'   # 'none'|'subwin'|'tta'
SW_ENABLE = True
TTA_SHIFTS_S = [-0.05, -0.025, 0.0, 0.025, 0.05]
SW_LEN, SW_STRIDE = 4.5, 2.0
COMBINE_TTA_AND_SUBWIN = True  # Promedia logits TTA y Subwindow (50/50) si True

# Sampler balanceado
USE_WEIGHTED_SAMPLER = True

# EMA
USE_EMA = True
EMA_DECAY = 0.999  # 0.999 ~ suave y efectivo

EXCLUDE_SUBJECTS = {38, 88, 89, 92, 100, 104}
EXPECTED_8 = ['C3','C4','Cz','CP3','CP4','FC3','FC4','FCz']
CLASS_NAMES = ['left', 'right', 'both_fists', 'both_feet']

IMAGERY_RUNS_LR = {4, 8, 12}
IMAGERY_RUNS_BF = {6, 10, 14}

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🚀 Usando dispositivo: {DEVICE}")
print("🧠 INICIANDO EXPERIMENTO CON CNN+Transformer (K-Fold por sujeto como EEGNet)")
print(f"🔧 Configuración: 4c, 8 canales, 6s | EPOCHS={EPOCHS}, BATCH={BATCH_SIZE}, LR={BASE_LR} | ZSCORE_PER_EPOCH={ZSCORE_PER_EPOCH}")

# =========================
# UTILIDADES I/O
# =========================
def normalize_ch_name(name: str) -> str:
    s = re.sub(r'[^A-Za-z0-9]', '', name)
    return s.upper()

NORMALIZED_TARGETS = [normalize_ch_name(c) for c in EXPECTED_8]

def pick_8_channels(raw: mne.io.BaseRaw) -> mne.io.BaseRaw:
    chs = raw.info['ch_names']
    norm_map = {normalize_ch_name(ch): ch for ch in chs}
    picked = []
    for target_norm, target_orig in zip(NORMALIZED_TARGETS, EXPECTED_8):
        if target_norm in norm_map:
            picked.append(norm_map[target_norm])
        else:
            raise RuntimeError(f"Canal requerido '{target_orig}' no encontrado. Disponibles: {chs}")
    return raw.pick(picks=picked)

def list_subject_imagery_edfs(subject_id: str) -> list:
    subj_dir = DATA_RAW / subject_id
    edfs = []
    for r in [4, 6, 8, 10, 12, 14]:
        edfs.extend(glob(str(subj_dir / f"{subject_id}R{r:02d}.edf")))
    return sorted(edfs)

def subject_id_to_int(s: str) -> int:
    m = re.match(r'[Ss](\d+)', s)
    return int(m.group(1)) if m else -1

def load_subject_epochs(subject_id: str, resample_hz: int, do_notch: bool, do_bandpass: bool,
                        do_car: bool, bp_lo: float, bp_hi: float):
    edfs = list_subject_imagery_edfs(subject_id)
    if len(edfs) == 0:
        return np.empty((0,8,1), dtype=np.float32), np.empty((0,), dtype=int), None

    X_list, y_list, sfreq_list = [], [], []

    for edf_path in edfs:
        m = re.search(r"R(\d{2})", Path(edf_path).name)
        run = int(m.group(1)) if m else -1

        raw = mne.io.read_raw_edf(edf_path, preload=True, verbose='ERROR')
        raw = pick_8_channels(raw)

        if do_notch:
            raw.notch_filter(freqs=[60.0], picks='all', verbose='ERROR')
        if do_bandpass:
            raw.filter(l_freq=bp_lo, h_freq=bp_hi, picks='all', verbose='ERROR')
        if do_car:
            raw.set_eeg_reference('average', projection=False, verbose='ERROR')

        if resample_hz is not None and resample_hz > 0:
            raw.resample(resample_hz)
        sfreq = raw.info['sfreq']

        events, event_id = mne.events_from_annotations(raw, verbose='ERROR')
        keep = {k: v for k, v in event_id.items() if k in {'T1', 'T2'}}
        if len(keep) == 0:
            continue

        epochs = mne.Epochs(raw, events=events, event_id=keep, tmin=TMIN, tmax=TMAX,
                            baseline=None, preload=True, verbose='ERROR')
        X = epochs.get_data()

        if ZSCORE_PER_EPOCH:
            X = X.astype(np.float32)
            eps = 1e-6
            mu = X.mean(axis=2, keepdims=True)
            sd = X.std(axis=2, keepdims=True) + eps
            X = (X - mu) / sd

        ev_codes = epochs.events[:, 2]
        inv = {v: k for k, v in keep.items()}
        y_run = []
        for code in ev_codes:
            lab = inv[code]
            if run in IMAGERY_RUNS_LR:
                y_run.append(0 if lab == 'T1' else 1)
            elif run in IMAGERY_RUNS_BF:
                y_run.append(2 if lab == 'T1' else 3)
            else:
                y_run.append(-1)
        y_run = np.array(y_run, dtype=int)
        mask = y_run >= 0
        if mask.sum() == 0:
            continue

        X_list.append(X[mask])
        y_list.append(y_run[mask])
        sfreq_list.append(sfreq)

    if len(X_list) == 0:
        return np.empty((0,8,1), dtype=np.float32), np.empty((0,), dtype=int), None

    X_all = np.concatenate(X_list, axis=0).astype(np.float32)
    y_all = np.concatenate(y_list, axis=0).astype(int)

    if len(set([int(round(s)) for s in sfreq_list])) != 1:
        raise RuntimeError(f"Sampling rates inconsistentes: {sfreq_list}")

    return X_all, y_all, sfreq_list[0]

def load_fold_subjects(folds_json: Path, fold: int):
    with open(folds_json, 'r') as f:
        data = json.load(f)
    for item in data.get('folds', []):
        if int(item.get('fold', -1)) == int(fold):
            return list(item.get('train', [])), list(item.get('test', []))
    raise ValueError(f"Fold {fold} not found in {folds_json}")

def standardize_per_channel(train_X, other_X):
    C = train_X.shape[1]
    train_X = train_X.astype(np.float32)
    other_X = other_X.astype(np.float32)
    for c in range(C):
        mu = train_X[:, c, :].mean()
        sd = train_X[:, c, :].std()
        sd = sd if sd > 1e-6 else 1.0
        train_X[:, c, :] = (train_X[:, c, :] - mu) / sd
        other_X[:, c, :]  = (other_X[:, c, :] - mu) / sd
    return train_X, other_X

# =========================
# MODELO (GroupNorm en conv)
# =========================
def make_gn(num_channels, num_groups=8):
    # Ajusta grupos para que dividan a num_channels
    g = min(num_groups, num_channels)
    while num_channels % g != 0 and g > 1:
        g -= 1
    return nn.GroupNorm(g, num_channels)

class DepthwiseSeparableConv(nn.Module):
    def __init__(self, in_ch, out_ch, k, s=1, p=0, p_drop=0.2):
        super().__init__()
        self.dw = nn.Conv1d(in_ch, in_ch, kernel_size=k, stride=s, padding=p, groups=in_ch, bias=False)
        self.pw = nn.Conv1d(in_ch, out_ch, kernel_size=1, bias=False)
        self.norm = make_gn(out_ch)
        self.act = nn.ELU()
        self.dropout = nn.Dropout(p=p_drop)
    def forward(self, x):
        x = self.dw(x); x = self.pw(x); x = self.norm(x)
        x = self.act(x); x = self.dropout(x)
        return x

class EEGCNNTransformer(nn.Module):
    def __init__(self, n_ch=8, n_cls=4, d_model=128, n_heads=4, n_layers=2,
                 p_drop=0.2, p_drop_encoder=0.3):
        super().__init__()
        self.conv_t = nn.Sequential(
            nn.Conv1d(n_ch, 32, kernel_size=129, stride=2, padding=64, bias=False),
            make_gn(32),
            nn.ELU(),
            nn.Dropout(p=p_drop),
            DepthwiseSeparableConv(32, 64, k=31, s=2, p=15, p_drop=p_drop),
            DepthwiseSeparableConv(64, 128, k=15, s=2, p=7,  p_drop=p_drop),
        )
        self.proj = nn.Conv1d(128, d_model, kernel_size=1, bias=False)
        self.dropout = nn.Dropout(p=p_drop_encoder)  # ↑ 0.3
        self.pos_encoding = None
        enc = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=n_heads, dim_feedforward=2*d_model,
            batch_first=True, activation='gelu', dropout=0.1, norm_first=True
        )
        self.encoder = nn.TransformerEncoder(enc, num_layers=n_layers)
        self.cls = nn.Parameter(torch.zeros(1, 1, d_model))
        nn.init.normal_(self.cls, std=0.02)
        self.head = nn.Sequential(nn.LayerNorm(d_model), nn.Linear(d_model, n_cls))

    def _positional_encoding(self, L, d):
        pos = torch.arange(0, L, dtype=torch.float32).unsqueeze(1)
        i   = torch.arange(0, d, dtype=torch.float32).unsqueeze(0)
        angle = pos / torch.pow(10000, (2 * (i//2)) / d)
        pe = torch.zeros(L, d, dtype=torch.float32)
        pe[:, 0::2] = torch.sin(angle[:, 0::2])
        pe[:, 1::2] = torch.cos(angle[:, 1::2])
        return pe

    def forward(self, x):
        z = self.conv_t(x)           # (B, 128, T')
        z = self.proj(z)             # (B, d_model, T')
        z = self.dropout(z)
        z = z.transpose(1, 2)        # (B, T', d_model)
        B, L, D = z.shape
        if (self.pos_encoding is None) or (self.pos_encoding.shape[0] != L) or (self.pos_encoding.shape[1] != D):
            self.pos_encoding = self._positional_encoding(L, D).to(z.device)
        z = z + self.pos_encoding[None, :, :]
        cls_tok = self.cls.expand(B, -1, -1)
        z = torch.cat([cls_tok, z], dim=1)
        z = self.encoder(z)
        cls = z[:, 0, :]
        return self.head(cls)

# =========================
# (1) FOCAL LOSS
# =========================
class FocalLoss(nn.Module):
    def __init__(self, alpha: torch.Tensor, gamma: float = 1.5, reduction: str = 'mean'):
        super().__init__()
        self.alpha = alpha / alpha.sum()
        self.gamma = gamma
        self.reduction = reduction
    def forward(self, logits, target):
        logp = nn.functional.log_softmax(logits, dim=-1)      # (B,C)
        p = logp.exp()
        idx = torch.arange(target.shape[0], device=logits.device)
        pt = p[idx, target]
        logpt = logp[idx, target]
        at = self.alpha[target]
        loss = - at * ((1 - pt) ** self.gamma) * logpt
        if self.reduction == 'mean': return loss.mean()
        if self.reduction == 'sum':  return loss.sum()
        return loss

# =========================
# (4) AUGMENTS MÁS FUERTES (pero estables)
# =========================
def augment_batch(
    xb,
    p_jitter=0.35, p_noise=0.35, p_chdrop=0.15,
    max_jitter_frac=0.03, noise_std=0.03, max_chdrop=1
):
    """
    - jitter: shift temporal pequeño (±3% de la longitud)
    - noise: ruido gaussiano leve
    - chdrop: apagado de 1 canal aleatorio
    """
    B, C, T = xb.shape
    if np.random.rand() < p_jitter:
        max_shift = int(max(1, T*max_jitter_frac))
        shifts = torch.randint(low=-max_shift, high=max_shift+1, size=(B,), device=xb.device)
        for i in range(B):
            xb[i] = torch.roll(xb[i], shifts=int(shifts[i].item()), dims=-1)
    if np.random.rand() < p_noise:
        xb = xb + noise_std*torch.randn_like(xb)
    if np.random.rand() < p_chdrop and max_chdrop > 0:
        k = min(max_chdrop, C)
        for i in range(B):
            idx = torch.randperm(C, device=xb.device)[:k]
            xb[i, idx, :] = 0.0
    return xb

# =========================
# EMA de pesos
# =========================
class ModelEMA:
    def __init__(self, model: nn.Module, decay: float = 0.999, device=None):
        self.ema = self._clone(model).to(device if device is not None else next(model.parameters()).device)
        self.decay = decay
        self._updates = 0
        self.update(model, force=True)

    def _clone(self, model):
        ema = type(model)()
        ema.load_state_dict(model.state_dict())
        for p in ema.parameters():
            p.requires_grad_(False)
        return ema

    @torch.no_grad()
    def update(self, model: nn.Module, force: bool = False):
        d = self.decay
        if self._updates < 200:
            # warmup EMA decay para primeras iteraciones
            d = 0.0 + (self._updates / 200.0) * self.decay
        msd = model.state_dict()
        esd = self.ema.state_dict()
        for k in esd.keys():
            if esd[k].dtype.is_floating_point:
                esd[k].mul_(d).add_(msd[k].detach(), alpha=1.0 - d)
            else:
                esd[k] = msd[k]
        self._updates += 1

# =========================
# INFERENCIA TTA / SUBWINDOW
# =========================
def subwindow_logits(model, X, sfreq, sw_len, sw_stride, device):
    model.eval()
    wl = int(round(sw_len * sfreq))
    st = int(round(sw_stride * sfreq))
    wl = max(1, min(wl, X.shape[-1])); st = max(1, st)
    out = []
    with torch.no_grad():
        for i in range(X.shape[0]):
            x = X[i]; acc = []
            for s in range(0, max(1, X.shape[-1]-wl+1), st):
                seg = x[:, s:s+wl]
                if seg.shape[-1] < wl:
                    pad = wl - seg.shape[-1]
                    seg = np.pad(seg, ((0,0),(0,pad)), mode='edge')
                xb = torch.tensor(seg[None, ...], dtype=torch.float32, device=device)
                logit = model(xb).detach().cpu().numpy()[0]
                acc.append(logit)
            acc = np.mean(np.stack(acc, axis=0), axis=0) if len(acc) else np.zeros(4, dtype=np.float32)
            out.append(acc)
    return np.stack(out, axis=0)

def time_shift_tta_logits(model, X, sfreq, shifts_s, device):
    model.eval()
    T = X.shape[-1]; out = []
    with torch.no_grad():
        for i in range(X.shape[0]):
            x0 = X[i]; acc = []
            for sh in shifts_s:
                shift = int(round(sh * sfreq))
                if shift == 0:
                    x = x0
                elif shift > 0:
                    x = np.pad(x0[:, shift:], ((0,0),(0,shift)), mode='edge')[:, :T]
                else:
                    shift = -shift
                    x = np.pad(x0[:, :-shift], ((0,0),(shift,0)), mode='edge')[:, :T]
                xb = torch.tensor(x[None, ...], dtype=torch.float32, device=device)
                logit = model(xb).detach().cpu().numpy()[0]
                acc.append(logit)
            out.append(np.mean(np.stack(acc, axis=0), axis=0))
    return np.stack(out, axis=0)

# =========================
# Utilidades splits estratificados por sujeto
# =========================
def build_subject_label_map(subject_ids):
    """
    Estima una etiqueta dominante por sujeto (argmax de su histograma de clases).
    Se usa sólo para estratificar la selección de sujetos de validación.
    """
    y_dom_list = []
    for sid in subject_ids:
        Xs, ys, _ = load_subject_epochs(sid, RESAMPLE_HZ, DO_NOTCH, DO_BANDPASS, DO_CAR, BP_LO, BP_HI)
        if len(ys) == 0:
            y_dom_list.append(-1)
            continue
        binc = np.bincount(ys, minlength=4)
        y_dom = int(np.argmax(binc))
        y_dom_list.append(y_dom)
    return np.array(y_dom_list, dtype=int)

# =========================
# TRAIN/EVAL por FOLD (con train/val/test)
# =========================
def train_one_fold(fold:int, device):
    # --- sujetos por fold ---
    def load_fold_subjects_local(folds_json: Path, fold: int):
        return load_fold_subjects(folds_json, fold)

    train_sub, test_sub = load_fold_subjects_local(FOLDS_JSON, fold)
    train_sub = [s for s in train_sub if subject_id_to_int(s) not in EXCLUDE_SUBJECTS]
    test_sub  = [s for s in test_sub  if subject_id_to_int(s) not in EXCLUDE_SUBJECTS]

    # Split de validación por sujetos (determinista y opcionalmente estratificado)
    rng = np.random.RandomState(RANDOM_STATE + fold)
    tr_subjects = sorted(train_sub)

    if VAL_STRAT_SUBJECT and len(tr_subjects) > 1:
        y_dom = build_subject_label_map(tr_subjects)
        # Reemplaza -1 (sujetos sin trials) por la moda para no romper el split
        if np.any(y_dom < 0):
            mask = y_dom >= 0
            moda = int(np.bincount(y_dom[mask]).argmax()) if mask.sum() > 0 else 0
            y_dom[~mask] = moda
        n_val_subj = max(1, int(round(len(tr_subjects) * VAL_SUBJECT_FRAC)))
        sss = StratifiedShuffleSplit(n_splits=1, test_size=n_val_subj, random_state=RANDOM_STATE + fold)
        # sss.split(X, y) — X no importa, pasamos índice
        idx = np.arange(len(tr_subjects))
        _, val_idx = next(sss.split(idx, y_dom))
        val_subjects = sorted([tr_subjects[i] for i in val_idx])
        train_subjects = [s for s in tr_subjects if s not in val_subjects]
    else:
        tr_subjects_shuf = tr_subjects.copy()
        rng.shuffle(tr_subjects_shuf)
        n_val_subj = max(1, int(round(len(tr_subjects_shuf) * VAL_SUBJECT_FRAC)))
        val_subjects = sorted(tr_subjects_shuf[:n_val_subj])
        train_subjects = sorted(tr_subjects_shuf[n_val_subj:])

    # Carga TRAIN/VAL/TEST
    X_tr_list, y_tr_list, sub_tr_list = [], [], []
    X_val_list, y_val_list, sub_val_list = [], [], []
    X_te_list, y_te_list, sub_te_list = [], [], []
    sfreq = None

    for sid in tqdm(train_subjects, desc=f"Cargando train fold{fold}"):
        Xs, ys, sf = load_subject_epochs(sid, RESAMPLE_HZ, DO_NOTCH, DO_BANDPASS, DO_CAR, BP_LO, BP_HI)
        if len(ys) == 0: continue
        X_tr_list.append(Xs); y_tr_list.append(ys)
        sub_tr_list.append(np.full_like(ys, fill_value=subject_id_to_int(sid)))
        sfreq = sf if sfreq is None else sfreq

    for sid in tqdm(val_subjects, desc=f"Cargando val fold{fold}"):
        Xs, ys, sf = load_subject_epochs(sid, RESAMPLE_HZ, DO_NOTCH, DO_BANDPASS, DO_CAR, BP_LO, BP_HI)
        if len(ys) == 0: continue
        X_val_list.append(Xs); y_val_list.append(ys)
        sub_val_list.append(np.full_like(ys, fill_value=subject_id_to_int(sid)))
        sfreq = sf if sfreq is None else sfreq

    for sid in tqdm(test_sub, desc=f"Cargando test fold{fold}"):
        Xs, ys, sf = load_subject_epochs(sid, RESAMPLE_HZ, DO_NOTCH, DO_BANDPASS, DO_CAR, BP_LO, BP_HI)
        if len(ys) == 0: continue
        X_te_list.append(Xs); y_te_list.append(ys)
        sub_te_list.append(np.full_like(ys, fill_value=subject_id_to_int(sid)))
        sfreq = sf if sfreq is None else sfreq

    # Concatenar
    X_tr = np.concatenate(X_tr_list, axis=0); y_tr = np.concatenate(y_tr_list, axis=0)
    sub_tr = np.concatenate(sub_tr_list, axis=0)
    X_val = np.concatenate(X_val_list, axis=0); y_val = np.concatenate(y_val_list, axis=0)
    sub_val = np.concatenate(sub_val_list, axis=0)
    X_te = np.concatenate(X_te_list, axis=0); y_te = np.concatenate(y_te_list, axis=0)
    sub_te = np.concatenate(sub_te_list, axis=0)

    print(f"[Fold {fold}/5] Entrenando modelo global... (n_train={len(y_tr)} | n_val={len(y_val)} | n_test={len(y_te)})")

    # Normalización por canal (fit en TRAIN y aplicar a VAL/TEST)
    if ZSCORE_PER_EPOCH:
        X_tr_std, X_val_std, X_te_std = X_tr, X_val, X_te
    else:
        X_tr_std, X_val_std = standardize_per_channel(X_tr, X_val)
        _,        X_te_std  = standardize_per_channel(X_tr, X_te)  # reutiliza stats de train

    # Datasets
    tr_ds  = TensorDataset(torch.tensor(X_tr_std),  torch.tensor(y_tr).long(),  torch.tensor(sub_tr).long())
    val_ds = TensorDataset(torch.tensor(X_val_std), torch.tensor(y_val).long(), torch.tensor(sub_val).long())
    te_ds  = TensorDataset(torch.tensor(X_te_std),  torch.tensor(y_te).long(),  torch.tensor(sub_te).long())

    # Weighted sampler por sujeto/clase
    def make_weighted_sampler(dataset: TensorDataset):
        Xb, yb, sb = dataset.tensors
        yb_np = yb.numpy()
        sb_np = sb.numpy()
        # frec por sujeto
        uniq_s, cnt_s = np.unique(sb_np, return_counts=True)
        map_s = {s:c for s,c in zip(uniq_s, cnt_s)}
        # frec por (sujeto,clase)
        key = sb_np.astype(np.int64) * 10 + yb_np.astype(np.int64)  # asumiendo clases <10
        uniq_k, cnt_k = np.unique(key, return_counts=True)
        map_k = {k:c for k,c in zip(uniq_k, cnt_k)}
        # peso = 1 / (frec_sujeto * frec_clase_en_sujeto)
        w = []
        for s, y in zip(sb_np, yb_np):
            k = int(s)*10 + int(y)
            ws = map_s[int(s)]
            wk = map_k[k]
            w.append(1.0 / (float(ws) * float(wk)))
        w = np.array(w, dtype=np.float64)
        w = w / (w.mean() + 1e-12)
        sampler = WeightedRandomSampler(weights=torch.tensor(w, dtype=torch.double),
                                        num_samples=len(yb_np), replacement=True)
        return sampler

    if USE_WEIGHTED_SAMPLER:
        tr_sampler = make_weighted_sampler(tr_ds)
        tr_ld  = DataLoader(tr_ds, batch_size=BATCH_SIZE, sampler=tr_sampler, drop_last=False, worker_init_fn=seed_worker)
    else:
        tr_ld  = DataLoader(tr_ds, batch_size=BATCH_SIZE, shuffle=True,  drop_last=False, worker_init_fn=seed_worker)

    val_ld = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, worker_init_fn=seed_worker)
    te_ld  = DataLoader(te_ds,  batch_size=BATCH_SIZE, shuffle=False, drop_last=False, worker_init_fn=seed_worker)

    # Modelo
    model = EEGCNNTransformer(n_ch=8, n_cls=4, d_model=D_MODEL, n_heads=N_HEADS,
                              n_layers=N_LAYERS, p_drop=P_DROP, p_drop_encoder=P_DROP_ENCODER).to(device)

    # Optimizador + (1) Focal Loss (ligero upweight a both_fists)
    opt = torch.optim.AdamW(model.parameters(), lr=BASE_LR, weight_decay=1e-2)

    class_counts = np.bincount(y_tr, minlength=4).astype(np.float32)
    inv = class_counts.sum() / (4.0 * np.maximum(class_counts, 1.0))
    alpha = torch.tensor(inv, dtype=torch.float32, device=device)
    # Aumentamos un poco el peso de both_fists para mejorar su recall
    alpha_mean = alpha.mean().item()
    alpha[2] = 1.5 * alpha_mean
    crit = FocalLoss(alpha=alpha, gamma=1.5, reduction='mean')

    # (6) Warmup+Cosine (min_factor 0.2)
    from torch.optim.lr_scheduler import LambdaLR
    total_epochs = EPOCHS
    warmup_epochs = max(1, int(WARMUP_EPOCHS))
    min_factor = 0.2
    def lr_lambda(current_epoch):
        if current_epoch < warmup_epochs:
            return (current_epoch + 1) / warmup_epochs
        progress = (current_epoch - warmup_epochs) / max(1, (total_epochs - warmup_epochs))
        progress = min(1.0, max(0.0, progress))
        return min_factor + 0.5 * (1.0 - min_factor) * (1.0 + np.cos(np.pi * progress))
    scheduler = LambdaLR(opt, lr_lambda=lr_lambda)

    # EMA
    if USE_EMA:
        ema = ModelEMA(model, decay=EMA_DECAY, device=device)
    else:
        ema = None

    # Entrenamiento con early stopping por F1 macro (en VAL; usando EMA si está activo)
    best_f1, best_state, wait = 0.0, None, 0
    hist = {"ep": [], "tr_loss": [], "tr_acc": [], "val_acc": [], "val_f1m": [], "lr": []}

    def evaluate_on(loader, use_ema=True):
        mdl = ema.ema if (ema is not None and use_ema) else model
        mdl.eval()
        preds, gts = [], []
        with torch.no_grad():
            for xb, yb, _sb in loader:
                xb = xb.to(device)
                p = mdl(xb).argmax(dim=1).cpu().numpy()
                preds.append(p); gts.append(yb.numpy())
        preds = np.concatenate(preds); gts = np.concatenate(gts)
        acc = accuracy_score(gts, preds)
        f1m = f1_score(gts, preds, average='macro')
        return acc, f1m

    for ep in range(1, EPOCHS+1):
        # ---- Train ----
        model.train()
        tr_loss, n_seen, tr_correct = 0.0, 0, 0
        for xb, yb, _sb in tr_ld:
            xb = xb.to(device); yb = yb.to(device)
            xb = augment_batch(xb)  # (4) aumentos más fuertes
            opt.zero_grad()
            logits = model(xb)
            loss = crit(logits, yb)
            loss.backward()
            opt.step()
            if ema is not None:
                ema.update(model)

            tr_loss += loss.item() * len(yb)
            n_seen += len(yb)
            tr_correct += (logits.argmax(1) == yb).sum().item()
        tr_loss /= max(1, n_seen)
        tr_acc = tr_correct / max(1, n_seen)

        # ---- Val (con EMA si aplica) ----
        acc, f1m = evaluate_on(val_ld, use_ema=True)

        hist["ep"].append(ep)
        hist["tr_loss"].append(tr_loss)
        hist["tr_acc"].append(tr_acc)
        hist["val_acc"].append(acc)
        hist["val_f1m"].append(f1m)
        hist["lr"].append(scheduler.get_last_lr()[0])

        print(f"  Época {ep:3d} | train_loss={tr_loss:.4f} | train_acc={tr_acc:.4f} | val_acc={acc:.4f} | val_f1m={f1m:.4f} | LR={scheduler.get_last_lr()[0]:.6f}")

        improved = f1m > best_f1 + 1e-4
        if improved:
            best_f1 = f1m
            # guardamos el estado EMA si existe; si no, el del modelo
            ref_model = ema.ema if ema is not None else model
            best_state = {k: v.detach().cpu() for k, v in ref_model.state_dict().items()}
            wait = 0
        else:
            wait += 1

        scheduler.step()
        if wait >= PATIENCE:
            print(f"  Early stopping en época {ep} (mejor val_f1m={best_f1:.4f})")
            break

    # Cargamos mejor estado guardado (EMA si estaba activo)
    if best_state is not None:
        (ema.ema if (ema is not None) else model).load_state_dict(best_state)

    # ---- Guardar curva ----
    fig = plt.figure(figsize=(8,4.5))
    ax1 = plt.gca()
    ax1.plot(hist["ep"], hist["tr_loss"], label="train_loss")
    ax1.plot(hist["ep"], hist["val_f1m"], label="val_f1m")
    ax1.plot(hist["ep"], hist["val_acc"], label="val_acc")
    ax1.set_xlabel("Época"); ax1.set_title(f"Fold {fold} — Curva de entrenamiento")
    ax1.legend(); ax1.grid(True, alpha=0.3)
    out_png = f"training_curve_fold{fold}.png"
    plt.tight_layout()
    plt.savefig(out_png, dpi=140)
    plt.close(fig)
    print(f"↳ Curva de entrenamiento guardada: {out_png}")

    # ---- Evaluación final en TEST (con TTA/subwindows y ensemble si aplica) ----
    eval_model = ema.ema if (ema is not None) else model
    eval_model.eval()

    sfreq_used = RESAMPLE_HZ
    if sfreq_used is None:
        sfreq_used = int(round(X_te_std.shape[-1] / (TMAX - TMIN)))

    if (not SW_ENABLE) or SW_MODE == 'none':
        preds, gts = [], []
        with torch.no_grad():
            for xb, yb, _sb in te_ld:
                xb = xb.to(device)
                p = eval_model(xb).argmax(dim=1).cpu().numpy()
                preds.append(p); gts.append(yb.numpy())
        preds = np.concatenate(preds); gts = np.concatenate(gts)
    elif SW_MODE in ('subwin', 'tta'):
        logits_tta = None
        logits_sw  = None
        if SW_MODE == 'subwin':
            logits_sw = subwindow_logits(eval_model, X_te_std, sfreq_used, SW_LEN, SW_STRIDE, device)
        elif SW_MODE == 'tta':
            logits_tta = time_shift_tta_logits(eval_model, X_te_std, sfreq_used, TTA_SHIFTS_S, device)

        if COMBINE_TTA_AND_SUBWIN:
            if logits_tta is None:
                logits_tta = time_shift_tta_logits(eval_model, X_te_std, sfreq_used, TTA_SHIFTS_S, device)
            if logits_sw is None:
                logits_sw  = subwindow_logits(eval_model, X_te_std, sfreq_used, SW_LEN, SW_STRIDE, device)
            logits = 0.5 * logits_tta + 0.5 * logits_sw
        else:
            logits = logits_tta if logits_tta is not None else logits_sw

        preds = logits.argmax(axis=1); gts = y_te
    else:
        raise ValueError(f"SW_MODE desconocido: {SW_MODE}")

    acc = accuracy_score(gts, preds)
    f1m = f1_score(gts, preds, average='macro')
    print(f"[Fold {fold}/5] Global acc={acc:.4f}\n")
    print(classification_report(gts, preds, target_names=[c.replace('_',' ') for c in CLASS_NAMES], digits=4))
    print("Confusion matrix (rows=true, cols=pred):")
    print(confusion_matrix(gts, preds, labels=[0,1,2,3]))

    return acc, f1m

# =========================
# LOOP 5 FOLDS + RESUMEN
# =========================
if __name__ == "__main__":
    acc_folds, f1_folds = [], []
    for fold in range(1, 6):
        acc, f1m = train_one_fold(fold, DEVICE)
        acc_folds.append(f"{acc:.4f}")
        f1_folds.append(f"{f1m:.4f}")

    acc_mean = float(np.mean([float(a) for a in acc_folds]))
    f1_mean  = float(np.mean([float(f) for f in f1_folds]))

    print("\n============================================================")
    print("RESULTADOS FINALES")
    print("============================================================")
    print(f"Global folds (ACC): {acc_folds}")
    print(f"Global mean ACC: {acc_mean:.4f}")
    print(f"F1 folds (MACRO): {f1_folds}")
    print(f"F1 mean (MACRO): {f1_mean:.4f}")


🚀 Usando dispositivo: cuda
🧠 INICIANDO EXPERIMENTO CON CNN+Transformer (K-Fold por sujeto como EEGNet)
🔧 Configuración: 4c, 8 canales, 6s | EPOCHS=60, BATCH=64, LR=0.001 | ZSCORE_PER_EPOCH=False


Cargando train fold1: 100%|██████████| 67/67 [00:08<00:00,  8.06it/s]
Cargando val fold1: 100%|██████████| 15/15 [00:02<00:00,  7.47it/s]
Cargando test fold1: 100%|██████████| 21/21 [00:02<00:00,  8.07it/s]


[Fold 1/5] Entrenando modelo global... (n_train=5628 | n_val=1260 | n_test=1764)




RuntimeError: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
