# EEG Motor Imagery (Left/Right) — CNN + Transformer
Este notebook entrena y evalúa un modelo CNN+Transformer para clasificación binaria (left/right)
en EEG usando K-Fold por sujeto, con opciones de:
- Preprocesamiento
- Balanceo (WeightedRandomSampler)
- EMA de pesos
- Inferencia con TTA y/o subventanas
- Interpretabilidad
- Barrido de arquitectura y evaluación final 5-fold del mejor set de hiperparámetros


### Configuración Inicial y Reproducibilidad

In [47]:
# -*- coding: utf-8 -*-
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "7"
os.environ['PYTHONHASHSEED'] = '42'
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'

import re, json, random, copy, time, csv
from pathlib import Path
from glob import glob
import numpy as np
import mne
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset, WeightedRandomSampler
from sklearn.metrics import (
    accuracy_score, f1_score, confusion_matrix, classification_report,
    precision_score, recall_score
)
from sklearn.model_selection import StratifiedShuffleSplit, StratifiedKFold
from scipy import stats
from sklearn.manifold import TSNE  # Añadido para t-SNE

### Configuración de Reproducibilidad

Esta sección establece las semillas y configuraciones necesarias para garantizar resultados reproducibles en el entrenamiento del modelo.

In [48]:
# =========================
# REPRODUCIBILIDAD
# =========================
RANDOM_STATE = 42

def seed_everything(seed: int = 42):
    """
    Establece semillas para todas las bibliotecas de aleatoriedad
    para garantizar reproducibilidad completa.
    """
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    try:
        torch.backends.cuda.matmul.allow_tf32 = False
        torch.backends.cudnn.allow_tf32 = False
    except Exception:
        pass
    try:
        torch.use_deterministic_algorithms(True)
    except Exception:
        pass

def seed_worker(worker_id: int):
    """
    Función para inicializar workers de DataLoader con semilla.
    """
    worker_seed = RANDOM_STATE + worker_id
    np.random.seed(worker_seed)
    random.seed(worker_seed)

seed_everything(RANDOM_STATE)

### Configuración de Hiperparámetros

Definición de todos los hiperparámetros del modelo y preprocesamiento de datos EEG.

In [49]:
# =========================
# CONFIG
# =========================
PROJ = Path('..').resolve().parent
DATA_RAW = PROJ / 'data' / 'raw'
FOLDS_JSON = PROJ / 'models' / '00_folds' / 'Kfold5.json'

# Hiperparámetros de entrenamiento
EPOCHS = 60
BATCH_SIZE = 64
BASE_LR = 5e-4
WARMUP_EPOCHS = 4
PATIENCE = 8

# Split de validación por fold (por sujetos)
VAL_SUBJECT_FRAC = 0.18
VAL_STRAT_SUBJECT = True

# Prepro global
RESAMPLE_HZ = None
DO_NOTCH = True
DO_BANDPASS = False
BP_LO, BP_HI = 4.0, 38.0
DO_CAR = False
ZSCORE_PER_EPOCH = False

# Parámetros del modelo
D_MODEL = 144
N_HEADS = 4
N_LAYERS = 1
P_DROP = 0.1
P_DROP_ENCODER = 0.1

# Ventana temporal
TMIN, TMAX = -1.0, 5.0

# TTA / SUBWINDOW en TEST
SW_MODE = 'tta'   # 'none'|'subwin'|'tta'
SW_ENABLE = True
TTA_SHIFTS_S = [-0.075, -0.05, -0.025, 0.0, 0.025, 0.05, 0.075]
SW_LEN, SW_STRIDE = 4.5, 1.5
COMBINE_TTA_AND_SUBWIN = False

# Sampler balanceado
USE_WEIGHTED_SAMPLER = True

# EMA (solo GLOBAL). En FT se desactiva.
USE_EMA = True
EMA_DECAY = 0.9995

# Fine-tuning (por sujeto) — sigue activo, pero ahora con n_cls=2
FT_N_FOLDS = 4
FT_FREEZE_EPOCHS = 8
FT_UNFREEZE_EPOCHS = 8
FT_PATIENCE = 6
FT_BATCH = 64
FT_LR_HEAD = 1e-3
FT_LR_BACKBONE = 2e-4
FT_WD = 1e-3
FT_AUG = dict(p_jitter=0.25, p_noise=0.25, p_chdrop=0.10, max_jitter_frac=0.02, noise_std=0.02, max_chdrop=1)

# Configuración de sujetos y canales
EXCLUDE_SUBJECTS = {38, 88, 89, 92, 100, 104}
EXPECTED_8 = ['C3','C4','Cz','CP3','CP4','FC3','FC4','FCz']
CLASS_NAMES = ['left', 'right']  # <-- 2 clases

# Solo runs de imaginación L/R
IMAGERY_RUNS_LR = {4, 8, 12}

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f" Usando dispositivo: {DEVICE}")
print(" INICIANDO EXPERIMENTO CON CNN+Transformer (K-Fold por sujeto)")
print(f" Config: 2c (L/R), 8 canales, 6s | EPOCHS={EPOCHS}, BATCH={BATCH_SIZE}, LR={BASE_LR} | ZSCORE_PER_EPOCH={ZSCORE_PER_EPOCH}")

 Usando dispositivo: cuda
 INICIANDO EXPERIMENTO CON CNN+Transformer (K-Fold por sujeto)
 Config: 2c (L/R), 8 canales, 6s | EPOCHS=60, BATCH=64, LR=0.0005 | ZSCORE_PER_EPOCH=False


### Funciones de Utilidad para I/O

Funciones para manejar la carga y normalización de datos EEG.

In [50]:
# =========================
# UTILIDADES I/O
# =========================
def normalize_ch_name(name: str) -> str:
    """
    Normaliza nombres de canales eliminando caracteres especiales.
    """
    s = re.sub(r'[^A-Za-z0-9]', '', name)
    return s.upper()

NORMALIZED_TARGETS = [normalize_ch_name(c) for c in EXPECTED_8]

def pick_8_channels(raw: mne.io.BaseRaw) -> mne.io.BaseRaw:
    """
    Selecciona los 8 canales motores específicos del dataset.
    """
    chs = raw.info['ch_names']
    norm_map = {normalize_ch_name(ch): ch for ch in chs}
    picked = []
    for target_norm, target_orig in zip(NORMALIZED_TARGETS, EXPECTED_8):
        if target_norm in norm_map:
            picked.append(norm_map[target_norm])
        else:
            raise RuntimeError(f"Canal requerido '{target_orig}' no encontrado. Disponibles: {chs}")
    return raw.pick(picks=picked)

def list_subject_imagery_edfs(subject_id: str) -> list:
    """
    Lista archivos EDF de imaginación motora para un sujeto.
    """
    subj_dir = DATA_RAW / subject_id
    edfs = []
    for r in [4, 6, 8, 10, 12, 14]:
        edfs.extend(glob(str(subj_dir / f"{subject_id}R{r:02d}.edf")))
    return sorted(edfs)

def subject_id_to_int(s: str) -> int:
    """
    Convierte ID de sujeto de string a entero.
    """
    m = re.match(r'[Ss](\d+)', s)
    return int(m.group(1)) if m else -1

def load_subject_epochs(subject_id: str, resample_hz: int, do_notch: bool, do_bandpass: bool,
                        do_car: bool, bp_lo: float, bp_hi: float):
    """
    Carga y preprocesa épocas EEG para un sujeto específico.
    """
    edfs = list_subject_imagery_edfs(subject_id)
    if len(edfs) == 0:
        return np.empty((0,8,1), dtype=np.float32), np.empty((0,), dtype=int), None

    X_list, y_list, sfreq_list = [], [], []
    for edf_path in edfs:
        m = re.search(r"R(\d{2})", Path(edf_path).name)
        run = int(m.group(1)) if m else -1

        # Usar SOLO runs L/R
        if run not in IMAGERY_RUNS_LR:
            continue

        raw = mne.io.read_raw_edf(edf_path, preload=True, verbose='ERROR')
        raw = pick_8_channels(raw)

        if do_notch:
            raw.notch_filter(freqs=[60.0], picks='all', verbose='ERROR')
        if do_bandpass:
            raw.filter(l_freq=bp_lo, h_freq=bp_hi, picks='all', verbose='ERROR')
        if do_car:
            raw.set_eeg_reference('average', projection=False, verbose='ERROR')
        if resample_hz is not None and resample_hz > 0:
            raw.resample(resample_hz)

        sfreq = raw.info['sfreq']
        events, event_id = mne.events_from_annotations(raw, verbose='ERROR')

        keep = {k: v for k, v in event_id.items() if k in {'T1', 'T2'}}
        if len(keep) == 0:
            continue

        epochs = mne.Epochs(raw, events=events, event_id=keep, tmin=TMIN, tmax=TMAX,
                            baseline=None, preload=True, verbose='ERROR')
        X = epochs.get_data()

        if ZSCORE_PER_EPOCH:
            X = X.astype(np.float32)
            eps = 1e-6
            mu = X.mean(axis=2, keepdims=True)
            sd = X.std(axis=2, keepdims=True) + eps
            X = (X - mu) / sd

        ev_codes = epochs.events[:, 2]
        inv = {v: k for k, v in keep.items()}
        y_run = np.array([0 if inv[c] == 'T1' else 1 for c in ev_codes], dtype=int)

        X_list.append(X)
        y_list.append(y_run)
        sfreq_list.append(sfreq)

    if len(X_list) == 0:
        return np.empty((0,8,1), dtype=np.float32), np.empty((0,), dtype=int), None

    X_all = np.concatenate(X_list, axis=0).astype(np.float32)
    y_all = np.concatenate(y_list, axis=0).astype(int)

    if len(set([int(round(s)) for s in sfreq_list])) != 1:
        raise RuntimeError(f"Sampling rates inconsistentes: {sfreq_list}")

    return X_all, y_all, sfreq_list[0]

def load_fold_subjects(folds_json: Path, fold: int):
    """
    Carga sujetos de train y test para un fold específico.
    """
    with open(folds_json, 'r') as f:
        data = json.load(f)
    for item in data.get('folds', []):
        if int(item.get('fold', -1)) == int(fold):
            return list(item.get('train', [])), list(item.get('test', []))
    raise ValueError(f"Fold {fold} not found in {folds_json}")

def standardize_per_channel(train_X, other_X):
    """
    Estandariza datos por canal usando estadísticas del conjunto de entrenamiento.
    """
    C = train_X.shape[1]
    train_X = train_X.astype(np.float32)
    other_X = other_X.astype(np.float32)
    for c in range(C):
        mu = train_X[:, c, :].mean()
        sd = train_X[:, c, :].std()
        sd = sd if sd > 1e-6 else 1.0
        train_X[:, c, :] = (train_X[:, c, :] - mu) / sd
        other_X[:, c, :] = (other_X[:, c, :] - mu) / sd
    return train_X, other_X

### Arquitectura del Modelo: CNN + Transformer

Implementación del modelo híbrido CNN-Transformer para clasificación de EEG.

In [51]:
# =========================
# MODELO (GroupNorm en conv) con capturas y variaciones
# =========================
def make_gn(num_channels, num_groups=8):
    """
    Crea capa GroupNorm adaptativa que garantiza divisibilidad.
    """
    g = min(num_groups, num_channels)
    while num_channels % g != 0 and g > 1:
        g -= 1
    return nn.GroupNorm(g, num_channels)

class DepthwiseSeparableConv(nn.Module):
    """
    Convolución depthwise separable para reducción de parámetros.
    """
    def __init__(self, in_ch, out_ch, k, s=1, p=0, p_drop=0.2):
        super().__init__()
        self.dw = nn.Conv1d(in_ch, in_ch, kernel_size=k, stride=s, padding=p, groups=in_ch, bias=False)
        self.pw = nn.Conv1d(in_ch, out_ch, kernel_size=1, bias=False)
        self.norm = make_gn(out_ch)
        self.act = nn.ELU()
        self.dropout = nn.Dropout(p=p_drop)

    def forward(self, x):
        x = self.dw(x); x = self.pw(x); x = self.norm(x)
        x = self.act(x); x = self.dropout(x)
        return x

class _CustomEncoderLayer(nn.Module):
    """
    Capa personalizada del encoder Transformer con atención multi-cabeza.
    """
    def __init__(self, d_model, nhead, dim_feedforward, dropout, batch_first, norm_first, return_attn=True):
        super().__init__()
        self.return_attn = return_attn
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.activation = nn.GELU()
        self.norm_first = norm_first

    def forward(self, src):
        sa_out, attn_weights = self.self_attn(src, src, src, need_weights=True, average_attn_weights=False)
        src = self.norm1(src + self.dropout1(sa_out))
        ff = self.linear2(self.dropout(self.activation(self.linear1(src))))
        src = self.norm2(src + self.dropout2(ff))
        return src, attn_weights  # [B, heads, T, T]

class _CustomEncoder(nn.Module):
    """
    Encoder Transformer compuesto por múltiples capas.
    """
    def __init__(self, layer, num_layers):
        super().__init__()
        self.layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_layers)])

    def forward(self, src):
        attn_list = []
        out = src
        for lyr in self.layers:
            out, attn = lyr(out)
            attn_list.append(attn)
        return out, attn_list

class EEGCNNTransformer(nn.Module):
    """
    Modelo híbrido CNN-Transformer para clasificación de señales EEG.
    """
    def __init__(self, n_ch=8, n_cls=2, d_model=128, n_heads=4, n_layers=2,
                 p_drop=0.2, p_drop_encoder=0.3, n_dw_blocks=2,
                 k_list=(31,15,7), s_list=(2,2,2), p_list=(15,7,3), capture_attn=False):
        super().__init__()
        self.capture_attn = capture_attn
        self._last_attn = None
        self.pos_encoding = None

        # Stem convolucional inicial
        stem = [
            nn.Conv1d(n_ch, 32, kernel_size=129, stride=2, padding=64, bias=False),
            make_gn(32),
            nn.ELU(),
            nn.Dropout(p=p_drop),
        ]
        
        # Bloques depthwise separable
        blocks = []
        in_c, out_cs = 32, [64, 128, 256, 256, 256]
        n_dw_blocks = int(np.clip(n_dw_blocks, 0, len(out_cs)))
        for i in range(n_dw_blocks):
            k = k_list[i] if i < len(k_list) else 7
            s = s_list[i] if i < len(s_list) else 2
            p = p_list[i] if i < len(p_list) else k//2
            blocks.append(DepthwiseSeparableConv(in_c, out_cs[i], k=k, s=s, p=p, p_drop=p_drop))
            in_c = out_cs[i]

        self.conv_t = nn.Sequential(*stem, *blocks)
        self.proj = nn.Conv1d(in_c if n_dw_blocks>0 else 32, d_model, kernel_size=1, bias=False)
        self.dropout = nn.Dropout(p=p_drop_encoder)

        # Encoder Transformer
        enc = _CustomEncoderLayer(d_model=d_model, nhead=n_heads, dim_feedforward=2*d_model,
                                  dropout=0.1, batch_first=True, norm_first=False, return_attn=True)
        self.encoder = _CustomEncoder(enc, num_layers=n_layers)

        # Token CLS y head de clasificación
        self.cls = nn.Parameter(torch.zeros(1, 1, d_model))
        nn.init.normal_(self.cls, std=0.02)
        self.head = nn.Sequential(nn.LayerNorm(d_model), nn.Linear(d_model, n_cls))

    def _positional_encoding(self, L, d):
        """
        Codificación posicional sinusoidal para secuencias.
        """
        pos = torch.arange(0, L, dtype=torch.float32).unsqueeze(1)
        i   = torch.arange(0, d, dtype=torch.float32).unsqueeze(0)
        angle = pos / torch.pow(10000, (2 * (i//2)) / d)
        pe = torch.zeros(L, d, dtype=torch.float32)
        pe[:, 0::2] = torch.sin(angle[:, 0::2])
        pe[:, 1::2] = torch.cos(angle[:, 1::2])
        return pe

    def forward(self, x):
        # Capas convolucionales
        z = self.conv_t(x)           # (B, C', T')
        z = self.proj(z)             # (B, d_model, T')
        z = self.dropout(z)
        z = z.transpose(1, 2)        # (B, T', d_model)
        
        # Codificación posicional
        B, L, D = z.shape
        if (self.pos_encoding is None) or (self.pos_encoding.shape[0] != L) or (self.pos_encoding.shape[1] != D):
            self.pos_encoding = self._positional_encoding(L, D).to(z.device)
        z = z + self.pos_encoding[None, :, :]
        
        # Añadir token CLS
        cls_tok = self.cls.expand(B, -1, -1)
        z = torch.cat([cls_tok, z], dim=1)  # (B, 1+L, D)

        # Encoder Transformer
        z, attn_list = self.encoder(z)
        if self.capture_attn:
            self._last_attn = attn_list

        # Clasificación usando token CLS
        cls = z[:, 0, :]
        return self.head(cls)

    def get_last_attention_maps(self):
        """
        Retorna mapas de atención de la última pasada forward.
        """
        return self._last_attn

### Funciones de Pérdida y Aumentación de Datos

Implementación de Focal Loss y técnicas de aumentación de datos para EEG.

In [52]:
# =========================
# FOCAL LOSS (2 clases)
# =========================
class FocalLoss(nn.Module):
    """
    Focal Loss para manejar desbalance de clases.
    Reduce la contribución de ejemplos fáciles y enfoca en los difíciles.
    """
    def __init__(self, alpha: torch.Tensor, gamma: float = 1.5, reduction: str = 'mean'):
        super().__init__()
        self.alpha = alpha / alpha.sum()
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, logits, target):
        logp = nn.functional.log_softmax(logits, dim=-1)      # (B,C)
        p = logp.exp()
        idx = torch.arange(target.shape[0], device=logits.device)
        pt = p[idx, target]
        logpt = logp[idx, target]
        at = self.alpha[target]
        loss = - at * ((1 - pt) ** self.gamma) * logpt
        if self.reduction == 'mean': return loss.mean()
        if self.reduction == 'sum':  return loss.sum()
        return loss

# =========================
# AUGMENTS
# =========================
def augment_batch(xb, p_jitter=0.35, p_noise=0.35, p_chdrop=0.15,
                  max_jitter_frac=0.03, noise_std=0.03, max_chdrop=1):
    """
    Aplica aumentación de datos a un batch de señales EEG.
    Técnicas: jitter temporal, ruido gaussiano, drop de canales.
    """
    B, C, T = xb.shape
    if np.random.rand() < p_jitter:
        max_shift = int(max(1, T*max_jitter_frac))
        shifts = torch.randint(low=-max_shift, high=max_shift+1, size=(B,), device=xb.device)
        for i in range(B):
            xb[i] = torch.roll(xb[i], shifts=int(shifts[i].item()), dims=-1)
    if np.random.rand() < p_noise:
        xb = xb + noise_std*torch.randn_like(xb)
    if np.random.rand() < p_chdrop and max_chdrop > 0:
        k = min(max_chdrop, C)
        for i in range(B):
            idx = torch.randperm(C, device=xb.device)[:k]
            xb[i, idx, :] = 0.0
    return xb

def augment_batch_ft(xb):
    """
    Aplica aumentación específica para fine-tuning.
    """
    return augment_batch(xb, **FT_AUG)

### EMA y Técnicas de Inferencia

Implementación de Exponential Moving Average y técnicas de inferencia robusta.

In [53]:
# =========================
# EMA de pesos (solo global)
# =========================
class ModelEMA:
    """
    Exponential Moving Average para estabilizar el entrenamiento.
    Mantiene una versión suavizada de los pesos del modelo.
    """
    def __init__(self, model: nn.Module, decay: float = 0.9995, device=None):
        self.ema = self._clone(model).to(device if device is not None else next(model.parameters()).device)
        self.decay = decay
        self._updates = 0
        self.update(model, force=True)

    def _clone(self, model):
        import copy
        ema = copy.deepcopy(model)
        for p in ema.parameters():
            p.requires_grad_(False)
        return ema

    @torch.no_grad()
    def update(self, model: nn.Module, force: bool = False):
        d = self.decay
        if self._updates < 1000:
            d = (self._updates / 1000.0) * self.decay
        msd = model.state_dict()
        esd = self.ema.state_dict()
        for k in esd.keys():
            if esd[k].dtype.is_floating_point:
                esd[k].mul_(d).add_(msd[k].detach(), alpha=1.0 - d)
            else:
                esd[k] = msd[k]
        self._updates += 1

# =========================
# INFERENCIA TTA / SUBWINDOW
# =========================
def subwindow_logits(model, X, sfreq, sw_len, sw_stride, device):
    """
    Inferencia usando subventanas para capturar información temporal.
    """
    model.eval()
    wl = int(round(sw_len * sfreq))
    st = int(round(sw_stride * sfreq))
    wl = max(1, min(wl, X.shape[-1])); st = max(1, st)
    out = []
    with torch.no_grad():
        for i in range(X.shape[0]):
            x = X[i]; acc = []
            for s in range(0, max(1, X.shape[-1]-wl+1), st):
                seg = x[:, s:s+wl]
                if seg.shape[-1] < wl:
                    pad = wl - seg.shape[-1]
                    seg = np.pad(seg, ((0,0),(0,pad)), mode='edge')
                xb = torch.tensor(seg[None, ...], dtype=torch.float32, device=device)
                logit = model(xb).detach().cpu().numpy()[0]
                acc.append(logit)
            acc = np.mean(np.stack(acc, axis=0), axis=0) if len(acc) else np.zeros(2, dtype=np.float32)
            out.append(acc)
    return np.stack(out, axis=0)

def time_shift_tta_logits(model, X, sfreq, shifts_s, device):
    """
    Test Time Augmentation usando desplazamientos temporales.
    """
    model.eval()
    T = X.shape[-1]; out = []
    with torch.no_grad():
        for i in range(X.shape[0]):
            x0 = X[i]; acc = []
            for sh in shifts_s:
                shift = int(round(sh * sfreq))
                if shift == 0:
                    x = x0
                elif shift > 0:
                    x = np.pad(x0[:, shift:], ((0,0),(0,shift)), mode='edge')[:, :T]
                else:
                    shift = -shift
                    x = np.pad(x0[:, :-shift], ((0,0),(shift,0)), mode='edge')[:, :T]
                xb = torch.tensor(x[None, ...], dtype=torch.float32, device=device)
                logit = model(xb).detach().cpu().numpy()[0]
                acc.append(logit)
            out.append(np.mean(np.stack(acc, axis=0), axis=0))
    return np.stack(out, axis=0)

### Métricas y Herramientas de Visualización

Funciones para calcular métricas y visualizar resultados.

In [54]:
# =========================
# Métricas extendidas
# =========================
def compute_metrics(y_true, y_pred):
    """
    Calcula métricas extendidas de clasificación.
    """
    prec_macro = precision_score(y_true, y_pred, average='macro', zero_division=0)
    rec_macro  = recall_score(y_true, y_pred, average='macro', zero_division=0)
    f1_macro   = f1_score(y_true, y_pred, average='macro', zero_division=0)

    cm = confusion_matrix(y_true, y_pred, labels=[0,1])
    TN0 = cm[1,1]; FP0 = cm[1,0]
    spec0 = TN0 / (TN0 + FP0 + 1e-12)
    TN1 = cm[0,0]; FP1 = cm[0,1]
    spec1 = TN1 / (TN1 + FP1 + 1e-12)
    spec_macro = (spec0 + spec1) / 2.0

    recs = recall_score(y_true, y_pred, labels=[0,1], average=None, zero_division=0)
    sens_macro = float(recs.mean())

    return {
        "precision_macro": float(prec_macro),
        "recall_macro": float(rec_macro),
        "f1_macro": float(f1_macro),
        "specificity_macro": float(spec_macro),
        "sensitivity_macro": float(sens_macro),
        "cm": cm.tolist()
    }

# =========================
# Helper: matriz de confusión con anotaciones
# =========================
def plot_confusion_with_text(cm, class_names, title, out_path, cmap='Blues'):
    """
    Plota matriz de confusión con valores anotados.
    """
    fig, ax = plt.subplots(figsize=(4.8, 4.2))
    im = ax.imshow(cm, cmap=cmap)
    ax.set_title(title)
    ax.set_xlabel("Predicted")
    ax.set_ylabel("True label")
    ax.set_xticks(range(len(class_names))); ax.set_xticklabels(class_names)
    ax.set_yticks(range(len(class_names))); ax.set_yticklabels(class_names)

    vmax = cm.max() if cm.size else 1
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            ax.text(j, i, str(cm[i, j]),
                    ha='center', va='center',
                    color='white' if cm[i, j] > vmax/2 else 'black',
                    fontsize=11, fontweight='bold')

    plt.tight_layout()
    plt.savefig(out_path, dpi=140)
    plt.close(fig)

### Herramientas de Interpretabilidad

Funciones para analizar y visualizar el comportamiento del modelo.

In [55]:
# =========================
# FUNCIONES DE INTERPRETABILIDAD (SOLO LAS QUE PIDES)
# =========================

# ===== Saliency por canal =====
def channel_saliency(model, xb, target_class=None):
    """
    Importancia por canal basada en |∂score/∂x| promedio en el tiempo.
    xb: tensor (1, C, T) en GPU.
    Devuelve: vector (C,) normalizado [0,1].
    """
    model.eval()
    xb = xb.clone().detach().requires_grad_(True)

    logits = model(xb)
    if target_class is None:
        cls = logits.argmax(1)
    else:
        cls = torch.tensor([int(target_class)], device=xb.device)
    score = logits[0, cls]

    model.zero_grad(set_to_none=True)
    if xb.grad is not None:
        xb.grad.zero_()
    score.backward()

    grad = xb.grad.detach()[0]     # (C, T)
    sal = grad.abs().mean(dim=1)   # (C,)
    sal = (sal - sal.min()) / (sal.max() - sal.min() + 1e-8)
    return sal.cpu().numpy()

# ===== Topomaps =====
def make_mne_info_8ch(sfreq: float):
    """
    Crea un Info de MNE con tus 8 canales motores.
    """
    info = mne.create_info(
        ch_names=EXPECTED_8,
        sfreq=float(sfreq),
        ch_types="eeg"
    )
    try:
        montage = mne.channels.make_standard_montage("standard_1020")
        info.set_montage(montage)
    except Exception:
        pass
    return info

def plot_topomap_from_values(values, info, title: str, out_path: Path,
                             cmap="Reds", cbar_label="Value",
                             vmin=None, vmax=None):
    """
    Dibuja y guarda un topomap dado un vector de valores por canal.
    values: array (n_channels,)
    """
    values = np.asarray(values, dtype=float)
    fig, ax = plt.subplots(figsize=(4.2, 4.0))
    im, _ = mne.viz.plot_topomap(
        values, info, axes=ax, show=False, cmap=cmap
    )
    if (vmin is not None) or (vmax is not None):
        im.set_clim(vmin=vmin, vmax=vmax)
    ax.set_title(title, fontsize=10)
    cbar = plt.colorbar(im, ax=ax, shrink=0.7)
    cbar.set_label(cbar_label, fontsize=9)
    plt.tight_layout()
    fig.savefig(out_path, dpi=160, bbox_inches="tight")
    plt.close(fig)
    print(f"↳ Topomap guardado: {out_path}")

# ===== -log10(p) análisis estadístico =====
def compute_channel_pvalues_lr(X, y):
    """
    P-values por canal para la diferencia left vs right.
    Usa potencia media (X^2) promediada en el tiempo como feature.
    X: (N, C, T)  | y: (N,) con labels 0=left, 1=right
    Devuelve: vector (C,) con p-values (ttest_ind Welch).
    """
    X = np.asarray(X, dtype=float)
    y = np.asarray(y, dtype=int)
    left = X[y == 0]
    right = X[y == 1]

    if (left.size == 0) or (right.size == 0):
        return np.ones(X.shape[1], dtype=float)

    feat_left = (left ** 2).mean(axis=2)   # (N_left, C)
    feat_right = (right ** 2).mean(axis=2)

    pvals = []
    for ch in range(feat_left.shape[1]):
        _, p = stats.ttest_ind(
            feat_left[:, ch],
            feat_right[:, ch],
            equal_var=False
        )
        pvals.append(p)
    return np.array(pvals, dtype=float)

# ===== t-SNE de CLS embeddings =====
def compute_cls_embeddings(model, X_data, device):
    """
    Calcula embeddings CLS del modelo.
    X_data: (N, C, T)
    Devuelve: (N, d_model) embeddings CLS
    """
    model.eval()
    embeddings = []
    with torch.no_grad():
        for i in range(0, len(X_data), 64):
            xb = torch.tensor(X_data[i:i+64], dtype=torch.float32, device=device)
            # Pasar por el modelo completo para obtener CLS token
            z = model.conv_t(xb)           # (B, C', T')
            z = model.proj(z)              # (B, d_model, T')
            z = z.transpose(1, 2)          # (B, T', d_model)
            
            # Añadir positional encoding (simulando forward)
            B, L, D = z.shape
            if model.pos_encoding is None or model.pos_encoding.shape[0] != L or model.pos_encoding.shape[1] != D:
                pos_enc = model._positional_encoding(L, D).to(z.device)
            else:
                pos_enc = model.pos_encoding
            z = z + pos_enc[None, :, :]
            
            # Añadir CLS token y pasar por encoder
            cls_tok = model.cls.expand(B, -1, -1)
            z_with_cls = torch.cat([cls_tok, z], dim=1)  # (B, 1+L, D)
            
            # Pasar por encoder (solo necesitamos CLS)
            z_encoder, _ = model.encoder(z_with_cls)
            cls_embedding = z_encoder[:, 0, :]  # CLS token (B, D)
            embeddings.append(cls_embedding.cpu().numpy())
    
    embeddings = np.concatenate(embeddings, axis=0)
    return embeddings

def compute_tsne_embeddings(model, X_data, device, n_components=2, perplexity=30):
    """
    Calcula embeddings CLS y luego aplica t-SNE.
    X_data: (N, C, T)
    """
    # Obtener embeddings CLS
    embeddings = compute_cls_embeddings(model, X_data, device)
    
    # Aplicar t-SNE
    tsne = TSNE(n_components=n_components, perplexity=perplexity, 
                random_state=RANDOM_STATE, init='pca', n_iter=1000)
    tsne_result = tsne.fit_transform(embeddings)
    return tsne_result, embeddings

def plot_tsne(tsne_result, y_true, title, out_path: Path):
    """
    Grafica t-SNE con colores según clase.
    """
    fig, ax = plt.subplots(figsize=(7, 5))
    
    # Colores para cada clase
    colors = ['green', 'blue']  # 0:left (verde), 1:right (azul)
    
    for cls_idx in range(2):
        mask = y_true == cls_idx
        if mask.any():
            ax.scatter(tsne_result[mask, 0], tsne_result[mask, 1], 
                      c=colors[cls_idx], label=CLASS_NAMES[cls_idx], 
                      alpha=0.7, s=25, edgecolors='w', linewidth=0.5)
    
    ax.set_title(title, fontsize=12, fontweight='bold')
    ax.set_xlabel("t-SNE 1")
    ax.set_ylabel("t-SNE 2")
    ax.legend(title="Clase", fontsize=9)
    ax.grid(True, alpha=0.3, linestyle='--', linewidth=0.5)
    
    # Añadir información adicional
    n_samples = len(y_true)
    n_left = np.sum(y_true == 0)
    n_right = np.sum(y_true == 1)
    info_text = f"Total: {n_samples} | Left: {n_left} | Right: {n_right}"
    ax.text(0.02, 0.98, info_text, transform=ax.transAxes, fontsize=9,
            verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
    
    plt.tight_layout()
    fig.savefig(out_path, dpi=160, bbox_inches="tight")
    plt.close(fig)
    print(f"↳ t-SNE (CLS embeddings) guardado: {out_path}")

### Métricas de Consumo y Hiperparámetros

Funciones para calcular consumo computacional y gestionar hiperparámetros.

In [56]:
# =========================
# Consumo / hiperparámetros
# =========================
def count_params(model, trainable_only=False):
    """
    Cuenta parámetros del modelo.
    """
    if trainable_only:
        return int(sum(p.numel() for p in model.parameters() if p.requires_grad))
    return int(sum(p.numel() for p in model.parameters()))

def estimate_transformer_flops(L, d, n_heads, n_layers):
    """
    Estima FLOPs del encoder Transformer.
    """
    d_ff = 2*d
    proj = 4 * L * (d*d)         # Q,K,V,Out
    attn = 2 * (L*L*d)           # scores + aplicar a V
    ff   = 2 * L * (d * d_ff)    # FFN (2 lineales)
    return n_layers * (proj + attn + ff)

@torch.no_grad()
def benchmark_latency(model, X_sample, device, n_warm=10, n_runs=50):
    """
    Mide latencia de inferencia del modelo.
    """
    if torch.is_tensor(X_sample):
        xb = X_sample.clone().detach().to(device)
    else:
        xb = torch.tensor(X_sample, dtype=torch.float32, device=device).clone().detach()

    for _ in range(n_warm):
        _ = model(xb)
    if torch.cuda.is_available():
        torch.cuda.synchronize()

    t0 = time.time()
    for _ in range(n_runs):
        _ = model(xb)
    if torch.cuda.is_available():
        torch.cuda.synchronize()
    return (time.time() - t0) / n_runs

def current_hparams():
    """
    Retorna diccionario con todos los hiperparámetros actuales.
    """
    return {
        "EPOCHS": EPOCHS, "BATCH_SIZE": BATCH_SIZE, "BASE_LR": BASE_LR,
        "WARMUP_EPOCHS": WARMUP_EPOCHS, "PATIENCE": PATIENCE,
        "RESAMPLE_HZ": RESAMPLE_HZ, "DO_NOTCH": DO_NOTCH, "DO_BANDPASS": DO_BANDPASS,
        "BP_LO": BP_LO, "BP_HI": BP_HI, "DO_CAR": DO_CAR, "ZSCORE_PER_EPOCH": ZSCORE_PER_EPOCH,
        "D_MODEL": D_MODEL, "N_HEADS": N_HEADS, "N_LAYERS": N_LAYERS,
        "P_DROP": P_DROP, "P_DROP_ENCODER": P_DROP_ENCODER,
        "TMIN": TMIN, "TMAX": TMAX, "SW_MODE": SW_MODE, "SW_ENABLE": SW_ENABLE,
        "TTA_SHIFTS_S": TTA_SHIFTS_S, "SW_LEN": SW_LEN, "SW_STRIDE": SW_STRIDE,
        "COMBINE_TTA_AND_SUBWIN": COMBINE_TTA_AND_SUBWIN,
        "USE_WEIGHTED_SAMPLER": USE_WEIGHTED_SAMPLER,
        "USE_EMA": USE_EMA, "EMA_DECAY": EMA_DECAY,
        "FT_N_FOLDS": FT_N_FOLDS, "FT_FREEZE_EPOCHS": FT_FREEZE_EPOCHS, "FT_UNFREEZE_EPOCHS": FT_UNFREEZE_EPOCHS,
        "FT_PATIENCE": FT_PATIENCE, "FT_BATCH": FT_BATCH, "FT_LR_HEAD": FT_LR_HEAD,
        "FT_LR_BACKBONE": FT_LR_BACKBONE, "FT_WD": FT_WD, "FT_AUG": FT_AUG,
    }

# =========================
# Helpers de salida/archivos
# =========================
def ensure_dir(p: Path):
    """
    Crea directorio si no existe.
    """
    p.mkdir(parents=True, exist_ok=True)
    return p

def make_run_dirs(base_dir: Path, tag: str, fold: int):
    """
    Devuelve (run_dir, fold_dir).
    """
    run_dir = ensure_dir(base_dir / tag)
    fold_dir = ensure_dir(run_dir / f"fold{fold}")
    return run_dir, fold_dir

def save_csv_rows(csv_path: Path, fieldnames: list, rows: list):
    """
    Guarda filas en archivo CSV.
    """
    write_header = not csv_path.exists()
    with open(csv_path, "a", newline="") as f:
        w = csv.DictWriter(f, fieldnames=fieldnames)
        if write_header:
            w.writeheader()
        for r in rows:
            w.writerow(r)

### Función Principal de Entrenamiento

Función que entrena un fold completo del modelo.

In [57]:
# =========================
# TRAIN/EVAL GLOBAL POR FOLD
# =========================
def train_one_fold(
    fold:int,
    device,
    model_factory=None,
    run_tag="base",
    out_dir: Path = None,
    seed=None,
    save_artifacts: bool = True,  
    do_ft: bool = True            
):
    """
    Entrena y evalúa el modelo para un fold específico.
    
    Args:
        fold: Número del fold (1-5)
        device: Dispositivo para entrenamiento (CPU/GPU)
        model_factory: Función que crea instancia del modelo
        run_tag: Identificador de la ejecución
        out_dir: Directorio de salida
        seed: Semilla para reproducibilidad
        save_artifacts: Si guarda artefactos de entrenamiento
        do_ft: Si realiza fine-tuning por sujeto
    """
    # Semilla fija por corrida (justa entre combos/folds)
    base = RANDOM_STATE if seed is None else int(seed)
    seed_everything(base)

    def load_fold_subjects_local(folds_json: Path, fold: int):
        return load_fold_subjects(folds_json, fold)

    # --- directorios de salida ---
    if save_artifacts:
        base_out = Path(".") if out_dir is None else Path(out_dir)
        run_dir, fold_dir = make_run_dirs(base_out, run_tag, fold)
        print(f"[IO] Guardando artefactos en: {fold_dir.resolve()}")
    else:
        run_dir = None
        fold_dir = None
        print("[LITE] GridSearch: no se guardarán artefactos por combinación.")

    # --- split de sujetos ---
    train_sub, test_sub = load_fold_subjects_local(FOLDS_JSON, fold)
    train_sub = [s for s in train_sub if subject_id_to_int(s) not in EXCLUDE_SUBJECTS]
    test_sub  = [s for s in test_sub  if subject_id_to_int(s) not in EXCLUDE_SUBJECTS]

    rng = np.random.RandomState(RANDOM_STATE + fold)
    tr_subjects = sorted(train_sub)

    if VAL_STRAT_SUBJECT and len(tr_subjects) > 1:
        y_dom = build_subject_label_map(tr_subjects)
        if np.any(y_dom < 0):
            mask = y_dom >= 0
            moda = int(np.bincount(y_dom[mask]).argmax()) if mask.sum() > 0 else 0
            y_dom[~mask] = moda
        n_val_subj = max(1, int(round(len(tr_subjects) * VAL_SUBJECT_FRAC)))
        sss = StratifiedShuffleSplit(n_splits=1, test_size=n_val_subj, random_state=RANDOM_STATE + fold)
        idx = np.arange(len(tr_subjects))
        _, val_idx = next(sss.split(idx, y_dom))
        val_subjects = sorted([tr_subjects[i] for i in val_idx])
        train_subjects = [s for s in tr_subjects if s not in val_subjects]
    else:
        tr_subjects_shuf = tr_subjects.copy()
        rng.shuffle(tr_subjects_shuf)
        n_val_subj = max(1, int(round(len(tr_subjects_shuf) * VAL_SUBJECT_FRAC)))
        val_subjects = sorted(tr_subjects_shuf[:n_val_subj])
        train_subjects = sorted(tr_subjects_shuf[n_val_subj:])

    # --- carga de datos ---
    X_tr_list, y_tr_list, sub_tr_list = [], [], []
    X_val_list, y_val_list, sub_val_list = [], [], []
    X_te_list, y_te_list, sub_te_list = [], [], []
    sfreq = None

    for sid in tqdm(train_subjects, desc=f"Cargando train fold{fold}"):
        Xs, ys, sf = load_subject_epochs(sid, RESAMPLE_HZ, DO_NOTCH, DO_BANDPASS, DO_CAR, BP_LO, BP_HI)
        if len(ys) == 0: continue
        X_tr_list.append(Xs); y_tr_list.append(ys)
        sub_tr_list.append(np.full_like(ys, fill_value=subject_id_to_int(sid)))
        sfreq = sf if sfreq is None else sfreq

    for sid in tqdm(val_subjects, desc=f"Cargando val fold{fold}"):
        Xs, ys, sf = load_subject_epochs(sid, RESAMPLE_HZ, DO_NOTCH, DO_BANDPASS, DO_CAR, BP_LO, BP_HI)
        if len(ys) == 0: continue
        X_val_list.append(Xs); y_val_list.append(ys)
        sub_val_list.append(np.full_like(ys, fill_value=subject_id_to_int(sid)))
        sfreq = sf if sfreq is None else sfreq

    for sid in tqdm(test_sub, desc=f"Cargando test fold{fold}"):
        Xs, ys, sf = load_subject_epochs(sid, RESAMPLE_HZ, DO_NOTCH, DO_BANDPASS, DO_CAR, BP_LO, BP_HI)
        if len(ys) == 0: continue
        X_te_list.append(Xs); y_te_list.append(ys)
        sub_te_list.append(np.full_like(ys, fill_value=subject_id_to_int(sid)))
        sfreq = sf if sfreq is None else sfreq

    # Concatenar
    X_tr = np.concatenate(X_tr_list, axis=0); y_tr = np.concatenate(y_tr_list, axis=0)
    sub_tr = np.concatenate(sub_tr_list, axis=0)
    X_val = np.concatenate(X_val_list, axis=0); y_val = np.concatenate(y_val_list, axis=0)
    sub_val = np.concatenate(sub_val_list, axis=0)
    X_te = np.concatenate(X_te_list, axis=0); y_te = np.concatenate(y_te_list, axis=0)
    sub_te = np.concatenate(sub_te_list, axis=0)

    print(f"[Fold {fold}/5] Entrenando modelo global... (n_train={len(y_tr)} | n_val={len(y_val)} | n_test={len(y_te)})")

    # Normalización por canal
    if ZSCORE_PER_EPOCH:
        X_tr_std, X_val_std, X_te_std = X_tr, X_val, X_te
    else:
        X_tr_std, X_val_std = standardize_per_channel(X_tr, X_val)
        _,        X_te_std  = standardize_per_channel(X_tr, X_te)

    # Datasets
    tr_ds  = TensorDataset(torch.tensor(X_tr_std),  torch.tensor(y_tr).long(),  torch.tensor(sub_tr).long())
    val_ds = TensorDataset(torch.tensor(X_val_std), torch.tensor(y_val).long(), torch.tensor(sub_val).long())
    te_ds  = TensorDataset(torch.tensor(X_te_std),  torch.tensor(y_te).long(),  torch.tensor(sub_te).long())

    # Weighted sampler
    def make_weighted_sampler(dataset: TensorDataset):
        _Xb, yb, sb = dataset.tensors
        yb_np = yb.numpy()
        sb_np = sb.numpy()
        uniq_s, cnt_s = np.unique(sb_np, return_counts=True)
        map_s = {s:c for s,c in zip(uniq_s, cnt_s)}
        key = sb_np.astype(np.int64) * 10 + yb_np.astype(np.int64)
        uniq_k, cnt_k = np.unique(key, return_counts=True)
        map_k = {k:c for k,c in zip(uniq_k, cnt_k)}
        a, b = 0.8, 1.0
        w = []
        for s, y in zip(sb_np, yb_np):
            k = int(s)*10 + int(y)
            ws = float(map_s[int(s)])
            wk = float(map_k[k])
            w.append((ws ** (-a)) * (wk ** (-b)))
        w = np.array(w, dtype=np.float64)
        w = w / (w.mean() + 1e-12)
        sampler = WeightedRandomSampler(weights=torch.tensor(w, dtype=torch.double),
                                        num_samples=len(yb_np), replacement=True)
        return sampler

    if USE_WEIGHTED_SAMPLER:
        tr_sampler = make_weighted_sampler(tr_ds)
        tr_ld  = DataLoader(tr_ds, batch_size=BATCH_SIZE, sampler=tr_sampler, drop_last=False, worker_init_fn=seed_worker)
    else:
        tr_ld  = DataLoader(tr_ds, batch_size=BATCH_SIZE, shuffle=True,  drop_last=False, worker_init_fn=seed_worker)

    val_ld = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, worker_init_fn=seed_worker)
    te_ld  = DataLoader(te_ds,  batch_size=BATCH_SIZE, shuffle=False, drop_last=False, worker_init_fn=seed_worker)

    # --- Medidores de tiempo/memoria por fase ---
    train_time_s = None
    train_mem_mb = None
    test_time_s  = None
    test_mem_mb  = None

    # Modelo
    if model_factory is None:
        model = EEGCNNTransformer(n_ch=8, n_cls=2, d_model=D_MODEL, n_heads=N_HEADS,
                                  n_layers=N_LAYERS, p_drop=P_DROP, p_drop_encoder=P_DROP_ENCODER,
                                  n_dw_blocks=2, capture_attn=True).to(device)
    else:
        model = model_factory()

    # Opt + Loss
    opt = torch.optim.AdamW(model.parameters(), lr=BASE_LR, weight_decay=1e-2)
    class_counts = np.bincount(y_tr, minlength=2).astype(np.float32)
    inv = class_counts.sum() / (2.0 * np.maximum(class_counts, 1.0))
    alpha = torch.tensor(inv, dtype=torch.float32, device=device)
    crit = FocalLoss(alpha=alpha, gamma=1.5, reduction='mean')

    # LR scheduler Warmup+Cosine
    from torch.optim.lr_scheduler import LambdaLR
    total_epochs = EPOCHS
    warmup_epochs = max(1, int(WARMUP_EPOCHS))
    min_factor = 0.1
    def lr_lambda(current_epoch):
        if current_epoch < warmup_epochs:
            return (current_epoch + 1) / warmup_epochs
        progress = (current_epoch - warmup_epochs) / max(1, (total_epochs - warmup_epochs))
        progress = min(1.0, max(0.0, progress))
        return min_factor + 0.5 * (1.0 - min_factor) * (1.0 + np.cos(np.pi * progress))
    scheduler = LambdaLR(opt, lr_lambda=lr_lambda)

    # EMA
    ema = ModelEMA(model, decay=EMA_DECAY, device=device) if USE_EMA else None

    # Entrenamiento global (con tracking de GAP)
    best_f1, best_state, wait = 0.0, None, 0
    hist = {"ep": [], "tr_loss": [], "tr_acc": [], "val_acc": [], "val_f1m": [], "val_loss": [], "lr": []}

    # NUEVO: métricas en la época de mejor valid F1
    best_val_acc = float('nan')
    best_val_f1  = float('nan')
    best_tr_acc  = float('nan')
    best_tr_f1   = float('nan')

    def evaluate_on(loader, use_ema=True, loss_fn=None):
        mdl = ema.ema if (ema is not None and use_ema) else model
        mdl.eval()
        preds, gts = [], []
        tot_loss, n_seen = 0.0, 0
        with torch.no_grad():
            for xb, yb, _sb in loader:
                xb = xb.to(device); yb = yb.to(device)
                logits = mdl(xb)
                if loss_fn is not None:
                    L = loss_fn(logits, yb).item()
                    tot_loss += L * len(yb); n_seen += len(yb)
                p = logits.argmax(dim=1).cpu().numpy()
                preds.append(p); gts.append(yb.cpu().numpy())
        preds = np.concatenate(preds); gts = np.concatenate(gts)
        acc = accuracy_score(gts, preds)
        f1m = f1_score(gts, preds, average='macro')
        vloss = (tot_loss / max(1,n_seen)) if n_seen>0 else float('nan')
        return acc, f1m, vloss, preds, gts

    # === NUEVO: iniciar cronómetro + memoria de ENTRENAMIENTO ===
    if torch.cuda.is_available():
        torch.cuda.reset_peak_memory_stats(device)
    t_train_start = time.time()

    for ep in range(1, EPOCHS+1):
        model.train()
        tr_loss, n_seen, tr_correct = 0.0, 0, 0
        for xb, yb, _sb in tr_ld:
            xb = xb.to(device); yb = yb.to(device)
            xb = augment_batch(xb)
            opt.zero_grad()
            logits = model(xb)
            loss = crit(logits, yb)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            opt.step()
            if ema is not None:
                ema.update(model)

            tr_loss += loss.item() * len(yb)
            n_seen += len(yb)
            tr_correct += (logits.argmax(1) == yb).sum().item()
        tr_loss /= max(1, n_seen)
        tr_acc = tr_correct / max(1, n_seen)

        acc, f1m, vloss, _, _ = evaluate_on(val_ld, use_ema=True, loss_fn=crit)

        hist["ep"].append(ep)
        hist["tr_loss"].append(tr_loss)
        hist["tr_acc"].append(tr_acc)
        hist["val_acc"].append(acc)
        hist["val_f1m"].append(f1m)
        hist["val_loss"].append(vloss)
        hist["lr"].append(scheduler.get_last_lr()[0])

        print(f"  Época {ep:3d} | train_loss={tr_loss:.4f} | train_acc={tr_acc:.4f} | val_loss={vloss:.4f} | val_acc={acc:.4f} | val_f1m={f1m:.4f} | LR={scheduler.get_last_lr()[0]:.6f}")

        improved = f1m > best_f1 + 1e-4
        if improved:
            best_f1 = f1m
            ref_model = ema.ema if ema is not None else model
            best_state = {k: v.detach().cpu() for k, v in ref_model.state_dict().items()}
            wait = 0

            # NUEVO: computar train en el mismo modelo/época p/ medir gap
            tr_acc_eval, tr_f1_eval, _, _, _ = evaluate_on(tr_ld, use_ema=True, loss_fn=None)
            best_tr_acc = tr_acc_eval
            best_tr_f1  = tr_f1_eval
            best_val_acc = acc
            best_val_f1  = f1m
        else:
            wait += 1

        scheduler.step()
        if wait >= PATIENCE:
            print(f"  Early stopping en época {ep} (mejor val_f1m={best_f1:.4f})")
            break

    # === cerrar cronómetro de ENTRENAMIENTO ===
    t_train_end = time.time()
    train_time_s = t_train_end - t_train_start
    if torch.cuda.is_available():
        train_mem_mb = torch.cuda.max_memory_allocated(device) / (1024.0 ** 2)

    # Restaurar mejor estado
    if best_state is not None:
        (ema.ema if (ema is not None) else model).load_state_dict(best_state)
        model.load_state_dict(best_state)
    for p in model.parameters():
        p.requires_grad_(True)

    if save_artifacts and fold_dir is not None:
        torch.save(best_state, fold_dir / f"model_global_fold{fold}_{run_tag}.pth")
        print(f"↳ Modelo global guardado en: {fold_dir / f'model_global_fold{fold}_{run_tag}.pth'}")

    # ---- Curvas (sólo si se guardan artefactos) ----
    if save_artifacts:
        # ONLY train_loss and val_loss, correct scale
        fig = plt.figure(figsize=(8, 4.5))
        ax1 = plt.gca()
        ax1.plot(hist["ep"], hist["tr_loss"], label="train_loss")
        ax1.plot(hist["ep"], hist["val_loss"], label="val_loss")
        ax1.set_xlabel("Epoch")
        ax1.set_ylabel("Loss")
        ax1.set_title(f"Fold {fold} — Training Curve (2 classes) [{run_tag}]")
        ax1.legend()
        ax1.grid(True, alpha=0.3)
        out_png = fold_dir / f"training_curve_fold{fold}_{run_tag}.png"
        plt.tight_layout()
        plt.savefig(out_png, dpi=140)
        plt.close(fig)
        print(f"↳ Training curve saved: {out_png}")
        # Guardar historial de curvas (epoch, train_loss, val_loss) para curva global
        hist_path = fold_dir / f"history_fold{fold}_{run_tag}.npz"
        np.savez(
            hist_path,
            ep=np.array(hist["ep"], dtype=np.int32),
            tr_loss=np.array(hist["tr_loss"], dtype=np.float32),
            val_loss=np.array(hist["val_loss"], dtype=np.float32),
        )
        print(f"↳ History saved: {hist_path}")

    # ---- Evaluación final en TEST (global) ----
    eval_model = ema.ema if (ema is not None) else model
    eval_model.eval()

    sfreq_used = RESAMPLE_HZ
    if sfreq_used is None:
        sfreq_used = int(round(X_te_std.shape[-1] / (TMAX - TMIN)))

    # === NUEVO: cronómetro + memoria de TEST ===
    if torch.cuda.is_available():
        torch.cuda.reset_peak_memory_stats(device)
    t_test_start = time.time()

    if (not SW_ENABLE) or SW_MODE == 'none':
        preds, gts = [], []
        with torch.no_grad():
            for xb, yb, _sb in te_ld:
                xb = xb.to(device)
                p = eval_model(xb).argmax(dim=1).cpu().numpy()
                preds.append(p); gts.append(yb.numpy())
        preds = np.concatenate(preds); gts = np.concatenate(gts)
    elif SW_MODE in ('subwin', 'tta'):
        logits_tta = None
        logits_sw  = None
        if SW_MODE == 'subwin':
            logits_sw = subwindow_logits(eval_model, X_te_std, sfreq_used, SW_LEN, SW_STRIDE, device)
        elif SW_MODE == 'tta':
            logits_tta = time_shift_tta_logits(eval_model, X_te_std, sfreq_used, TTA_SHIFTS_S, device)
        logits = logits_tta if logits_tta is not None else logits_sw
        preds = logits.argmax(axis=1); gts = y_te
    else:
        raise ValueError(f"SW_MODE desconocido: {SW_MODE}")

    # === NUEVO: cerrar cronómetro de TEST ===
    t_test_end = time.time()
    test_time_s = t_test_end - t_test_start
    if torch.cuda.is_available():
        test_mem_mb = torch.cuda.max_memory_allocated(device) / (1024.0 ** 2)

    acc = accuracy_score(gts, preds)
    f1m = f1_score(gts, preds, average='macro')
    print(f"[Fold {fold}/5] Global acc={acc:.4f} | f1_macro={f1m:.4f}\n")
    print(classification_report(gts, preds, target_names=[c.replace('_',' ') for c in CLASS_NAMES], digits=4))
    print("Confusion matrix (rows=true, cols=pred):")
    cm = confusion_matrix(gts, preds, labels=[0,1])
    print(cm)

    # Métricas extendidas SIEMPRE (aunque no guardes artefactos)
    met = compute_metrics(gts, preds)
    print(
        f"precision_macro={met['precision_macro']:.4f} | "
        f"recall_macro={met['recall_macro']:.4f} | "
        f"specificity_macro={met['specificity_macro']:.4f} | "
        f"sensitivity_macro={met['sensitivity_macro']:.4f}"
    )

    if save_artifacts:
        cm_png = fold_dir / f"confusion_global_fold{fold}_{run_tag}.png"
        plot_confusion_with_text(
            cm=cm,
            class_names=CLASS_NAMES,
            title=f"Confusion — Fold {fold} Global (2 classes) [{run_tag}]",
            out_path=cm_png,
            cmap='Blues'
        )
        print(f"↳ Confusion matrix saved: {cm_png}")

    # =========================
    # INTERPRETABILIDAD
    # =========================
    if save_artifacts:
        try:
            model_to_probe = model
            model_to_probe.eval()

            # Minibatch de validación para ejemplos ilustrativos
            with torch.no_grad():
                xb_val, yb_val, _ = next(iter(val_ld))
            xb_val = xb_val[:8].to(device)

            # --- 1. Mapas de atención (Transformer) ---
            _ = model_to_probe(xb_val)
            attn_list = model_to_probe.get_last_attention_maps()
            if attn_list and len(attn_list) > 0:
                # attn_list[0]: [B, heads, T, T]
                A = attn_list[0][0].mean(dim=0).detach().cpu().numpy()
                plt.figure(figsize=(5, 4))
                plt.imshow(A, aspect="auto")
                plt.title(f"Attention (layer1, mean heads) [{run_tag}]")
                plt.colorbar()
                plt.tight_layout()
                attn_png = fold_dir / f"attention_map_fold{fold}_{run_tag}.png"
                plt.savefig(attn_png, dpi=140)
                plt.close()
                print(f"↳ Mapa de atención: {attn_png}")

            # Tomamos UN ejemplo para todos los mapas
            xb_ex = xb_val[:1]
            pred_cls = int(model_to_probe(xb_ex).argmax(1).item())

            # --- 2. Topomap de saliency por canal ---
            saliency_ch = channel_saliency(model_to_probe, xb_ex, target_class=pred_cls)

            sfreq_topo = sfreq_used if sfreq_used is not None else int(
                round(X_te_std.shape[-1] / (TMAX - TMIN))
            )
            info_topo = make_mne_info_8ch(sfreq_topo)

            topo_sal_png = fold_dir / f"topomap_saliency_fold{fold}_{run_tag}.png"
            plot_topomap_from_values(
                saliency_ch,
                info_topo,
                title=f"Channel saliency — Fold {fold} [{run_tag}]",
                out_path=topo_sal_png,
                cmap="Reds",
                cbar_label="Normalized saliency"
            )
            # Guardar saliency numérica por canal para promedio GLOBAL
            saliency_path = fold_dir / f"saliency_channels_fold{fold}_{run_tag}.npy"
            np.save(saliency_path, saliency_ch)

            # --- 3. Topomap de -log10(p) left vs right (TEST) ---
            pvals = compute_channel_pvalues_lr(X_te_std, y_te)
            pvals_log = -np.log10(pvals + 1e-12)

            topo_p_png = fold_dir / f"topomap_log10p_fold{fold}_{run_tag}.png"
            plot_topomap_from_values(
                pvals_log,
                info_topo,
                title=f"-log10(p) Left vs Right — Fold {fold} [{run_tag}]",
                out_path=topo_p_png,
                cmap="viridis",
                cbar_label="-log10(p)"
            )

            # Guardar valores numéricos por canal para promedio GLOBAL
            pvals_log_path = fold_dir / f"pvals_log_channels_fold{fold}_{run_tag}.npy"
            np.save(pvals_log_path, pvals_log)

            # --- 4. t-SNE de embeddings del modelo ---
            # Usar un subconjunto de TEST para t-SNE (máximo 1000 muestras)
            n_tsne_samples = min(1000, len(X_te_std))
            indices = rng.choice(len(X_te_std), n_tsne_samples, replace=False)
            X_tsne = X_te_std[indices]
            y_tsne = y_te[indices]

            tsne_result, cls_embeddings = compute_tsne_embeddings(
                model_to_probe, X_tsne, DEVICE, n_components=2, perplexity=30
            )
            
            tsne_png = fold_dir / f"tsne_cls_fold{fold}_{run_tag}.png"
            plot_tsne(
                tsne_result, 
                y_tsne,
                title=f"t-SNE CLS embeddings — Fold {fold} [{run_tag}]",
                out_path=tsne_png
            )
            
            # Guardar t-SNE resultados y embeddings CLS
            tsne_path = fold_dir / f"tsne_results_fold{fold}_{run_tag}.npy"
            np.save(tsne_path, tsne_result)
            
            cls_emb_path = fold_dir / f"cls_embeddings_fold{fold}_{run_tag}.npy"
            np.save(cls_emb_path, cls_embeddings)

        except Exception as e:
            print(f"[WARN] Interpretabilidad no generada: {e}")

    # =========================
    # Fine-tuning progresivo por sujeto (opcional)
    # =========================
    if do_ft:
        subjects = np.unique(sub_te)
        subj_to_idx = {s: np.where(sub_te == s)[0] for s in subjects}

        def make_ft_optimizer(model_ft):
            head_params = list(model_ft.head.parameters())
            backbone_params = [p for n,p in model_ft.named_parameters() if not n.startswith('head.')]
            return torch.optim.AdamW([
                {'params': backbone_params, 'lr': FT_LR_BACKBONE},
                {'params': head_params,     'lr': FT_LR_HEAD}
            ], weight_decay=FT_WD)

        def freeze_backbone(model_ft, freeze=True):
            for n,p in model_ft.named_parameters():
                if not n.startswith('head.'):
                    p.requires_grad_(not freeze)

        def ft_make_criterion_from_counts(counts):
            inv = counts.sum() / (2.0 * np.maximum(counts.astype(np.float32), 1.0))
            a = torch.tensor(inv, dtype=torch.float32, device=device)
            return FocalLoss(alpha=a, gamma=1.5, reduction='mean')

        def evaluate_tensor(model_t, X, y):
            model_t.eval()
            with torch.no_grad():
                xb = torch.tensor(X, dtype=torch.float32, device=device)
                p = model_t(xb).argmax(1).cpu().numpy()
            return accuracy_score(y, p), f1_score(y, p, average='macro'), p

        all_true, all_pred = [], []
        base_state = (ema.ema if (ema is not None) else model).state_dict()

        for s in subjects:
            idx = subj_to_idx[s]
            Xs, ys = X_te_std[idx], y_te[idx]
            if len(np.unique(ys)) < 2 or len(ys) < 20:
                eval_model2 = model_factory() if model_factory is not None else EEGCNNTransformer(
                    n_ch=8, n_cls=2, d_model=D_MODEL, n_heads=N_HEADS,
                    n_layers=N_LAYERS, p_drop=P_DROP, p_drop_encoder=P_DROP_ENCODER,
                    n_dw_blocks=2, capture_attn=False).to(device)
                eval_model2.load_state_dict(base_state)
                _, _, pred_s = evaluate_tensor(eval_model2, Xs, ys)
                all_true.append(ys); all_pred.append(pred_s)
                continue

            skf = StratifiedKFold(n_splits=FT_N_FOLDS, shuffle=True, random_state=RANDOM_STATE + fold + int(s))
            for tr_idx, te_idx in skf.split(Xs, ys):
                Xtr, ytr = Xs[tr_idx].copy(), ys[tr_idx].copy()
                Xte_s, yte_s = Xs[te_idx].copy(), ys[te_idx].copy()
                Xtr_std, Xte_std = standardize_per_channel(Xtr, Xte_s)

                ft_model = model_factory() if model_factory is not None else EEGCNNTransformer(
                    n_ch=8, n_cls=2, d_model=D_MODEL, n_heads=N_HEADS,
                    n_layers=N_LAYERS, p_drop=P_DROP, p_drop_encoder=P_DROP_ENCODER,
                    n_dw_blocks=2, capture_attn=False).to(device)
                ft_model.load_state_dict(base_state)

                opt_ft = make_ft_optimizer(ft_model)
                alpha_counts = np.bincount(ytr, minlength=2)
                ft_crit = ft_make_criterion_from_counts(alpha_counts)

                ds_tr = TensorDataset(torch.tensor(Xtr_std), torch.tensor(ytr).long())
                ds_te = TensorDataset(torch.tensor(Xte_std), torch.tensor(yte_s).long())
                ld_tr = DataLoader(ds_tr, batch_size=FT_BATCH, shuffle=True,  drop_last=False, worker_init_fn=seed_worker)
                ld_te = DataLoader(ds_te, batch_size=FT_BATCH, shuffle=False, drop_last=False, worker_init_fn=seed_worker)

                # freeze
                freeze_backbone(ft_model, True)
                best_f1_s, best_state_s, wait_s = -1, None, 0
                for _ep in range(1, FT_FREEZE_EPOCHS+1):
                    ft_model.train()
                    for xb, yb in ld_tr:
                        xb = xb.to(device); yb = yb.to(device)
                        xb = augment_batch_ft(xb)
                        opt_ft.zero_grad()
                        logits = ft_model(xb)
                        loss = ft_crit(logits, yb)
                        loss.backward()
                        torch.nn.utils.clip_grad_norm_(ft_model.parameters(), 1.0)
                        opt_ft.step()
                    acc_s, f1_s, _ = evaluate_tensor(ft_model, Xte_std, yte_s)
                    if f1_s > best_f1_s + 1e-4:
                        best_f1_s = f1_s
                        best_state_s = {k: v.detach().cpu() for k,v in ft_model.state_dict().items()}
                        wait_s = 0
                    else:
                        wait_s += 1
                    if wait_s >= FT_PATIENCE: break
                if best_state_s is not None:
                    ft_model.load_state_dict(best_state_s)

                # unfreeze
                freeze_backbone(ft_model, False)
                best_f1_s2, best_state_s2, wait_s2 = -1, None, 0
                for _ep in range(1, FT_UNFREEZE_EPOCHS+1):
                    ft_model.train()
                    for xb, yb in ld_tr:
                        xb = xb.to(device); yb = yb.to(device)
                        xb = augment_batch_ft(xb)
                        opt_ft.zero_grad()
                        logits = ft_model(xb)
                        loss = ft_crit(logits, yb)
                        loss.backward()
                        torch.nn.utils.clip_grad_norm_(ft_model.parameters(), 1.0)
                        opt_ft.step()
                    acc_s, f1_s, _ = evaluate_tensor(ft_model, Xte_std, yte_s)
                    if f1_s > best_f1_s2 + 1e-4:
                        best_f1_s2 = f1_s
                        best_state_s2 = {k: v.detach().cpu() for k,v in ft_model.state_dict().items()}
                        wait_s2 = 0
                    else:
                        wait_s2 += 1
                    if wait_s2 >= FT_PATIENCE: break
                if best_state_s2 is not None:
                    ft_model.load_state_dict(best_state_s2)

                _, _, pred_s = evaluate_tensor(ft_model, Xte_std, yte_s)
                all_true.append(yte_s); all_pred.append(pred_s)

        all_true = np.concatenate(all_true) if len(all_true) else np.array([], dtype=int)
        all_pred = np.concatenate(all_pred) if len(all_pred) else np.array([], dtype=int)
        if len(all_true) > 0:
            ft_acc = accuracy_score(all_true, all_pred)
            ft_f1m = f1_score(all_true, all_pred, average='macro')
            print(f"  Fine-tuning PROGRESIVO (por sujeto, {FT_N_FOLDS}-fold CV) acc={ft_acc:.4f} | f1_macro={ft_f1m:.4f}")
            if save_artifacts:
                cm_ft = confusion_matrix(all_true, all_pred, labels=[0,1])
                ft_png = fold_dir / f"confusion_ft_fold{fold}_{run_tag}.png"
                plot_confusion_with_text(
                    cm=cm_ft,
                    class_names=CLASS_NAMES,
                    title=f"Confusion — Fold {fold} FT (2 clases) [{run_tag}]",
                    out_path=ft_png,
                    cmap='Greens'
                )
                print(f"↳ Matriz de confusión FT guardada: {ft_png}")
        else:
            ft_acc = float('nan')
            print("  Fine-tuning PROGRESIVO: no se generaron splits válidos.")
    else:
        ft_acc = float('nan')
        print("  [LITE] FT desactivado (do_ft=False).")

    # =========================
    # Consumo (siempre calculado); guardar JSON sólo si save_artifacts
    # =========================
    cons = {}
    try:
        eval_model.eval()
        with torch.no_grad():
            xb_small, _, _ = next(iter(val_ld))
        xb_small = xb_small[:8].to(device)

        with torch.no_grad():
            z = eval_model.conv_t(xb_small)
            z = eval_model.proj(z).transpose(1, 2)  # (B, L, D)
            L, d = z.shape[1], z.shape[2]

        params_total = count_params(eval_model, trainable_only=False)
        params_train = count_params(model, trainable_only=True)
        flops = estimate_transformer_flops(L=L, d=d, n_heads=N_HEADS, n_layers=N_LAYERS)
        lat   = benchmark_latency(eval_model, xb_small, DEVICE)

        cons = dict(
            params_total=params_total,
            params_trainable=params_train,
            flops_est=int(flops),
            latency_s=float(lat),
            L=L, d=d,
            n_heads=N_HEADS, n_layers=N_LAYERS,
            run_tag=run_tag, fold=fold,
            # GAP val-train en la mejor época de valid F1
            val_best_acc=float(best_val_acc),
            val_best_f1=float(best_val_f1),
            tr_at_best_acc=float(best_tr_acc),
            tr_at_best_f1=float(best_tr_f1),
            # === NUEVO: métricas de tiempo/memoria por fase ===
            train_time_s=float(train_time_s) if train_time_s is not None else None,
            train_mem_mb=float(train_mem_mb) if train_mem_mb is not None else None,
            test_time_s=float(test_time_s) if test_time_s is not None else None,
            test_mem_mb=float(test_mem_mb) if test_mem_mb is not None else None,
        )

        if save_artifacts and (fold_dir is not None):
            cons_json = fold_dir / f"consumption_fold{fold}_{run_tag}.json"
            with open(cons_json, "w") as f:
                json.dump(cons, f, indent=2)

        print(f"[Consumo] params_total={params_total:,} | params_trainable={params_train:,} | "
              f"FLOPs~{flops/1e6:.1f}M | latency/batch={lat*1000:.2f} ms (B={xb_small.size(0)})")
    except Exception as e:
        print(f"[WARN] Métricas de consumo no generadas: {e}")

    return acc, f1m, ft_acc, cons, met

### Utilidades de Splits y Configuraciones Adicionales

Funciones auxiliares para manejar splits de datos y configuraciones.

In [58]:
# =========================
# Utilidades splits estratificados por sujeto
# =========================
def build_subject_label_map(subject_ids):
    """
    Construye mapa de etiquetas dominantes por sujeto para estratificación.
    """
    y_dom_list = []
    for sid in subject_ids:
        Xs, ys, _ = load_subject_epochs(sid, RESAMPLE_HZ, DO_NOTCH, DO_BANDPASS, DO_CAR, BP_LO, BP_HI)
        if len(ys) == 0:
            y_dom_list.append(-1)
            continue
        binc = np.bincount(ys, minlength=2)
        y_dom = int(np.argmax(binc))
        y_dom_list.append(y_dom)
    return np.array(y_dom_list, dtype=int)

### Funciones de Barrido de Hiperparámetros

Funciones para realizar búsqueda de hiperparámetros y entrenamiento completo.

In [59]:
# =========================
# Barrido (con carpetas y resumen)
# =========================
def run_arch_sweep(device, base_blocks=2, deltas=(-2,0,+2), head_set=(2,4,6), sweep_fold=1):
    """
    Realiza barrido de hiperparámetros (número de bloques y cabezas de atención).
    """
    base_dir = ensure_dir(Path("Barrido"))
    results = []
    run_id = 0
    for dlt in deltas:
        n_blocks = int(np.clip(base_blocks + dlt, 0, 4))
        for nh in head_set:
            if D_MODEL % nh != 0:
                print(f"[SKIP] n_heads={nh} incompatible con D_MODEL={D_MODEL} (no divisible).")
                continue
            run_id += 1
            run_tag = f"nb{n_blocks}_h{nh}"
            print(f"\n=== RUN {run_id} | n_dw_blocks={n_blocks} | n_heads={nh} ===")

            def make_model_for_fold():
                return EEGCNNTransformer(n_ch=8, n_cls=2, d_model=D_MODEL, n_heads=nh,
                                         n_layers=N_LAYERS, p_drop=P_DROP, p_drop_encoder=P_DROP_ENCODER,
                                         n_dw_blocks=n_blocks, capture_attn=True).to(device)
            acc, f1m, ft_acc, cons, met = train_one_fold(
                fold=sweep_fold, device=device,
                model_factory=make_model_for_fold, run_tag=run_tag, out_dir=base_dir,
                save_artifacts=False, do_ft=False
            )
            results.append(dict(
                run=run_id,
                n_dw_blocks=n_blocks,
                n_heads=nh,
                fold=sweep_fold,
                acc=acc,
                f1_macro=f1m,
                ft_acc=ft_acc,
                # === MÉTRICAS EXTENDIDAS EN TEST ===
                precision_macro=met["precision_macro"],
                recall_macro=met["recall_macro"],
                specificity_macro=met["specificity_macro"],
                sensitivity_macro=met["sensitivity_macro"],
                # === CONSUMO ===
                tag=run_tag,
                params_total=cons.get("params_total"),
                params_trainable=cons.get("params_trainable"),
                flops_est=cons.get("flops_est"),
                latency_s=cons.get("latency_s"),
                L=cons.get("L"),
                d=cons.get("d"),
                enc_heads=cons.get("n_heads"),
                enc_layers=cons.get("n_layers"),
            ))

    if len(results):
        summary_path = Path("Barrido") / "summary_sweep.csv"
        fields = list(results[0].keys())
        save_csv_rows(summary_path, fields, results)
        print(f"↳ Resumen de barrido (con consumo) actualizado: {summary_path.resolve()}")
    return results

# =========================
# 5 Folds del ganador (con carpetas y resumen)
# =========================
def run_full_5fold_for_config(device, n_blocks, n_heads):
    """
    Entrena y evalúa el modelo en los 5 folds para una configuración específica.
    """
    tag = f"nb{n_blocks}_h{n_heads}"
    base_dir = ensure_dir(Path("ModeloW") / tag)

    acc_folds, f1_folds, ft_acc_folds = [], [], []
    rows = []

    # === NUEVO: listas para costo computacional por fold ===
    train_times, train_mems = [], []
    test_times,  test_mems  = [], []
    params_list, flops_list = [], []

    # Acumuladores para mapas GLOBALES
    saliency_list = []   # Para saliency por fold
    pvals_log_list = []  # Para -log10(p) por fold
    tsne_results_list = []  # Para t-SNE por fold (solo guardamos, no promediamos)

    # Acumulador para matriz de confusión global (suma de las de cada fold)
    global_cm = np.zeros((2, 2), dtype=int)
    
    def factory():
        return EEGCNNTransformer(n_ch=8, n_cls=2, d_model=D_MODEL, n_heads=n_heads,
                                 n_layers=N_LAYERS, p_drop=P_DROP, p_drop_encoder=P_DROP_ENCODER,
                                 n_dw_blocks=n_blocks, capture_attn=True).to(device)

    for fold in range(1, 6):
        acc, f1m, ft_acc, cons, met = train_one_fold(
            fold, device, model_factory=factory, run_tag=tag, out_dir=base_dir
        )
        # Acumular matriz de confusión de este fold en la matriz GLOBAL
        if met is not None and "cm" in met:
            global_cm += np.array(met["cm"], dtype=int)
        acc_folds.append(acc); f1_folds.append(f1m)
        ft_acc_folds.append(ft_acc if ft_acc == ft_acc else np.nan)

        # === NUEVO: guardar consumo de este fold ===
        train_times.append(cons.get("train_time_s"))
        train_mems.append(cons.get("train_mem_mb"))
        test_times.append(cons.get("test_time_s"))
        test_mems.append(cons.get("test_mem_mb"))
        params_list.append(cons.get("params_total"))
        flops_list.append(cons.get("flops_est"))

        rows.append(dict(
            tag=tag, fold=fold, acc=acc, f1_macro=f1m, ft_acc=ft_acc,
            params_total=cons.get("params_total"),
            params_trainable=cons.get("params_trainable"),
            flops_est=cons.get("flops_est"),
            latency_s=cons.get("latency_s"),
            L=cons.get("L"), d=cons.get("d"),
            enc_heads=cons.get("n_heads"), enc_layers=cons.get("n_layers")
        ))

        # Cargar mapas de interpretabilidad de este fold para promedio GLOBAL
        try:
            fold_dir = base_dir / tag / f"fold{fold}"
            # Cargar saliency
            sal_path = fold_dir / f"saliency_channels_fold{fold}_{tag}.npy"
            if sal_path.exists():
                saliency_list.append(np.load(sal_path))
            
            # Cargar -log10(p)
            pvals_path = fold_dir / f"pvals_log_channels_fold{fold}_{tag}.npy"
            if pvals_path.exists():
                pvals_log_list.append(np.load(pvals_path))
                
            # Cargar t-SNE (solo para referencia, no promediamos)
            tsne_path = fold_dir / f"tsne_results_fold{fold}_{tag}.npy"
            if tsne_path.exists():
                tsne_results_list.append(np.load(tsne_path))
        except Exception as e:
            print(f"[WARN] No se pudieron cargar mapas para fold {fold}: {e}")

    acc_mean = float(np.nanmean(acc_folds))
    f1_mean  = float(np.nanmean(f1_folds))
    ft_mean  = float(np.nanmean(ft_acc_folds))
    delta_mean = float(ft_mean - acc_mean) if not np.isnan(ft_mean) else np.nan

    summary_path = base_dir / f"summary_{tag}.csv"
    save_csv_rows(summary_path, list(rows[0].keys()), rows)
    print(f"↳ Resumen 5-fold guardado: {summary_path.resolve()}")

    print("\n============================================================")
    print(f"RESULTADOS FINALES [{tag}] (2 clases: left/right)")
    print("============================================================")
    print("Global folds (ACC):", [f"{a:.4f}" for a in acc_folds])
    print("Global mean ACC:   ", f"{acc_mean:.4f}")
    print("F1 folds (MACRO):  ", [f"{f:.4f}" for f in f1_folds])
    print("F1 mean (MACRO):   ", f"{f1_mean:.4f}")
    if not np.isnan(ft_mean):
        print("FT folds (ACC):    ", [("nan" if np.isnan(a) else f"{a:.4f}") for a in ft_acc_folds])
        print("FT mean (ACC):     ", f"{ft_mean:.4f}")
        print("Δ(FT-Global) mean: ", f"{delta_mean:+.4f}")
    else:
        print("FT mean (ACC): N/A")

    # ========================================================
    # Tabla de métricas computacionales tipo paper
    # ========================================================

    tt = np.array([t for t in train_times if t is not None], dtype=float)
    tm = np.array([m for m in train_mems if m is not None], dtype=float)

    train_time_mean = float(tt.mean()) if len(tt) else float('nan')
    train_time_std  = float(tt.std(ddof=0)) if len(tt) else float('nan')
    train_mem_mean  = float(tm.mean()) if len(tm) else float('nan')
    train_mem_std   = float(tm.std(ddof=0)) if len(tm) else float('nan')

    # Para test tomamos el primer fold válido (single run)
    test_time = next((float(t) for t in test_times if t is not None), float('nan'))
    test_mem  = next((float(m) for m in test_mems if m is not None), float('nan'))

    # Modelo (mismo en todos los folds)
    params_total = next((p for p in params_list if p is not None), 0)
    flops_est    = next((f for f in flops_list if f is not None), 0)

    params_m = params_total / 1e6
    flops_g  = flops_est / 1e9

    print("\n================ Computational metrics ================")
    print(f"Model tag: {tag}")
    print("Phase        time (s)           memory (MB)        FLOPs (g)   params (m)")
    print("-----------  -----------------  -----------------  ----------  ----------")
    print(f"model specs  N/A                N/A                 {flops_g:8.3f}   {params_m:8.3f}")
    print(f"train        {train_time_mean:6.2f} ({train_time_std:5.2f})   {train_mem_mean:7.1f} ({train_mem_std:5.1f})     N/A        N/A")
    print(f"test         {test_time:6.2f}              {test_mem:7.1f}              N/A        N/A")

    comp_rows = [
        dict(phase="model_specs",
             time_mean_s=None, time_std_s=None,
             mem_mean_mb=None, mem_std_mb=None,
             flops_g=flops_g, params_m=params_m),
        dict(phase="train",
             time_mean_s=train_time_mean, time_std_s=train_time_std,
             mem_mean_mb=train_mem_mean, mem_std_mb=train_mem_std,
             flops_g=None, params_m=None),
        dict(phase="test",
             time_mean_s=test_time, time_std_s=None,
             mem_mean_mb=test_mem, mem_std_mb=None,
             flops_g=None, params_m=None),
    ]

    comp_path = base_dir / f"computational_{tag}.csv"
    save_csv_rows(comp_path,
                  fieldnames=list(comp_rows[0].keys()),
                  rows=comp_rows)
    print(f"↳ Tabla de métricas computacionales guardada en: {comp_path.resolve()}")
    
    # ========================================================
    # Matriz de confusión GLOBAL (5 folds concatenados)
    # ========================================================
    try:
        cm_global_png = base_dir / f"confusion_global_allfolds_{tag}.png"
        plot_confusion_with_text(
            cm=global_cm,
            class_names=CLASS_NAMES,
            title=f"Confusion — Global (5-fold) [{tag}]",
            out_path=cm_global_png,
            cmap='Blues'
        )
        print(f"↳ Matriz de confusión GLOBAL guardada: {cm_global_png}")
    except Exception as e:
        print(f"[WARN] No se pudo generar la matriz de confusión global: {e}")

    # ========================================================
    # Curva de LOSS GLOBAL (promedio sobre los 5 folds)
    # ========================================================
    try:
        histories = []
        for fold in range(1, 6):
            hist_path = base_dir / tag / f"fold{fold}" / f"history_fold{fold}_{tag}.npz"
            if not hist_path.exists():
                print(f"[WARN] History no encontrado para fold {fold}: {hist_path}")
                continue
            data = np.load(hist_path)
            hist = {
                "ep":      data["ep"],
                "tr_loss": data["tr_loss"],
                "val_loss": data["val_loss"],
            }
            histories.append(hist)

        if len(histories) > 0:
            # Longitud máxima de epochs entre los folds
            max_len = max(h["ep"].shape[0] for h in histories)
            epochs = np.arange(1, max_len + 1)

            mean_tr = []
            mean_val = []

            for i in range(max_len):
                tr_vals = [h["tr_loss"][i] for h in histories if h["tr_loss"].shape[0] > i]
                val_vals = [h["val_loss"][i] for h in histories if h["val_loss"].shape[0] > i]
                if len(tr_vals) == 0 or len(val_vals) == 0:
                    break  # ya no hay más epochs comunes
                mean_tr.append(np.mean(tr_vals))
                mean_val.append(np.mean(val_vals))

            epochs = epochs[:len(mean_tr)]
            mean_tr = np.array(mean_tr)
            mean_val = np.array(mean_val)

            fig = plt.figure(figsize=(8, 4.5))
            ax = plt.gca()
            ax.plot(epochs, mean_tr, label="Train loss (mean across folds)")
            ax.plot(epochs, mean_val, label="Val loss (mean across folds)")
            ax.set_xlabel("Epoch")
            ax.set_ylabel("Loss")
            ax.set_title(f"Global Training Curve — 5-fold mean [{tag}]")
            ax.legend()
            ax.grid(True, alpha=0.3)
            out_png = base_dir / f"training_curve_global_{tag}.png"
            plt.tight_layout()
            plt.savefig(out_png, dpi=140)
            plt.close(fig)
            print(f"↳ Curva de loss GLOBAL guardada: {out_png}")
        else:
            print("[WARN] No se encontraron historiales para construir curva global.")
    except Exception as e:
        print(f"[WARN] No se pudo generar la curva de loss global: {e}")

    # ========================================================
    # MAPAS GLOBALES (promedio de los 5 folds)
    # ========================================================
    try:
        # Crear info para topomaps
        info_topo_global = make_mne_info_8ch(sfreq=160.0)
        
        # 1. Topomap GLOBAL de saliency (promedio 5 folds)
        if len(saliency_list) > 0:
            sal_mean = np.mean(np.stack(saliency_list, axis=0), axis=0)   # (C,)
            topo_sal_global_png = base_dir / f"topomap_saliency_GLOBAL_{tag}.png"
            plot_topomap_from_values(
                sal_mean,
                info_topo_global,
                title=f"Channel saliency — GLOBAL (5 folds) [{tag}]",
                out_path=topo_sal_global_png,
                cmap="Reds",
                cbar_label="Normalized saliency"
            )
            print(f"↳ Topomap GLOBAL de saliency guardado: {topo_sal_global_png}")
        
        # 2. Topomap GLOBAL de -log10(p) (promedio 5 folds)
        if len(pvals_log_list) > 0:
            pvals_log_mean = np.mean(np.stack(pvals_log_list, axis=0), axis=0)   # (C,)
            topo_p_global_png = base_dir / f"topomap_log10p_GLOBAL_{tag}.png"
            plot_topomap_from_values(
                pvals_log_mean,
                info_topo_global,
                title=f"-log10(p) Left vs Right — GLOBAL (5 folds) [{tag}]",
                out_path=topo_p_global_png,
                cmap="viridis",
                cbar_label="-log10(p)"
            )
            print(f"↳ Topomap GLOBAL de -log10(p) guardado: {topo_p_global_png}")
        
        # 3. t-SNE GLOBAL usando CLS embeddings del último fold
        if len(tsne_results_list) > 0:
            # Usar el t-SNE del último fold como representativo
            tsne_representative = tsne_results_list[-1]
            
            # Cargar etiquetas correspondientes (del último fold)
            fold_dir = base_dir / tag / f"fold5"
            
            # Cargar y_te del fold 5 (si existe guardado)
            y_te_fold5_path = fold_dir / "y_te.npy"
            if y_te_fold5_path.exists():
                y_te_fold5 = np.load(y_te_fold5_path)
                # Tomar las mismas muestras que usamos para t-SNE
                rng = np.random.RandomState(RANDOM_STATE)
                n_tsne_samples = min(1000, len(y_te_fold5))
                indices = rng.choice(len(y_te_fold5), n_tsne_samples, replace=False)
                y_tsne_global = y_te_fold5[indices]
                
                tsne_global_png = base_dir / f"tsne_cls_GLOBAL_{tag}.png"
                plot_tsne(
                    tsne_representative,
                    y_tsne_global,
                    title=f"t-SNE CLS embeddings — GLOBAL (Fold 5 representative) [{tag}]",
                    out_path=tsne_global_png
                )
                print(f"↳ t-SNE GLOBAL (CLS embeddings) guardado: {tsne_global_png}")
                
    except Exception as e:
        print(f"[WARN] No se pudieron generar mapas GLOBALES: {e}")

### Main y Ejecución Principal

Función principal para ejecutar el pipeline completo.

In [60]:
# =========================
# MAIN
# =========================
if __name__ == "__main__":
    print("MODELO")
    # 1) Barrido con UN SOLO FOLD (fold=1). Ajusta deltas y head_set si quieres.
    # sweep = run_arch_sweep(DEVICE, base_blocks=2, deltas=(-2,0,+2), head_set=(2,4,6), sweep_fold=1)
    # sweep = run_arch_sweep(DEVICE, base_blocks=2, deltas=(-2,0,+2), head_set=(2,4,6), sweep_fold=2)
    # sweep = run_arch_sweep(DEVICE, base_blocks=2, deltas=(-2,0,+2), head_set=(2,4,6), sweep_fold=3)
    # sweep = run_arch_sweep(DEVICE, base_blocks=2, deltas=(-2,0,+2), head_set=(2,4,6), sweep_fold=4)
    # sweep = run_arch_sweep(DEVICE, base_blocks=2, deltas=(-2,0,+2), head_set=(2,4,6), sweep_fold=5)

    # 2) (Luego) con los GANADORES corre los 5 folds:
    run_full_5fold_for_config(DEVICE, n_blocks=2, n_heads=6)

    # 3) (Opcional) imprimir hiperparámetros contados:
    # hp = current_hparams()
    # print(f"\nHiperparámetros (n={len(hp)}):")
    # for k,v in hp.items():
    #     print(f" - {k}: {v}")

MODELO
[IO] Guardando artefactos en: /root/Proyecto/EEG_Clasificador/models/04_hybrid/ModeloW/nb2_h6/nb2_h6/fold1


Cargando train fold1: 100%|██████████| 67/67 [00:03<00:00, 17.43it/s]
Cargando val fold1: 100%|██████████| 15/15 [00:00<00:00, 17.44it/s]
Cargando test fold1: 100%|██████████| 21/21 [00:01<00:00, 17.44it/s]


[Fold 1/5] Entrenando modelo global... (n_train=2814 | n_val=630 | n_test=882)
  Época   1 | train_loss=0.1318 | train_acc=0.5213 | val_loss=0.1249 | val_acc=0.5000 | val_f1m=0.3443 | LR=0.000125
  Época   2 | train_loss=0.1224 | train_acc=0.5729 | val_loss=0.1187 | val_acc=0.6190 | val_f1m=0.6166 | LR=0.000250
  Época   3 | train_loss=0.1102 | train_acc=0.6546 | val_loss=0.1111 | val_acc=0.6508 | val_f1m=0.6402 | LR=0.000375
  Época   4 | train_loss=0.0962 | train_acc=0.7416 | val_loss=0.1044 | val_acc=0.7206 | val_f1m=0.7198 | LR=0.000500
  Época   5 | train_loss=0.0841 | train_acc=0.7839 | val_loss=0.0987 | val_acc=0.7444 | val_f1m=0.7439 | LR=0.000500
  Época   6 | train_loss=0.0764 | train_acc=0.8003 | val_loss=0.0960 | val_acc=0.7556 | val_f1m=0.7539 | LR=0.000500
  Época   7 | train_loss=0.0721 | train_acc=0.8124 | val_loss=0.0919 | val_acc=0.7635 | val_f1m=0.7627 | LR=0.000499
  Época   8 | train_loss=0.0662 | train_acc=0.8348 | val_loss=0.1250 | val_acc=0.7190 | val_f1m=0.7119



↳ t-SNE (CLS embeddings) guardado: ModeloW/nb2_h6/nb2_h6/fold1/tsne_cls_fold1_nb2_h6.png
  Fine-tuning PROGRESIVO (por sujeto, 4-fold CV) acc=0.8526 | f1_macro=0.8526
↳ Matriz de confusión FT guardada: ModeloW/nb2_h6/nb2_h6/fold1/confusion_ft_fold1_nb2_h6.png
[Consumo] params_total=232,290 | params_trainable=232,290 | FLOPs~24.3M | latency/batch=1.11 ms (B=8)
[IO] Guardando artefactos en: /root/Proyecto/EEG_Clasificador/models/04_hybrid/ModeloW/nb2_h6/nb2_h6/fold2


Cargando train fold2: 100%|██████████| 67/67 [00:03<00:00, 16.97it/s]
Cargando val fold2: 100%|██████████| 15/15 [00:00<00:00, 17.25it/s]
Cargando test fold2: 100%|██████████| 21/21 [00:01<00:00, 17.27it/s]


[Fold 2/5] Entrenando modelo global... (n_train=2814 | n_val=630 | n_test=882)
  Época   1 | train_loss=0.1290 | train_acc=0.5267 | val_loss=0.1202 | val_acc=0.6032 | val_f1m=0.6029 | LR=0.000125
  Época   2 | train_loss=0.1214 | train_acc=0.5839 | val_loss=0.1159 | val_acc=0.6222 | val_f1m=0.6183 | LR=0.000250
  Época   3 | train_loss=0.1168 | train_acc=0.6276 | val_loss=0.1052 | val_acc=0.7095 | val_f1m=0.7095 | LR=0.000375
  Época   4 | train_loss=0.0987 | train_acc=0.7278 | val_loss=0.0958 | val_acc=0.7206 | val_f1m=0.7202 | LR=0.000500
  Época   5 | train_loss=0.0910 | train_acc=0.7594 | val_loss=0.0973 | val_acc=0.7175 | val_f1m=0.7028 | LR=0.000500
  Época   6 | train_loss=0.0808 | train_acc=0.7921 | val_loss=0.0942 | val_acc=0.7492 | val_f1m=0.7492 | LR=0.000500
  Época   7 | train_loss=0.0776 | train_acc=0.8042 | val_loss=0.0910 | val_acc=0.7476 | val_f1m=0.7417 | LR=0.000499
  Época   8 | train_loss=0.0771 | train_acc=0.8056 | val_loss=0.0823 | val_acc=0.7873 | val_f1m=0.7870



↳ t-SNE (CLS embeddings) guardado: ModeloW/nb2_h6/nb2_h6/fold2/tsne_cls_fold2_nb2_h6.png
  Fine-tuning PROGRESIVO (por sujeto, 4-fold CV) acc=0.8878 | f1_macro=0.8877
↳ Matriz de confusión FT guardada: ModeloW/nb2_h6/nb2_h6/fold2/confusion_ft_fold2_nb2_h6.png
[Consumo] params_total=232,290 | params_trainable=232,290 | FLOPs~24.3M | latency/batch=1.12 ms (B=8)
[IO] Guardando artefactos en: /root/Proyecto/EEG_Clasificador/models/04_hybrid/ModeloW/nb2_h6/nb2_h6/fold3


Cargando train fold3: 100%|██████████| 67/67 [00:03<00:00, 16.92it/s]
Cargando val fold3: 100%|██████████| 15/15 [00:00<00:00, 17.52it/s]
Cargando test fold3: 100%|██████████| 21/21 [00:01<00:00, 17.13it/s]


[Fold 3/5] Entrenando modelo global... (n_train=2814 | n_val=630 | n_test=882)
  Época   1 | train_loss=0.1315 | train_acc=0.5174 | val_loss=0.1267 | val_acc=0.5079 | val_f1m=0.3685 | LR=0.000125
  Época   2 | train_loss=0.1251 | train_acc=0.5402 | val_loss=0.1324 | val_acc=0.5190 | val_f1m=0.3840 | LR=0.000250
  Época   3 | train_loss=0.1242 | train_acc=0.5547 | val_loss=0.1109 | val_acc=0.6810 | val_f1m=0.6809 | LR=0.000375
  Época   4 | train_loss=0.1067 | train_acc=0.6844 | val_loss=0.0947 | val_acc=0.7222 | val_f1m=0.7221 | LR=0.000500
  Época   5 | train_loss=0.0889 | train_acc=0.7633 | val_loss=0.0884 | val_acc=0.7651 | val_f1m=0.7621 | LR=0.000500
  Época   6 | train_loss=0.0765 | train_acc=0.7974 | val_loss=0.0827 | val_acc=0.8016 | val_f1m=0.8015 | LR=0.000500
  Época   7 | train_loss=0.0721 | train_acc=0.8259 | val_loss=0.0834 | val_acc=0.7937 | val_f1m=0.7936 | LR=0.000499
  Época   8 | train_loss=0.0738 | train_acc=0.8223 | val_loss=0.0810 | val_acc=0.8063 | val_f1m=0.8063



↳ t-SNE (CLS embeddings) guardado: ModeloW/nb2_h6/nb2_h6/fold3/tsne_cls_fold3_nb2_h6.png
  Fine-tuning PROGRESIVO (por sujeto, 4-fold CV) acc=0.8469 | f1_macro=0.8469
↳ Matriz de confusión FT guardada: ModeloW/nb2_h6/nb2_h6/fold3/confusion_ft_fold3_nb2_h6.png
[Consumo] params_total=232,290 | params_trainable=232,290 | FLOPs~24.3M | latency/batch=1.14 ms (B=8)
[IO] Guardando artefactos en: /root/Proyecto/EEG_Clasificador/models/04_hybrid/ModeloW/nb2_h6/nb2_h6/fold4


Cargando train fold4: 100%|██████████| 68/68 [00:03<00:00, 17.37it/s]
Cargando val fold4: 100%|██████████| 15/15 [00:00<00:00, 17.29it/s]
Cargando test fold4: 100%|██████████| 20/20 [00:01<00:00, 17.22it/s]


[Fold 4/5] Entrenando modelo global... (n_train=2856 | n_val=630 | n_test=840)
  Época   1 | train_loss=0.1308 | train_acc=0.4972 | val_loss=0.1204 | val_acc=0.5651 | val_f1m=0.5476 | LR=0.000125
  Época   2 | train_loss=0.1254 | train_acc=0.5501 | val_loss=0.1177 | val_acc=0.6079 | val_f1m=0.6056 | LR=0.000250
  Época   3 | train_loss=0.1174 | train_acc=0.6008 | val_loss=0.1263 | val_acc=0.6508 | val_f1m=0.6418 | LR=0.000375
  Época   4 | train_loss=0.1039 | train_acc=0.6992 | val_loss=0.1061 | val_acc=0.6889 | val_f1m=0.6888 | LR=0.000500
  Época   5 | train_loss=0.0898 | train_acc=0.7626 | val_loss=0.1008 | val_acc=0.7349 | val_f1m=0.7327 | LR=0.000500
  Época   6 | train_loss=0.0802 | train_acc=0.7805 | val_loss=0.0876 | val_acc=0.7635 | val_f1m=0.7622 | LR=0.000500
  Época   7 | train_loss=0.0810 | train_acc=0.7920 | val_loss=0.0990 | val_acc=0.7286 | val_f1m=0.7182 | LR=0.000499
  Época   8 | train_loss=0.0737 | train_acc=0.8099 | val_loss=0.0873 | val_acc=0.7762 | val_f1m=0.7749



↳ t-SNE (CLS embeddings) guardado: ModeloW/nb2_h6/nb2_h6/fold4/tsne_cls_fold4_nb2_h6.png
  Fine-tuning PROGRESIVO (por sujeto, 4-fold CV) acc=0.8738 | f1_macro=0.8738
↳ Matriz de confusión FT guardada: ModeloW/nb2_h6/nb2_h6/fold4/confusion_ft_fold4_nb2_h6.png
[Consumo] params_total=232,290 | params_trainable=232,290 | FLOPs~24.3M | latency/batch=1.13 ms (B=8)
[IO] Guardando artefactos en: /root/Proyecto/EEG_Clasificador/models/04_hybrid/ModeloW/nb2_h6/nb2_h6/fold5


Cargando train fold5: 100%|██████████| 68/68 [00:03<00:00, 17.14it/s]
Cargando val fold5: 100%|██████████| 15/15 [00:00<00:00, 17.27it/s]
Cargando test fold5: 100%|██████████| 20/20 [00:01<00:00, 17.35it/s]


[Fold 5/5] Entrenando modelo global... (n_train=2856 | n_val=630 | n_test=840)
  Época   1 | train_loss=0.1320 | train_acc=0.5035 | val_loss=0.1220 | val_acc=0.5095 | val_f1m=0.4079 | LR=0.000125
  Época   2 | train_loss=0.1241 | train_acc=0.5445 | val_loss=0.1177 | val_acc=0.5889 | val_f1m=0.5691 | LR=0.000250
  Época   3 | train_loss=0.1225 | train_acc=0.5795 | val_loss=0.1122 | val_acc=0.6619 | val_f1m=0.6540 | LR=0.000375
  Época   4 | train_loss=0.0968 | train_acc=0.7416 | val_loss=0.1055 | val_acc=0.7000 | val_f1m=0.6988 | LR=0.000500
  Época   5 | train_loss=0.0907 | train_acc=0.7560 | val_loss=0.1002 | val_acc=0.7127 | val_f1m=0.7107 | LR=0.000500
  Época   6 | train_loss=0.0747 | train_acc=0.8123 | val_loss=0.0961 | val_acc=0.7190 | val_f1m=0.7162 | LR=0.000500
  Época   7 | train_loss=0.0737 | train_acc=0.8200 | val_loss=0.1029 | val_acc=0.7286 | val_f1m=0.7273 | LR=0.000499
  Época   8 | train_loss=0.0708 | train_acc=0.8232 | val_loss=0.1032 | val_acc=0.7333 | val_f1m=0.7278



↳ t-SNE (CLS embeddings) guardado: ModeloW/nb2_h6/nb2_h6/fold5/tsne_cls_fold5_nb2_h6.png
  Fine-tuning PROGRESIVO (por sujeto, 4-fold CV) acc=0.8786 | f1_macro=0.8786
↳ Matriz de confusión FT guardada: ModeloW/nb2_h6/nb2_h6/fold5/confusion_ft_fold5_nb2_h6.png
[Consumo] params_total=232,290 | params_trainable=232,290 | FLOPs~24.3M | latency/batch=1.12 ms (B=8)
↳ Resumen 5-fold guardado: /root/Proyecto/EEG_Clasificador/models/04_hybrid/ModeloW/nb2_h6/summary_nb2_h6.csv

RESULTADOS FINALES [nb2_h6] (2 clases: left/right)
Global folds (ACC): ['0.8107', '0.8356', '0.8084', '0.8119', '0.8464']
Global mean ACC:    0.8226
F1 folds (MACRO):   ['0.8107', '0.8353', '0.8084', '0.8098', '0.8464']
F1 mean (MACRO):    0.8221
FT folds (ACC):     ['0.8526', '0.8878', '0.8469', '0.8738', '0.8786']
FT mean (ACC):      0.8679
Δ(FT-Global) mean:  +0.0453

Model tag: nb2_h6
Phase        time (s)           memory (MB)        FLOPs (g)   params (m)
-----------  -----------------  -----------------  ----------

---
## ESTUDIO DE ABLACIÓN

Esta sección implementa un estudio de ablación sistemático para evaluar la contribución de cada componente del modelo híbrido CNN-Transformer.

### Variantes evaluadas:
1. **Baseline**: Modelo completo con todos los componentes
2. **Sin Depthwise Blocks**: Solo stem convolucional (n_dw_blocks=0)
3. **Sin Codificación Posicional**: Elimina positional encoding
4. **Sin Token CLS**: Usa average pooling en vez del token de clasificación
5. **Con BatchNorm**: Reemplaza GroupNorm por BatchNorm
6. **Con ReLU**: Reemplaza activaciones ELU/GELU por ReLU

In [None]:
# =========================
# VARIANTES DEL MODELO PARA ABLACIÓN
# =========================

# VARIANTE 1: Sin bloques depthwise (solo stem)
class EEGCNNTransformer_NoDepthwise(nn.Module):
    """Variante sin bloques depthwise separable"""
    def __init__(self, n_ch=8, n_cls=2, d_model=128, n_heads=4, n_layers=2,
                 p_drop=0.2, p_drop_encoder=0.3, capture_attn=False):
        super().__init__()
        self.capture_attn = capture_attn
        self._last_attn = None
        self.pos_encoding = None

        # Solo stem, sin bloques depthwise
        self.conv_t = nn.Sequential(
            nn.Conv1d(n_ch, 32, kernel_size=129, stride=2, padding=64, bias=False),
            make_gn(32),
            nn.ELU(),
            nn.Dropout(p=p_drop),
        )
        
        self.proj = nn.Conv1d(32, d_model, kernel_size=1, bias=False)
        self.dropout = nn.Dropout(p=p_drop_encoder)

        # Encoder Transformer
        enc = _CustomEncoderLayer(d_model=d_model, nhead=n_heads, dim_feedforward=2*d_model,
                                  dropout=0.1, batch_first=True, norm_first=False, return_attn=True)
        self.encoder = _CustomEncoder(enc, num_layers=n_layers)

        # Token CLS y head
        self.cls = nn.Parameter(torch.zeros(1, 1, d_model))
        nn.init.normal_(self.cls, std=0.02)
        self.head = nn.Sequential(nn.LayerNorm(d_model), nn.Linear(d_model, n_cls))

    def _positional_encoding(self, L, d):
        pos = torch.arange(0, L, dtype=torch.float32).unsqueeze(1)
        i   = torch.arange(0, d, dtype=torch.float32).unsqueeze(0)
        angle = pos / torch.pow(10000, (2 * (i//2)) / d)
        pe = torch.zeros(L, d, dtype=torch.float32)
        pe[:, 0::2] = torch.sin(angle[:, 0::2])
        pe[:, 1::2] = torch.cos(angle[:, 1::2])
        return pe

    def forward(self, x):
        z = self.conv_t(x)
        z = self.proj(z)
        z = self.dropout(z)
        z = z.transpose(1, 2)
        
        # Positional encoding
        B, L, D = z.shape
        if (self.pos_encoding is None) or (self.pos_encoding.shape[0] != L) or (self.pos_encoding.shape[1] != D):
            self.pos_encoding = self._positional_encoding(L, D).to(z.device)
        z = z + self.pos_encoding[None, :, :]
        
        # CLS token
        cls_tok = self.cls.expand(B, -1, -1)
        z = torch.cat([cls_tok, z], dim=1)

        z, attn_list = self.encoder(z)
        if self.capture_attn:
            self._last_attn = attn_list

        cls = z[:, 0, :]
        return self.head(cls)

    def get_last_attention_maps(self):
        return self._last_attn


# VARIANTE 2: Sin codificación posicional
class EEGCNNTransformer_NoPosEncoding(nn.Module):
    """Variante sin positional encoding"""
    def __init__(self, n_ch=8, n_cls=2, d_model=128, n_heads=4, n_layers=2,
                 p_drop=0.2, p_drop_encoder=0.3, n_dw_blocks=2,
                 k_list=(31,15,7), s_list=(2,2,2), p_list=(15,7,3), capture_attn=False):
        super().__init__()
        self.capture_attn = capture_attn
        self._last_attn = None

        # Stem convolucional
        stem = [
            nn.Conv1d(n_ch, 32, kernel_size=129, stride=2, padding=64, bias=False),
            make_gn(32),
            nn.ELU(),
            nn.Dropout(p=p_drop),
        ]
        
        # Bloques depthwise
        blocks = []
        in_c, out_cs = 32, [64, 128, 256, 256, 256]
        n_dw_blocks = int(np.clip(n_dw_blocks, 0, len(out_cs)))
        for i in range(n_dw_blocks):
            k = k_list[i] if i < len(k_list) else 7
            s = s_list[i] if i < len(s_list) else 2
            p = p_list[i] if i < len(p_list) else k//2
            blocks.append(DepthwiseSeparableConv(in_c, out_cs[i], k=k, s=s, p=p, p_drop=p_drop))
            in_c = out_cs[i]

        self.conv_t = nn.Sequential(*stem, *blocks)
        self.proj = nn.Conv1d(in_c if n_dw_blocks>0 else 32, d_model, kernel_size=1, bias=False)
        self.dropout = nn.Dropout(p=p_drop_encoder)

        # Encoder Transformer
        enc = _CustomEncoderLayer(d_model=d_model, nhead=n_heads, dim_feedforward=2*d_model,
                                  dropout=0.1, batch_first=True, norm_first=False, return_attn=True)
        self.encoder = _CustomEncoder(enc, num_layers=n_layers)

        # Token CLS y head
        self.cls = nn.Parameter(torch.zeros(1, 1, d_model))
        nn.init.normal_(self.cls, std=0.02)
        self.head = nn.Sequential(nn.LayerNorm(d_model), nn.Linear(d_model, n_cls))

    def forward(self, x):
        z = self.conv_t(x)
        z = self.proj(z)
        z = self.dropout(z)
        z = z.transpose(1, 2)
        
        # SIN positional encoding
        
        # CLS token
        B = z.shape[0]
        cls_tok = self.cls.expand(B, -1, -1)
        z = torch.cat([cls_tok, z], dim=1)

        z, attn_list = self.encoder(z)
        if self.capture_attn:
            self._last_attn = attn_list

        cls = z[:, 0, :]
        return self.head(cls)

    def get_last_attention_maps(self):
        return self._last_attn


# VARIANTE 3: Sin token CLS (average pooling)
class EEGCNNTransformer_NoCLS(nn.Module):
    """Variante sin token CLS, usa average pooling"""
    def __init__(self, n_ch=8, n_cls=2, d_model=128, n_heads=4, n_layers=2,
                 p_drop=0.2, p_drop_encoder=0.3, n_dw_blocks=2,
                 k_list=(31,15,7), s_list=(2,2,2), p_list=(15,7,3), capture_attn=False):
        super().__init__()
        self.capture_attn = capture_attn
        self._last_attn = None
        self.pos_encoding = None

        # Stem convolucional
        stem = [
            nn.Conv1d(n_ch, 32, kernel_size=129, stride=2, padding=64, bias=False),
            make_gn(32),
            nn.ELU(),
            nn.Dropout(p=p_drop),
        ]
        
        # Bloques depthwise
        blocks = []
        in_c, out_cs = 32, [64, 128, 256, 256, 256]
        n_dw_blocks = int(np.clip(n_dw_blocks, 0, len(out_cs)))
        for i in range(n_dw_blocks):
            k = k_list[i] if i < len(k_list) else 7
            s = s_list[i] if i < len(s_list) else 2
            p = p_list[i] if i < len(p_list) else k//2
            blocks.append(DepthwiseSeparableConv(in_c, out_cs[i], k=k, s=s, p=p, p_drop=p_drop))
            in_c = out_cs[i]

        self.conv_t = nn.Sequential(*stem, *blocks)
        self.proj = nn.Conv1d(in_c if n_dw_blocks>0 else 32, d_model, kernel_size=1, bias=False)
        self.dropout = nn.Dropout(p=p_drop_encoder)

        # Encoder Transformer
        enc = _CustomEncoderLayer(d_model=d_model, nhead=n_heads, dim_feedforward=2*d_model,
                                  dropout=0.1, batch_first=True, norm_first=False, return_attn=True)
        self.encoder = _CustomEncoder(enc, num_layers=n_layers)

        # Head (sin CLS token)
        self.head = nn.Sequential(nn.LayerNorm(d_model), nn.Linear(d_model, n_cls))

    def _positional_encoding(self, L, d):
        pos = torch.arange(0, L, dtype=torch.float32).unsqueeze(1)
        i   = torch.arange(0, d, dtype=torch.float32).unsqueeze(0)
        angle = pos / torch.pow(10000, (2 * (i//2)) / d)
        pe = torch.zeros(L, d, dtype=torch.float32)
        pe[:, 0::2] = torch.sin(angle[:, 0::2])
        pe[:, 1::2] = torch.cos(angle[:, 1::2])
        return pe

    def forward(self, x):
        z = self.conv_t(x)
        z = self.proj(z)
        z = self.dropout(z)
        z = z.transpose(1, 2)
        
        # Positional encoding
        B, L, D = z.shape
        if (self.pos_encoding is None) or (self.pos_encoding.shape[0] != L) or (self.pos_encoding.shape[1] != D):
            self.pos_encoding = self._positional_encoding(L, D).to(z.device)
        z = z + self.pos_encoding[None, :, :]
        
        # SIN CLS token

        z, attn_list = self.encoder(z)
        if self.capture_attn:
            self._last_attn = attn_list

        # Average pooling en lugar de CLS
        pooled = z.mean(dim=1)  # (B, D)
        return self.head(pooled)

    def get_last_attention_maps(self):
        return self._last_attn


# VARIANTE 4: Con BatchNorm en vez de GroupNorm
def make_bn(num_channels):
    """Crea capa BatchNorm1d"""
    return nn.BatchNorm1d(num_channels)

class DepthwiseSeparableConv_BN(nn.Module):
    """Convolución depthwise separable con BatchNorm"""
    def __init__(self, in_ch, out_ch, k, s=1, p=0, p_drop=0.2):
        super().__init__()
        self.dw = nn.Conv1d(in_ch, in_ch, kernel_size=k, stride=s, padding=p, groups=in_ch, bias=False)
        self.pw = nn.Conv1d(in_ch, out_ch, kernel_size=1, bias=False)
        self.norm = make_bn(out_ch)
        self.act = nn.ELU()
        self.dropout = nn.Dropout(p=p_drop)

    def forward(self, x):
        x = self.dw(x); x = self.pw(x); x = self.norm(x)
        x = self.act(x); x = self.dropout(x)
        return x

class EEGCNNTransformer_BatchNorm(nn.Module):
    """Variante con BatchNorm en vez de GroupNorm"""
    def __init__(self, n_ch=8, n_cls=2, d_model=128, n_heads=4, n_layers=2,
                 p_drop=0.2, p_drop_encoder=0.3, n_dw_blocks=2,
                 k_list=(31,15,7), s_list=(2,2,2), p_list=(15,7,3), capture_attn=False):
        super().__init__()
        self.capture_attn = capture_attn
        self._last_attn = None
        self.pos_encoding = None

        # Stem con BatchNorm
        stem = [
            nn.Conv1d(n_ch, 32, kernel_size=129, stride=2, padding=64, bias=False),
            make_bn(32),
            nn.ELU(),
            nn.Dropout(p=p_drop),
        ]
        
        # Bloques depthwise con BatchNorm
        blocks = []
        in_c, out_cs = 32, [64, 128, 256, 256, 256]
        n_dw_blocks = int(np.clip(n_dw_blocks, 0, len(out_cs)))
        for i in range(n_dw_blocks):
            k = k_list[i] if i < len(k_list) else 7
            s = s_list[i] if i < len(s_list) else 2
            p = p_list[i] if i < len(p_list) else k//2
            blocks.append(DepthwiseSeparableConv_BN(in_c, out_cs[i], k=k, s=s, p=p, p_drop=p_drop))
            in_c = out_cs[i]

        self.conv_t = nn.Sequential(*stem, *blocks)
        self.proj = nn.Conv1d(in_c if n_dw_blocks>0 else 32, d_model, kernel_size=1, bias=False)
        self.dropout = nn.Dropout(p=p_drop_encoder)

        # Encoder Transformer
        enc = _CustomEncoderLayer(d_model=d_model, nhead=n_heads, dim_feedforward=2*d_model,
                                  dropout=0.1, batch_first=True, norm_first=False, return_attn=True)
        self.encoder = _CustomEncoder(enc, num_layers=n_layers)

        # Token CLS y head
        self.cls = nn.Parameter(torch.zeros(1, 1, d_model))
        nn.init.normal_(self.cls, std=0.02)
        self.head = nn.Sequential(nn.LayerNorm(d_model), nn.Linear(d_model, n_cls))

    def _positional_encoding(self, L, d):
        pos = torch.arange(0, L, dtype=torch.float32).unsqueeze(1)
        i   = torch.arange(0, d, dtype=torch.float32).unsqueeze(0)
        angle = pos / torch.pow(10000, (2 * (i//2)) / d)
        pe = torch.zeros(L, d, dtype=torch.float32)
        pe[:, 0::2] = torch.sin(angle[:, 0::2])
        pe[:, 1::2] = torch.cos(angle[:, 1::2])
        return pe

    def forward(self, x):
        z = self.conv_t(x)
        z = self.proj(z)
        z = self.dropout(z)
        z = z.transpose(1, 2)
        
        # Positional encoding
        B, L, D = z.shape
        if (self.pos_encoding is None) or (self.pos_encoding.shape[0] != L) or (self.pos_encoding.shape[1] != D):
            self.pos_encoding = self._positional_encoding(L, D).to(z.device)
        z = z + self.pos_encoding[None, :, :]
        
        # CLS token
        cls_tok = self.cls.expand(B, -1, -1)
        z = torch.cat([cls_tok, z], dim=1)

        z, attn_list = self.encoder(z)
        if self.capture_attn:
            self._last_attn = attn_list

        cls = z[:, 0, :]
        return self.head(cls)

    def get_last_attention_maps(self):
        return self._last_attn


# VARIANTE 5: Con ReLU en vez de ELU/GELU
class DepthwiseSeparableConv_ReLU(nn.Module):
    """Convolución depthwise separable con ReLU"""
    def __init__(self, in_ch, out_ch, k, s=1, p=0, p_drop=0.2):
        super().__init__()
        self.dw = nn.Conv1d(in_ch, in_ch, kernel_size=k, stride=s, padding=p, groups=in_ch, bias=False)
        self.pw = nn.Conv1d(in_ch, out_ch, kernel_size=1, bias=False)
        self.norm = make_gn(out_ch)
        self.act = nn.ReLU()  # ReLU en vez de ELU
        self.dropout = nn.Dropout(p=p_drop)

    def forward(self, x):
        x = self.dw(x); x = self.pw(x); x = self.norm(x)
        x = self.act(x); x = self.dropout(x)
        return x

class _CustomEncoderLayer_ReLU(nn.Module):
    """Encoder Transformer con ReLU en FFN"""
    def __init__(self, d_model, nhead, dim_feedforward, dropout, batch_first, norm_first, return_attn=True):
        super().__init__()
        self.return_attn = return_attn
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.activation = nn.ReLU()  # ReLU en vez de GELU
        self.norm_first = norm_first

    def forward(self, src):
        sa_out, attn_weights = self.self_attn(src, src, src, need_weights=True, average_attn_weights=False)
        src = self.norm1(src + self.dropout1(sa_out))
        ff = self.linear2(self.dropout(self.activation(self.linear1(src))))
        src = self.norm2(src + self.dropout2(ff))
        return src, attn_weights

class EEGCNNTransformer_ReLU(nn.Module):
    """Variante con ReLU en vez de ELU/GELU"""
    def __init__(self, n_ch=8, n_cls=2, d_model=128, n_heads=4, n_layers=2,
                 p_drop=0.2, p_drop_encoder=0.3, n_dw_blocks=2,
                 k_list=(31,15,7), s_list=(2,2,2), p_list=(15,7,3), capture_attn=False):
        super().__init__()
        self.capture_attn = capture_attn
        self._last_attn = None
        self.pos_encoding = None

        # Stem con ReLU
        stem = [
            nn.Conv1d(n_ch, 32, kernel_size=129, stride=2, padding=64, bias=False),
            make_gn(32),
            nn.ReLU(),  # ReLU en vez de ELU
            nn.Dropout(p=p_drop),
        ]
        
        # Bloques depthwise con ReLU
        blocks = []
        in_c, out_cs = 32, [64, 128, 256, 256, 256]
        n_dw_blocks = int(np.clip(n_dw_blocks, 0, len(out_cs)))
        for i in range(n_dw_blocks):
            k = k_list[i] if i < len(k_list) else 7
            s = s_list[i] if i < len(s_list) else 2
            p = p_list[i] if i < len(p_list) else k//2
            blocks.append(DepthwiseSeparableConv_ReLU(in_c, out_cs[i], k=k, s=s, p=p, p_drop=p_drop))
            in_c = out_cs[i]

        self.conv_t = nn.Sequential(*stem, *blocks)
        self.proj = nn.Conv1d(in_c if n_dw_blocks>0 else 32, d_model, kernel_size=1, bias=False)
        self.dropout = nn.Dropout(p=p_drop_encoder)

        # Encoder Transformer con ReLU
        enc = _CustomEncoderLayer_ReLU(d_model=d_model, nhead=n_heads, dim_feedforward=2*d_model,
                                       dropout=0.1, batch_first=True, norm_first=False, return_attn=True)
        self.encoder = _CustomEncoder(enc, num_layers=n_layers)

        # Token CLS y head
        self.cls = nn.Parameter(torch.zeros(1, 1, d_model))
        nn.init.normal_(self.cls, std=0.02)
        self.head = nn.Sequential(nn.LayerNorm(d_model), nn.Linear(d_model, n_cls))

    def _positional_encoding(self, L, d):
        pos = torch.arange(0, L, dtype=torch.float32).unsqueeze(1)
        i   = torch.arange(0, d, dtype=torch.float32).unsqueeze(0)
        angle = pos / torch.pow(10000, (2 * (i//2)) / d)
        pe = torch.zeros(L, d, dtype=torch.float32)
        pe[:, 0::2] = torch.sin(angle[:, 0::2])
        pe[:, 1::2] = torch.cos(angle[:, 1::2])
        return pe

    def forward(self, x):
        z = self.conv_t(x)
        z = self.proj(z)
        z = self.dropout(z)
        z = z.transpose(1, 2)
        
        # Positional encoding
        B, L, D = z.shape
        if (self.pos_encoding is None) or (self.pos_encoding.shape[0] != L) or (self.pos_encoding.shape[1] != D):
            self.pos_encoding = self._positional_encoding(L, D).to(z.device)
        z = z + self.pos_encoding[None, :, :]
        
        # CLS token
        cls_tok = self.cls.expand(B, -1, -1)
        z = torch.cat([cls_tok, z], dim=1)

        z, attn_list = self.encoder(z)
        if self.capture_attn:
            self._last_attn = attn_list

        cls = z[:, 0, :]
        return self.head(cls)

    def get_last_attention_maps(self):
        return self._last_attn

print("✓ Variantes del modelo definidas para estudio de ablación")

### Función de Entrenamiento del Estudio de Ablación

Esta función entrena cada variante del modelo y recopila métricas para comparación.

In [None]:
# =========================
# ESTUDIO DE ABLACIÓN - FUNCIÓN PRINCIPAL
# =========================

def run_ablation_study(device, fold_list=[1], base_dir=None):
    """
    Ejecuta estudio de ablación entrenando todas las variantes del modelo.
    
    Args:
        device: Dispositivo para entrenamiento (cuda/cpu)
        fold_list: Lista de folds a evaluar (default: [1])
        base_dir: Directorio base para guardar resultados
    
    Returns:
        DataFrame con resultados comparativos
    """
    if base_dir is None:
        base_dir = PROJ / 'models' / '04_hybrid' / 'ablation_study'
    else:
        base_dir = Path(base_dir)
    
    ensure_dir(base_dir)
    
    # Definir variantes del modelo
    variants = {
        'baseline': {
            'name': 'Baseline (Completo)',
            'factory': lambda: EEGCNNTransformer(
                n_ch=8, n_cls=2, d_model=D_MODEL, n_heads=N_HEADS, 
                n_layers=N_LAYERS, p_drop=P_DROP, p_drop_encoder=P_DROP_ENCODER,
                n_dw_blocks=2, capture_attn=False
            ),
            'description': 'Modelo completo con todos los componentes'
        },
        'no_depthwise': {
            'name': 'Sin Depthwise Blocks',
            'factory': lambda: EEGCNNTransformer_NoDepthwise(
                n_ch=8, n_cls=2, d_model=D_MODEL, n_heads=N_HEADS,
                n_layers=N_LAYERS, p_drop=P_DROP, p_drop_encoder=P_DROP_ENCODER,
                capture_attn=False
            ),
            'description': 'Solo stem convolucional, sin bloques depthwise'
        },
        'no_posenc': {
            'name': 'Sin Codificación Posicional',
            'factory': lambda: EEGCNNTransformer_NoPosEncoding(
                n_ch=8, n_cls=2, d_model=D_MODEL, n_heads=N_HEADS,
                n_layers=N_LAYERS, p_drop=P_DROP, p_drop_encoder=P_DROP_ENCODER,
                n_dw_blocks=2, capture_attn=False
            ),
            'description': 'Sin positional encoding en Transformer'
        },
        'no_cls': {
            'name': 'Sin Token CLS',
            'factory': lambda: EEGCNNTransformer_NoCLS(
                n_ch=8, n_cls=2, d_model=D_MODEL, n_heads=N_HEADS,
                n_layers=N_LAYERS, p_drop=P_DROP, p_drop_encoder=P_DROP_ENCODER,
                n_dw_blocks=2, capture_attn=False
            ),
            'description': 'Usa average pooling en vez de token CLS'
        },
        'batchnorm': {
            'name': 'Con BatchNorm',
            'factory': lambda: EEGCNNTransformer_BatchNorm(
                n_ch=8, n_cls=2, d_model=D_MODEL, n_heads=N_HEADS,
                n_layers=N_LAYERS, p_drop=P_DROP, p_drop_encoder=P_DROP_ENCODER,
                n_dw_blocks=2, capture_attn=False
            ),
            'description': 'BatchNorm en vez de GroupNorm'
        },
        'relu': {
            'name': 'Con ReLU',
            'factory': lambda: EEGCNNTransformer_ReLU(
                n_ch=8, n_cls=2, d_model=D_MODEL, n_heads=N_HEADS,
                n_layers=N_LAYERS, p_drop=P_DROP, p_drop_encoder=P_DROP_ENCODER,
                n_dw_blocks=2, capture_attn=False
            ),
            'description': 'ReLU en vez de ELU/GELU'
        }
    }
    
    # Almacenar resultados
    results = []
    
    print("="*80)
    print("INICIANDO ESTUDIO DE ABLACIÓN")
    print("="*80)
    print(f"Variantes a evaluar: {len(variants)}")
    print(f"Folds: {fold_list}")
    print(f"Directorio de salida: {base_dir}")
    print("="*80)
    
    # Entrenar cada variante
    for variant_key, variant_info in variants.items():
        print(f"\n{'='*80}")
        print(f"VARIANTE: {variant_info['name']}")
        print(f"Descripción: {variant_info['description']}")
        print(f"{'='*80}")
        
        variant_dir = base_dir / variant_key
        ensure_dir(variant_dir)
        
        # Resultados de esta variante
        variant_results = {
            'variant': variant_key,
            'name': variant_info['name'],
            'description': variant_info['description']
        }
        
        # Métricas acumuladas por fold
        fold_accs = []
        fold_f1s = []
        fold_params = []
        
        for fold in fold_list:
            print(f"\n--- Fold {fold} ---")
            seed_everything(RANDOM_STATE + fold)
            
            try:
                # Cargar datos
                train_subs, test_subs = load_fold_subjects(FOLDS_JSON, fold)
                print(f"Train: {len(train_subs)} sujetos, Test: {len(test_subs)} sujetos")
                
                # Cargar épocas
                X_train_list, y_train_list = [], []
                for sid in tqdm(train_subs, desc="Cargando train"):
                    X, y, _ = load_subject_epochs(
                        sid, RESAMPLE_HZ, DO_NOTCH, DO_BANDPASS, 
                        DO_CAR, BP_LO, BP_HI
                    )
                    if X.shape[0] > 0:
                        X_train_list.append(X)
                        y_train_list.append(y)
                
                X_test_list, y_test_list = [], []
                for sid in tqdm(test_subs, desc="Cargando test"):
                    X, y, _ = load_subject_epochs(
                        sid, RESAMPLE_HZ, DO_NOTCH, DO_BANDPASS,
                        DO_CAR, BP_LO, BP_HI
                    )
                    if X.shape[0] > 0:
                        X_test_list.append(X)
                        y_test_list.append(y)
                
                X_train = np.concatenate(X_train_list, axis=0)
                y_train = np.concatenate(y_train_list, axis=0)
                X_test = np.concatenate(X_test_list, axis=0)
                y_test = np.concatenate(y_test_list, axis=0)
                
                # Estandarizar
                X_train, X_test = standardize_per_channel(X_train, X_test)
                
                print(f"Train: {X_train.shape}, Test: {X_test.shape}")
                
                # Crear modelo
                model = variant_info['factory']().to(device)
                n_params = count_params(model)
                fold_params.append(n_params)
                print(f"Parámetros: {n_params:,}")
                
                # Preparar DataLoader
                train_dataset = TensorDataset(
                    torch.tensor(X_train, dtype=torch.float32),
                    torch.tensor(y_train, dtype=torch.long)
                )
                
                if USE_WEIGHTED_SAMPLER:
                    class_counts = np.bincount(y_train)
                    class_weights = 1.0 / (class_counts + 1e-6)
                    sample_weights = class_weights[y_train]
                    sampler = WeightedRandomSampler(
                        weights=sample_weights,
                        num_samples=len(sample_weights),
                        replacement=True
                    )
                    train_loader = DataLoader(
                        train_dataset, batch_size=BATCH_SIZE,
                        sampler=sampler, worker_init_fn=seed_worker,
                        generator=torch.Generator().manual_seed(RANDOM_STATE)
                    )
                else:
                    train_loader = DataLoader(
                        train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                        worker_init_fn=seed_worker,
                        generator=torch.Generator().manual_seed(RANDOM_STATE)
                    )
                
                # Pérdida y optimizador
                alpha = torch.tensor([1.0, 1.0], device=device)
                criterion = FocalLoss(alpha=alpha, gamma=1.5)
                optimizer = torch.optim.AdamW(model.parameters(), lr=BASE_LR, weight_decay=1e-4)
                scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                    optimizer, T_max=EPOCHS, eta_min=1e-6
                )
                
                # EMA
                ema = ModelEMA(model, decay=EMA_DECAY, device=device) if USE_EMA else None
                
                # Entrenamiento
                best_val_acc = 0.0
                patience_counter = 0
                
                for epoch in range(EPOCHS):
                    model.train()
                    train_loss = 0.0
                    
                    for xb, yb in train_loader:
                        xb, yb = xb.to(device), yb.to(device)
                        
                        # Augmentación
                        if epoch > WARMUP_EPOCHS:
                            xb = augment_batch(xb)
                        
                        optimizer.zero_grad()
                        logits = model(xb)
                        loss = criterion(logits, yb)
                        loss.backward()
                        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                        optimizer.step()
                        
                        if ema is not None:
                            ema.update(model)
                        
                        train_loss += loss.item()
                    
                    scheduler.step()
                    
                    # Validación en test set (simplificado)
                    model.eval()
                    with torch.no_grad():
                        X_test_t = torch.tensor(X_test, dtype=torch.float32, device=device)
                        logits = model(X_test_t)
                        y_pred = logits.argmax(dim=1).cpu().numpy()
                        val_acc = accuracy_score(y_test, y_pred)
                    
                    if (epoch + 1) % 10 == 0:
                        print(f"Epoch {epoch+1}/{EPOCHS} - Loss: {train_loss/len(train_loader):.4f}, Val Acc: {val_acc:.4f}")
                    
                    # Early stopping
                    if val_acc > best_val_acc:
                        best_val_acc = val_acc
                        patience_counter = 0
                        # Guardar mejor modelo
                        torch.save(model.state_dict(), variant_dir / f"best_fold{fold}.pt")
                    else:
                        patience_counter += 1
                        if patience_counter >= PATIENCE:
                            print(f"Early stopping en epoch {epoch+1}")
                            break
                
                # Evaluación final
                model.load_state_dict(torch.load(variant_dir / f"best_fold{fold}.pt"))
                model.eval()
                
                with torch.no_grad():
                    X_test_t = torch.tensor(X_test, dtype=torch.float32, device=device)
                    logits = model(X_test_t)
                    y_pred = logits.argmax(dim=1).cpu().numpy()
                
                # Métricas
                acc = accuracy_score(y_test, y_pred)
                f1 = f1_score(y_test, y_pred, average='macro')
                metrics = compute_metrics(y_test, y_pred)
                
                fold_accs.append(acc)
                fold_f1s.append(f1)
                
                print(f"Fold {fold} - Acc: {acc:.4f}, F1: {f1:.4f}")
                
                # Limpiar memoria
                del model, optimizer, scheduler, train_loader, X_train, y_train
                torch.cuda.empty_cache()
                
            except Exception as e:
                print(f"ERROR en fold {fold}: {e}")
                import traceback
                traceback.print_exc()
                fold_accs.append(0.0)
                fold_f1s.append(0.0)
                fold_params.append(0)
        
        # Compilar resultados de la variante
        variant_results['mean_accuracy'] = np.mean(fold_accs) if fold_accs else 0.0
        variant_results['std_accuracy'] = np.std(fold_accs) if fold_accs else 0.0
        variant_results['mean_f1'] = np.mean(fold_f1s) if fold_f1s else 0.0
        variant_results['std_f1'] = np.std(fold_f1s) if fold_f1s else 0.0
        variant_results['n_params'] = fold_params[0] if fold_params else 0
        variant_results['folds'] = fold_list
        variant_results['fold_accs'] = fold_accs
        variant_results['fold_f1s'] = fold_f1s
        
        results.append(variant_results)
        
        print(f"\n{variant_info['name']} - Resumen:")
        print(f"  Accuracy: {variant_results['mean_accuracy']:.4f} ± {variant_results['std_accuracy']:.4f}")
        print(f"  F1-Score: {variant_results['mean_f1']:.4f} ± {variant_results['std_f1']:.4f}")
        print(f"  Parámetros: {variant_results['n_params']:,}")
    
    # Crear DataFrame con resultados
    import pandas as pd
    df_results = pd.DataFrame(results)
    
    # Guardar resultados
    results_csv = base_dir / 'ablation_results.csv'
    df_results.to_csv(results_csv, index=False)
    print(f"\n{'='*80}")
    print(f"Resultados guardados en: {results_csv}")
    print(f"{'='*80}")
    
    return df_results


def plot_ablation_results(df_results, output_dir=None):
    """
    Genera visualizaciones de los resultados del estudio de ablación.
    
    Args:
        df_results: DataFrame con resultados
        output_dir: Directorio para guardar gráficos
    """
    if output_dir is None:
        output_dir = PROJ / 'models' / '04_hybrid' / 'ablation_study'
    else:
        output_dir = Path(output_dir)
    
    ensure_dir(output_dir)
    
    # Ordenar por accuracy
    df_sorted = df_results.sort_values('mean_accuracy', ascending=False)
    
    # Gráfico de barras con accuracy y F1
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
    
    # Accuracy
    ax1.barh(df_sorted['name'], df_sorted['mean_accuracy'], 
             xerr=df_sorted['std_accuracy'], capsize=5, color='steelblue')
    ax1.set_xlabel('Accuracy')
    ax1.set_title('Comparación de Accuracy por Variante')
    ax1.set_xlim([0, 1.0])
    ax1.axvline(x=df_sorted.iloc[0]['mean_accuracy'], color='red', 
                linestyle='--', alpha=0.5, label='Mejor')
    ax1.legend()
    ax1.grid(axis='x', alpha=0.3)
    
    # F1-Score
    ax2.barh(df_sorted['name'], df_sorted['mean_f1'], 
             xerr=df_sorted['std_f1'], capsize=5, color='coral')
    ax2.set_xlabel('F1-Score (Macro)')
    ax2.set_title('Comparación de F1-Score por Variante')
    ax2.set_xlim([0, 1.0])
    ax2.axvline(x=df_sorted.iloc[0]['mean_f1'], color='red', 
                linestyle='--', alpha=0.5, label='Mejor')
    ax2.legend()
    ax2.grid(axis='x', alpha=0.3)
    
    plt.tight_layout()
    fig_path = output_dir / 'ablation_comparison.png'
    plt.savefig(fig_path, dpi=150, bbox_inches='tight')
    print(f"Gráfico guardado en: {fig_path}")
    plt.close()
    
    # Tabla de comparación
    fig, ax = plt.subplots(figsize=(12, 4))
    ax.axis('tight')
    ax.axis('off')
    
    table_data = []
    for _, row in df_sorted.iterrows():
        table_data.append([
            row['name'],
            f"{row['mean_accuracy']:.4f} ± {row['std_accuracy']:.4f}",
            f"{row['mean_f1']:.4f} ± {row['std_f1']:.4f}",
            f"{row['n_params']:,}"
        ])
    
    table = ax.table(cellText=table_data,
                     colLabels=['Variante', 'Accuracy', 'F1-Score', 'Parámetros'],
                     cellLoc='left',
                     loc='center')
    table.auto_set_font_size(False)
    table.set_fontsize(9)
    table.scale(1, 2)
    
    # Colorear la mejor fila
    for i in range(4):
        table[(1, i)].set_facecolor('#90EE90')
    
    plt.title('Resumen de Resultados - Estudio de Ablación', 
              fontsize=12, fontweight='bold', pad=20)
    
    table_path = output_dir / 'ablation_table.png'
    plt.savefig(table_path, dpi=150, bbox_inches='tight')
    print(f"Tabla guardada en: {table_path}")
    plt.close()
    
    return df_sorted

print("✓ Funciones de estudio de ablación definidas")

### Ejecución del Estudio de Ablación

Ejecuta el estudio de ablación y genera visualizaciones comparativas.

In [None]:
# =========================
# EJECUTAR ESTUDIO DE ABLACIÓN
# =========================

if __name__ == "__main__":
    # Configuración del estudio
    # Puedes cambiar fold_list para evaluar en múltiples folds
    # Por defecto usa solo fold 1 para rapidez (para pruebas)
    # Para resultados finales, usar fold_list=[1, 2, 3, 4, 5]
    
    print("\n" + "="*80)
    print("ESTUDIO DE ABLACIÓN - CNN+Transformer para EEG Motor Imagery")
    print("="*80 + "\n")
    
    # OPCIÓN 1: Prueba rápida con 1 fold
    fold_list = [1]
    
    # OPCIÓN 2: Evaluación completa con 5 folds (descomenta para usar)
    # fold_list = [1, 2, 3, 4, 5]
    
    # Ejecutar estudio de ablación
    print(f"Ejecutando estudio con {len(fold_list)} fold(s): {fold_list}")
    print("Esto puede tardar varios minutos dependiendo del número de folds...\n")
    
    df_results = run_ablation_study(
        device=DEVICE,
        fold_list=fold_list,
        base_dir=PROJ / 'models' / '04_hybrid' / 'ablation_study'
    )
    
    # Mostrar resultados
    print("\n" + "="*80)
    print("RESULTADOS DEL ESTUDIO DE ABLACIÓN")
    print("="*80)
    print(df_results[['name', 'mean_accuracy', 'std_accuracy', 'mean_f1', 'std_f1', 'n_params']].to_string(index=False))
    
    # Generar gráficos
    print("\nGenerando visualizaciones...")
    df_sorted = plot_ablation_results(
        df_results,
        output_dir=PROJ / 'models' / '04_hybrid' / 'ablation_study'
    )
    
    # Análisis de contribución relativa (vs baseline)
    print("\n" + "="*80)
    print("ANÁLISIS DE CONTRIBUCIÓN RELATIVA")
    print("="*80)
    
    baseline_acc = df_results[df_results['variant'] == 'baseline']['mean_accuracy'].values[0]
    
    print(f"\nBaseline Accuracy: {baseline_acc:.4f}\n")
    print("Variante                      | Acc ± Std          | Δ vs Baseline | % Change")
    print("-" * 80)
    
    for _, row in df_sorted.iterrows():
        delta = row['mean_accuracy'] - baseline_acc
        pct_change = (delta / baseline_acc) * 100 if baseline_acc > 0 else 0
        sign = "+" if delta >= 0 else ""
        
        print(f"{row['name']:30s} | {row['mean_accuracy']:.4f} ± {row['std_accuracy']:.4f} | "
              f"{sign}{delta:+.4f}    | {sign}{pct_change:+.2f}%")
    
    print("\n" + "="*80)
    print("INTERPRETACIÓN:")
    print("="*80)
    print("• Δ positivo: El componente REDUCE el rendimiento (su eliminación mejora)")
    print("• Δ negativo: El componente MEJORA el rendimiento (su eliminación empeora)")
    print("• Δ cercano a 0: El componente tiene impacto mínimo")
    print("\n")
    
    # Identificar componentes más importantes
    contributions = []
    for _, row in df_results.iterrows():
        if row['variant'] != 'baseline':
            delta = baseline_acc - row['mean_accuracy']  # Invertido: pérdida de rendimiento
            contributions.append({
                'component': row['name'],
                'contribution': delta,
                'abs_contribution': abs(delta)
            })
    
    contributions_sorted = sorted(contributions, key=lambda x: x['abs_contribution'], reverse=True)
    
    print("Ranking de componentes por impacto (mayor a menor):")
    print("-" * 80)
    for i, comp in enumerate(contributions_sorted, 1):
        impact = "BENEFICIOSO" if comp['contribution'] > 0 else "PERJUDICIAL" if comp['contribution'] < 0 else "NEUTRAL"
        print(f"{i}. {comp['component']:30s} | Contribución: {comp['contribution']:+.4f} | {impact}")
    
    print("\n" + "="*80)
    print("ESTUDIO COMPLETADO")
    print("="*80)
    print(f"Resultados guardados en: {PROJ / 'models' / '04_hybrid' / 'ablation_study'}")
    print("="*80 + "\n")

---
## Instrucciones de Uso

### Configuración rápida:
1. **Prueba rápida** (1 fold): Deja `fold_list = [1]` en la celda anterior
2. **Evaluación completa** (5 folds): Cambia a `fold_list = [1, 2, 3, 4, 5]`

### Tiempo estimado:
- **1 fold**: ~10-15 minutos por variante (total ~60-90 minutos para 6 variantes)
- **5 folds**: ~50-75 minutos por variante (total ~5-7.5 horas para 6 variantes)

### Resultados generados:
- **CSV**: `ablation_results.csv` - Tabla con todas las métricas
- **Gráficos**: 
  - `ablation_comparison.png` - Comparación visual de accuracy y F1
  - `ablation_table.png` - Tabla formateada con resultados
- **Modelos**: Guardados en `ablation_study/<variante>/best_fold{X}.pt`

### Interpretación de resultados:
- **Baseline** es tu modelo completo (referencia)
- Si al **eliminar un componente** el accuracy **baja** → ese componente es **beneficioso**
- Si al **eliminar un componente** el accuracy **sube** → ese componente puede estar **perjudicando**
- Diferencias < 1-2% pueden ser ruido estadístico

### Para tu tesis:
Puedes usar los gráficos y tablas generadas directamente en tu documento. La tabla de "Contribución Relativa" te ayudará a argumentar qué componentes son esenciales para tu modelo.