### Rutas, deps, utilidades básicas (crear carpetas, seeds, logger)

In [1]:
# %% [setup]
import os, sys, random, time, json
from pathlib import Path
from datetime import datetime
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import mne

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, ConfusionMatrixDisplay

# ---------- Rutas base (ajusta si tu repo difiere) ----------
PROJ      = Path('..').resolve().parent   # asumiendo este notebook vive en eegnet/
DATA_PROC = PROJ / 'data' / 'processed'

OUT_INTRA = PROJ / 'models' / 'eegnet'
OUT_LOSO  = PROJ / 'models' / 'eegnet'
OUT_INTER = PROJ / 'models' / 'inter_fixed_eegnet'
for d in (OUT_INTRA, OUT_LOSO, OUT_INTER):
    (d/'figures').mkdir(parents=True, exist_ok=True)
    (d/'tables').mkdir(parents=True, exist_ok=True)
    (d/'logs').mkdir(parents=True, exist_ok=True)

# ---------- Utilidades ----------
def set_seed(seed: int = 42):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

class AverageMeter:
    def __init__(self): self.reset()
    def reset(self):
        self.sum = 0.0; self.cnt = 0; self.avg = 0.0
    def update(self, val, n=1):
        self.sum += float(val) * n; self.cnt += n; self.avg = self.sum / max(1, self.cnt)

def count_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def list_subjects(fif_dir=DATA_PROC, pattern='S???_MI-epo.fif'):
    return sorted({Path(p).stem.split('_')[0] for p in Path(fif_dir).glob(pattern)})

def epochs_to_arrays(epochs: mne.Epochs):
    X = epochs.get_data().astype(np.float32)         # (N, C, T)
    y = epochs.events[:, -1].astype(np.int64)        # ints
    return X, y

def epochs_to_labels(epochs: mne.Epochs):
    inv = {v:k for k,v in epochs.event_id.items()}
    y = np.array([inv[e[-1]] for e in epochs.events], dtype=object)
    return y

def ts_now(): return datetime.now().strftime("%Y%m%d-%H%M%S")

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")


Using device: cuda


### Modelo EEGNet v4 (PyTorch) con hiperparámetros potentes

In [2]:
# %% [modelo]
class EEGNet(nn.Module):
    """
    EEGNet (Lawhern et al., 2018) para entrada [B, 1, C, T].
    Parámetros típicos: F1=8, D=2, F2=16, kernel_time=64, pool1=4, pool2=8, dropout=0.5
    """
    def __init__(self, n_classes, C, T, F1=8, D=2, F2=16, kernel_time=64,
                 pool1=4, pool2=8, dropout=0.5):
        super().__init__()
        self.n_classes = n_classes

        # Temporal
        self.conv_time = nn.Conv2d(1, F1, (1, kernel_time),
                                   padding=(0, kernel_time//2), bias=False)
        self.bn_time = nn.BatchNorm2d(F1)

        # Spatial (depthwise)
        self.depthwise = nn.Conv2d(F1, F1*D, (C, 1), groups=F1, bias=False)
        self.bn_depth = nn.BatchNorm2d(F1*D)
        self.pool1 = nn.AvgPool2d((1, pool1))
        self.drop1 = nn.Dropout(dropout)

        # Separable (depthwise temporal + pointwise)
        self.sep_depth = nn.Conv2d(F1*D, F1*D, (1, 16), padding=(0, 16//2),
                                   groups=F1*D, bias=False)
        self.sep_point = nn.Conv2d(F1*D, F2, (1, 1), bias=False)
        self.bn_sep = nn.BatchNorm2d(F2)
        self.pool2 = nn.AvgPool2d((1, pool2))
        self.drop2 = nn.Dropout(dropout)

        # Clasificador (se define lazy en forward)
        self.classifier = None

    def forward_features(self, x):
        x = self.conv_time(x); x = self.bn_time(x); x = F.elu(x)
        x = self.depthwise(x); x = self.bn_depth(x); x = F.elu(x)
        x = self.pool1(x); x = self.drop1(x)
        x = self.sep_depth(x); x = self.sep_point(x); x = self.bn_sep(x); x = F.elu(x)
        x = self.pool2(x); x = self.drop2(x)
        return x

    def forward(self, x):
        x = self.forward_features(x)
        if self.classifier is None:
            feat_dim = x.shape[1] * x.shape[2] * x.shape[3]
            self.classifier = nn.Linear(feat_dim, self.n_classes).to(x.device)
        x = torch.flatten(x, 1)
        return self.classifier(x)


### Dataset, normalización por canal (stats de train), augmentations suaves y DataLoaders

In [3]:
# %% [dataloader — FIX: targets 0..K-1 consistente]
from sklearn.preprocessing import LabelEncoder
from collections import Counter

def _make_windows_indices(n_times, win_size, stride):
    if stride <= 0:
        s = max(0, (n_times - win_size)//2)
        return [(s, min(s + win_size, n_times))]
    idx = []
    for s in range(0, max(1, n_times - win_size + 1), stride):
        e = s + win_size
        if e <= n_times: idx.append((s, e))
    if not idx:
        idx = [(0, n_times)]
    return idx

def fit_channel_normalizer(X_train):
    mean_c = X_train.mean(axis=(0,2))
    std_c  = X_train.std(axis=(0,2)) + 1e-7
    return mean_c.astype(np.float32), std_c.astype(np.float32)

def augment_light(x, p_noise=0.3, p_scale=0.3):
    C,T = x.shape
    if np.random.rand() < p_noise:
        x = x + 0.01*np.random.randn(C,T).astype(np.float32)
    if np.random.rand() < p_scale:
        scale = (0.9 + 0.2*np.random.rand(C)).astype(np.float32)
        x = x * scale[:, None]
    return x

class EEGWindowDataset(Dataset):
    def __init__(self, X, y_enc, win_size, stride=0, mean_=None, std_=None,
                 augment_fn=None, add_channel_dim=True):
        assert X.ndim == 3
        self.X = X; self.y = y_enc
        self.n_trials, self.n_ch, self.n_times = X.shape
        self.win_size = min(int(win_size), self.n_times)
        self.stride = int(stride)
        self.augment_fn = augment_fn
        self.add_channel_dim = add_channel_dim

        self.win_index = []
        if self.stride > 0:
            widx = _make_windows_indices(self.n_times, self.win_size, self.stride)
            for t in range(self.n_trials):
                for (s,e) in widx: self.win_index.append((t,s,e))
        else:
            (s,e) = _make_windows_indices(self.n_times, self.win_size, 0)[0]
            for t in range(self.n_trials): self.win_index.append((t,s,e))

        if mean_ is None or std_ is None:
            self.mean_ = np.zeros((self.n_ch,), dtype=np.float32)
            self.std_  = np.ones((self.n_ch,), dtype=np.float32)
        else:
            self.mean_ = mean_.astype(np.float32)
            self.std_  = np.clip(std_.astype(np.float32), 1e-7, None)

    def __len__(self): return len(self.win_index)

    def __getitem__(self, i):
        t, s, e = self.win_index[i]
        x = self.X[t, :, s:e]                           # (C,Tw)
        x = (x - self.mean_[:, None]) / self.std_[:, None]
        if self.augment_fn is not None: x = self.augment_fn(x)
        x = torch.from_numpy(x)                         # (C,Tw)
        if self.add_channel_dim: x = x.unsqueeze(0)     # (1,C,Tw)
        y = torch.tensor(self.y[t], dtype=torch.long)   # ya 0..K-1
        return x, y

def _epochs_to_arrays_and_labels(epochs: mne.Epochs):
    X = epochs.get_data().astype(np.float32)         # (N,C,T)
    inv = {v:k for k,v in epochs.event_id.items()}
    y_names = np.array([inv[e[-1]] for e in epochs.events], dtype=object)
    return X, y_names

def make_eegnet_loaders_from_epochs(
    ep_train, ep_val, ep_test, fs,
    crop_sec=(0.5,3.5),
    train_win_sec=1.5, train_stride_sec=0.25,
    eval_win_sec=None, batch_size=64, num_workers=4, use_augment=True
):
    # Recorte temporal coherente
    if crop_sec is not None:
        ep_train = ep_train.copy().crop(*crop_sec)
        ep_val   = ep_val.copy().crop(*crop_sec)
        ep_test  = ep_test.copy().crop(*crop_sec)

    # Arrays y etiquetas nominales
    Xtr, ytr_names = _epochs_to_arrays_and_labels(ep_train)
    Xva, yva_names = _epochs_to_arrays_and_labels(ep_val)
    Xte, yte_names = _epochs_to_arrays_and_labels(ep_test)

    # Encoder global (train+val+test) para que todo quede 0..K-1
    le = LabelEncoder().fit(np.concatenate([ytr_names, yva_names, yte_names]))
    ytr = le.transform(ytr_names).astype(np.int64)
    yva = le.transform(yva_names).astype(np.int64)
    yte = le.transform(yte_names).astype(np.int64)

    # Normalización por canal aprendida SOLO en TRAIN
    mean_c, std_c = fit_channel_normalizer(Xtr)

    # Ventanas
    win_tr = int(round((train_win_sec or (crop_sec[1]-crop_sec[0]))*fs))
    stride = int(round(train_stride_sec*fs)) if train_stride_sec else 0
    win_ev = int(round((eval_win_sec if eval_win_sec is not None else train_win_sec)*fs))

    ds_tr = EEGWindowDataset(Xtr,ytr,win_tr,stride,mean_c,std_c,augment_light if use_augment else None,True)
    ds_va = EEGWindowDataset(Xva,yva,win_ev,0,mean_c,std_c,None,True)
    ds_te = EEGWindowDataset(Xte,yte,win_ev,0,mean_c,std_c,None,True)

    dl_tr = DataLoader(ds_tr,batch_size=batch_size,shuffle=True,num_workers=num_workers,
                       pin_memory=True, drop_last=True)
    dl_va = DataLoader(ds_va,batch_size=batch_size,shuffle=False,num_workers=num_workers,
                       pin_memory=True)
    dl_te = DataLoader(ds_te,batch_size=batch_size,shuffle=False,num_workers=num_workers,
                       pin_memory=True)

    # Pesos de clase con etiquetas codificadas 0..K-1
    cnt = Counter(ytr.tolist()); n = sum(cnt.values())
    idxs = sorted(cnt.keys())
    weights = torch.tensor([n/cnt[c] for c in idxs], dtype=torch.float32)

    info = dict(mean=mean_c, std=std_c, n_classes=len(idxs), class_weights=weights,
                n_train_windows=len(ds_tr), n_val_windows=len(ds_va), n_test_windows=len(ds_te),
                C=Xtr.shape[1], T=win_tr, label_encoder_classes=le.classes_)
    return dl_tr, dl_va, dl_te, info


### Entrenamiento, evaluación y calibración few-shot (BN + FC, opcional depthwise)

In [4]:
# %% [train+eval con logging cada 10]
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR

def train_epoch(model, loader, optimizer, criterion, device):
    model.train()
    loss_m = AverageMeter(); correct=0; total=0
    for xb,yb in loader:
        xb, yb = xb.to(device, non_blocking=True), yb.to(device, non_blocking=True)
        optimizer.zero_grad(set_to_none=True)
        logits = model(xb)
        loss = criterion(logits, yb)
        loss.backward()
        optimizer.step()
        loss_m.update(loss.item(), xb.size(0))
        pred = logits.argmax(1); correct += (pred==yb).sum().item(); total += yb.numel()
    return loss_m.avg, correct/max(1,total)

@torch.no_grad()
def evaluate_model(model, loader, device):
    model.eval()
    ys=[]; yh=[]
    for xb,yb in loader:
        xb = xb.to(device, non_blocking=True)
        logits = model(xb)
        yh.append(logits.argmax(1).cpu().numpy())
        ys.append(yb.numpy())
    y_true = np.concatenate(ys); y_hat = np.concatenate(yh)
    acc = accuracy_score(y_true, y_hat)
    f1m = f1_score(y_true, y_hat, average='macro')
    cm  = confusion_matrix(y_true, y_hat)
    return acc, f1m, cm

def fit_model(model, dl_tr, dl_va, max_epochs=150, lr=3e-4, weight_decay=1e-3,
              device='cuda', class_weights=None, patience=20, log_file=None):
    model.to(device)
    optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = CosineAnnealingLR(optimizer, T_max=max_epochs)
    criterion = nn.CrossEntropyLoss(weight=class_weights.to(device) if class_weights is not None else None)

    # logger simple a archivo solo cada 10 épocas
    def log10(msg, epoch=None):
        if epoch is None or (epoch % 10 == 0):
            print(msg)
            if log_file is not None:
                with open(log_file, "a", encoding="utf-8") as f:
                    f.write(msg + "\n")

    best = {'val_f1': -1, 'state': None, 'epoch': -1}
    bad = 0
    for epoch in range(1, max_epochs+1):
        tr_loss, tr_acc = train_epoch(model, dl_tr, optimizer, criterion, device)
        va_acc, va_f1, _ = evaluate_model(model, dl_va, device)
        scheduler.step()

        log10(f"[Epoch {epoch:03d}] train_loss={tr_loss:.4f} | train_acc={tr_acc:.3f} | val_acc={va_acc:.3f} | val_f1={va_f1:.3f}",
              epoch=epoch)

        if va_f1 > best['val_f1']:
            best.update(val_f1=va_f1, state={k:v.detach().cpu() for k,v in model.state_dict().items()}, epoch=epoch)
            bad = 0
        else:
            bad += 1
            if bad >= patience:
                log10(f"[EarlyStop] sin mejora en {patience} épocas (mejor f1={best['val_f1']:.3f} @ {best['epoch']})",
                      epoch=epoch)
                break

    # resumen final (se imprime y guarda)
    final_msg = f"[TRAIN END] best_val_f1={best['val_f1']:.3f} @ epoch={best['epoch']}"
    print(final_msg)
    if log_file is not None:
        with open(log_file, "a", encoding="utf-8") as f:
            f.write(final_msg + "\n")

    if best['state'] is not None:
        model.load_state_dict({k:v for k,v in best['state'].items()})
    return model, best


### INTRA — k-fold por sujeto (reporta CSV/TXT + confusiones)

In [5]:
# %% [pipelines: INTRA, LOSO, INTER + calibración]
def split_calibration(ep_test: mne.Epochs, k_per_class=5, seed=42):
    if not k_per_class or k_per_class <= 0:
        return None, ep_test
    rng = np.random.RandomState(int(seed))
    labels = ep_test.events[:, -1]
    ep_calib_list, ep_eval_list = [], []
    for code in np.unique(labels):
        idx = np.where(labels == code)[0]
        rng.shuffle(idx)
        take = min(k_per_class, len(idx))
        if take > 0: ep_calib_list.append(ep_test.copy()[idx[:take]])
        if len(idx) > take: ep_eval_list.append(ep_test.copy()[idx[take:]])
    ep_calib = mne.concatenate_epochs(ep_calib_list, on_mismatch='ignore') if ep_calib_list else None
    ep_eval  = mne.concatenate_epochs(ep_eval_list,  on_mismatch='ignore') if ep_eval_list else ep_test
    return ep_calib, ep_eval

# -------------------- INTRA (k-fold) --------------------
def run_intra_all_eegnet(
    fif_dir=DATA_PROC, k=5, crop_window=(0.5,3.5),
    max_epochs=120, lr=3e-4, weight_decay=1e-3,
    F1=16, D=2, F2=32, kernel_time=64, dropout=0.5, pool1=4, pool2=8,
    train_win_sec=1.5, train_stride_sec=0.25, eval_win_sec=None,
    batch_size=128, seed=42, fs=250.0, use_augment=True, log_tag="intra"
):
    set_seed(seed)
    ts = ts_now()
    rows = []
    rows_subj = [] 

    subs = list_subjects(fif_dir)
    print(f"[INTRA-EEGNet] sujetos: {subs}")

    # asegurar carpeta de logs
    (OUT_INTRA/'logs').mkdir(parents=True, exist_ok=True)
    log_file = OUT_INTRA/'logs'/f"{ts}_eegnet_{log_tag}.txt"

    for sid in subs:
        ep = mne.read_epochs(str(Path(fif_dir)/f"{sid}_MI-epo.fif"), preload=True, verbose=False)
        if crop_window is not None:
            ep.crop(*crop_window)

        # Etiquetas base del sujeto (para estratificar)
        y = epochs_to_labels(ep)
        le = LabelEncoder().fit(y)
        y_enc = le.transform(y)

        idx = np.arange(len(ep))
        skf = StratifiedKFold(n_splits=k, shuffle=True, random_state=seed)

        for fold, (tr_idx, va_idx) in enumerate(skf.split(idx, y_enc), start=1):
            ep_tr = ep.copy()[tr_idx]
            ep_va = ep.copy()[va_idx]

            dl_tr, dl_va, dl_te, info = make_eegnet_loaders_from_epochs(
                ep_tr, ep_va, ep_va,
                fs=fs, crop_sec=None,
                train_win_sec=train_win_sec,
                train_stride_sec=train_stride_sec,
                eval_win_sec=(eval_win_sec if eval_win_sec is not None else train_win_sec),
                batch_size=batch_size, use_augment=use_augment
            )

            n_classes = info['n_classes']
            C = ep.info['nchan']
            Tw = int(info['T'])

            model = EEGNet(
                n_classes, C, Tw,
                F1=F1, D=D, F2=F2, kernel_time=kernel_time,
                pool1=pool1, pool2=pool2, dropout=dropout
            )

            model, best = fit_model(
                model, dl_tr, dl_va,
                max_epochs=max_epochs, lr=lr, weight_decay=weight_decay,
                device=device, class_weights=info['class_weights'],
                patience=20, log_file=str(log_file)
            )

            acc, f1m, _ = evaluate_model(model, dl_va, device)
            msg = f"[INTRA] {sid} | fold {fold} | acc={acc:.3f} | f1m={f1m:.3f}"
            print(msg)
            with open(log_file, "a", encoding="utf-8") as f:
                f.write(msg + "\n")

            rows.append(dict(subject=sid, fold=fold, acc=float(acc), f1_macro=float(f1m)))

        # --- PROMEDIO POR SUJETO (después de completar los k folds del sujeto) ---
        subj_rows = [r for r in rows if r['subject'] == sid]
        acc_mean = float(np.mean([r['acc'] for r in subj_rows])) if subj_rows else 0.0
        f1_mean  = float(np.mean([r['f1_macro'] for r in subj_rows])) if subj_rows else 0.0
        rows_subj.append(dict(subject=sid, acc_mean=acc_mean, f1_mean=f1_mean, n_folds=len(subj_rows)))

        msg_subj = f"[INTRA] {sid} | mean_acc={acc_mean:.3f} | mean_f1={f1_mean:.3f}"
        print(msg_subj)
        with open(log_file, "a", encoding="utf-8") as f:
            f.write(msg_subj + "\n")

    # --- DataFrames finales ---
    df = pd.DataFrame(rows).sort_values(['subject','fold'])
    df_avg_subject = pd.DataFrame(rows_subj).sort_values('subject')

    # Promedio GLOBAL como media de los promedios por sujeto
    if len(df_avg_subject):
        acc_mu = float(df_avg_subject['acc_mean'].mean())
        f1_mu  = float(df_avg_subject['f1_mean'].mean())
    else:
        acc_mu = 0.0
        f1_mu = 0.0

    print(f"[INTRA GLOBAL] mean_acc={acc_mu:.3f} | mean_f1={f1_mu:.3f}")
    with open(log_file, "a", encoding="utf-8") as f:
        f.write(f"[INTRA GLOBAL] mean_acc={acc_mu:.3f} | mean_f1={f1_mu:.3f}\n")

    # --- Guardar CSVs ---
    (OUT_INTRA/'tables').mkdir(parents=True, exist_ok=True)
    ts2 = ts_now()
    out_csv_all = OUT_INTRA/'tables'/f"{ts2}_eegnet_intra_all.csv"
    out_csv_avg = OUT_INTRA/'tables'/f"{ts2}_eegnet_intra_subject_avg.csv"
    df.to_csv(out_csv_all, index=False)
    df_avg_subject.to_csv(out_csv_avg, index=False)
    print("CSV →", out_csv_all)
    print("CSV →", out_csv_avg)

    return df, df_avg_subject


# -------------------- LOSO --------------------
def run_loso_eegnet(
    fif_dir=DATA_PROC, crop_window=(0.5,3.5),
    max_epochs=150, lr=3e-4, weight_decay=1e-3,
    F1=16, D=2, F2=32, kernel_time=64, dropout=0.5, pool1=4, pool2=8,
    train_win_sec=1.5, train_stride_sec=0.25, eval_win_sec=None,
    batch_size=128, seed=42, fs=250.0, use_augment=True,
    calibrate_k_per_class=0, calibrate_mode='head', log_tag="loso"
):
    set_seed(seed)
    ts = ts_now()
    rows=[]
    subs = list_subjects(fif_dir)
    log_file = OUT_LOSO/'logs'/f"{ts}_eegnet_{log_tag}.txt"

    for s_test in subs:
        train_ids = [s for s in subs if s != s_test]
        ep_tr = mne.concatenate_epochs([mne.read_epochs(str(Path(fif_dir)/f"{s}_MI-epo.fif"),
                                                       preload=True, verbose=False) for s in train_ids],
                                       on_mismatch='ignore')
        ep_te = mne.read_epochs(str(Path(fif_dir)/f"{s_test}_MI-epo.fif"), preload=True, verbose=False)
        ep_te = ep_te.reorder_channels(ep_tr.ch_names)

        if crop_window is not None:
            ep_tr.crop(*crop_window); ep_te.crop(*crop_window)

        # split calib/eval dentro de TEST
        ep_calib, ep_eval = split_calibration(ep_te, k_per_class=calibrate_k_per_class, seed=seed)

        # separa VALIDATION del TRAIN (para early stopping)
        y_tr = epochs_to_labels(ep_tr); le = LabelEncoder().fit(y_tr)
        idx = np.arange(len(ep_tr))
        tr_idx, va_idx = train_test_split(idx, test_size=0.15, random_state=seed,
                                          shuffle=True, stratify=le.transform(y_tr))
        ep_tr_sub = ep_tr.copy()[tr_idx]; ep_va = ep_tr.copy()[va_idx]

        dl_tr, dl_va, dl_eval, info = make_eegnet_loaders_from_epochs(
            ep_tr_sub, ep_va, ep_eval, fs=fs, crop_sec=None,
            train_win_sec=train_win_sec, train_stride_sec=train_stride_sec,
            eval_win_sec=(eval_win_sec if eval_win_sec is not None else train_win_sec),
            batch_size=batch_size, use_augment=use_augment
        )
        n_classes = info['n_classes']; C = ep_tr.info['nchan']; Tw = int(info['T'])
        model = EEGNet(n_classes, C, Tw, F1=F1, D=D, F2=F2, kernel_time=kernel_time,
                       pool1=pool1, pool2=pool2, dropout=dropout)

        model, best = fit_model(model, dl_tr, dl_va, max_epochs=max_epochs, lr=lr,
                                weight_decay=weight_decay, device=device,
                                class_weights=info['class_weights'], patience=20,
                                log_file=str(log_file))

        # calibración few-shot si aplica
        if calibrate_k_per_class and calibrate_k_per_class > 0 and ep_calib is not None:
            _, _, dl_calib, _ = make_eegnet_loaders_from_epochs(
                ep_tr, ep_va, ep_calib, fs=fs, crop_sec=None,
                train_win_sec=train_win_sec, train_stride_sec=train_stride_sec,
                eval_win_sec=(eval_win_sec if eval_win_sec is not None else train_win_sec),
                batch_size=batch_size, use_augment=False
            )
            if calibrate_mode.lower() == 'head':
                for p in model.parameters(): p.requires_grad_(False)
                for p in model.classifier.parameters(): p.requires_grad_(True)
                opt = torch.optim.AdamW(model.classifier.parameters(), lr=1e-3, weight_decay=0.0)
                crit = torch.nn.CrossEntropyLoss()
                model.to(device); model.train()
                for xb,yb in dl_calib:
                    xb,yb = xb.to(device), yb.to(device)
                    opt.zero_grad(set_to_none=True)
                    logits = model(xb); loss = crit(logits, yb)
                    loss.backward(); opt.step()
            else:
                opt = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.0)
                crit = torch.nn.CrossEntropyLoss()
                model.to(device); model.train()
                steps = min(300, len(dl_calib))
                it = 0
                for xb,yb in dl_calib:
                    xb,yb = xb.to(device), yb.to(device)
                    opt.zero_grad(set_to_none=True)
                    logits = model(xb); loss = crit(logits, yb)
                    loss.backward(); opt.step()
                    it += 1
                    if it >= steps: break

        acc, f1m, cm = evaluate_model(model, dl_eval, device)
        msg = f"[LOSO] test={s_test} | acc={acc:.3f} | f1m={f1m:.3f}"
        print(msg)
        with open(log_file, "a", encoding="utf-8") as f:
            f.write(msg + "\n")
        rows.append(dict(test_subject=s_test, acc=float(acc), f1_macro=float(f1m), n_test=len(ep_eval)))

        # figura por sujeto (opcional)
        fig, ax = plt.subplots(figsize=(5,4), dpi=140)
        ConfusionMatrixDisplay(cm).plot(ax=ax, cmap='Blues', colorbar=False, values_format='d')
        ax.set_title(f"EEGNet LOSO — {s_test}")
        fig.tight_layout()
        fig.savefig(OUT_LOSO/'figures'/f"cm_loso_{s_test}.png"); plt.close(fig)

    df = pd.DataFrame(rows)
    # fila GLOBAL
    g_acc = df.acc.mean(); g_f1 = df.f1_macro.mean()
    df.loc[len(df)] = dict(test_subject='GLOBAL', acc=g_acc, f1_macro=g_f1, n_test=int(df.n_test.sum()))
    print(f"[LOSO GLOBAL] acc_avg={g_acc:.3f} | f1m_avg={g_f1:.3f}")
    with open(log_file, "a", encoding="utf-8") as f:
        f.write(f"[LOSO GLOBAL] acc_avg={g_acc:.3f} | f1m_avg={g_f1:.3f}\n")

    ts2 = ts_now()
    out_csv = OUT_LOSO/'tables'/f"{ts2}_eegnet_loso.csv"
    df.to_csv(out_csv, index=False)
    print("CSV →", out_csv)
    return df

# -------------------- INTER (test fijo) --------------------
FIXED_TEST_SUBJECTS = [
    'S007','S025','S029','S031','S032','S034','S035','S042','S043','S049','S056','S058','S062','S072',
    'S001','S010','S013','S017','S019','S030',
    'S005','S006','S009','S097'
]

def make_split_with_fixed_test(fif_dir=DATA_PROC, fixed_test=(), val_size=16, seed=42):
    all_files = sorted(Path(fif_dir).glob('S???_MI-epo.fif'))
    sids = [p.stem.split('_')[0] for p in all_files]
    existing = set(sids)
    test_subs = [s for s in list(fixed_test) if s in existing]
    remaining = [s for s in sids if s not in test_subs]
    if len(remaining) <= val_size:
        raise ValueError("No hay suficientes sujetos para validación.")
    train_subs, val_subs = train_test_split(remaining, test_size=val_size, random_state=seed, shuffle=True)
    return train_subs, val_subs, test_subs

def run_inter_fixedsplit_eegnet(
    fif_dir=DATA_PROC, crop_window=(0.5,3.5),
    max_epochs=180, lr=3e-4, weight_decay=1e-3,
    F1=16, D=2, F2=32, kernel_time=64, dropout=0.5, pool1=4, pool2=8,
    train_win_sec=1.5, train_stride_sec=0.25, eval_win_sec=None,
    batch_size=128, seed=42, fs=250.0, use_augment=True,
    refit_on_trainval_for_test=True,
    calibrate_k_per_class=0, calibrate_mode='head', log_tag="inter"
):
    set_seed(seed)
    ts = ts_now()
    log_file = OUT_INTER/'logs'/f"{ts}_eegnet_{log_tag}.txt"

    train_subs, val_subs, test_subs = make_split_with_fixed_test(fif_dir, FIXED_TEST_SUBJECTS, val_size=16, seed=seed)
    print(f"TRAIN={len(train_subs)} | VAL={len(val_subs)} | TEST(fijo)={len(test_subs)}")
    with open(log_file, "a", encoding="utf-8") as f:
        f.write(f"TRAIN={len(train_subs)} | VAL={len(val_subs)} | TEST={len(test_subs)}\n")

    ep_tr = mne.concatenate_epochs([mne.read_epochs(str(Path(fif_dir)/f"{s}_MI-epo.fif"),
                                                   preload=True, verbose=False) for s in train_subs],
                                   on_mismatch='ignore')
    ep_va = mne.concatenate_epochs([mne.read_epochs(str(Path(fif_dir)/f"{s}_MI-epo.fif"),
                                                   preload=True, verbose=False) for s in val_subs],
                                   on_mismatch='ignore')
    ep_te_all = mne.concatenate_epochs([mne.read_epochs(str(Path(fif_dir)/f"{s}_MI-epo.fif"),
                                                       preload=True, verbose=False) for s in test_subs],
                                       on_mismatch='ignore')
    ep_va = ep_va.reorder_channels(ep_tr.ch_names)
    ep_te_all = ep_te_all.reorder_channels(ep_tr.ch_names)

    # Entrena (TRAIN vs VAL)
    dl_tr, dl_va, _, info = make_eegnet_loaders_from_epochs(
        ep_tr, ep_va, ep_va, fs=fs, crop_sec=crop_window,
        train_win_sec=train_win_sec, train_stride_sec=train_stride_sec,
        eval_win_sec=(eval_win_sec if eval_win_sec is not None else train_win_sec),
        batch_size=batch_size, use_augment=use_augment
    )
    n_classes = info['n_classes']; C = ep_tr.info['nchan']; Tw = int(info['T'])
    model = EEGNet(n_classes, C, Tw, F1=F1, D=D, F2=F2, kernel_time=kernel_time,
                   pool1=pool1, pool2=pool2, dropout=dropout)

    model, best = fit_model(model, dl_tr, dl_va, max_epochs=max_epochs, lr=lr,
                            weight_decay=weight_decay, device=device,
                            class_weights=info['class_weights'], patience=20,
                            log_file=str(log_file))

    # Test: refit base (TRAIN+VAL) opcional para normalización y ventanas
    ep_trva = mne.concatenate_epochs([ep_tr, ep_va], on_mismatch='ignore') if refit_on_trainval_for_test else ep_tr
    ep_calib, ep_eval = split_calibration(ep_te_all, k_per_class=calibrate_k_per_class, seed=seed)

    # Loaders para eval y (si aplica) calib
    _, _, dl_eval, _ = make_eegnet_loaders_from_epochs(
        ep_trva, ep_va, ep_eval, fs=fs, crop_sec=crop_window,
        train_win_sec=train_win_sec, train_stride_sec=train_stride_sec,
        eval_win_sec=(eval_win_sec if eval_win_sec is not None else train_win_sec),
        batch_size=batch_size, use_augment=False
    )

    if calibrate_k_per_class and calibrate_k_per_class > 0 and ep_calib is not None:
        _, _, dl_calib, _ = make_eegnet_loaders_from_epochs(
            ep_trva, ep_va, ep_calib, fs=fs, crop_sec=crop_window,
            train_win_sec=train_win_sec, train_stride_sec=train_stride_sec,
            eval_win_sec=(eval_win_sec if eval_win_sec is not None else train_win_sec),
            batch_size=batch_size, use_augment=False
        )
        if calibrate_mode.lower() == 'head':
            for p in model.parameters(): p.requires_grad_((False))
            for p in model.classifier.parameters(): p.requires_grad_(True)
            opt = torch.optim.AdamW(model.classifier.parameters(), lr=1e-3, weight_decay=0.0)
            crit = torch.nn.CrossEntropyLoss()
            model.to(device); model.train()
            for xb,yb in dl_calib:
                xb,yb = xb.to(device), yb.to(device)
                opt.zero_grad(set_to_none=True)
                logits = model(xb); loss = crit(logits, yb)
                loss.backward(); opt.step()
        else:
            opt = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.0)
            crit = torch.nn.CrossEntropyLoss()
            model.to(device); model.train()
            steps = min(300, len(dl_calib))
            it = 0
            for xb,yb in dl_calib:
                xb,yb = xb.to(device), yb.to(device)
                opt.zero_grad(set_to_none=True)
                logits = model(xb); loss = crit(logits, yb)
                loss.backward(); opt.step()
                it += 1
                if it >= steps: break

    acc_te, f1_te, cm_te = evaluate_model(model, dl_eval, device)
    msg = f"[INTER-EEGNet] ACC_test={acc_te:.3f} | F1m_test={f1_te:.3f}"
    print(msg)
    with open(log_file, "a", encoding="utf-8") as f:
        f.write(msg + "\n")

    # Figura y tablas
    fig, ax = plt.subplots(figsize=(5.4, 4.6), dpi=140)
    ConfusionMatrixDisplay(cm_te).plot(ax=ax, cmap="Blues", colorbar=True, values_format='d')
    ax.set_title("EEGNet — TEST fijo")
    fig.tight_layout()
    fig.savefig(OUT_INTER/'figures'/f"inter_eegnet_confusion_{ts}.png"); plt.close(fig)

    df = pd.DataFrame([dict(
        mode="inter_fixedsplit_eegnet",
        acc_test=float(acc_te), f1_test=float(f1_te),
        n_train=len(train_subs), n_val=len(val_subs), n_test=len(test_subs),
        crop=str(crop_window), train_win_sec=float(train_win_sec),
        stride_sec=float(train_stride_sec), eval_win_sec=float(eval_win_sec or train_win_sec),
        F1=int(F1), D=int(D), F2=int(F2), kernel_time=int(kernel_time),
        dropout=float(dropout), pool1=int(pool1), pool2=int(pool2),
        refit_on_trainval_for_test=bool(refit_on_trainval_for_test),
        calibrate_k_per_class=int(calibrate_k_per_class or 0),
        calibrate_mode=str(calibrate_mode)
    )])
    out_csv = OUT_INTER/'tables'/f"{ts}_inter_eegnet.csv"
    df.to_csv(out_csv, index=False)
    print("CSV →", out_csv)
    return df


### Ejemplos de ejecución

In [6]:
# %% [ejemplos de ejecución]
# Ajusta hiperparámetros aquí
COMMON = dict(
    crop_window=(0.5, 3.5),
    F1=16, D=2, F2=32, kernel_time=64, dropout=0.5, pool1=4, pool2=8,
    train_win_sec=1.5, train_stride_sec=0.25, eval_win_sec=1.5,
    batch_size=256, seed=42, fs=160.0, use_augment=True
)

# -------- INTRA (k-fold) --------
# df_intra, df_intra_subject_avg = run_intra_all_eegnet(
#     fif_dir=DATA_PROC,
#     k=5,
#     max_epochs=200, lr=3e-4, weight_decay=1e-3,
#     **COMMON, log_tag="intra_default"
# )
# try:
#     display(df_intra.head())
#     display(df_intra_subject_avg.head())
# except:
#     pass

# -------- LOSO --------
# df_loso = run_loso_eegnet(
#     fif_dir=DATA_PROC,
#     max_epochs=150, lr=3e-4, weight_decay=1e-3,
#     calibrate_k_per_class=5,      # 0/None => sin calibración
#     calibrate_mode='head',        # 'head' recomendado
#     **COMMON, log_tag="loso_default"
# )
# try:
#     display(df_loso.head())
# except:
#     pass

# -------- INTER (test fijo + calibración) --------
df_inter = run_inter_fixedsplit_eegnet(
    fif_dir=DATA_PROC,
    max_epochs=180, lr=3e-4, weight_decay=1e-3,
    refit_on_trainval_for_test=True,
    calibrate_k_per_class=10,     # few-shot por clase
    calibrate_mode='head',
    **COMMON, log_tag="inter_default"
)
try:
    display(df_inter)
except:
    pass


TRAIN=63 | VAL=16 | TEST(fijo)=24
Not setting metadata
5287 matching events found
No baseline correction applied
Not setting metadata
1367 matching events found
No baseline correction applied
Not setting metadata
2043 matching events found
No baseline correction applied
[Epoch 010] train_loss=1.3678 | train_acc=0.297 | val_acc=0.300 | val_f1=0.296
[Epoch 020] train_loss=1.3551 | train_acc=0.316 | val_acc=0.296 | val_f1=0.296
[Epoch 030] train_loss=1.3512 | train_acc=0.323 | val_acc=0.296 | val_f1=0.295
[TRAIN END] best_val_f1=0.309 @ epoch=13
Not setting metadata
6654 matching events found
No baseline correction applied
Not setting metadata
40 matching events found
No baseline correction applied
Not setting metadata
2003 matching events found
No baseline correction applied


  ep_calib = mne.concatenate_epochs(ep_calib_list, on_mismatch='ignore') if ep_calib_list else None
  ep_eval  = mne.concatenate_epochs(ep_eval_list,  on_mismatch='ignore') if ep_eval_list else ep_test


[INTER-EEGNet] ACC_test=0.377 | F1m_test=0.379
CSV → /root/Proyecto/EEG_Clasificador/models/inter_fixed_eegnet/tables/20251006-214917_inter_eegnet.csv


Unnamed: 0,mode,acc_test,f1_test,n_train,n_val,n_test,crop,train_win_sec,stride_sec,eval_win_sec,F1,D,F2,kernel_time,dropout,pool1,pool2,refit_on_trainval_for_test,calibrate_k_per_class,calibrate_mode
0,inter_fixedsplit_eegnet,0.376935,0.378948,63,16,24,"(0.5, 3.5)",1.5,0.25,1.5,16,2,32,64,0.5,4,8,True,10,head
