In [2]:
# -*- coding: utf-8 -*-
# EEGNet + (SIN balanceo por sujeto/clase) + Kfold5.json (solo train/test) +
# Augments y TTA iguales al CNN+Transformer + Fine-tuning progresivo (igual que antes)

import os, re, json, random, copy, itertools
os.environ["CUDA_VISIBLE_DEVICES"] = "7"
os.environ['PYTHONHASHSEED'] = '42'
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'

from pathlib import Path
from glob import glob
from datetime import datetime

import numpy as np
import pandas as pd
import mne
from tqdm import tqdm
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score, f1_score

# =========================
# REPRODUCIBILIDAD
# =========================
RANDOM_STATE = 42

def seed_everything(seed: int = 42):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    try:
        torch.backends.cuda.matmul.allow_tf32 = False
        torch.backends.cudnn.allow_tf32 = False
    except Exception:
        pass
    try:
        torch.use_deterministic_algorithms(True)
    except Exception:
        pass

def seed_worker(worker_id: int):
    worker_seed = RANDOM_STATE + worker_id
    np.random.seed(worker_seed)
    random.seed(worker_seed)

seed_everything(RANDOM_STATE)

# =========================
# CONFIG
# =========================
PROJ = Path('..').resolve().parent
DATA_RAW = PROJ / 'data' / 'raw'
FOLDS_JSON = PROJ / 'models' / '00_folds' / 'Kfold5.json'

# Entrenamiento global
EPOCHS_GLOBAL   = 100
BATCH_SIZE      = 64
LR_INIT         = 1e-2
SGDR_T0         = 6
SGDR_Tmult      = 2
GLOBAL_PATIENCE = 10
LOG_EVERY       = 5

# Validación desde sujetos de train (igual a tu CNN+TF)
VAL_SUBJECT_FRAC   = 0.18
VAL_STRAT_SUBJECT  = True

# Fine-tuning progresivo
CALIB_CV_FOLDS = 4
FT_EPOCHS   = 30
FT_BASE_LR  = 5e-5
FT_HEAD_LR  = 1e-3
FT_L2SP     = 1e-4
FT_PATIENCE = 5
FT_VAL_RATIO = 0.2

# Ventana temporal y prepro
FS = 160.0
TMIN, TMAX = -1.0, 5.0
NORM_EPOCH_ZSCORE = True   # por-época canal-a-canal

# Excluir sujetos (igual que antes)
EXCLUDE_SUBJECTS = {38, 88, 89, 92, 100, 104}

# Runs y canales
IMAGERY_RUNS_LR = {4, 8, 12}
IMAGERY_RUNS_BF = {6, 10, 14}
EXPECTED_8 = ['C3','C4','Cz','CP3','CP4','FC3','FC4','FCz']
CLASS_NAMES_4C = ['Left','Right','Both Fists','Both Feet']

# TTA estilo CNN+TF
SW_MODE = 'tta'                     # 'none'|'subwin'|'tta'
SW_ENABLE = True
TTA_SHIFTS_S = [-0.075, -0.05, -0.025, 0.0, 0.025, 0.05, 0.075]
SW_LEN, SW_STRIDE = 4.5, 1.5
COMBINE_TTA_AND_SUBWIN = False      # mantenemos apagado (como tu comentario)

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🚀 Usando dispositivo: {DEVICE}")
print("🧠 INICIANDO EXPERIMENTO CON EEGNet (mismo Kfold y TTA/Augments que CNN+TF)")

# =========================
# UTILIDADES I/O
# =========================
def normalize_ch_name(name: str) -> str:
    s = re.sub(r'[^A-Za-z0-9]', '', name)
    return s.upper()

NORMALIZED_TARGETS = [normalize_ch_name(c) for c in EXPECTED_8]

def pick_8_channels(raw: mne.io.BaseRaw) -> mne.io.BaseRaw:
    chs = raw.info['ch_names']
    norm_map = {normalize_ch_name(ch): ch for ch in chs}
    picked = []
    for target_norm, target_orig in zip(NORMALIZED_TARGETS, EXPECTED_8):
        if target_norm in norm_map:
            picked.append(norm_map[target_norm])
        else:
            raise RuntimeError(f"Canal requerido '{target_orig}' no encontrado. Disponibles: {chs}")
    return raw.pick(picks=picked)

def subject_id_to_int(s: str) -> int:
    m = re.match(r'[Ss](\d+)', s)
    return int(m.group(1)) if m else -1

def list_subject_imagery_edfs(subject_id: str) -> list:
    subj_dir = DATA_RAW / subject_id
    edfs = []
    for r in sorted(list(IMAGERY_RUNS_LR | IMAGERY_RUNS_BF)):
        edfs.extend(glob(str(subj_dir / f"{subject_id}R{r:02d}.edf")))
    return sorted(edfs)

def load_fold_subjects(folds_json: Path, fold: int):
    with open(folds_json, 'r') as f:
        data = json.load(f)
    for item in data.get('folds', []):
        if int(item.get('fold', -1)) == int(fold):
            return list(item.get('train', [])), list(item.get('test', []))
    raise ValueError(f"Fold {fold} not found in {folds_json}")

def load_subject_epochs(subject_id: str):
    """Devuelve X (N, T, C) y y (N,) con 8 canales y FS=160 si es necesario."""
    edfs = list_subject_imagery_edfs(subject_id)
    if len(edfs) == 0:
        return np.empty((0, 1, 1), dtype=np.float32), np.empty((0,), dtype=int), None
    X_list, y_list, sfreq_list = [], [], []
    for edf_path in edfs:
        m = re.search(r"R(\d{2})", Path(edf_path).name)
        run = int(m.group(1)) if m else -1
        raw = mne.io.read_raw_edf(edf_path, preload=True, verbose='ERROR')
        raw = pick_8_channels(raw)
        # re-muestreo a FS si hace falta
        if abs(raw.info['sfreq'] - FS) > 1e-6:
            raw.resample(FS)
        sfreq = raw.info['sfreq']
        events, event_id = mne.events_from_annotations(raw, verbose='ERROR')
        keep = {k: v for k, v in event_id.items() if k in {'T1', 'T2'}}
        if len(keep) == 0:
            continue
        epochs = mne.Epochs(raw, events=events, event_id=keep, tmin=TMIN, tmax=TMAX,
                            baseline=None, preload=True, verbose='ERROR')
        X = epochs.get_data().astype(np.float32)      # (N, C, T)
        # z-score por época/canal si está activo
        if NORM_EPOCH_ZSCORE:
            eps = 1e-6
            mu = X.mean(axis=2, keepdims=True)
            sd = X.std(axis=2, keepdims=True) + eps
            X = (X - mu) / sd
        # mapear etiquetas por run
        ev_codes = epochs.events[:, 2]
        inv = {v: k for k, v in keep.items()}
        y_run = []
        for code in ev_codes:
            lab = inv[code]
            if run in IMAGERY_RUNS_LR:
                y_run.append(0 if lab == 'T1' else 1)
            elif run in IMAGERY_RUNS_BF:
                y_run.append(2 if lab == 'T1' else 3)
            else:
                y_run.append(-1)
        y_run = np.array(y_run, dtype=int)
        mask = y_run >= 0
        if mask.sum() == 0:
            continue
        # convertimos a (N, T, C) para EEGNet utilidades previas
        X_tc = np.transpose(X[mask], (0, 2, 1))   # (N, T, C)
        X_list.append(X_tc)
        y_list.append(y_run[mask])
        sfreq_list.append(sfreq)
    if len(X_list) == 0:
        return np.empty((0, 1, 1), dtype=np.float32), np.empty((0,), dtype=int), None
    X_all = np.concatenate(X_list, axis=0)
    y_all = np.concatenate(y_list, axis=0)
    if len(set([int(round(s)) for s in sfreq_list])) != 1:
        raise RuntimeError(f"Sampling rates inconsistentes: {sfreq_list}")
    return X_all, y_all, sfreq_list[0]

def standardize_per_channel_trainfit(X_tr_TxC, X_other_TxC):
    """Estandariza por canal usando estadísticas de TRAIN, sobre matrices (N, T, C)."""
    Xtr = X_tr_TxC.astype(np.float32).copy()
    Xot = X_other_TxC.astype(np.float32).copy()
    C = Xtr.shape[2]
    for c in range(C):
        mu = Xtr[:, :, c].mean()
        sd = Xtr[:, :, c].std()
        sd = sd if sd > 1e-6 else 1.0
        Xtr[:, :, c] = (Xtr[:, :, c] - mu) / sd
        Xot[:, :, c] = (Xot[:, :, c] - mu) / sd
    return Xtr, Xot

def build_subject_label_map(subject_ids):
    """Etiqueta dominante por sujeto (para estratificar la selección de val)."""
    y_dom_list = []
    for sid in subject_ids:
        Xs, ys, _ = load_subject_epochs(sid)
        if len(ys) == 0:
            y_dom_list.append(-1)
            continue
        binc = np.bincount(ys, minlength=4)
        y_dom_list.append(int(np.argmax(binc)))
    return np.array(y_dom_list, dtype=int)

# =========================
# AUGMENTS (estilo CNN+TF) SOBRE (B,1,T,C)
# =========================
def augment_batch_eegnet(xb_1tC, p_jitter=0.35, p_noise=0.35, p_chdrop=0.15,
                         max_jitter_frac=0.03, noise_std=0.03, max_chdrop=1):
    """
    Aplica las mismas ideas de augment que tu CNN+Transformer, pero adaptadas a (B,1,T,C).
    Internamente convierte a (B,C,T), aplica, y regresa a (B,1,T,C).
    """
    B, _, T, C = xb_1tC.shape
    xb = xb_1tC[:, 0].permute(0, 2, 1).contiguous()   # (B,C,T)
    # jitter temporal
    if np.random.rand() < p_jitter:
        max_shift = int(max(1, T * max_jitter_frac))
        shifts = torch.randint(low=-max_shift, high=max_shift+1, size=(B,), device=xb.device)
        for i in range(B):
            xb[i] = torch.roll(xb[i], shifts=int(shifts[i].item()), dims=-1)
    # ruido gaussiano
    if np.random.rand() < p_noise:
        xb = xb + noise_std * torch.randn_like(xb)
    # channel drop
    if np.random.rand() < p_chdrop and max_chdrop > 0:
        k = min(max_chdrop, C)
        for i in range(B):
            idx = torch.randperm(C, device=xb.device)[:k]
            xb[i, idx, :] = 0.0
    # volver a (B,1,T,C)
    xb_1tC = xb.permute(0, 2, 1).unsqueeze(1).contiguous()
    return xb_1tC

# =========================
# TTA estilo CNN+TF (time-shifts)
# =========================
@torch.no_grad()
def time_shift_tta_logits_eegnet(model, X_TxC, sfreq, shifts_s, device):
    """
    X_TxC: (N, T, C) numpy
    Genera logits promediados tras desplazar en el tiempo.
    """
    model.eval()
    N, T, C = X_TxC.shape
    out = []
    for i in range(N):
        x0 = X_TxC[i]  # (T,C)
        acc = []
        for sh in shifts_s:
            shift = int(round(sh * sfreq))
            if shift == 0:
                x = x0
            elif shift > 0:
                # desplaza a la derecha, rellena borde
                pad = np.tile(x0[:1, :], (shift, 1))
                x = np.concatenate([pad, x0[:-shift, :]], axis=0)
            else:
                # desplaza a la izquierda
                shift = -shift
                pad = np.tile(x0[-1:, :], (shift, 1))
                x = np.concatenate([x0[shift:, :], pad], axis=0)
            # a tensor (1,1,T,C)
            xb = torch.tensor(x, dtype=torch.float32, device=device).unsqueeze(0).unsqueeze(0)
            logit = model(xb).detach().cpu().numpy()[0]
            acc.append(logit)
        out.append(np.mean(np.stack(acc, axis=0), axis=0))
    return np.stack(out, axis=0)

# =========================
# MODELO EEGNet
# =========================
class ChannelDropout(nn.Module):
    def __init__(self, p=0.1):
        super().__init__()
        self.p = p
    def forward(self, x):
        if not self.training or self.p<=0: return x
        B,_,T,C = x.shape
        mask = (torch.rand(B,1,1,C, device=x.device) > self.p).float()
        return x * mask

class EEGNet(nn.Module):
    """
    Entrada: x de forma (B, 1, T, C)
    """
    def __init__(self, n_ch: int, n_classes: int,
                 F1: int = 24, D: int = 2, kernel_t: int = 64, k_sep: int = 16,
                 pool1_t: int = 4, pool2_t: int = 6,
                 drop1_p: float = 0.35, drop2_p: float = 0.6,
                 chdrop_p: float = 0.10):
        super().__init__()
        self.n_ch = n_ch
        self.n_classes = n_classes
        self.F1 = F1
        self.D = D
        self.F2 = F1 * D
        self.kernel_t = kernel_t
        self.k_sep = k_sep
        self.pool1_t = pool1_t
        self.pool2_t = pool2_t

        self.chdrop = ChannelDropout(p=chdrop_p)
        self.act = nn.ELU()

        # Bloque 1: temporal
        self.conv_temporal = nn.Conv2d(1, F1, kernel_size=(kernel_t, 1),
                                       padding=(kernel_t // 2, 0), bias=False)
        self.bn1 = nn.BatchNorm2d(F1, momentum=0.99, eps=1e-3)

        # Bloque 2: depthwise (espacial)
        self.conv_depthwise = nn.Conv2d(F1, self.F2, kernel_size=(1, n_ch),
                                        groups=F1, bias=False)
        self.bn2 = nn.BatchNorm2d(self.F2, momentum=0.99, eps=1e-3)
        self.pool1 = nn.AvgPool2d(kernel_size=(pool1_t, 1), stride=(pool1_t, 1))
        self.drop1 = nn.Dropout(drop1_p)

        # Bloque 3: separable temporal
        self.conv_sep_depth = nn.Conv2d(self.F2, self.F2, kernel_size=(k_sep, 1),
                                        groups=self.F2, padding=(k_sep // 2, 0), bias=False)
        self.conv_sep_point = nn.Conv2d(self.F2, self.F2, kernel_size=(1, 1), bias=False)
        self.bn3 = nn.BatchNorm2d(self.F2, momentum=0.99, eps=1e-3)
        self.pool2 = nn.AvgPool2d(kernel_size=(pool2_t, 1), stride=(pool2_t, 1))
        self.drop2 = nn.Dropout(drop2_p)

        self.flatten = nn.Flatten()
        # Cabeza dinámica (se arma en forward según T)
        self.fc = None
        self.out = None
        self._T_in = None

    def _build_head(self, T_in: int, device: torch.device):
        T1 = T_in // self.pool1_t
        T2 = T1 // self.pool2_t
        feat_dim = self.F2 * T2 * 1
        self.fc  = nn.Linear(feat_dim, 128, bias=True).to(device)
        self.out = nn.Linear(128, self.n_classes, bias=True).to(device)
        self._T_in = T_in

    def ensure_head(self, T_in: int, device: torch.device):
        if (self.fc is None) or (self.out is None) or (self._T_in != T_in):
            self._build_head(T_in, device)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, _, T, C = x.shape
        self.ensure_head(T, x.device)
        x = self.chdrop(x)  # ChannelDropout

        z = self.conv_temporal(x)
        z = self.bn1(z); z = self.act(z)

        z = self.conv_depthwise(z)   # (B, F2, T, 1)
        z = self.bn2(z); z = self.act(z)
        z = self.pool1(z)
        z = self.drop1(z)

        z = self.conv_sep_depth(z)
        z = self.conv_sep_point(z)
        z = self.bn3(z); z = self.act(z)
        z = self.pool2(z)
        z = self.drop2(z)

        z = self.flatten(z)
        z = self.fc(z); z = self.act(z)
        z = self.out(z)
        return z

# =========================
# LOSS (igual que tu EEGNet previo: soft CE ponderada)
# =========================
class WeightedSoftCrossEntropy(nn.Module):
    def __init__(self, class_weights=None, label_smoothing=0.0):
        super().__init__()
        self.register_buffer('w', None if class_weights is None else class_weights.clone().float())
        self.ls = float(label_smoothing)
    def forward(self, logits, target_probs):
        if self.ls > 0:
            K = logits.size(1)
            target_probs = (1-self.ls)*target_probs + self.ls*(1.0/K)
        logp = torch.log_softmax(logits, dim=1)
        loss_per_class = -(target_probs * logp)
        if self.w is not None:
            loss_per_class = loss_per_class * self.w.unsqueeze(0)
        loss = loss_per_class.sum(dim=1).mean()
        return loss

def make_class_weight_tensor(y_indices, n_classes, boost_bfists=1.20, boost_bfeet=1.20, device=DEVICE):
    counts = np.bincount(y_indices, minlength=n_classes).astype(np.float32)
    counts[counts == 0] = 1.0
    w = counts.sum() / counts           # inverso de frecuencia
    w = w / w.mean()                    # normaliza a media 1
    # boosts manuales
    w[2] *= float(boost_bfists)         # clase 2: Both Fists
    w[3] *= float(boost_bfeet)          # clase 3: Both Feet (NUEVO)
    w = w / w.mean()                    # re-normaliza a media 1 (importante)
    return torch.tensor(w, dtype=torch.float32, device=device)


# =========================
# TRAIN/EVAL HELPERS
# =========================
@torch.no_grad()
def apply_max_norm(model, max_value=2.0, p=2.0):
    layers = []
    if hasattr(model, 'conv_depthwise'): layers.append(model.conv_depthwise)
    if hasattr(model, 'conv_sep_point'): layers.append(model.conv_sep_point)
    if hasattr(model, 'fc'):             layers.append(model.fc)
    if hasattr(model, 'out'):            layers.append(model.out)
    for layer in layers:
        if hasattr(layer, 'weight') and layer.weight is not None:
            w = layer.weight.data
            norms = w.view(w.size(0), -1).norm(p=p, dim=1, keepdim=True)
            desired = torch.clamp(norms, max=max_value)
            w.view(w.size(0), -1).mul_(desired / (1e-8 + norms))

def plot_training_curves(history, fname):
    plt.figure(figsize=(6,4))
    plt.plot(history['train_acc'], label='train_acc')
    plt.plot(history['val_acc'], label='val_acc')
    plt.xlabel('Epoch'); plt.ylabel('Accuracy')
    plt.title('Training curve'); plt.legend()
    plt.tight_layout(); plt.savefig(fname, dpi=150); plt.close()

def plot_confusion(y_true, y_pred, classes, title, fname):
    cm = confusion_matrix(y_true, y_pred, labels=list(range(len(classes))))
    with np.errstate(invalid='ignore'):
        cm_norm = cm.astype('float') / cm.sum(axis=1, keepdims=True)
    cm_norm = np.nan_to_num(cm_norm)
    plt.figure(figsize=(6,5))
    plt.imshow(cm_norm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title(title); plt.colorbar(fraction=0.046, pad=0.04)
    ticks = np.arange(len(classes))
    plt.xticks(ticks, classes, rotation=45, ha='right'); plt.yticks(ticks, classes)
    fmt = '.2f'; thresh = cm_norm.max()/2.
    for i, j in itertools.product(range(cm_norm.shape[0]), range(cm_norm.shape[1])):
        plt.text(j, i, format(cm_norm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm_norm[i, j] > thresh else "black")
    plt.ylabel('True'); plt.xlabel('Predicted')
    plt.tight_layout(); plt.savefig(fname, dpi=150); plt.close()

# =========================
# FINE-TUNING PROGRESIVO (igual que antes)
# =========================
def _freeze_for_mode(model, mode):
    for p in model.parameters(): p.requires_grad = False
    if mode == 'out':
        for p in model.out.parameters(): p.requires_grad = True
    elif mode == 'head':
        for p in model.fc.parameters():  p.requires_grad = True
        for p in model.out.parameters(): p.requires_grad = True
    elif mode == 'spatial+head':
        for p in model.conv_depthwise.parameters(): p.requires_grad = True
        for p in model.bn2.parameters():           p.requires_grad = True
        for p in model.conv_sep_depth.parameters():p.requires_grad = True
        for p in model.conv_sep_point.parameters():p.requires_grad = True
        for p in model.bn3.parameters():           p.requires_grad = True
        for p in model.fc.parameters():            p.requires_grad = True
        for p in model.out.parameters():           p.requires_grad = True
    else:
        raise ValueError(mode)

def _class_weights(y_np, n_classes):
    counts = np.bincount(y_np, minlength=n_classes).astype(np.float32)
    counts[counts == 0] = 1.0
    weights = counts.sum() / counts
    weights = weights / weights.mean()
    return torch.tensor(weights, dtype=torch.float32, device=DEVICE)

def _train_one_mode(model, X_cal_TxC, y_cal, n_classes, mode,
                    epochs=FT_EPOCHS, batch_size=32,
                    head_lr=FT_HEAD_LR, base_lr=FT_BASE_LR,
                    l2sp_lambda=FT_L2SP, patience=FT_PATIENCE, val_ratio=FT_VAL_RATIO):
    # split interno (sobre ensayos del sujeto)
    sss = StratifiedShuffleSplit(n_splits=1, test_size=val_ratio, random_state=RANDOM_STATE)
    (tr_idx, va_idx), = sss.split(X_cal_TxC, y_cal)
    Xtr, ytr = X_cal_TxC[tr_idx], y_cal[tr_idx]
    Xva, yva = X_cal_TxC[va_idx], y_cal[va_idx]

    # a tensores (B,1,T,C)
    def to_tensor(X_TxC, y_idx):
        xb = torch.from_numpy(X_TxC).float().unsqueeze(1)   # (B,1,T,C)
        yb = torch.from_numpy(y_idx).long()
        return xb, yb

    xb_tr, yb_tr = to_tensor(Xtr, ytr)
    xb_va, yb_va = to_tensor(Xva, yva)

    _freeze_for_mode(model, mode)
    if mode == 'spatial+head':
        base_params = (list(model.conv_depthwise.parameters()) +
                       list(model.bn2.parameters()) +
                       list(model.conv_sep_depth.parameters()) +
                       list(model.conv_sep_point.parameters()) +
                       list(model.bn3.parameters()))
        head_params = list(model.fc.parameters()) + list(model.out.parameters())
        train_params = base_params + head_params
        opt = optim.Adam([
            {"params": base_params, "lr": base_lr},
            {"params": head_params, "lr": head_lr},
        ])
    else:
        train_params = list(model.parameters())  # se respeta freeze por requires_grad
        opt = optim.Adam([p for p in train_params if p.requires_grad], lr=head_lr)

    # referencia L2SP
    ref = [p.detach().clone().to(p.device) for p in [p for p in train_params if p.requires_grad]]
    class_w = _class_weights(ytr, n_classes)
    crit = nn.CrossEntropyLoss(weight=class_w)

    best_state = copy.deepcopy(model.state_dict())
    best_val = float('inf'); bad = 0

    for _ in range(epochs):
        model.train()
        # mini-batches
        for s in range(0, xb_tr.size(0), batch_size):
            xbt = xb_tr[s:s+batch_size].to(DEVICE)
            ybt = yb_tr[s:s+batch_size].to(DEVICE)
            opt.zero_grad(set_to_none=True)
            logits = model(xbt)
            loss = crit(logits, ybt)
            # L2SP
            reg = 0.0
            idx = 0
            for p in train_params:
                if p.requires_grad:
                    reg = reg + torch.sum((p - ref[idx])**2)
                    idx += 1
            loss = loss + l2sp_lambda * reg
            loss.backward(); opt.step()
            apply_max_norm(model, max_value=2.0, p=2.0)

        # validación
        model.eval()
        with torch.no_grad():
            val_loss = 0.0; nval = 0
            for s in range(0, xb_va.size(0), batch_size):
                xbv = xb_va[s:s+batch_size].to(DEVICE)
                ybv = yb_va[s:s+batch_size].to(DEVICE)
                logits = model(xbv)
                loss = crit(logits, ybv)
                val_loss += loss.item() * xbv.size(0); nval += xbv.size(0)
            val_loss /= max(1, nval)

        if val_loss + 1e-7 < best_val:
            best_val = val_loss; bad = 0
            best_state = copy.deepcopy(model.state_dict())
        else:
            bad += 1
            if bad >= patience: break

    model.load_state_dict(best_state)
    return model

@torch.no_grad()
def predict_numpy(model, X_np_TxC, device):
    model.eval()
    xb = torch.from_numpy(X_np_TxC).float().unsqueeze(1).to(device)  # (B,1,T,C)
    logits = model(xb)
    return logits.argmax(dim=1).cpu().numpy()

def subject_cv_finetune_predict_progressive(model_global, Xs_TxC, ys, device,
                                            n_splits=CALIB_CV_FOLDS, n_classes=4):
    from sklearn.model_selection import StratifiedKFold
    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=RANDOM_STATE)
    y_true_full = np.empty_like(ys); y_pred_full = np.empty_like(ys)
    for tr_idx, te_idx in skf.split(Xs_TxC, ys):
        Xcal, ycal = Xs_TxC[tr_idx], ys[tr_idx]
        Xho,  yho  = Xs_TxC[te_idx], ys[te_idx]

        # OUT
        m_out = copy.deepcopy(model_global)
        _train_one_mode(m_out, Xcal, ycal, n_classes, mode='out',
                        epochs=FT_EPOCHS, head_lr=FT_HEAD_LR, l2sp_lambda=FT_L2SP,
                        patience=FT_PATIENCE, val_ratio=FT_VAL_RATIO)
        yhat_out = predict_numpy(m_out, Xho, device); acc_out = (yhat_out == yho).mean()

        # HEAD
        m_head = copy.deepcopy(model_global)
        _train_one_mode(m_head, Xcal, ycal, n_classes, mode='head',
                        epochs=FT_EPOCHS, head_lr=FT_HEAD_LR, l2sp_lambda=FT_L2SP,
                        patience=FT_PATIENCE, val_ratio=FT_VAL_RATIO)
        yhat_head = predict_numpy(m_head, Xho, device); acc_head = (yhat_head == yho).mean()

        # SPATIAL+HEAD
        m_sp = copy.deepcopy(model_global)
        _train_one_mode(m_sp, Xcal, ycal, n_classes, mode='spatial+head',
                        epochs=FT_EPOCHS, head_lr=FT_HEAD_LR, base_lr=FT_BASE_LR,
                        l2sp_lambda=FT_L2SP, patience=FT_PATIENCE, val_ratio=FT_VAL_RATIO)
        yhat_sp = predict_numpy(m_sp, Xho, device); acc_sp = (yhat_sp == yho).mean()

        best_idx = int(np.argmax([acc_out, acc_head, acc_sp]))
        yhat_best = [yhat_out, yhat_head, yhat_sp][best_idx]

        y_true_full[te_idx] = yho; y_pred_full[te_idx] = yhat_best

    return y_true_full, y_pred_full

# =========================
# TRAIN/EVAL por FOLD (con train/val/test desde Kfold5.json)
# =========================
def train_one_fold(fold:int, device):
    # sujetos del fold (solo usamos listas train/test del JSON)
    train_sub, test_sub = load_fold_subjects(FOLDS_JSON, fold)
    train_sub = [s for s in train_sub if subject_id_to_int(s) not in EXCLUDE_SUBJECTS]
    test_sub  = [s for s in test_sub  if subject_id_to_int(s) not in EXCLUDE_SUBJECTS]

    # split de validación por sujetos desde train (estratificado por etiqueta dominante como en tu CNN+TF)
    rng = np.random.RandomState(RANDOM_STATE + fold)
    tr_subjects = sorted(train_sub)
    if VAL_STRAT_SUBJECT and len(tr_subjects) > 1:
        y_dom = build_subject_label_map(tr_subjects)
        if np.any(y_dom < 0):
            mask = y_dom >= 0
            moda = int(np.bincount(y_dom[mask]).argmax()) if mask.sum() > 0 else 0
            y_dom[~mask] = moda
        n_val_subj = max(1, int(round(len(tr_subjects) * VAL_SUBJECT_FRAC)))
        sss = StratifiedShuffleSplit(n_splits=1, test_size=n_val_subj, random_state=RANDOM_STATE + fold)
        idx = np.arange(len(tr_subjects))
        _, val_idx = next(sss.split(idx, y_dom))
        val_subjects = sorted([tr_subjects[i] for i in val_idx])
        train_subjects = [s for s in tr_subjects if s not in val_subjects]
    else:
        tr_subjects_shuf = tr_subjects.copy()
        rng.shuffle(tr_subjects_shuf)
        n_val_subj = max(1, int(round(len(tr_subjects_shuf) * VAL_SUBJECT_FRAC)))
        val_subjects   = sorted(tr_subjects_shuf[:n_val_subj])
        train_subjects = sorted(tr_subjects_shuf[n_val_subj:])

    # cargar datos (SIN balanceo: todos los ensayos)
    X_tr_list, y_tr_list = [], []
    X_val_list, y_val_list = [], []
    X_te_list, y_te_list = [], []
    sfreq = None

    for sid in tqdm(train_subjects, desc=f"Cargando train fold{fold}"):
        Xs, ys, sf = load_subject_epochs(sid)
        if len(ys) == 0: continue
        X_tr_list.append(Xs); y_tr_list.append(ys); sfreq = sf if sfreq is None else sfreq

    for sid in tqdm(val_subjects, desc=f"Cargando val fold{fold}"):
        Xs, ys, sf = load_subject_epochs(sid)
        if len(ys) == 0: continue
        X_val_list.append(Xs); y_val_list.append(ys); sfreq = sf if sfreq is None else sfreq

    for sid in tqdm(test_sub, desc=f"Cargando test fold{fold}"):
        Xs, ys, sf = load_subject_epochs(sid)
        if len(ys) == 0: continue
        X_te_list.append(Xs); y_te_list.append(ys); sfreq = sf if sfreq is None else sfreq

    # concatenar
    X_tr = np.concatenate(X_tr_list, axis=0); y_tr = np.concatenate(y_tr_list, axis=0)
    X_val = np.concatenate(X_val_list, axis=0); y_val = np.concatenate(y_val_list, axis=0)
    X_te  = np.concatenate(X_te_list,  axis=0); y_te  = np.concatenate(y_te_list,  axis=0)

    print(f"[Fold {fold}/5] Entrenando modelo global... (n_train={len(y_tr)} | n_val={len(y_val)} | n_test={len(y_te)})")

    # estandarización por canal con stats de TRAIN
    X_tr_std, X_val_std = standardize_per_channel_trainfit(X_tr, X_val)
    _,        X_te_std  = standardize_per_channel_trainfit(X_tr, X_te)

    # datasets → tensores de entrada (B,1,T,C) para EEGNet
    def to_tensor_1tC(X_TxC, y_idx):
        xb = torch.tensor(X_TxC, dtype=torch.float32).unsqueeze(1)  # (B,1,T,C)
        yb = torch.tensor(y_idx, dtype=torch.long)
        return TensorDataset(xb, yb)

    tr_ds  = to_tensor_1tC(X_tr_std, y_tr)
    val_ds = to_tensor_1tC(X_val_std, y_val)
    te_ds  = to_tensor_1tC(X_te_std,  y_te)

    tr_ld  = DataLoader(tr_ds, batch_size=BATCH_SIZE, shuffle=True,  drop_last=False, worker_init_fn=seed_worker)
    val_ld = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, worker_init_fn=seed_worker)
    te_ld  = DataLoader(te_ds,  batch_size=BATCH_SIZE, shuffle=False, drop_last=False, worker_init_fn=seed_worker)

    # modelo + opt + scheduler
    n_classes = 4
    model = EEGNet(n_ch=8, n_classes=n_classes,
                   F1=24, D=2, kernel_t=64, k_sep=16,
                   pool1_t=4, pool2_t=6, drop1_p=0.35, drop2_p=0.6,
                   chdrop_p=0.10).to(device)

    opt = optim.Adam(model.parameters(), lr=LR_INIT)
    scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(opt, T_0=SGDR_T0, T_mult=SGDR_Tmult)

    # criterio suave ponderado (+20% Both Fists) con label smoothing
    class_weights = make_class_weight_tensor(y_tr, n_classes, boost_bfists=1.25, boost_bfeet=1.05, device=device)
    criterion_soft = WeightedSoftCrossEntropy(class_weights, label_smoothing=0.05)

    # métricas rápidas
    @torch.no_grad()
    def eval_acc(loader):
        model.eval()
        y_true, y_pred = [], []
        for xb, yb in loader:
            xb = xb.to(device)
            # TTA en validación? Normalmente NO (dejamos off)
            logits = model(xb)
            pred = logits.argmax(dim=1).cpu().numpy()
            y_true.append(yb.numpy()); y_pred.append(pred)
        y_true = np.concatenate(y_true); y_pred = np.concatenate(y_pred)
        return (y_true == y_pred).mean()

    # entrenamiento global (con augments estilo CNN+TF)
    history = {'train_acc': [], 'val_acc': []}
    best_state = copy.deepcopy(model.state_dict()); best_val = -1.0; bad = 0

    for epoch in range(1, EPOCHS_GLOBAL + 1):
        model.train()
        correct, seen = 0, 0
        for xb, yb in tr_ld:
            xb = xb.to(device); yb = yb.to(device)
            # AUGMENTS (mismas ideas que CNN+TF)
            xb = augment_batch_eegnet(xb,
                                      p_jitter=0.35, p_noise=0.35, p_chdrop=0.15,
                                      max_jitter_frac=0.03, noise_std=0.03, max_chdrop=1)
            # loss suave (one-hot)
            yt = torch.nn.functional.one_hot(yb, num_classes=n_classes).float()
            opt.zero_grad(set_to_none=True)
            logits = model(xb)
            loss = criterion_soft(logits, yt)
            loss.backward(); opt.step()
            apply_max_norm(model, max_value=2.0, p=2.0)
            correct += (logits.argmax(1) == yb).sum().item()
            seen += yb.size(0)
        scheduler.step(epoch - 1 + 1e-8)

        tr_acc = correct / max(1, seen)
        va_acc = eval_acc(val_ld)
        history['train_acc'].append(tr_acc)
        history['val_acc'].append(va_acc)

        if (epoch % LOG_EVERY == 0) or epoch in (1, 10, 20, 50, 100):
            cur_lr = opt.param_groups[0]['lr']
            print(f"  Época {epoch:3d} | train_acc={tr_acc:.4f} | val_acc={va_acc:.4f} | LR={cur_lr:.5f}")

        if va_acc > best_val + 1e-4:
            best_val = va_acc; best_state = copy.deepcopy(model.state_dict()); bad = 0
        else:
            bad += 1
            if bad >= GLOBAL_PATIENCE:
                print(f"  Early stopping en época {epoch} (mejor val_acc={best_val:.4f})")
                break

    # guardar curva
    curve_path = f"training_curve_fold{fold}.png"
    plot_training_curves(history, curve_path)
    print(f"↳ Curva de entrenamiento guardada: {curve_path}")

    # cargar mejor estado
    model.load_state_dict(best_state)

    # ===== EVALUACIÓN EN TEST (TTA estilo CNN+TF) =====
    if (not SW_ENABLE) or SW_MODE == 'none':
        y_true, y_pred = [], []
        with torch.no_grad():
            for xb, yb in te_ld:
                xb = xb.to(device)
                logits = model(xb)
                pred = logits.argmax(1).cpu().numpy()
                y_true.append(yb.numpy()); y_pred.append(pred)
        y_true = np.concatenate(y_true); y_pred = np.concatenate(y_pred)
    elif SW_MODE in ('tta', 'subwin'):
        sfreq_used = int(round(X_te_std.shape[1] / (TMAX - TMIN)))  # T = X_te_std.shape[1]
        logits_tta = None; logits_sw = None
        if SW_MODE == 'tta':
            logits_tta = time_shift_tta_logits_eegnet(model, X_te_std, sfreq_used, TTA_SHIFTS_S, device)
        else:
            # subwindow opcional; si quisieras activarlo, implementa análogo a tu CNN+TF
            raise NotImplementedError("SW_MODE='subwin' no implementado aquí (se pidió TTA).")
        logits = logits_tta
        y_pred = logits.argmax(axis=1); y_true = y_te
    else:
        raise ValueError(f"SW_MODE desconocido: {SW_MODE}")

    acc_global = accuracy_score(y_true, y_pred)
    f1m_global = f1_score(y_true, y_pred, average='macro')
    print(f"[Fold {fold}/5] Global acc={acc_global:.4f}")
    print(classification_report(y_true, y_pred, target_names=CLASS_NAMES_4C, digits=4))
    print("Confusion matrix (rows=true, cols=pred):")
    print(confusion_matrix(y_true, y_pred, labels=[0,1,2,3]))

    # ---------- Fine-tuning PROGRESIVO por sujeto (igual que antes) ----------
    # agrupamos por sujeto en el TEST
    # reconstruimos arrays por sujeto usando las mismas X_te_std/y_te
    # Primero necesitamos el vector de sujetos por ensayo:
    sub_te = []
    for sid in test_sub:
        Xs, ys, _ = load_subject_epochs(sid)
        if len(ys) == 0: continue
        sub_te += [subject_id_to_int(sid)] * len(ys)
    sub_te = np.array(sub_te, dtype=int)

    y_true_ft_all, y_pred_ft_all = [], []
    used_subjects = 0
    for sid in np.unique(sub_te):
        idx = np.where(sub_te == sid)[0]
        Xs_TxC = X_te_std[idx]
        ys_sub = y_te[idx]
        if len(ys_sub) < CALIB_CV_FOLDS or len(np.unique(ys_sub)) < 2:
            continue
        y_true_subj, y_pred_subj = subject_cv_finetune_predict_progressive(
            model, Xs_TxC, ys_sub, DEVICE, n_splits=CALIB_CV_FOLDS, n_classes=4
        )
        y_true_ft_all.append(y_true_subj); y_pred_ft_all.append(y_pred_subj)
        used_subjects += 1

    if len(y_true_ft_all) > 0:
        y_true_ft_all = np.concatenate(y_true_ft_all)
        y_pred_ft_all = np.concatenate(y_pred_ft_all)
        acc_ft = (y_true_ft_all == y_pred_ft_all).mean()
        print(f"  Fine-tuning PROGRESIVO (por sujeto, {CALIB_CV_FOLDS}-fold CV) acc={acc_ft:.4f} | sujetos={used_subjects}")
        print(f"  Δ(FT-Global) = {acc_ft - acc_global:+.4f}")
    else:
        acc_ft = np.nan
        print("  Fine-tuning PROGRESIVO no ejecutado (sujeto(s) con muestras insuficientes).")

    # matriz de confusión global (sin FT)
    cm_png = f"confusion_global_fold{fold}.png"
    plot_confusion(y_true, y_pred, CLASS_NAMES_4C, title=f"Confusion Matrix - Fold {fold}", fname=cm_png)
    print(f"↳ Matriz de confusión guardada: {cm_png}")

    return acc_global, f1m_global, acc_ft

# =========================
# LOOP 5 FOLDS + RESUMEN
# =========================
if __name__ == "__main__":
    print("🧠 INICIANDO EXPERIMENTO CON EEGNet + Augments/TTA estilo CNN+TF (sin balanceo)")
    print(f"🔧 Configuración: 4c, 8 canales, 6s | EPOCHS_GLOBAL={EPOCHS_GLOBAL}, BATCH={BATCH_SIZE}, LR={LR_INIT} | ZSCORE_PER_EPOCH={NORM_EPOCH_ZSCORE}")
    acc_folds, f1_folds = [], []
    ft_folds = []

    for fold in range(1, 6):
        acc, f1m, acc_ft = train_one_fold(fold, DEVICE)  # <-- recibe acc_ft
        acc_folds.append(f"{acc:.4f}")
        f1_folds.append(f"{f1m:.4f}")
        # Permite NaN si algún sujeto no tuvo muestras suficientes para FT
        ft_folds.append("nan" if (acc_ft is None or (isinstance(acc_ft, float) and np.isnan(acc_ft))) else f"{acc_ft:.4f}")

    acc_mean = float(np.mean([float(a) for a in acc_folds]))
    f1_mean  = float(np.mean([float(f) for f in f1_folds]))
    # Media de FT ignorando 'nan'
    ft_vals = [float(x) for x in ft_folds if x != "nan"]
    ft_mean = float(np.mean(ft_vals)) if len(ft_vals) > 0 else float("nan")
    delta_mean = (ft_mean - acc_mean) if len(ft_vals) > 0 else float("nan")

    print("\n============================================================")
    print("RESULTADOS FINALES")
    print("============================================================")
    print(f"Global folds (ACC): {acc_folds}")
    print(f"Global mean ACC: {acc_mean:.4f}")
    print(f"F1 folds (MACRO): {f1_folds}")
    print(f"F1 mean (MACRO): {f1_mean:.4f}")
    # ---- NUEVO BLOQUE RESUMEN FT ----
    print(f"Fine-tune PROGRESIVO folds: {ft_folds}")
    if len(ft_vals) > 0:
        print(f"Fine-tune PROGRESIVO mean: {ft_mean:.4f}")
        print(f"Δ(FT-Global) mean: {delta_mean:+.4f}")
    else:
        print("Fine-tune PROGRESIVO mean: nan")
        print("Δ(FT-Global) mean: nan")



🚀 Usando dispositivo: cuda
🧠 INICIANDO EXPERIMENTO CON EEGNet (mismo Kfold y TTA/Augments que CNN+TF)
🧠 INICIANDO EXPERIMENTO CON EEGNet + Augments/TTA estilo CNN+TF (sin balanceo)
🔧 Configuración: 4c, 8 canales, 6s | EPOCHS_GLOBAL=100, BATCH=64, LR=0.01 | ZSCORE_PER_EPOCH=True


Cargando train fold1: 100%|██████████| 67/67 [00:08<00:00,  7.62it/s]
Cargando val fold1: 100%|██████████| 15/15 [00:02<00:00,  7.30it/s]
Cargando test fold1: 100%|██████████| 21/21 [00:02<00:00,  7.13it/s]


[Fold 1/5] Entrenando modelo global... (n_train=5628 | n_val=1260 | n_test=1764)
  Época   1 | train_acc=0.3520 | val_acc=0.3944 | LR=0.01000
  Época   5 | train_acc=0.4232 | val_acc=0.4310 | LR=0.00250
  Época  10 | train_acc=0.4351 | val_acc=0.4246 | LR=0.00854
  Época  15 | train_acc=0.4531 | val_acc=0.4452 | LR=0.00250
  Época  20 | train_acc=0.4474 | val_acc=0.4286 | LR=0.00996
  Época  25 | train_acc=0.4439 | val_acc=0.4349 | LR=0.00854
  Early stopping en época 25 (mejor val_acc=0.4452)
↳ Curva de entrenamiento guardada: training_curve_fold1.png
[Fold 1/5] Global acc=0.4507
              precision    recall  f1-score   support

        Left     0.6461    0.3560    0.4591       441
       Right     0.5372    0.4422    0.4851       441
  Both Fists     0.3447    0.5283    0.4172       441
   Both Feet     0.4357    0.4762    0.4550       441

    accuracy                         0.4507      1764
   macro avg     0.4909    0.4507    0.4541      1764
weighted avg     0.4909    0.450

Cargando train fold2: 100%|██████████| 67/67 [00:05<00:00, 11.34it/s]
Cargando val fold2: 100%|██████████| 15/15 [00:01<00:00, 11.63it/s]
Cargando test fold2: 100%|██████████| 21/21 [00:01<00:00, 11.35it/s]


[Fold 2/5] Entrenando modelo global... (n_train=5628 | n_val=1260 | n_test=1764)
  Época   1 | train_acc=0.3109 | val_acc=0.3683 | LR=0.01000
  Época   5 | train_acc=0.3850 | val_acc=0.4484 | LR=0.00250
  Época  10 | train_acc=0.3984 | val_acc=0.4238 | LR=0.00854
  Época  15 | train_acc=0.4227 | val_acc=0.4500 | LR=0.00250
  Época  20 | train_acc=0.4152 | val_acc=0.4444 | LR=0.00996
  Época  25 | train_acc=0.4129 | val_acc=0.4365 | LR=0.00854
  Época  30 | train_acc=0.4241 | val_acc=0.4532 | LR=0.00565
  Época  35 | train_acc=0.4314 | val_acc=0.4492 | LR=0.00250
  Época  40 | train_acc=0.4456 | val_acc=0.4532 | LR=0.00038
  Early stopping en época 44 (mejor val_acc=0.4611)
↳ Curva de entrenamiento guardada: training_curve_fold2.png
[Fold 2/5] Global acc=0.5079
              precision    recall  f1-score   support

        Left     0.6480    0.4717    0.5459       441
       Right     0.6559    0.4626    0.5426       441
  Both Fists     0.3901    0.6599    0.4903       441
   Both Feet

Cargando train fold3: 100%|██████████| 67/67 [00:07<00:00,  9.00it/s]
Cargando val fold3: 100%|██████████| 15/15 [00:01<00:00,  9.20it/s]
Cargando test fold3: 100%|██████████| 21/21 [00:02<00:00,  7.84it/s]


[Fold 3/5] Entrenando modelo global... (n_train=5628 | n_val=1260 | n_test=1764)
  Época   1 | train_acc=0.3264 | val_acc=0.4381 | LR=0.01000
  Época   5 | train_acc=0.4243 | val_acc=0.4817 | LR=0.00250
  Época  10 | train_acc=0.4248 | val_acc=0.4810 | LR=0.00854
  Época  15 | train_acc=0.4471 | val_acc=0.4960 | LR=0.00250
  Época  20 | train_acc=0.4383 | val_acc=0.4770 | LR=0.00996
  Época  25 | train_acc=0.4367 | val_acc=0.4825 | LR=0.00854
  Early stopping en época 25 (mejor val_acc=0.4960)
↳ Curva de entrenamiento guardada: training_curve_fold3.png
[Fold 3/5] Global acc=0.4700
              precision    recall  f1-score   support

        Left     0.6026    0.4195    0.4947       441
       Right     0.6341    0.4127    0.5000       441
  Both Fists     0.3757    0.6032    0.4630       441
   Both Feet     0.4242    0.4444    0.4341       441

    accuracy                         0.4700      1764
   macro avg     0.5092    0.4700    0.4729      1764
weighted avg     0.5092    0.470

Cargando train fold4: 100%|██████████| 68/68 [00:06<00:00, 11.30it/s]
Cargando val fold4: 100%|██████████| 15/15 [00:01<00:00, 11.50it/s]
Cargando test fold4: 100%|██████████| 20/20 [00:01<00:00, 10.82it/s]


[Fold 4/5] Entrenando modelo global... (n_train=5712 | n_val=1260 | n_test=1680)
  Época   1 | train_acc=0.3270 | val_acc=0.4341 | LR=0.01000
  Época   5 | train_acc=0.4027 | val_acc=0.4683 | LR=0.00250
  Época  10 | train_acc=0.4210 | val_acc=0.4698 | LR=0.00854
  Época  15 | train_acc=0.4275 | val_acc=0.4698 | LR=0.00250
  Época  20 | train_acc=0.4259 | val_acc=0.4635 | LR=0.00996
  Época  25 | train_acc=0.4272 | val_acc=0.4833 | LR=0.00854
  Época  30 | train_acc=0.4389 | val_acc=0.4873 | LR=0.00565
  Época  35 | train_acc=0.4611 | val_acc=0.4802 | LR=0.00250
  Época  40 | train_acc=0.4582 | val_acc=0.4952 | LR=0.00038
  Early stopping en época 41 (mejor val_acc=0.4992)
↳ Curva de entrenamiento guardada: training_curve_fold4.png
[Fold 4/5] Global acc=0.4857
              precision    recall  f1-score   support

        Left     0.5932    0.4548    0.5148       420
       Right     0.6271    0.4524    0.5256       420
  Both Fists     0.3776    0.6429    0.4758       420
   Both Feet

Cargando train fold5: 100%|██████████| 68/68 [00:09<00:00,  7.32it/s]
Cargando val fold5: 100%|██████████| 15/15 [00:01<00:00,  9.30it/s]
Cargando test fold5: 100%|██████████| 20/20 [00:02<00:00,  8.62it/s]


[Fold 5/5] Entrenando modelo global... (n_train=5712 | n_val=1260 | n_test=1680)
  Época   1 | train_acc=0.3328 | val_acc=0.3937 | LR=0.01000
  Época   5 | train_acc=0.4090 | val_acc=0.4103 | LR=0.00250
  Época  10 | train_acc=0.4039 | val_acc=0.4127 | LR=0.00854
  Época  15 | train_acc=0.4319 | val_acc=0.4278 | LR=0.00250
  Early stopping en época 19 (mejor val_acc=0.4341)
↳ Curva de entrenamiento guardada: training_curve_fold5.png
[Fold 5/5] Global acc=0.5089
              precision    recall  f1-score   support

        Left     0.6087    0.5333    0.5685       420
       Right     0.5455    0.5000    0.5217       420
  Both Fists     0.4173    0.5643    0.4798       420
   Both Feet     0.5125    0.4381    0.4724       420

    accuracy                         0.5089      1680
   macro avg     0.5210    0.5089    0.5106      1680
weighted avg     0.5210    0.5089    0.5106      1680

Confusion matrix (rows=true, cols=pred):
[[224  31 123  42]
 [ 21 210 105  84]
 [ 76  58 237  49]
 