# 4 Clases

### Configuración Inicial y Reproducibilidad

In [19]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "7"
os.environ['PYTHONHASHSEED'] = '42'
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'

import re, json, random, time, copy
from pathlib import Path
from glob import glob
import numpy as np
import mne
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset, WeightedRandomSampler
from sklearn.metrics import (
    accuracy_score,
    f1_score,
    confusion_matrix,
    classification_report,
    precision_score,
    recall_score,
)
from sklearn.model_selection import StratifiedShuffleSplit, StratifiedKFold
from sklearn.manifold import TSNE
from scipy import stats

### Configuración de Reproducibilidad

Esta sección establece las semillas y configuraciones necesarias para garantizar resultados reproducibles en el entrenamiento del modelo para clasificación de 4 clases.

In [20]:
# =========================
# REPRODUCIBILIDAD
# =========================
RANDOM_STATE = 42

def seed_everything(seed: int = 42):
    """
    Establece semillas para todas las bibliotecas de aleatoriedad
    para garantizar reproducibilidad completa.
    """
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    try:
        torch.backends.cuda.matmul.allow_tf32 = False
        torch.backends.cudnn.allow_tf32 = False
    except Exception:
        pass
    try:
        torch.use_deterministic_algorithms(True)
    except Exception:
        pass

def seed_worker(worker_id: int):
    """
    Función para inicializar workers de DataLoader con semilla.
    """
    worker_seed = RANDOM_STATE + worker_id
    np.random.seed(worker_seed)
    random.seed(worker_seed)

seed_everything(RANDOM_STATE)

### Configuración de Hiperparámetros para 4 Clases

Definición de todos los hiperparámetros del modelo y preprocesamiento de datos EEG para clasificación de 4 clases.

In [21]:
# =========================
# CONFIG
# =========================
PROJ = Path('..').resolve().parent
DATA_RAW = PROJ / 'data' / 'raw'
FOLDS_JSON = PROJ / 'models' / '00_folds' / 'Kfold5.json'

# Directorio de salida para 4 clases (MODIFICACIÓN)
OUTPUT_DIR = Path("4clases")

# Hiperparámetros de entrenamiento
EPOCHS = 60
BATCH_SIZE = 64
BASE_LR = 1e-3
WARMUP_EPOCHS = 8
PATIENCE = 12

# Split de validación por fold (por sujetos)
VAL_SUBJECT_FRAC = 0.18
VAL_STRAT_SUBJECT = True

# Prepro global
RESAMPLE_HZ = None
DO_NOTCH = True
DO_BANDPASS = False
BP_LO, BP_HI = 4.0, 38.0
DO_CAR = False
ZSCORE_PER_EPOCH = False

# Parámetros del modelo
D_MODEL = 128
N_HEADS = 4
N_LAYERS = 2
P_DROP = 0.2
P_DROP_ENCODER = 0.3

# Ventana temporal
TMIN, TMAX = -1.0, 5.0

# TTA / SUBWINDOW en TEST
SW_MODE = 'tta'   # 'none'|'subwin'|'tta'
SW_ENABLE = True
TTA_SHIFTS_S = [-0.075, -0.05, -0.025, 0.0, 0.025, 0.05, 0.075]
SW_LEN, SW_STRIDE = 4.5, 1.5
COMBINE_TTA_AND_SUBWIN = False

# Sampler balanceado
USE_WEIGHTED_SAMPLER = True

# EMA (solo GLOBAL). En FT se desactiva.
USE_EMA = True
EMA_DECAY = 0.9995

# Fine-tuning (por sujeto)
FT_N_FOLDS = 4
FT_FREEZE_EPOCHS = 8           # etapa 1 (congelar backbone)
FT_UNFREEZE_EPOCHS = 8         # etapa 2 (descongelar)
FT_PATIENCE = 6                # early stopping en FT
FT_BATCH = 64
FT_LR_HEAD = 1e-3              # cabeza (5x aprox del backbone)
FT_LR_BACKBONE = 2e-4          # backbone
FT_WD = 1e-3                   # weight decay menor en FT
FT_AUG = dict(p_jitter=0.25, p_noise=0.25, p_chdrop=0.10,
              max_jitter_frac=0.02, noise_std=0.02, max_chdrop=1)
FT_ALPHA_BOOST_FISTS = 1.25
FT_ALPHA_BOOST_FEET  = 1.05

# Configuración de sujetos y canales
EXCLUDE_SUBJECTS = {38, 88, 89, 92, 100, 104}
EXPECTED_8 = ['C3','C4','Cz','CP3','CP4','FC3','FC4','FCz']
CLASS_NAMES = ['left', 'right', 'both_fists', 'both_feet']

# Runs específicos para cada tipo de imaginación
IMAGERY_RUNS_LR = {4, 8, 12}
IMAGERY_RUNS_BF = {6, 10, 14}

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f" Usando dispositivo: {DEVICE}")
print(" INICIANDO EXPERIMENTO CON CNN+Transformer (K-Fold por sujeto como EEGNet)")
print(f" Configuración: 4c, 8 canales, 6s | EPOCHS={EPOCHS}, BATCH={BATCH_SIZE}, LR={BASE_LR} | ZSCORE_PER_EPOCH={ZSCORE_PER_EPOCH}")

# Tag y FLOPs aproximados para la tabla final
MODEL_TAG = "nb4_h2_4c"
MODEL_FLOPS_G = 0.03

 Usando dispositivo: cuda
 INICIANDO EXPERIMENTO CON CNN+Transformer (K-Fold por sujeto como EEGNet)
 Configuración: 4c, 8 canales, 6s | EPOCHS=60, BATCH=64, LR=0.001 | ZSCORE_PER_EPOCH=False


### Funciones de Utilidad para I/O

Funciones para manejar la carga y normalización de datos EEG.

In [22]:
# =========================
# UTILIDADES I/O
# =========================
def normalize_ch_name(name: str) -> str:
    """
    Normaliza nombres de canales eliminando caracteres especiales.
    """
    s = re.sub(r'[^A-Za-z0-9]', '', name)
    return s.upper()

NORMALIZED_TARGETS = [normalize_ch_name(c) for c in EXPECTED_8]

def pick_8_channels(raw: mne.io.BaseRaw) -> mne.io.BaseRaw:
    """
    Selecciona los 8 canales motores específicos del dataset.
    """
    chs = raw.info['ch_names']
    norm_map = {normalize_ch_name(ch): ch for ch in chs}
    picked = []
    for target_norm, target_orig in zip(NORMALIZED_TARGETS, EXPECTED_8):
        if target_norm in norm_map:
            picked.append(norm_map[target_norm])
        else:
            raise RuntimeError(f"Canal requerido '{target_orig}' no encontrado. Disponibles: {chs}")
    return raw.pick(picks=picked)

def list_subject_imagery_edfs(subject_id: str) -> list:
    """
    Lista archivos EDF de imaginación motora para un sujeto.
    """
    subj_dir = DATA_RAW / subject_id
    edfs = []
    for r in [4, 6, 8, 10, 12, 14]:
        edfs.extend(glob(str(subj_dir / f"{subject_id}R{r:02d}.edf")))
    return sorted(edfs)

def subject_id_to_int(s: str) -> int:
    """
    Convierte ID de sujeto de string a entero.
    """
    m = re.match(r'[Ss](\d+)', s)
    return int(m.group(1)) if m else -1

### Funciones de Medición de Latencia y P-Values

Herramientas para medir rendimiento computacional y análisis estadístico.

In [23]:
# =========================
# Medición de latencia (ms por batch)
# =========================
def estimate_latency_ms(model, example_batch, device, n_warmup=10, n_iters=50):
    """
    Estima la latencia media (ms) de un forward pass para un batch dado.
    example_batch: numpy array (B, C, T)
    """
    model.eval()
    xb = torch.tensor(example_batch, dtype=torch.float32, device=device)

    # Warmup (para estabilizar GPU/cuDNN)
    with torch.no_grad():
        for _ in range(n_warmup):
            _ = model(xb)
    if torch.cuda.is_available():
        torch.cuda.synchronize(device)

    # Medición
    t0 = time.perf_counter()
    with torch.no_grad():
        for _ in range(n_iters):
            _ = model(xb)
    if torch.cuda.is_available():
        torch.cuda.synchronize(device)

    dt = (time.perf_counter() - t0) / max(1, n_iters)
    return dt * 1000.0  # ms


def compute_channel_pvalues_4c(X, y, n_classes=4):
    """
    P-values por canal para diferencia entre 4 clases.
    Usa potencia media (X^2) promediada en el tiempo como feature.
    X: (N, C, T)
    y: (N,) con labels 0..3
    Devuelve: vector (C,) con p-values (ANOVA 1-vía).
    """
    X = np.asarray(X, dtype=float)
    y = np.asarray(y, dtype=int)

    if X.ndim != 3:
        raise ValueError("X debe ser (N, C, T)")

    N, C, T = X.shape
    # feature = potencia media por trial y canal
    feats = (X ** 2).mean(axis=2)   # (N, C)

    pvals = np.ones(C, dtype=float)

    for ch in range(C):
        grupos = [feats[y == cls, ch] for cls in range(n_classes)
                  if np.any(y == cls)]
        # si hay menos de 2 grupos con muestras, el test no tiene sentido
        if len(grupos) < 2:
            pvals[ch] = 1.0
            continue
        _, p = stats.f_oneway(*grupos)
        pvals[ch] = p

    return pvals

### Funciones de Saliency y Topomaps

Herramientas para visualizar importancia de canales y crear mapas topográficos.

In [24]:
# =========================
# Channel saliency (Grad*Input agregado por canal) + topomap
# =========================
def compute_channel_saliency(
    model,
    X,              # numpy (N, C, T) estandarizado (test)
    device,
    max_trials=256,
    batch_size=32,
):
    """
    Importancia por canal usando |Grad*Input|, promediado
    sobre tiempo y sobre un subconjunto de trials.
    Devuelve vector (C,) normalizado [0,1].

    Esta es la magnitud detrás del topomap
    "Channel importance (saliency)" que usaste en 2 clases.
    """
    X = np.asarray(X, dtype=np.float32)
    N, C, T = X.shape
    if N == 0:
        raise ValueError("X vacío en compute_channel_saliency")

    n_use = min(max_trials, N)
    rng = np.random.RandomState(RANDOM_STATE + 777)
    idx = rng.choice(np.arange(N), size=n_use, replace=False)

    model.eval()
    saliency_sum = np.zeros(C, dtype=np.float64)
    n_batches = 0

    for start in range(0, n_use, batch_size):
        batch_idx = idx[start:start+batch_size]
        xb_np = X[batch_idx]   # (B, C, T)

        xb = torch.tensor(xb_np, dtype=torch.float32, device=device)
        xb.requires_grad_(True)

        model.zero_grad(set_to_none=True)
        logits = model(xb)                      # (B, n_cls)
        preds = logits.argmax(dim=1)           # clase objetivo = predicha

        # score = suma de logits de la clase predicha
        score = logits[torch.arange(len(preds), device=device), preds].sum()
        score.backward()

        grad = xb.grad.detach()                # (B, C, T)
        contrib = (grad * xb).abs().sum(dim=-1)        # (B, C)
        saliency_batch = contrib.mean(dim=0).detach().cpu().numpy()  # (C,)


        saliency_sum += saliency_batch
        n_batches += 1

    saliency = saliency_sum / max(1, n_batches)
    # normalizar a [0,1] para el topomap
    saliency = saliency - saliency.min()
    maxv = saliency.max()
    if maxv > 0:
        saliency = saliency / maxv
    return saliency.astype(np.float32)


def save_channel_saliency_topomap(
    model,
    X_te_std,
    sfreq_used,
    fold,
    run_tag,
    output_dir: Path = OUTPUT_DIR
):
    """
    Calcula y guarda topomap de saliency por canal.
    MODIFICACIÓN: Guarda en directorio de 4 clases.
    """
    # Crear directorio del fold
    fold_dir = output_dir / f"fold{fold}"
    fold_dir.mkdir(parents=True, exist_ok=True)
    
    saliency = compute_channel_saliency(
        model=model,
        X=X_te_std,
        device=DEVICE,
        max_trials=256,
        batch_size=32,
    )  # (C,)

    info_topo = make_mne_info_8ch(sfreq_used)

    out_path = fold_dir / f"topomap_saliency_fold{fold}_{run_tag}.png"
    title = f"Channel importance (saliency) — Fold {fold} [{run_tag}]"

    plot_topomap_from_values(
        saliency,
        info_topo,
        title=title,
        out_path=out_path,
        cmap="Reds",
        cbar_label="Normalized saliency",
    )
    print(f"↳ Topomap saliency guardado: {out_path}")
    return saliency



def make_mne_info_8ch(sfreq: float):
    """
    Crea estructura de información de MNE para los 8 canales.
    """
    info = mne.create_info(
        ch_names=EXPECTED_8,
        sfreq=float(sfreq),
        ch_types="eeg"
    )
    try:
        montage = mne.channels.make_standard_montage("standard_1020")
        info.set_montage(montage)
    except Exception:
        pass
    return info


def plot_topomap_from_values(values, info, title: str, out_path: str,
                             cmap="viridis", cbar_label="Value",
                             vmin=None, vmax=None):
    """
    Dibuja y guarda un topomap dado un vector de valores por canal.
    MODIFICACIÓN: Acepta Path o string para out_path.
    """
    values = np.asarray(values, dtype=float)
    fig, ax = plt.subplots(figsize=(4.2, 4.0))

    im, _ = mne.viz.plot_topomap(
        values, info, axes=ax, show=False, cmap=cmap
    )

    if (vmin is not None) or (vmax is not None):
        im.set_clim(vmin=vmin, vmax=vmax)

    ax.set_title(title, fontsize=10)
    cbar = plt.colorbar(im, ax=ax, shrink=0.7)
    cbar.set_label(cbar_label, fontsize=9)
    plt.tight_layout()
    fig.savefig(str(out_path), dpi=160, bbox_inches="tight")
    plt.close(fig)
    print(f"↳ Topomap guardado: {out_path}")

### Carga y Preprocesamiento de Datos

Funciones para cargar y preprocesar datos EEG de sujetos individuales.

In [25]:
def load_subject_epochs(subject_id: str, resample_hz: int, do_notch: bool,
                        do_bandpass: bool, do_car: bool, bp_lo: float, bp_hi: float):
    """
    Carga y preprocesa épocas EEG para un sujeto específico.
    Maneja las 4 clases: left (0), right (1), both_fists (2), both_feet (3).
    """
    edfs = list_subject_imagery_edfs(subject_id)
    if len(edfs) == 0:
        return np.empty((0,8,1), dtype=np.float32), np.empty((0,), dtype=int), None

    X_list, y_list, sfreq_list = [], [], []
    for edf_path in edfs:
        m = re.search(r"R(\d{2})", Path(edf_path).name)
        run = int(m.group(1)) if m else -1

        raw = mne.io.read_raw_edf(edf_path, preload=True, verbose='ERROR')
        raw = pick_8_channels(raw)

        if do_notch:
            raw.notch_filter(freqs=[60.0], picks='all', verbose='ERROR')
        if do_bandpass:
            raw.filter(l_freq=bp_lo, h_freq=bp_hi, picks='all', verbose='ERROR')
        if do_car:
            raw.set_eeg_reference('average', projection=False, verbose='ERROR')
        if resample_hz is not None and resample_hz > 0:
            raw.resample(resample_hz)

        sfreq = raw.info['sfreq']
        events, event_id = mne.events_from_annotations(raw, verbose='ERROR')
        keep = {k: v for k, v in event_id.items() if k in {'T1', 'T2'}}
        if len(keep) == 0:
            continue

        epochs = mne.Epochs(raw, events=events, event_id=keep, tmin=TMIN, tmax=TMAX,
                            baseline=None, preload=True, verbose='ERROR')
        X = epochs.get_data()

        if ZSCORE_PER_EPOCH:
            X = X.astype(np.float32)
            eps = 1e-6
            mu = X.mean(axis=2, keepdims=True)
            sd = X.std(axis=2, keepdims=True) + eps
            X = (X - mu) / sd

        ev_codes = epochs.events[:, 2]
        inv = {v: k for k, v in keep.items()}
        y_run = []
        for code in ev_codes:
            lab = inv[code]
            if run in IMAGERY_RUNS_LR:
                y_run.append(0 if lab == 'T1' else 1)
            elif run in IMAGERY_RUNS_BF:
                y_run.append(2 if lab == 'T1' else 3)
            else:
                y_run.append(-1)
        y_run = np.array(y_run, dtype=int)
        mask = y_run >= 0
        if mask.sum() == 0:
            continue

        X_list.append(X[mask])
        y_list.append(y_run[mask])
        sfreq_list.append(sfreq)

    if len(X_list) == 0:
        return np.empty((0,8,1), dtype=np.float32), np.empty((0,), dtype=int), None

    X_all = np.concatenate(X_list, axis=0).astype(np.float32)
    y_all = np.concatenate(y_list, axis=0).astype(int)

    if len(set([int(round(s)) for s in sfreq_list])) != 1:
        raise RuntimeError(f"Sampling rates inconsistentes: {sfreq_list}")

    return X_all, y_all, sfreq_list[0]

def load_fold_subjects(folds_json: Path, fold: int):
    """
    Carga sujetos de train y test para un fold específico.
    """
    with open(folds_json, 'r') as f:
        data = json.load(f)
    for item in data.get('folds', []):
        if int(item.get('fold', -1)) == int(fold):
            return list(item.get('train', [])), list(item.get('test', []))
    raise ValueError(f"Fold {fold} not found in {folds_json}")

def standardize_per_channel(train_X, other_X):
    """
    Estandariza datos por canal usando estadísticas del conjunto de entrenamiento.
    """
    C = train_X.shape[1]
    train_X = train_X.astype(np.float32)
    other_X = other_X.astype(np.float32)
    for c in range(C):
        mu = train_X[:, c, :].mean()
        sd = train_X[:, c, :].std()
        sd = sd if sd > 1e-6 else 1.0
        train_X[:, c, :] = (train_X[:, c, :] - mu) / sd
        other_X[:, c, :] = (other_X[:, c, :] - mu) / sd
    return train_X, other_X

### Funciones de Visualización Grad-CAM++ para 4 Clases

Herramientas para visualizar activaciones del modelo mediante Grad-CAM++.

In [26]:
def save_cam_grid_4classes(
    X_te_std, y_te, model_to_probe,
    sfreq_eff, fold, run_tag,
    channel_mode="mean", output_dir: Path = OUTPUT_DIR
):
    """
    Genera:
      1) Un grid EEG + Grad-CAM++ con 4 filas, una trial correcta por clase.
      2) Un gráfico separado con el Grad-CAM++ 1D de esas 4 clases.

    MODIFICACIÓN: Guarda en directorio de 4 clases.
    """
    # Crear directorio del fold
    fold_dir = output_dir / f"fold{fold}"
    fold_dir.mkdir(parents=True, exist_ok=True)
    
    model_to_probe.eval()
    rng = np.random.RandomState(RANDOM_STATE + 2000 + fold)

    # 1) Predicciones en TODO el test
    with torch.no_grad():
        logits_chunks = []
        for i in range(0, len(X_te_std), 64):
            xb = torch.tensor(X_te_std[i:i+64],
                              dtype=torch.float32,
                              device=DEVICE)
            logits_chunks.append(model_to_probe(xb).cpu().numpy())
        logits = np.concatenate(logits_chunks, axis=0)
    preds = logits.argmax(axis=1)

    class_to_idx = {}
    for c in range(4):
        # correctos de esa clase
        mask_corr = (y_te == c) & (preds == c)
        idxs = np.where(mask_corr)[0]
        if len(idxs) == 0:
            # fallback: cualquier ejemplo de esa clase
            idxs = np.where(y_te == c)[0]
            if len(idxs) == 0:
                continue
        class_to_idx[c] = int(rng.choice(idxs, size=1)[0])

    if not class_to_idx:
        print("[WARN] No hay ejemplos disponibles para ninguna clase; no se genera grid 4c.")
        return

    # =======================
    # (A) EEG + Grad-CAM++ (4 filas)
    # =======================
    n_rows = len(class_to_idx)
    fig, axes = plt.subplots(
        n_rows, 1,
        figsize=(10, 2.8 * n_rows),
        sharex=True
    )
    if n_rows == 1:
        axes = [axes]

    cams_per_class = {}

    for row, c in enumerate(sorted(class_to_idx.keys())):
        ax = axes[row]
        idx = class_to_idx[c]

        x_epoch = X_te_std[idx]  # (C, T)
        xb = torch.tensor(x_epoch[None, ...],
                          dtype=torch.float32,
                          device=DEVICE)

        # Grad-CAM++ targeteado a la clase c
        cam_vec = gradcampp_1d(model_to_probe, xb, target_class=c)
        cams_per_class[c] = cam_vec

        title = f"Correct example — class: {CLASS_NAMES[c].replace('_', ' ')}"
        _plot_overlay_one(
            ax,
            x_epoch,
            cam_vec,
            sfreq=sfreq_eff,
            tmin=TMIN,
            title=title,
            channel_mode=channel_mode,
            cls_idx=c
        )

    axes[-1].set_xlabel("Time (s)")
    plt.suptitle(
        f"EEG + Grad-CAM++ (1 ejemplo por clase) — Fold {fold}  [{run_tag}]",
        y=1.02,
        fontsize=12
    )
    plt.tight_layout()
    out_png = fold_dir / f"eeg_cam_grid4_perclass_fold{fold}_{run_tag}.png"
    plt.savefig(out_png, dpi=140, bbox_inches="tight")
    plt.close(fig)
    print(f"↳ Grid 4 clases (EEG+CAM) guardado: {out_png}")

    # =======================
    # (B) CAM 1D para las 4 clases
    # =======================
    fig2, axes2 = plt.subplots(
        n_rows, 1,
        figsize=(6, 2.0 * n_rows),
        sharex=True
    )
    if n_rows == 1:
        axes2 = [axes2]

    for ax2, c in zip(axes2, sorted(cams_per_class.keys())):
        cam_vec = cams_per_class[c]
        ax2.plot(cam_vec)
        ax2.set_ylabel("CAM")
        ax2.set_ylim(0, 1.05)
        ax2.grid(True, alpha=0.3)
        ax2.set_title(f"Grad-CAM++ conv_t — class {CLASS_NAMES[c].replace('_', ' ')}")

    axes2[-1].set_xlabel("Time (conv feature index)")
    plt.tight_layout()
    out_png2 = fold_dir / f"gradcampp_perclass_fold{fold}_{run_tag}.png"
    plt.savefig(out_png2, dpi=140, bbox_inches="tight")
    plt.close(fig2)
    print(f"↳ Grad-CAM++ 1D por clase guardado: {out_png2}")

### Arquitectura del Modelo: CNN + Transformer para 4 Clases

Implementación del modelo híbrido CNN-Transformer para clasificación de 4 clases de EEG.

In [27]:
# =========================
# MODELO (GroupNorm en conv) + atención expuesta
# =========================
def make_gn(num_channels, num_groups=8):
    """
    Crea capa GroupNorm adaptativa que garantiza divisibilidad.
    """
    g = min(num_groups, num_channels)
    while num_channels % g != 0 and g > 1:
        g -= 1
    return nn.GroupNorm(g, num_channels)

class DepthwiseSeparableConv(nn.Module):
    """
    Convolución depthwise separable para reducción de parámetros.
    """
    def __init__(self, in_ch, out_ch, k, s=1, p=0, p_drop=0.2):
        super().__init__()
        self.dw = nn.Conv1d(in_ch, in_ch, kernel_size=k, stride=s,
                            padding=p, groups=in_ch, bias=False)
        self.pw = nn.Conv1d(in_ch, out_ch, kernel_size=1, bias=False)
        self.norm = make_gn(out_ch)
        self.act = nn.ELU()
        self.dropout = nn.Dropout(p=p_drop)

    def forward(self, x):
        x = self.dw(x)
        x = self.pw(x)
        x = self.norm(x)
        x = self.act(x)
        x = self.dropout(x)
        return x

class _CustomEncoderLayer(nn.Module):
    """
    Capa personalizada del encoder Transformer con atención multi-cabeza.
    """
    def __init__(self, d_model, nhead, dim_feedforward, dropout, batch_first, norm_first):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(
            d_model, nhead, dropout=dropout, batch_first=batch_first
        )
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.activation = nn.GELU()
        self.norm_first = norm_first

    def forward(self, src):
        sa_out, attn_weights = self.self_attn(
            src, src, src, need_weights=True, average_attn_weights=False
        )  # attn: [B, heads, T, T]
        src = self.norm1(src + self.dropout1(sa_out))
        ff = self.linear2(self.dropout(self.activation(self.linear1(src))))
        src = self.norm2(src + self.dropout2(ff))
        return src, attn_weights

class _CustomEncoder(nn.Module):
    """
    Encoder Transformer compuesto por múltiples capas.
    """
    def __init__(self, layer, num_layers):
        super().__init__()
        self.layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_layers)])

    def forward(self, src):
        attn_list = []
        out = src
        for lyr in self.layers:
            out, attn = lyr(out)
            attn_list.append(attn)
        return out, attn_list

class EEGCNNTransformer(nn.Module):
    """
    Modelo híbrido CNN-Transformer para clasificación de 4 clases de señales EEG.
    """
    def __init__(self, n_ch=8, n_cls=4, d_model=128, n_heads=4, n_layers=2,
                 p_drop=0.2, p_drop_encoder=0.3, capture_attn: bool = True):
        super().__init__()
        self.capture_attn = capture_attn
        self._last_attn = None
        self.pos_encoding = None

        # Stem convolucional inicial
        self.conv_t = nn.Sequential(
            nn.Conv1d(n_ch, 32, kernel_size=129, stride=2, padding=64, bias=False),
            make_gn(32),
            nn.ELU(),
            nn.Dropout(p=p_drop),
            DepthwiseSeparableConv(32, 64, k=31, s=2, p=15, p_drop=p_drop),
            DepthwiseSeparableConv(64, 128, k=15, s=2, p=7,  p_drop=p_drop),
        )
        self.proj = nn.Conv1d(128, d_model, kernel_size=1, bias=False)
        self.dropout = nn.Dropout(p=p_drop_encoder)

        # Encoder Transformer
        enc_layer = _CustomEncoderLayer(
            d_model=d_model, nhead=n_heads, dim_feedforward=2*d_model,
            dropout=0.1, batch_first=True, norm_first=False
        )
        self.encoder = _CustomEncoder(enc_layer, num_layers=n_layers)

        # Token CLS y head de clasificación
        self.cls = nn.Parameter(torch.zeros(1, 1, d_model))
        nn.init.normal_(self.cls, std=0.02)
        self.head = nn.Sequential(nn.LayerNorm(d_model), nn.Linear(d_model, n_cls))

    def _positional_encoding(self, L, d):
        """
        Codificación posicional sinusoidal para secuencias.
        """
        pos = torch.arange(0, L, dtype=torch.float32).unsqueeze(1)
        i   = torch.arange(0, d, dtype=torch.float32).unsqueeze(0)
        angle = pos / torch.pow(10000, (2 * (i//2)) / d)
        pe = torch.zeros(L, d, dtype=torch.float32)
        pe[:, 0::2] = torch.sin(angle[:, 0::2])
        pe[:, 1::2] = torch.cos(angle[:, 1::2])
        return pe

    def forward_features(self, x):
        """
        Devuelve el embedding CLS antes de la cabeza, y captura las atenciones.
        x: (B, C, T)
        """
        z = self.conv_t(x)           # (B, 128, T')
        z = self.proj(z)             # (B, d_model, T')
        z = self.dropout(z)
        z = z.transpose(1, 2)        # (B, T', d_model)
        B, L, D = z.shape
        if (self.pos_encoding is None) or \
           (self.pos_encoding.shape[0] != L) or \
           (self.pos_encoding.shape[1] != D):
            self.pos_encoding = self._positional_encoding(L, D).to(z.device)
        z = z + self.pos_encoding[None, :, :]
        cls_tok = self.cls.expand(B, -1, -1)
        z = torch.cat([cls_tok, z], dim=1)  # (B, 1+L, D)

        z, attn_list = self.encoder(z)
        if self.capture_attn:
            self._last_attn = attn_list

        cls = z[:, 0, :]             # (B, d_model)
        return cls

    def forward(self, x):
        """
        Forward completo del modelo.
        """
        cls = self.forward_features(x)
        return self.head(cls)

    def get_last_attention_maps(self):
        """
        Retorna mapas de atención de la última pasada forward.
        """
        return self._last_attn

### Funciones de Pérdida y Aumentación de Datos

Implementación de Focal Loss y técnicas de aumentación de datos para EEG.

In [28]:
# =========================
# FOCAL LOSS
# =========================
class FocalLoss(nn.Module):
    """
    Focal Loss para manejar desbalance de clases en 4 clases.
    Reduce la contribución de ejemplos fáciles y enfoca en los difíciles.
    """
    def __init__(self, alpha: torch.Tensor, gamma: float = 1.5, reduction: str = 'mean'):
        super().__init__()
        self.alpha = alpha / alpha.sum()
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, logits, target):
        logp = nn.functional.log_softmax(logits, dim=-1)      # (B,C)
        p = logp.exp()
        idx = torch.arange(target.shape[0], device=logits.device)
        pt = p[idx, target]
        logpt = logp[idx, target]
        at = self.alpha[target]
        loss = - at * ((1 - pt) ** self.gamma) * logpt
        if self.reduction == 'mean':
            return loss.mean()
        if self.reduction == 'sum':
            return loss.sum()
        return loss

# =========================
# AUGMENTS
# =========================
def augment_batch(
    xb,
    p_jitter=0.35, p_noise=0.35, p_chdrop=0.15,
    max_jitter_frac=0.03, noise_std=0.03, max_chdrop=1
):
    """
    Aplica aumentación de datos a un batch de señales EEG.
    Técnicas: jitter temporal, ruido gaussiano, drop de canales.
    """
    B, C, T = xb.shape
    if np.random.rand() < p_jitter:
        max_shift = int(max(1, T*max_jitter_frac))
        shifts = torch.randint(low=-max_shift, high=max_shift+1,
                               size=(B,), device=xb.device)
        for i in range(B):
            xb[i] = torch.roll(xb[i], shifts=int(shifts[i].item()), dims=-1)
    if np.random.rand() < p_noise:
        xb = xb + noise_std*torch.randn_like(xb)
    if np.random.rand() < p_chdrop and max_chdrop > 0:
        k = min(max_chdrop, C)
        for i in range(B):
            idx = torch.randperm(C, device=xb.device)[:k]
            xb[i, idx, :] = 0.0
    return xb

# Versión suave para FT
def augment_batch_ft(xb):
    """
    Aplica aumentación específica para fine-tuning.
    """
    return augment_batch(xb, **FT_AUG)

### EMA y Técnicas de Inferencia

Implementación de Exponential Moving Average y técnicas de inferencia robusta.

In [29]:
# =========================
# EMA de pesos (solo global)
# =========================
class ModelEMA:
    """
    Exponential Moving Average para estabilizar el entrenamiento.
    Mantiene una versión suavizada de los pesos del modelo.
    """
    def __init__(self, model: nn.Module, decay: float = 0.9995, device=None):
        self.ema = self._clone(model).to(
            device if device is not None else next(model.parameters()).device
        )
        self.decay = decay
        self._updates = 0
        self.update(model, force=True)

    def _clone(self, model):
        ema = type(model)(**{})  # usa los defaults del constructor
        ema.load_state_dict(model.state_dict())
        for p in ema.parameters():
            p.requires_grad_(False)
        return ema

    @torch.no_grad()
    def update(self, model: nn.Module, force: bool = False):
        d = self.decay
        if self._updates < 1000:
            d = (self._updates / 1000.0) * self.decay
        msd = model.state_dict()
        esd = self.ema.state_dict()
        for k in esd.keys():
            if esd[k].dtype.is_floating_point:
                esd[k].mul_(d).add_(msd[k].detach(), alpha=1.0 - d)
            else:
                esd[k] = msd[k]
        self._updates += 1

# =========================
# INFERENCIA TTA / SUBWINDOW
# =========================
def subwindow_logits(model, X, sfreq, sw_len, sw_stride, device):
    """
    Inferencia usando subventanas para capturar información temporal.
    """
    model.eval()
    wl = int(round(sw_len * sfreq))
    st = int(round(sw_stride * sfreq))
    wl = max(1, min(wl, X.shape[-1]))
    st = max(1, st)
    out = []
    with torch.no_grad():
        for i in range(X.shape[0]):
            x = X[i]
            acc = []
            for s in range(0, max(1, X.shape[-1]-wl+1), st):
                seg = x[:, s:s+wl]
                if seg.shape[-1] < wl:
                    pad = wl - seg.shape[-1]
                    seg = np.pad(seg, ((0,0),(0,pad)), mode='edge')
                xb = torch.tensor(seg[None, ...], dtype=torch.float32, device=device)
                logit = model(xb).detach().cpu().numpy()[0]
                acc.append(logit)
            acc = np.mean(np.stack(acc, axis=0), axis=0) if len(acc) \
                  else np.zeros(4, dtype=np.float32)
            out.append(acc)
    return np.stack(out, axis=0)

def time_shift_tta_logits(model, X, sfreq, shifts_s, device):
    """
    Test Time Augmentation usando desplazamientos temporales.
    """
    model.eval()
    T = X.shape[-1]
    out = []
    with torch.no_grad():
        for i in range(X.shape[0]):
            x0 = X[i]
            acc = []
            for sh in shifts_s:
                shift = int(round(sh * sfreq))
                if shift == 0:
                    x = x0
                elif shift > 0:
                    x = np.pad(x0[:, shift:], ((0,0),(0,shift)), mode='edge')[:, :T]
                else:
                    shift = -shift
                    x = np.pad(x0[:, :-shift], ((0,0),(shift,0)), mode='edge')[:, :T]
                xb = torch.tensor(x[None, ...], dtype=torch.float32, device=device)
                logit = model(xb).detach().cpu().numpy()[0]
                acc.append(logit)
            out.append(np.mean(np.stack(acc, axis=0), axis=0))
    return np.stack(out, axis=0)

### Herramientas de Visualización y Métricas

Funciones para calcular métricas y visualizar resultados.

In [30]:
# =========================
# Helper: matriz de confusión sin colorbar, con texto
# =========================
def plot_confusion_with_text(cm, class_names, title, out_path, cmap='Blues'):
    """
    Dibuja una matriz de confusión sin colorbar y con el conteo
    en cada celda.
    MODIFICACIÓN: Acepta Path para out_path.
    """
    fig, ax = plt.subplots(figsize=(4.8, 4.2))

    im = ax.imshow(cm, cmap=cmap)

    # Ejes y labels
    tick_labels = [c.replace('_', '\n') for c in class_names]
    ax.set_title(title)
    ax.set_xlabel("Predicted")
    ax.set_ylabel("True label")
    ax.set_xticks(range(len(class_names)))
    ax.set_yticks(range(len(class_names)))
    ax.set_xticklabels(tick_labels)
    ax.set_yticklabels(tick_labels)

    # Números dentro de cada celda
    max_val = cm.max() if cm.size else 1
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            val = cm[i, j]
            color = 'white' if val > max_val * 0.5 else 'black'
            ax.text(j, i, str(val),
                    ha='center', va='center',
                    color=color, fontsize=11, fontweight='bold')

    plt.tight_layout()
    plt.savefig(str(out_path), dpi=140)
    plt.close(fig)

# =========================
# Utilidades splits estratificados por sujeto
# =========================
def build_subject_label_map(subject_ids):
    """
    Construye mapa de etiquetas dominantes por sujeto para estratificación.
    """
    y_dom_list = []
    for sid in subject_ids:
        Xs, ys, _ = load_subject_epochs(
            sid, RESAMPLE_HZ, DO_NOTCH, DO_BANDPASS, DO_CAR, BP_LO, BP_HI
        )
        if len(ys) == 0:
            y_dom_list.append(-1)
            continue
        binc = np.bincount(ys, minlength=4)
        y_dom = int(np.argmax(binc))
        y_dom_list.append(y_dom)
    return np.array(y_dom_list, dtype=int)

# =========================
# MÉTRICAS COMPLETAS
# =========================
def compute_metrics_multiclass(y_true, y_pred, n_classes=4):
    """
    Devuelve Acc, F1 macro, F1 weighted, Prec macro, Rec macro,
    Spec macro y Sens (macro recall), además de la matriz de confusión.
    """
    y_true = np.asarray(y_true)
    y_pred = np.asarray(y_pred)

    acc = accuracy_score(y_true, y_pred)
    f1_macro = f1_score(y_true, y_pred, average='macro')
    f1_weighted = f1_score(y_true, y_pred, average='weighted')
    prec_macro = precision_score(y_true, y_pred, average='macro', zero_division=0)
    rec_macro = recall_score(y_true, y_pred, average='macro', zero_division=0)

    cm = confusion_matrix(y_true, y_pred, labels=list(range(n_classes)))
    spec_per_class = []
    for c in range(n_classes):
        TP = cm[c, c]
        FP = cm[:, c].sum() - TP
        FN = cm[c, :].sum() - TP
        TN = cm.sum() - (TP + FP + FN)
        spec = TN / (TN + FP + 1e-12)
        spec_per_class.append(spec)
    spec_macro = float(np.mean(spec_per_class))

    metrics = dict(
        acc=acc,
        f1_macro=f1_macro,
        f1_weighted=f1_weighted,
        prec_macro=prec_macro,
        rec_macro=rec_macro,
        spec_macro=spec_macro,
        sens_macro=rec_macro,  # sensibilidad = recall macro
        cm=cm,
    )
    return metrics

### Herramientas de Interpretabilidad: Grad-CAM++

Funciones para analizar y visualizar el comportamiento del modelo mediante Grad-CAM++.

In [31]:
# =========================
# Interpretabilidad: Grad-CAM++ 1D sobre conv_t
# =========================
def gradcampp_1d(model, x, target_class=None):
    """
    Grad-CAM++ 1D sobre la capa conv_t.
    x: tensor (1, C, T)
    """
    model.eval()
    feats = {}
    grads = {}

    def f_hook(m, inp, out):
        feats['f'] = out.detach()

    def b_hook(m, gin, gout):
        grads['g'] = gout[0].detach()

    handle_f = model.conv_t.register_forward_hook(f_hook)
    handle_b = model.conv_t.register_full_backward_hook(b_hook)

    logits = model(x)
    if target_class is None:
        cls = int(logits.argmax(1).item())
    else:
        cls = int(target_class)

    score = logits[:, cls].sum()
    model.zero_grad(set_to_none=True)
    score.backward(retain_graph=True)

    F = feats['f'][0]  # (C', T')
    G = grads['g'][0]  # (C', T')
    eps = 1e-8
    alpha_num = (G**2)
    alpha_den = 2*(G**2) + (F*G**3).sum(dim=-1, keepdims=True) + eps
    alpha = alpha_num / alpha_den
    weights = (alpha.clamp(min=0) * G.clamp(min=0)).sum(dim=-1)  # (C',)

    cam = (weights[:, None] * F).sum(dim=0)   # (T',)
    cam = torch.relu(cam)
    cam = (cam - cam.min()) / (cam.max() - cam.min() + eps)

    handle_f.remove()
    handle_b.remove()
    return cam.cpu().numpy()

def _upsample_cam_to_signal(cam_vec, T):
    """
    Interpola el vector CAM a la longitud temporal original.
    """
    x_src = np.linspace(0.0, 1.0, num=len(cam_vec))
    x_tgt = np.linspace(0.0, 1.0, num=T)
    return np.interp(x_tgt, x_src, cam_vec)

def _plot_overlay_one(ax, x_epoch, cam_vec, sfreq, tmin, title,
                      channel_mode="mean", cls_idx=None):
    """
    Overlay de EEG + Grad-CAM++ 1D.
    cls_idx: 0/1/2/3 para colorear según clase.
    """
    C, T = x_epoch.shape
    cam_T = _upsample_cam_to_signal(cam_vec, T)
    cam_T = (cam_T - cam_T.min()) / (cam_T.max() - cam_T.min() + 1e-8)

    if channel_mode == "mean":
        sig = x_epoch.mean(axis=0)
        sig_label = "mean(8ch)"
    elif isinstance(channel_mode, int):
        sig = x_epoch[channel_mode]; sig_label = f"ch#{channel_mode}"
    else:
        if channel_mode in EXPECTED_8:
            sig = x_epoch[EXPECTED_8.index(channel_mode)]
            sig_label = channel_mode
        else:
            sig = x_epoch.mean(axis=0); sig_label = "mean(8ch)"

    t = tmin + np.arange(T) / float(sfreq)
    thr = np.quantile(cam_T, 0.80)
    mask = cam_T >= thr

    # Colores por clase (4 clases)
    color_map = {
        0: "green",    # left
        1: "blue",     # right
        2: "orange",   # both_fists
        3: "purple",   # both_feet
    }
    sig_color = color_map.get(cls_idx, "black")

    ax.plot(t, sig, label=f"EEG ({sig_label})", color=sig_color)

    ax.fill_between(
        t,
        sig.min(),
        sig.max(),
        where=mask,
        alpha=0.25,
        color="red",
        label="Alta importancia (Grad-CAM++)"
    )

    ax.set_title(title)
    ax.set_ylabel("EEG (z)")
    ax.grid(True, alpha=0.25)

    ax2 = ax.twinx()
    ax2.plot(t, cam_T, linestyle="--", alpha=0.6, label="CAM (norm.)")
    ax2.set_ylim(0, 1.05); ax2.set_ylabel("CAM")
    h1,l1 = ax.get_legend_handles_labels(); h2,l2 = ax2.get_legend_handles_labels()
    ax.legend(h1+h2, l1+l2, loc="upper left")

def save_cam_grid_for_5_subjects(X_te_std, y_te, sub_te, model_to_probe,
                                 sfreq_eff, fold, run_tag,
                                 n_subjects=5, channel_mode="mean",
                                 output_dir: Path = OUTPUT_DIR):
    """
    Genera un grid de Grad-CAM++ en TEST usando SOLO ejemplos correctamente clasificados.
    MODIFICACIÓN: Guarda en directorio de 4 clases.
    """
    # Crear directorio del fold
    fold_dir = output_dir / f"fold{fold}"
    fold_dir.mkdir(parents=True, exist_ok=True)
    
    model_to_probe.eval()
    rng = np.random.RandomState(RANDOM_STATE + 1000 + fold)

    # 1) Predicciones en TODO el test
    with torch.no_grad():
        logits_chunks = []
        for i in range(0, len(X_te_std), 64):
            xb = torch.tensor(X_te_std[i:i+64], dtype=torch.float32, device=DEVICE)
            logits_chunks.append(model_to_probe(xb).cpu().numpy())
        logits = np.concatenate(logits_chunks, axis=0)
    preds = logits.argmax(axis=1)
    correct_mask = (preds == y_te)

    if correct_mask.sum() == 0:
        print("[WARN] No hay ejemplos correctamente clasificados en TEST; no se genera grid.")
        return

    X_corr   = X_te_std[correct_mask]
    y_corr   = y_te[correct_mask]
    sub_corr = sub_te[correct_mask]

    uniq_subjects = np.unique(sub_corr)
    if len(uniq_subjects) == 0:
        print("[WARN] No hay sujetos con aciertos; no se genera grid.")
        return

    sel_subjects = rng.choice(uniq_subjects,
                              size=min(n_subjects, len(uniq_subjects)),
                              replace=False)
    print(f"[INFO] Grid con {len(sel_subjects)} sujetos (solo aciertos).")

    fig, axes = plt.subplots(len(sel_subjects), 1,
                             figsize=(10, 2.8*len(sel_subjects)), sharex=True)
    if len(sel_subjects) == 1:
        axes = [axes]

    for ax, s in zip(axes, sel_subjects):
        idxs = np.where(sub_corr == s)[0]
        if len(idxs) == 0:
            ax.set_visible(False)
            continue

        i_local = int(rng.choice(idxs, size=1)[0])

        x_epoch = X_corr[i_local]  # (C,T)
        true_cls = int(y_corr[i_local])

        xb = torch.tensor(x_epoch[None, ...], dtype=torch.float32, device=DEVICE)
        with torch.no_grad():
            pred_cls = int(model_to_probe(xb).argmax(1).item())

        if pred_cls != true_cls:
            continue

        cam_vec = gradcampp_1d(model_to_probe, xb, target_class=pred_cls)
        title = f"S{int(s):03d}  |  correct ({CLASS_NAMES[true_cls]})"
        _plot_overlay_one(
            ax,
            x_epoch,
            cam_vec,
            sfreq=sfreq_eff,
            tmin=TMIN,
            title=title,
            channel_mode=channel_mode,
            cls_idx=true_cls
        )

    axes[-1].set_xlabel("Time (s)")
    plt.suptitle(f"EEG + Grad-CAM++ (correct) — Fold {fold}  [{run_tag}]",
                 y=1.02, fontsize=12)
    plt.tight_layout()

    out_png = fold_dir / f"eeg_cam_grid5_correct_fold{fold}_{run_tag}.png"
    plt.savefig(out_png, dpi=140, bbox_inches="tight")
    plt.close(fig)
    print(f"↳ Grid (solo CORRECTOS) guardado: {out_png}")

### Función Principal de Entrenamiento para 4 Clases

Función que entrena un fold completo del modelo para clasificación de 4 clases.

In [32]:
# =========================
# TRAIN/EVAL GLOBAL POR FOLD
# =========================
def train_one_fold(fold: int, device):
    """
    Entrena y evalúa el modelo para un fold específico (4 clases).
    
    Args:
        fold: Número del fold (1-5)
        device: Dispositivo para entrenamiento (CPU/GPU)
    
    Returns:
        Tupla con métricas y resultados del fold
    """
    def load_fold_subjects_local(folds_json: Path, fold: int):
        return load_fold_subjects(folds_json, fold)

    # Crear directorio de salida para este fold
    fold_dir = OUTPUT_DIR / f"fold{fold}"
    fold_dir.mkdir(parents=True, exist_ok=True)
    
    train_sub, test_sub = load_fold_subjects_local(FOLDS_JSON, fold)
    train_sub = [s for s in train_sub if subject_id_to_int(s) not in EXCLUDE_SUBJECTS]
    test_sub  = [s for s in test_sub  if subject_id_to_int(s) not in EXCLUDE_SUBJECTS]

    rng = np.random.RandomState(RANDOM_STATE + fold)
    tr_subjects = sorted(train_sub)

    if VAL_STRAT_SUBJECT and len(tr_subjects) > 1:
        y_dom = build_subject_label_map(tr_subjects)
        if np.any(y_dom < 0):
            mask = y_dom >= 0
            moda = int(np.bincount(y_dom[mask]).argmax()) if mask.sum() > 0 else 0
            y_dom[~mask] = moda
        n_val_subj = max(1, int(round(len(tr_subjects) * VAL_SUBJECT_FRAC)))
        sss = StratifiedShuffleSplit(
            n_splits=1, test_size=n_val_subj,
            random_state=RANDOM_STATE + fold
        )
        idx = np.arange(len(tr_subjects))
        _, val_idx = next(sss.split(idx, y_dom))
        val_subjects = sorted([tr_subjects[i] for i in val_idx])
        train_subjects = [s for s in tr_subjects if s not in val_subjects]
    else:
        tr_subjects_shuf = tr_subjects.copy()
        rng.shuffle(tr_subjects_shuf)
        n_val_subj = max(1, int(round(len(tr_subjects_shuf) * VAL_SUBJECT_FRAC)))
        val_subjects = sorted(tr_subjects_shuf[:n_val_subj])
        train_subjects = sorted(tr_subjects_shuf[n_val_subj:])

    # Carga TRAIN/VAL/TEST
    X_tr_list, y_tr_list, sub_tr_list = [], [], []
    X_val_list, y_val_list, sub_val_list = [], [], []
    X_te_list, y_te_list, sub_te_list = [], [], []
    sfreq = None

    for sid in tqdm(train_subjects, desc=f"Cargando train fold{fold}"):
        Xs, ys, sf = load_subject_epochs(
            sid, RESAMPLE_HZ, DO_NOTCH, DO_BANDPASS, DO_CAR, BP_LO, BP_HI
        )
        if len(ys) == 0:
            continue
        X_tr_list.append(Xs)
        y_tr_list.append(ys)
        sub_tr_list.append(np.full_like(ys, fill_value=subject_id_to_int(sid)))
        sfreq = sf if sfreq is None else sfreq

    for sid in tqdm(val_subjects, desc=f"Cargando val fold{fold}"):
        Xs, ys, sf = load_subject_epochs(
            sid, RESAMPLE_HZ, DO_NOTCH, DO_BANDPASS, DO_CAR, BP_LO, BP_HI
        )
        if len(ys) == 0:
            continue
        X_val_list.append(Xs)
        y_val_list.append(ys)
        sub_val_list.append(np.full_like(ys, fill_value=subject_id_to_int(sid)))
        sfreq = sf if sfreq is None else sfreq

    for sid in tqdm(test_sub, desc=f"Cargando test fold{fold}"):
        Xs, ys, sf = load_subject_epochs(
            sid, RESAMPLE_HZ, DO_NOTCH, DO_BANDPASS, DO_CAR, BP_LO, BP_HI
        )
        if len(ys) == 0:
            continue
        X_te_list.append(Xs)
        y_te_list.append(ys)
        sub_te_list.append(np.full_like(ys, fill_value=subject_id_to_int(sid)))
        sfreq = sf if sfreq is None else sfreq

    # Concatenar
    X_tr = np.concatenate(X_tr_list, axis=0)
    y_tr = np.concatenate(y_tr_list, axis=0)
    sub_tr = np.concatenate(sub_tr_list, axis=0)
    X_val = np.concatenate(X_val_list, axis=0)
    y_val = np.concatenate(y_val_list, axis=0)
    sub_val = np.concatenate(sub_val_list, axis=0)
    X_te = np.concatenate(X_te_list, axis=0)
    y_te = np.concatenate(y_te_list, axis=0)
    sub_te = np.concatenate(sub_te_list, axis=0)

    print(f"[Fold {fold}/5] Entrenando modelo global... "
          f"(n_train={len(y_tr)} | n_val={len(y_val)} | n_test={len(y_te)})")

    # Normalización por canal (fit en TRAIN y aplicar a VAL/TEST)
    if ZSCORE_PER_EPOCH:
        X_tr_std, X_val_std, X_te_std = X_tr, X_val, X_te
    else:
        X_tr_std, X_val_std = standardize_per_channel(X_tr, X_val)
        _,        X_te_std  = standardize_per_channel(X_tr, X_te)

    # Datasets
    tr_ds  = TensorDataset(torch.tensor(X_tr_std),
                           torch.tensor(y_tr).long(),
                           torch.tensor(sub_tr).long())
    val_ds = TensorDataset(torch.tensor(X_val_std),
                           torch.tensor(y_val).long(),
                           torch.tensor(sub_val).long())
    te_ds  = TensorDataset(torch.tensor(X_te_std),
                           torch.tensor(y_te).long(),
                           torch.tensor(sub_te).long())

    # Weighted sampler templado
    def make_weighted_sampler(dataset: TensorDataset):
        _Xb, yb, sb = dataset.tensors
        yb_np = yb.numpy()
        sb_np = sb.numpy()
        uniq_s, cnt_s = np.unique(sb_np, return_counts=True)
        map_s = {s: c for s, c in zip(uniq_s, cnt_s)}
        key = sb_np.astype(np.int64) * 10 + yb_np.astype(np.int64)
        uniq_k, cnt_k = np.unique(key, return_counts=True)
        map_k = {k: c for k, c in zip(uniq_k, cnt_k)}
        a, b = 0.8, 1.0
        w = []
        for s, y in zip(sb_np, yb_np):
            k = int(s)*10 + int(y)
            ws = float(map_s[int(s)])
            wk = float(map_k[k])
            w.append((ws ** (-a)) * (wk ** (-b)))
        w = np.array(w, dtype=np.float64)
        w = w / (w.mean() + 1e-12)
        sampler = WeightedRandomSampler(
            weights=torch.tensor(w, dtype=torch.double),
            num_samples=len(yb_np), replacement=True
        )
        return sampler

    if USE_WEIGHTED_SAMPLER:
        tr_sampler = make_weighted_sampler(tr_ds)
        tr_ld  = DataLoader(tr_ds, batch_size=BATCH_SIZE,
                            sampler=tr_sampler, drop_last=False,
                            worker_init_fn=seed_worker)
    else:
        tr_ld  = DataLoader(tr_ds, batch_size=BATCH_SIZE,
                            shuffle=True,  drop_last=False,
                            worker_init_fn=seed_worker)

    val_ld = DataLoader(val_ds, batch_size=BATCH_SIZE,
                        shuffle=False, drop_last=False,
                        worker_init_fn=seed_worker)
    te_ld  = DataLoader(te_ds,  batch_size=BATCH_SIZE,
                        shuffle=False, drop_last=False,
                        worker_init_fn=seed_worker)

    # Modelo
    model = EEGCNNTransformer(
        n_ch=8, n_cls=4, d_model=D_MODEL, n_heads=N_HEADS,
        n_layers=N_LAYERS, p_drop=P_DROP, p_drop_encoder=P_DROP_ENCODER,
        capture_attn=True
    ).to(device)

    # Optimizador + Focal Loss
    opt = torch.optim.AdamW(model.parameters(), lr=BASE_LR, weight_decay=1e-2)
    class_counts = np.bincount(y_tr, minlength=4).astype(np.float32)
    inv = class_counts.sum() / (4.0 * np.maximum(class_counts, 1.0))
    alpha = torch.tensor(inv, dtype=torch.float32, device=device)
    alpha_mean = alpha.mean().item()
    alpha[2] = 1.25 * alpha_mean  # both_fists
    alpha[3] = 1.05 * alpha_mean  # both_feet
    crit = FocalLoss(alpha=alpha, gamma=1.5, reduction='mean')

    # LR scheduler Warmup+Cosine
    from torch.optim.lr_scheduler import LambdaLR
    total_epochs = EPOCHS
    warmup_epochs = max(1, int(WARMUP_EPOCHS))
    min_factor = 0.1

    def lr_lambda(current_epoch):
        if current_epoch < warmup_epochs:
            return (current_epoch + 1) / warmup_epochs
        progress = (current_epoch - warmup_epochs) / max(1, (total_epochs - warmup_epochs))
        progress = min(1.0, max(0.0, progress))
        return min_factor + 0.5 * (1.0 - min_factor) * (1.0 + np.cos(np.pi * progress))

    scheduler = LambdaLR(opt, lr_lambda=lr_lambda)

    # EMA (solo global)
    if USE_EMA:
        ema = ModelEMA(model, decay=EMA_DECAY, device=device)
    else:
        ema = None

    # Entrenamiento global con early stopping por F1 macro
    best_f1, best_state, wait = 0.0, None, 0
    hist = {"ep": [], "tr_loss": [], "val_loss": [],
            "tr_acc": [], "val_acc": [], "val_f1m": [], "lr": []}

    if torch.cuda.is_available():
        torch.cuda.reset_peak_memory_stats(device)
    t_train0 = time.perf_counter()

    def evaluate_on(loader, use_ema=True):
        mdl = ema.ema if (ema is not None and use_ema) else model
        mdl.eval()
        preds, gts = [], []
        total_loss, n_total = 0.0, 0
        with torch.no_grad():
            for xb, yb, _sb in loader:
                xb = xb.to(device)
                yb = yb.to(device)
                logits = mdl(xb)
                loss = crit(logits, yb)
                total_loss += loss.item() * len(yb)
                n_total += len(yb)
                p = logits.argmax(dim=1).cpu().numpy()
                preds.append(p)
                gts.append(yb.cpu().numpy())
        preds = np.concatenate(preds)
        gts = np.concatenate(gts)
        acc = accuracy_score(gts, preds)
        f1m = f1_score(gts, preds, average='macro')
        val_loss = total_loss / max(1, n_total)
        return val_loss, acc, f1m

    for ep in range(1, EPOCHS+1):
        model.train()
        tr_loss, n_seen, tr_correct = 0.0, 0, 0
        for xb, yb, _sb in tr_ld:
            xb = xb.to(device)
            yb = yb.to(device)
            xb = augment_batch(xb)
            opt.zero_grad()
            logits = model(xb)
            loss = crit(logits, yb)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            opt.step()
            if ema is not None:
                ema.update(model)

            tr_loss += loss.item() * len(yb)
            n_seen += len(yb)
            tr_correct += (logits.argmax(1) == yb).sum().item()

        tr_loss /= max(1, n_seen)
        tr_acc = tr_correct / max(1, n_seen)

        val_loss, acc_val, f1m = evaluate_on(val_ld, use_ema=True)

        hist["ep"].append(ep)
        hist["tr_loss"].append(tr_loss)
        hist["val_loss"].append(val_loss)
        hist["tr_acc"].append(tr_acc)
        hist["val_acc"].append(acc_val)
        hist["val_f1m"].append(f1m)
        hist["lr"].append(scheduler.get_last_lr()[0])

        print(f"  Época {ep:3d} | train_loss={tr_loss:.4f} | "
              f"train_acc={tr_acc:.4f} | val_loss={val_loss:.4f} | "
              f"val_acc={acc_val:.4f} | val_f1m={f1m:.4f} | "
              f"LR={scheduler.get_last_lr()[0]:.6f}")

        improved = f1m > best_f1 + 1e-4
        if improved:
            best_f1 = f1m
            ref_model = ema.ema if ema is not None else model
            best_state = {k: v.detach().cpu() for k, v in ref_model.state_dict().items()}
            wait = 0
        else:
            wait += 1

        scheduler.step()
        if wait >= PATIENCE:
            print(f"  Early stopping en época {ep} (mejor val_f1m={best_f1:.4f})")
            break

    train_time = time.perf_counter() - t_train0
    if torch.cuda.is_available():
        train_mem = torch.cuda.max_memory_allocated(device) / (1024**2)
    else:
        train_mem = float('nan')

    # Cargamos los mejores pesos tanto en ema.ema (si existe) como en model,
    # y reactivamos gradients en model para usarlo en Grad-CAM++
    if best_state is not None:
        if ema is not None:
            ema.ema.load_state_dict(best_state)
        model.load_state_dict(best_state)

    for p in model.parameters():
        p.requires_grad_(True)

    # ---- Guardar curva (pérdidas) ----
    fig = plt.figure(figsize=(8, 4.5))
    ax1 = plt.gca()
    ax1.plot(hist["ep"], hist["tr_loss"], label="train_loss")
    ax1.plot(hist["ep"], hist["val_loss"], label="val_loss")
    ax1.set_xlabel("Epoch")
    ax1.set_ylabel("Loss")
    ax1.set_title(f"Fold {fold} — Training Curve (4 classes)")
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    out_png = fold_dir / f"training_curve_fold{fold}_4c.png"
    plt.tight_layout()
    plt.savefig(out_png, dpi=140)
    plt.close(fig)
    print(f"↳ Curva de entrenamiento guardada: {out_png}")

    # ---- Evaluación final en TEST (global) ----
    eval_model = ema.ema if (ema is not None) else model
    eval_model.eval()

    sfreq_used = RESAMPLE_HZ
    if sfreq_used is None:
        sfreq_used = int(round(X_te_std.shape[-1] / (TMAX - TMIN)))

    if torch.cuda.is_available():
        torch.cuda.reset_peak_memory_stats(device)
    t_test0 = time.perf_counter()

    if (not SW_ENABLE) or SW_MODE == 'none':
        preds, gts = [], []
        with torch.no_grad():
            for xb, yb, _sb in te_ld:
                xb = xb.to(device)
                p = eval_model(xb).argmax(dim=1).cpu().numpy()
                preds.append(p)
                gts.append(yb.numpy())
        preds = np.concatenate(preds)
        gts = np.concatenate(gts)
    elif SW_MODE in ('subwin', 'tta'):
        logits_tta = None
        logits_sw  = None
        if SW_MODE == 'subwin':
            logits_sw = subwindow_logits(eval_model, X_te_std, sfreq_used,
                                         SW_LEN, SW_STRIDE, device)
        elif SW_MODE == 'tta':
            logits_tta = time_shift_tta_logits(eval_model, X_te_std,
                                               sfreq_used, TTA_SHIFTS_S, device)

        if COMBINE_TTA_AND_SUBWIN:
            if logits_tta is None:
                logits_tta = time_shift_tta_logits(eval_model, X_te_std,
                                                   sfreq_used, TTA_SHIFTS_S, device)
            if logits_sw is None:
                logits_sw  = subwindow_logits(eval_model, X_te_std, sfreq_used,
                                              SW_LEN, SW_STRIDE, device)
            logits = 0.5 * logits_tta + 0.5 * logits_sw
        else:
            logits = logits_tta if logits_tta is not None else logits_sw

        preds = logits.argmax(axis=1)
        gts = y_te
    else:
        raise ValueError(f"SW_MODE desconocido: {SW_MODE}")

    test_time = time.perf_counter() - t_test0
    if torch.cuda.is_available():
        test_mem = torch.cuda.max_memory_allocated(device) / (1024**2)
    else:
        test_mem = float('nan')

    metrics_global = compute_metrics_multiclass(gts, preds, n_classes=4)
    acc = metrics_global["acc"]
    f1m = metrics_global["f1_macro"]

    print(f"[Fold {fold}/5] Global acc={acc:.4f} | f1_macro={f1m:.4f}\n")
    print(classification_report(
        gts, preds,
        target_names=[c.replace('_', ' ') for c in CLASS_NAMES],
        digits=4
    ))
    print("Confusion matrix (rows=true, cols=pred):")
    cm = metrics_global["cm"]
    print(cm)
    print(
        f"precision_macro={metrics_global['prec_macro']:.4f} | "
        f"recall_macro={metrics_global['rec_macro']:.4f} | "
        f"specificity_macro={metrics_global['spec_macro']:.4f} | "
        f"sensitivity_macro={metrics_global['sens_macro']:.4f} | "
        f"f1_weighted={metrics_global['f1_weighted']:.4f}"
    )

    # matriz de confusión global (imagen)
    cm_path = fold_dir / f"confusion_global_fold{fold}_4c.png"
    plot_confusion_with_text(
        cm=cm,
        class_names=CLASS_NAMES,
        title=f"Confusion — Fold {fold} Global (4 classes)",
        out_path=cm_path,
        cmap='Blues'
    )
    print(f"↳ Matriz de confusión global guardada: {cm_path}")

    # ===== Topomap de saliency (channel importance) =====
    saliency_fold = None
    try:
        saliency_fold = save_channel_saliency_topomap(
            model=eval_model,
            X_te_std=X_te_std,
            sfreq_used=sfreq_used,
            fold=fold,
            run_tag=MODEL_TAG,
            output_dir=OUTPUT_DIR
        )
    except Exception as e:
        print(f"[WARN] Topomap saliency no generado: {e}")

    # ===== Topomap de p-values 4 clases (TASK EFFECT) =====
    try:
        pvals_4c = compute_channel_pvalues_4c(X_te_std, y_te, n_classes=4)
        pvals_log = -np.log10(pvals_4c + 1e-12)  # comparable a los de 2c

        info_topo = make_mne_info_8ch(sfreq_used)
        topo_p_png = fold_dir / f"topomap_pvalues_4c_fold{fold}_{MODEL_TAG}.png"
        plot_topomap_from_values(
            pvals_log,
            info_topo,
            title=f"-log10(p) task effect (4 classes) — Fold {fold} [{MODEL_TAG}]",
            out_path=topo_p_png,
            cmap="viridis",
            cbar_label="-log10(p)"
        )
        print(f"↳ Topomap p-values 4c guardado: {topo_p_png}")
    except Exception as e:
        print(f"[WARN] Topomap p-values 4c no generado: {e}")

    # ---- t-SNE 2D de embeddings CLS (GLOBAL) ----
    print("↳ Calculando t-SNE 2D de CLS embeddings (global)...")
    eval_model.eval()
    with torch.no_grad():
        X_te_tensor = torch.tensor(X_te_std, dtype=torch.float32, device=device)
        cls_emb = eval_model.forward_features(X_te_tensor).cpu().numpy()

    tsne = TSNE(
        n_components=2,
        random_state=RANDOM_STATE + fold,
        perplexity=30,
        init='random',
        learning_rate='auto'
    )
    z = tsne.fit_transform(cls_emb)

    fig = plt.figure(figsize=(5.5, 5.0))
    for cls_idx, cls_name in enumerate(CLASS_NAMES):
        mask = (gts == cls_idx)
        plt.scatter(
            z[mask, 0], z[mask, 1],
            s=10, alpha=0.7, label=cls_name.replace('_', ' ')
        )
    plt.xlabel("t-SNE dimension 1")
    plt.ylabel("t-SNE dimension 2")
    plt.title(f"t-SNE of CLS embeddings — Fold {fold} Global (4 classes)")
    plt.legend(markerscale=2, fontsize=8)
    plt.grid(True, alpha=0.3)
    tsne_path = fold_dir / f"tsne_global_fold{fold}_4c.png"
    plt.tight_layout()
    plt.savefig(tsne_path, dpi=140)
    plt.close(fig)
    print(f"↳ t-SNE global guardado: {tsne_path}")

    # =========================
    # Interpretabilidad (attention + Grad-CAM++)
    # =========================
    try:
        # Usamos eval_model (EMA) para atención y t-SNE,
        # y model (con grad) para Grad-CAM++
        base_model = eval_model   # para attention maps
        cam_model  = model        # para Grad-CAM++

        # Attention maps (primera capa, media en heads)
        base_model.eval()
        with torch.no_grad():
            xb_val, yb_val, _ = next(iter(val_ld))
        xb_val = xb_val[:8].to(device)
        _ = base_model(xb_val)
        attn_list = base_model.get_last_attention_maps()
        if attn_list and len(attn_list) > 0:
            # attn_list[0]: [B, heads, T, T]
            A = attn_list[-1].mean(dim=1)[0].detach().cpu().numpy() # (T, T) del primer ejemplo
            plt.figure(figsize=(5, 4))
            plt.imshow(A, aspect='auto')
            plt.title(f"Attention (last layer, mean heads) — Fold {fold}")
            plt.colorbar()
            plt.tight_layout()
            attn_png = fold_dir / f"attention_map_fold{fold}_4c.png"
            plt.savefig(attn_png, dpi=140)
            plt.close()
            print(f"↳ Mapa de atención guardado: {attn_png}")

        # =========================
        # Grad-CAM++: 4 clases
        # =========================
        sfreq_eff = sfreq_used
        save_cam_grid_4classes(
            X_te_std=X_te_std,
            y_te=y_te,
            model_to_probe=cam_model,
            sfreq_eff=sfreq_eff,
            fold=fold,
            run_tag=MODEL_TAG,
            channel_mode="mean",
            output_dir=OUTPUT_DIR
        )
    except Exception as e:
        print(f"[WARN] Interpretabilidad no generada: {e}")

    # =========================
    # FINE-TUNING PROGRESIVO POR SUJETO (4-fold CV interno)
    # =========================
    subjects = np.unique(sub_te)
    subj_to_idx = {s: np.where(sub_te == s)[0] for s in subjects}

    def make_ft_optimizer(model):
        head_params = list(model.head.parameters())
        backbone_params = [p for n, p in model.named_parameters()
                           if not n.startswith('head.')]
        return torch.optim.AdamW([
            {'params': backbone_params, 'lr': FT_LR_BACKBONE},
            {'params': head_params,     'lr': FT_LR_HEAD}
        ], weight_decay=FT_WD)

    def freeze_backbone(model, freeze=True):
        for n, p in model.named_parameters():
            if not n.startswith('head.'):
                p.requires_grad_(not freeze)

    def ft_make_criterion_from_counts(counts):
        inv = counts.sum() / (4.0 * np.maximum(counts.astype(np.float32), 1.0))
        a = torch.tensor(inv, dtype=torch.float32, device=device)
        am = a.mean().item()
        a[2] = FT_ALPHA_BOOST_FISTS * am
        a[3] = FT_ALPHA_BOOST_FEET  * am
        return FocalLoss(alpha=a, gamma=1.5, reduction='mean')

    def evaluate_tensor(model_t, X, y):
        model_t.eval()
        with torch.no_grad():
            xb = torch.tensor(X, dtype=torch.float32, device=device)
            p = model_t(xb).argmax(1).cpu().numpy()
        return accuracy_score(y, p), f1_score(y, p, average='macro'), p

    all_true, all_pred = [], []
    base_state = (ema.ema if (ema is not None) else model).state_dict()

    for s in subjects:
        idx = subj_to_idx[s]
        Xs, ys = X_te_std[idx], y_te[idx]

        if len(np.unique(ys)) < 2 or len(ys) < 20:
            eval_model_local = EEGCNNTransformer(
                n_ch=8, n_cls=4, d_model=D_MODEL, n_heads=N_HEADS,
                n_layers=N_LAYERS, p_drop=P_DROP, p_drop_encoder=P_DROP_ENCODER,
                capture_attn=False
            ).to(device)
            eval_model_local.load_state_dict(base_state)
            acc_s, f1_s, pred_s = evaluate_tensor(eval_model_local, Xs, ys)
            all_true.append(ys)
            all_pred.append(pred_s)
            continue

        skf = StratifiedKFold(n_splits=FT_N_FOLDS, shuffle=True,
                              random_state=RANDOM_STATE + fold + int(s))
        for tr_idx, te_idx in skf.split(Xs, ys):
            Xtr, ytr = Xs[tr_idx].copy(), ys[tr_idx].copy()
            Xte_s, yte_s = Xs[te_idx].copy(), ys[te_idx].copy()

            Xtr_std, Xte_std = standardize_per_channel(Xtr, Xte_s)

            ft_model = EEGCNNTransformer(
                n_ch=8, n_cls=4, d_model=D_MODEL, n_heads=N_HEADS,
                n_layers=N_LAYERS, p_drop=P_DROP, p_drop_encoder=P_DROP_ENCODER,
                capture_attn=False
            ).to(device)
            ft_model.load_state_dict(base_state)

            opt_ft = make_ft_optimizer(ft_model)
            alpha_counts = np.bincount(ytr, minlength=4)
            ft_crit = ft_make_criterion_from_counts(alpha_counts)

            ds_tr = TensorDataset(torch.tensor(Xtr_std),
                                  torch.tensor(ytr).long())
            ds_te = TensorDataset(torch.tensor(Xte_std),
                                  torch.tensor(yte_s).long())
            ld_tr = DataLoader(ds_tr, batch_size=FT_BATCH,
                               shuffle=True,  drop_last=False,
                               worker_init_fn=seed_worker)
            ld_te = DataLoader(ds_te, batch_size=FT_BATCH,
                               shuffle=False, drop_last=False,
                               worker_init_fn=seed_worker)

            # ETAPA 1: freeze backbone
            freeze_backbone(ft_model, True)
            best_f1_s, best_state_s, wait_s = -1, None, 0
            for ep in range(1, FT_FREEZE_EPOCHS+1):
                ft_model.train()
                for xb, yb in ld_tr:
                    xb = xb.to(device)
                    yb = yb.to(device)
                    xb = augment_batch_ft(xb)
                    opt_ft.zero_grad()
                    logits = ft_model(xb)
                    loss = ft_crit(logits, yb)
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(ft_model.parameters(), 1.0)
                    opt_ft.step()

                acc_s, f1_s = evaluate_tensor(ft_model, Xte_std, yte_s)[:2]
                improved = f1_s > best_f1_s + 1e-4
                if improved:
                    best_f1_s = f1_s
                    best_state_s = {k: v.detach().cpu()
                                    for k, v in ft_model.state_dict().items()}
                    wait_s = 0
                else:
                    wait_s += 1
                if wait_s >= FT_PATIENCE:
                    break

            if best_state_s is not None:
                ft_model.load_state_dict(best_state_s)

            # ETAPA 2: unfreeze con early stopping
            freeze_backbone(ft_model, False)
            best_f1_s2, best_state_s2, wait_s2 = -1, None, 0
            for ep in range(1, FT_UNFREEZE_EPOCHS+1):
                ft_model.train()
                for xb, yb in ld_tr:
                    xb = xb.to(device)
                    yb = yb.to(device)
                    xb = augment_batch_ft(xb)
                    opt_ft.zero_grad()
                    logits = ft_model(xb)
                    loss = ft_crit(logits, yb)
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(ft_model.parameters(), 1.0)
                    opt_ft.step()

                acc_s, f1_s = evaluate_tensor(ft_model, Xte_std, yte_s)[:2]
                improved = f1_s > best_f1_s2 + 1e-4
                if improved:
                    best_f1_s2 = f1_s
                    best_state_s2 = {k: v.detach().cpu()
                                     for k, v in ft_model.state_dict().items()}
                    wait_s2 = 0
                else:
                    wait_s2 += 1
                if wait_s2 >= FT_PATIENCE:
                    break

            if best_state_s2 is not None:
                ft_model.load_state_dict(best_state_s2)

            acc_s, f1_s, pred_s = evaluate_tensor(ft_model, Xte_std, yte_s)
            all_true.append(yte_s)
            all_pred.append(pred_s)

    all_true = np.concatenate(all_true)
    all_pred = np.concatenate(all_pred)
    metrics_ft = compute_metrics_multiclass(all_true, all_pred, n_classes=4)
    ft_acc = metrics_ft["acc"]
    ft_f1m = metrics_ft["f1_macro"]

    print(f"  Fine-tuning PROGRESIVO (por sujeto, {FT_N_FOLDS}-fold CV) "
          f"acc={ft_acc:.4f} | f1_macro={ft_f1m:.4f}")
    print(f"  Δ(FT-Global) = {ft_acc - acc:+.4f}")
    cm_ft = metrics_ft["cm"]
    print("Confusion matrix FT (rows=true, cols=pred):")
    print(cm_ft)
    print(
        f"precision_macro={metrics_ft['prec_macro']:.4f} | "
        f"recall_macro={metrics_ft['rec_macro']:.4f} | "
        f"specificity_macro={metrics_ft['spec_macro']:.4f} | "
        f"sensitivity_macro={metrics_ft['sens_macro']:.4f} | "
        f"f1_weighted={metrics_ft['f1_weighted']:.4f}"
    )

    cm_ft_path = fold_dir / f"confusion_ft_fold{fold}_4c.png"
    plot_confusion_with_text(
        cm=cm_ft,
        class_names=CLASS_NAMES,
        title=f"Confusion — Fold {fold} FT (4 classes)",
        out_path=cm_ft_path,
        cmap='Greens'
    )
    print(f"↳ Matriz de confusión FT guardada: {cm_ft_path}")

    # =========================
    # Métricas de consumo (params + latencia)
    # =========================
    params_total = sum(p.numel() for p in model.parameters())
    params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    params_m = params_total / 1e6

    # Latencia por batch (B=8) usando el modelo global (EMA si existe)
    latency_batch_ms = float('nan')
    example_batch = None
    if len(X_te_std) >= 8:
        example_batch = X_te_std[:8]
    elif len(X_te_std) > 0:
        example_batch = X_te_std   # batch más pequeño si no hay 8 trials

    if example_batch is not None:
        latency_batch_ms = estimate_latency_ms(eval_model, example_batch, device)

    print(f"[Consumo] params_total={params_total:,} | "
          f"params_trainable={params_trainable:,} | "
          f"FLOPs~{MODEL_FLOPS_G*1e3:.1f}M | "
          f"latency/batch={latency_batch_ms:.2f} ms (B=8)")

    return (
        metrics_global,
        metrics_ft,
        train_time,
        train_mem,
        test_time,
        test_mem,
        params_m,
        latency_batch_ms,
        pvals_4c,    
        sfreq_used,  
        saliency_fold,
    )

### Loop Principal y Resumen Final

Función principal para ejecutar los 5 folds y generar resumen final.

In [33]:
# =========================
# LOOP 5 FOLDS + RESUMEN
# =========================
if __name__ == "__main__":
    # Crear directorio principal de 4 clases
    OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
    
    global_metrics_folds = []
    ft_metrics_folds = []
    train_times, train_mems = [], []
    test_times, test_mems = [], []
    params_m_list = []
    latencies_batch_ms = []
    pvals_folds = []       
    sfreq_folds = []    
    saliency_folds = []

    for fold in range(1, 6):
        (mg, mft, t_tr, mem_tr, t_te, mem_te,
         params_m, lat_ms, pvals_4c, sfreq_used, saliency_fold) = train_one_fold(fold, DEVICE)

        global_metrics_folds.append(mg)
        ft_metrics_folds.append(mft)
        train_times.append(t_tr)
        train_mems.append(mem_tr)
        test_times.append(t_te)
        test_mems.append(mem_te)
        params_m_list.append(params_m)
        latencies_batch_ms.append(lat_ms)
        pvals_folds.append(pvals_4c)    
        sfreq_folds.append(sfreq_used)   
        saliency_folds.append(saliency_fold)
        continue

    def mean_of(metric_list, key):
        return float(np.mean([m[key] for m in metric_list]))

    print("\n============================================================")
    print("RESULTADOS FINALES (4 clases)")
    print("============================================================")

    print("Global folds (ACC):", [f"{m['acc']:.4f}" for m in global_metrics_folds])
    print("Global mean ACC:   ", f"{mean_of(global_metrics_folds, 'acc'):.4f}")
    print("Global F1 folds (MACRO):", [f"{m['f1_macro']:.4f}" for m in global_metrics_folds])
    print("Global F1 mean (MACRO): ", f"{mean_of(global_metrics_folds, 'f1_macro'):.4f}")
    print("Global F1 folds (WEIGHTED):", [f"{m['f1_weighted']:.4f}" for m in global_metrics_folds])
    print("Global F1 mean (WEIGHTED): ", f"{mean_of(global_metrics_folds, 'f1_weighted'):.4f}")
    print("Global PREC mean (MACRO):  ", f"{mean_of(global_metrics_folds, 'prec_macro'):.4f}")
    print("Global REC/SENS mean (MACRO):", f"{mean_of(global_metrics_folds, 'rec_macro'):.4f}")
    print("Global SPEC mean (MACRO):  ", f"{mean_of(global_metrics_folds, 'spec_macro'):.4f}")

    print("\nFT folds (ACC):", [f"{m['acc']:.4f}" for m in ft_metrics_folds])
    print("FT mean ACC:    ", f"{mean_of(ft_metrics_folds, 'acc'):.4f}")
    print("FT F1 folds (MACRO):", [f"{m['f1_macro']:.4f}" for m in ft_metrics_folds])
    print("FT F1 mean (MACRO): ", f"{mean_of(ft_metrics_folds, 'f1_macro'):.4f}")
    print("FT F1 mean (WEIGHTED): ", f"{mean_of(ft_metrics_folds, 'f1_weighted'):.4f}")
    print("FT PREC mean (MACRO):  ", f"{mean_of(ft_metrics_folds, 'prec_macro'):.4f}")
    print("FT REC/SENS mean (MACRO):", f"{mean_of(ft_metrics_folds, 'rec_macro'):.4f}")
    print("FT SPEC mean (MACRO):  ", f"{mean_of(ft_metrics_folds, 'spec_macro'):.4f}")

    delta_mean = mean_of(ft_metrics_folds, 'acc') - mean_of(global_metrics_folds, 'acc')
    print(f"\nΔ(FT-Global) mean (ACC): {delta_mean:+.4f}")

    # ==================== Tabla de métricas computacionales ====================
    train_time_mean = float(np.mean(train_times))
    train_time_std  = float(np.std(train_times))
    test_time_mean  = float(np.mean(test_times))
    test_time_std   = float(np.std(test_times))

    train_mem_mean = float(np.mean(train_mems))
    train_mem_std  = float(np.std(train_mems))
    test_mem_mean  = float(np.mean(test_mems))
    test_mem_std   = float(np.std(test_mems))

    params_m_mean = float(np.mean(params_m_list))

    lat_mean = float(np.nanmean(latencies_batch_ms))
    lat_std  = float(np.nanstd(latencies_batch_ms))

    print("\n================ Computational metrics ================")
    print(f"Model tag: {MODEL_TAG}")
    print("Phase        time (s)           memory (MB)        FLOPs (g)   params (m)")
    print("-----------  -----------------  -----------------  ----------  ----------")
    print(f"model specs  N/A                N/A                    {MODEL_FLOPS_G:0.3f}      {params_m_mean:0.3f}")
    print(f"train        {train_time_mean:6.2f} ({train_time_std:5.2f})    "
          f"{train_mem_mean:6.1f} ({train_mem_std:4.1f})         N/A        N/A")
    print(f"test         {test_time_mean:6.2f} ({test_time_std:5.2f})    "
          f"{test_mem_mean:6.1f} ({test_mem_std:4.1f})         N/A        N/A")
    print(f"latency/batch (B=8): {lat_mean:.2f} ms ({lat_std:.2f})")

    # ==================== Topomap PROMEDIO 4 clases ====================
    try:
        # apilamos p-values [fold, canal] y promediamos
        pvals_stack = np.stack(pvals_folds, axis=0)      # (5, C)
        pvals_mean = pvals_stack.mean(axis=0)            # (C,)

        pvals_mean_log = -np.log10(pvals_mean + 1e-12)

        # usamos la media de sfreq (debería ser el mismo en todos)
        sfreq_mean = float(np.mean(sfreq_folds))
        info_topo_mean = make_mne_info_8ch(sfreq_mean)

        topo_mean_png = OUTPUT_DIR / f"topomap_pvalues_4c_mean_{MODEL_TAG}.png"
        plot_topomap_from_values(
            pvals_mean_log,
            info_topo_mean,
            title=f"-log10(p) task effect (4 classes) — 5-fold mean [{MODEL_TAG}]",
            out_path=topo_mean_png,
            cmap="viridis",
            cbar_label="-log10(p)"
        )
        print(f"↳ Topomap p-values 4c PROMEDIO guardado: {topo_mean_png}")
    except Exception as e:
        print(f"[WARN] Topomap p-values 4c promedio no generado: {e}")

    # ==================== Topomap SALIENCY PROMEDIO ====================
    try:
        sal_stack = np.stack(saliency_folds, axis=0)   # (5, C)
        sal_mean  = sal_stack.mean(axis=0)             # (C,)

        sal_mean = sal_mean - sal_mean.min()
        maxv = sal_mean.max()
        if maxv > 0:
            sal_mean = sal_mean / maxv

        sfreq_mean = float(np.mean(sfreq_folds))
        info_topo_mean = make_mne_info_8ch(sfreq_mean)

        topo_sal_mean_png = OUTPUT_DIR / f"topomap_saliency_mean_{MODEL_TAG}.png"
        plot_topomap_from_values(
            sal_mean,
            info_topo_mean,
            title=f"Channel importance (saliency) — 5-fold mean [{MODEL_TAG}]",
            out_path=topo_sal_mean_png,
            cmap="Reds",
            cbar_label="Normalized saliency",
        )
        print(f"↳ Topomap saliency PROMEDIO guardado: {topo_sal_mean_png}")
    except Exception as e:
        print(f"[WARN] Topomap saliency promedio no generado: {e}")

Cargando train fold1: 100%|██████████| 67/67 [00:07<00:00,  8.38it/s]
Cargando val fold1: 100%|██████████| 15/15 [00:01<00:00,  8.31it/s]
Cargando test fold1: 100%|██████████| 21/21 [00:02<00:00,  8.40it/s]


[Fold 1/5] Entrenando modelo global... (n_train=5628 | n_val=1260 | n_test=1764)
  Época   1 | train_loss=0.2329 | train_acc=0.2608 | val_loss=0.2246 | val_acc=0.2754 | val_f1m=0.2302 | LR=0.000125
  Época   2 | train_loss=0.2114 | train_acc=0.3360 | val_loss=0.2184 | val_acc=0.3794 | val_f1m=0.3413 | LR=0.000250
  Época   3 | train_loss=0.1843 | train_acc=0.4723 | val_loss=0.2042 | val_acc=0.4357 | val_f1m=0.4333 | LR=0.000375
  Época   4 | train_loss=0.1765 | train_acc=0.4913 | val_loss=0.1951 | val_acc=0.4619 | val_f1m=0.4583 | LR=0.000500
  Época   5 | train_loss=0.1711 | train_acc=0.5217 | val_loss=0.2070 | val_acc=0.4659 | val_f1m=0.4611 | LR=0.000625
  Época   6 | train_loss=0.1676 | train_acc=0.5323 | val_loss=0.2075 | val_acc=0.4460 | val_f1m=0.4437 | LR=0.000750
  Época   7 | train_loss=0.1631 | train_acc=0.5430 | val_loss=0.1967 | val_acc=0.4619 | val_f1m=0.4530 | LR=0.000875
  Época   8 | train_loss=0.1649 | train_acc=0.5300 | val_loss=0.1947 | val_acc=0.4635 | val_f1m=0.46

Cargando train fold2: 100%|██████████| 67/67 [00:08<00:00,  8.37it/s]
Cargando val fold2: 100%|██████████| 15/15 [00:01<00:00,  8.37it/s]
Cargando test fold2: 100%|██████████| 21/21 [00:02<00:00,  8.44it/s]


[Fold 2/5] Entrenando modelo global... (n_train=5628 | n_val=1260 | n_test=1764)
  Época   1 | train_loss=0.2279 | train_acc=0.2656 | val_loss=0.2207 | val_acc=0.2857 | val_f1m=0.1819 | LR=0.000125
  Época   2 | train_loss=0.2113 | train_acc=0.3451 | val_loss=0.1998 | val_acc=0.4302 | val_f1m=0.4000 | LR=0.000250
  Época   3 | train_loss=0.1895 | train_acc=0.4517 | val_loss=0.1896 | val_acc=0.4468 | val_f1m=0.4368 | LR=0.000375
  Época   4 | train_loss=0.1805 | train_acc=0.4732 | val_loss=0.1890 | val_acc=0.4659 | val_f1m=0.4670 | LR=0.000500
  Época   5 | train_loss=0.1812 | train_acc=0.4794 | val_loss=0.1880 | val_acc=0.4675 | val_f1m=0.4608 | LR=0.000625
  Época   6 | train_loss=0.1738 | train_acc=0.5039 | val_loss=0.1961 | val_acc=0.4571 | val_f1m=0.4350 | LR=0.000750
  Época   7 | train_loss=0.1758 | train_acc=0.4941 | val_loss=0.1879 | val_acc=0.4762 | val_f1m=0.4731 | LR=0.000875
  Época   8 | train_loss=0.1692 | train_acc=0.5105 | val_loss=0.1832 | val_acc=0.4952 | val_f1m=0.49

Cargando train fold3: 100%|██████████| 67/67 [00:08<00:00,  8.26it/s]
Cargando val fold3: 100%|██████████| 15/15 [00:01<00:00,  8.37it/s]
Cargando test fold3: 100%|██████████| 21/21 [00:02<00:00,  8.33it/s]


[Fold 3/5] Entrenando modelo global... (n_train=5628 | n_val=1260 | n_test=1764)
  Época   1 | train_loss=0.2287 | train_acc=0.2660 | val_loss=0.2213 | val_acc=0.2770 | val_f1m=0.2047 | LR=0.000125
  Época   2 | train_loss=0.2158 | train_acc=0.3186 | val_loss=0.1954 | val_acc=0.3897 | val_f1m=0.3571 | LR=0.000250
  Época   3 | train_loss=0.1877 | train_acc=0.4550 | val_loss=0.1960 | val_acc=0.4421 | val_f1m=0.4402 | LR=0.000375
  Época   4 | train_loss=0.1791 | train_acc=0.4940 | val_loss=0.1838 | val_acc=0.4746 | val_f1m=0.4767 | LR=0.000500
  Época   5 | train_loss=0.1769 | train_acc=0.4998 | val_loss=0.1830 | val_acc=0.4817 | val_f1m=0.4795 | LR=0.000625
  Época   6 | train_loss=0.1730 | train_acc=0.5130 | val_loss=0.1769 | val_acc=0.4913 | val_f1m=0.4935 | LR=0.000750
  Época   7 | train_loss=0.1683 | train_acc=0.5128 | val_loss=0.1804 | val_acc=0.5024 | val_f1m=0.5033 | LR=0.000875
  Época   8 | train_loss=0.1674 | train_acc=0.5272 | val_loss=0.1790 | val_acc=0.4992 | val_f1m=0.49

Cargando train fold4: 100%|██████████| 68/68 [00:08<00:00,  8.32it/s]
Cargando val fold4: 100%|██████████| 15/15 [00:01<00:00,  8.43it/s]
Cargando test fold4: 100%|██████████| 20/20 [00:02<00:00,  8.41it/s]


[Fold 4/5] Entrenando modelo global... (n_train=5712 | n_val=1260 | n_test=1680)
  Época   1 | train_loss=0.2275 | train_acc=0.2759 | val_loss=0.2215 | val_acc=0.2778 | val_f1m=0.1869 | LR=0.000125
  Época   2 | train_loss=0.2116 | train_acc=0.3568 | val_loss=0.2051 | val_acc=0.4016 | val_f1m=0.3481 | LR=0.000250
  Época   3 | train_loss=0.1907 | train_acc=0.4463 | val_loss=0.1958 | val_acc=0.4452 | val_f1m=0.4428 | LR=0.000375
  Época   4 | train_loss=0.1829 | train_acc=0.4674 | val_loss=0.1880 | val_acc=0.4865 | val_f1m=0.4875 | LR=0.000500
  Época   5 | train_loss=0.1753 | train_acc=0.5007 | val_loss=0.1928 | val_acc=0.4817 | val_f1m=0.4825 | LR=0.000625
  Época   6 | train_loss=0.1773 | train_acc=0.4939 | val_loss=0.1888 | val_acc=0.4833 | val_f1m=0.4758 | LR=0.000750
  Época   7 | train_loss=0.1763 | train_acc=0.4930 | val_loss=0.1855 | val_acc=0.4976 | val_f1m=0.4941 | LR=0.000875
  Época   8 | train_loss=0.1741 | train_acc=0.4995 | val_loss=0.1770 | val_acc=0.5032 | val_f1m=0.50

Cargando train fold5: 100%|██████████| 68/68 [00:08<00:00,  8.34it/s]
Cargando val fold5: 100%|██████████| 15/15 [00:01<00:00,  8.45it/s]
Cargando test fold5: 100%|██████████| 20/20 [00:02<00:00,  8.42it/s]


[Fold 5/5] Entrenando modelo global... (n_train=5712 | n_val=1260 | n_test=1680)
  Época   1 | train_loss=0.2276 | train_acc=0.2670 | val_loss=0.2297 | val_acc=0.2651 | val_f1m=0.1791 | LR=0.000125
  Época   2 | train_loss=0.2098 | train_acc=0.3568 | val_loss=0.2146 | val_acc=0.3841 | val_f1m=0.3675 | LR=0.000250
  Época   3 | train_loss=0.1882 | train_acc=0.4625 | val_loss=0.2151 | val_acc=0.4159 | val_f1m=0.4107 | LR=0.000375
  Época   4 | train_loss=0.1842 | train_acc=0.4692 | val_loss=0.1982 | val_acc=0.4540 | val_f1m=0.4561 | LR=0.000500
  Época   5 | train_loss=0.1762 | train_acc=0.5004 | val_loss=0.2008 | val_acc=0.4341 | val_f1m=0.4344 | LR=0.000625
  Época   6 | train_loss=0.1755 | train_acc=0.5035 | val_loss=0.1991 | val_acc=0.4437 | val_f1m=0.4371 | LR=0.000750
  Época   7 | train_loss=0.1709 | train_acc=0.5060 | val_loss=0.1973 | val_acc=0.4556 | val_f1m=0.4515 | LR=0.000875
  Época   8 | train_loss=0.1702 | train_acc=0.5254 | val_loss=0.2072 | val_acc=0.4413 | val_f1m=0.44