### Bloque — Rutas y logger

In [2]:
# %% [Rutas & logger — modelo Riemanniano (MDM/FgMDM)]
import sys, logging, warnings
from datetime import datetime
from pathlib import Path
import mne

# Este notebook está en: models/riemanniano_mdm/
# PROJ -> raíz del repo
PROJ = Path('..').resolve().parent          # .../models/riemanniano_mdm -> parent() = models -> parent() = <repo root>
DATA_PROC = PROJ / 'data' / 'processed'     # datos preprocesados (S???_MI-epo.fif)

# Salidas de este modelo (separadas)
RIEM_OUT_ROOT = PROJ / 'models' / 'riemanniano_mdm'
RIEM_FIG_DIR  = RIEM_OUT_ROOT / 'figures'
RIEM_TAB_DIR  = RIEM_OUT_ROOT / 'tables'
RIEM_LOG_DIR  = RIEM_OUT_ROOT / 'logs'
for d in (RIEM_FIG_DIR, RIEM_TAB_DIR, RIEM_LOG_DIR):
    d.mkdir(parents=True, exist_ok=True)

print(f"[Riemann] data procesados → {DATA_PROC}")
print(f"[Riemann] figuras  → {RIEM_FIG_DIR}")
print(f"[Riemann] tablas   → {RIEM_TAB_DIR}")
print(f"[Riemann] logs     → {RIEM_LOG_DIR}")

def _init_logger_riem(run_name: str):
    """
    Logger propio del modelo riemanniano.
    - Escribe a consola y a un TXT con timestamp en models/riemanniano_mdm/logs/.
    - Silencia el ruido de MNE.
    """
    ts = datetime.now().strftime("%Y%m%d-%H%M%S")
    log_path = RIEM_LOG_DIR / f"{ts}_{run_name}.txt"

    logger = logging.getLogger(run_name)
    logger.setLevel(logging.INFO)
    logger.handlers.clear()

    fmt = logging.Formatter("[%(asctime)s] %(levelname)s: %(message)s", datefmt="%H:%M:%S")
    ch = logging.StreamHandler(stream=sys.stdout); ch.setLevel(logging.INFO); ch.setFormatter(fmt)
    fh = logging.FileHandler(log_path, encoding="utf-8"); fh.setLevel(logging.INFO); fh.setFormatter(fmt)
    logger.addHandler(ch); logger.addHandler(fh)

    mne.set_log_level("ERROR")
    warnings.filterwarnings("ignore", category=UserWarning, module="mne")
    warnings.filterwarnings("ignore", category=RuntimeWarning, module="mne")
    return logger, log_path


[Riemann] data procesados → /root/Proyecto/EEG_Clasificador/data/processed
[Riemann] figuras  → /root/Proyecto/EEG_Clasificador/models/riemanniano_mdm/figures
[Riemann] tablas   → /root/Proyecto/EEG_Clasificador/models/riemanniano_mdm/tables
[Riemann] logs     → /root/Proyecto/EEG_Clasificador/models/riemanniano_mdm/logs


### Bloque - Helpers riemannianos (filtro-banco → covarianzas SPD → combinación “block”)

Filtra por sub-bandas (mu/beta),

Opcionalmente hace z-score por época,

Calcula covarianzas (estables con shrinkage),

Combina las sub-bandas en una covarianza bloque (SPD grande) para MDM/FgMDM.

In [3]:
# %% [Helpers Riemann — FB covariances en bloque (MDM vs FgMDM)]
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from math import ceil
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import ConfusionMatrixDisplay

# Bandas por defecto (mu/beta denso)
FB_BANDS_DENSE = [(f, f+2) for f in range(8, 30, 2)]
FB_BANDS_CLASSIC = [(8,12), (12,16), (16,20), (20,24), (24,28), (28,30)]
DEFAULT_FB_BANDS = FB_BANDS_DENSE

# pyRiemann
try:
    from pyriemann.estimation import Covariances
    from pyriemann.classification import MDM, FgMDM
except ImportError:
    raise ImportError("pyriemann no está instalado. Instala con: pip install pyriemann")

# Tokens motores (si quisieras recortar todavía más, aun cuando ya usas 8 canales)
_RIEM_MOTOR_TOKENS = ['C3','C4','Cz','CP3','CP4','FC3','FC4','FCz']

def _riem_find_motor_chs(ch_names, tokens=_RIEM_MOTOR_TOKENS):
    up = [c.upper() for c in ch_names]
    picks = []
    for tok in tokens:
        TU = tok.upper()
        for i, name in enumerate(up):
            if TU in name:
                picks.append(i); break
    return sorted(set(picks))

def _riem_epochwise_zscore(X, eps=1e-8):
    # No recomendado por defecto para Riemann; dejar False salvo diagnóstico
    mean = X.mean(axis=-1, keepdims=True)
    std  = X.std(axis=-1, keepdims=True)
    return (X - mean) / (std + eps)

def _riem_epochs_to_Xy(epochs):
    X = epochs.get_data()  # (n_trials, n_ch, n_times)
    inv = {v:k for k,v in epochs.event_id.items()}
    y = np.array([inv[e[-1]] for e in epochs.events], dtype=object)
    return X, y

def _normalize_trace(C):
    """
    Normaliza cada SPD por su traza para estabilizar escala.
    Soporta shape (..., n_ch, n_ch) arbitraria.
    """
    C = np.asarray(C, dtype=float)
    tr = np.trace(C, axis1=-2, axis2=-1)
    tr = np.where(tr == 0, 1.0, tr)
    return C / tr[..., None, None]

def _split_calibration(ep_te, k_per_class=5, shuffle=True, random_state=42, 
                       require_all_classes=False, return_indices=False):
    """
    Divide un conjunto de épocas de TEST en:
      - CALIB (k_per_class por clase)
      - EVAL (el resto)

    Parámetros
    ----------
    ep_te : mne.Epochs
        Épocas del sujeto a partir de las cuales se hará calibración+evaluación.
    k_per_class : int
        Nº de épocas por clase que se irán a CALIB. Si <=0 → (None, ep_te).
    shuffle : bool
        Si True, baraja índices dentro de cada clase antes de tomar k.
        (Recomendado para evitar sesgo por orden temporal).
    random_state : int
        Semilla para la aleatoriedad (si shuffle=True).
    require_all_classes : bool
        Si True, exige que TODAS las clases tengan al menos k épocas;
        si alguna no alcanza, devuelve (None, ep_te).
        Si False, usa min(k, n_clase) y continúa.
    return_indices : bool
        Si True, además devuelve (idx_calib, idx_eval, class_counts).

    Retorna
    -------
    ep_calib : mne.Epochs or None
    ep_eval  : mne.Epochs
    (opcionales)
    idx_calib : np.ndarray (int)
    idx_eval  : np.ndarray (int)
    class_counts : dict {code: (n_calib, n_eval, n_total)}
    """
    if k_per_class <= 0:
        if return_indices:
            n = len(ep_te)
            idx_all = np.arange(n)
            labels = ep_te.events[:, -1]
            counts = {int(c): (0, int((labels == c).sum()), int((labels == c).sum()))
                      for c in np.unique(labels)}
            return None, ep_te, np.array([], dtype=int), idx_all, counts
        return None, ep_te

    labels = ep_te.events[:, -1].astype(int)
    classes = np.unique(labels)
    rng = np.random.RandomState(random_state) if shuffle else None

    calib_idx = []
    eval_idx  = []
    class_counts = {}

    # Chequeo opcional: todas las clases deben tener >= k
    if require_all_classes:
        for c in classes:
            n_c = int((labels == c).sum())
            if n_c < k_per_class:
                # no cumple el mínimo → no calibrar este sujeto
                if return_indices:
                    counts = {int(code): (0, int((labels == code).sum()), int((labels == code).sum()))
                              for code in classes}
                    return None, ep_te, np.array([], dtype=int), np.arange(len(ep_te)), counts
                return None, ep_te

    for c in classes:
        idx_c = np.where(labels == c)[0]
        if shuffle and len(idx_c) > 1:
            rng.shuffle(idx_c)

        take = min(k_per_class, len(idx_c))
        sel = idx_c[:take]
        rem = idx_c[take:]

        if take > 0:
            calib_idx.append(sel)
        if len(rem) > 0:
            eval_idx.append(rem)

        class_counts[int(c)] = (int(take), int(len(rem)), int(len(idx_c)))

    # Si no hay nada para calibrar, devolver (None, ep_te)
    if len(calib_idx) == 0:
        if return_indices:
            idx_all = np.arange(len(ep_te))
            return None, ep_te, np.array([], dtype=int), idx_all, class_counts
        return None, ep_te

    calib_idx = np.concatenate(calib_idx) if len(calib_idx) else np.array([], dtype=int)
    eval_idx  = np.concatenate(eval_idx)  if len(eval_idx)  else np.array([], dtype=int)

    # Ordenar índices para que cada subconjunto quede en orden cronológico
    calib_idx.sort()
    eval_idx.sort()

    ep_calib = ep_te.copy()[calib_idx]
    ep_eval  = ep_te.copy()[eval_idx] if len(eval_idx) > 0 else ep_te.copy()[[]]  # vacío si no hay eval

    if return_indices:
        return ep_calib, ep_eval, calib_idx, eval_idx, class_counts
    return ep_calib, ep_eval



def _riem_fb_covariances(epochs,
                         fb_bands=DEFAULT_FB_BANDS,
                         motor_only=False,
                         zscore_epoch=False,
                         crop_window=None,
                         cov_estimator='oas',
                         model='mdm'):
    """
    Calcula covarianzas multi-banda:
      - Para 'mdm' → devuelve cov bloque-diagonal: (n_trials, n_fb*n_ch, n_fb*n_ch)
      - Para 'fgmdm' → devuelve pila por banda:    (n_trials, n_fb, n_ch, n_ch)
    En ambos casos, se normaliza por traza banda a banda.

    Retorna: C, y, classes
    """
    ep = epochs.copy()
    if crop_window is not None:
        ep.crop(*crop_window)

    if motor_only:
        picks = _riem_find_motor_chs(ep.ch_names)
        if picks:
            ep.pick(picks)

    X, y = _riem_epochs_to_Xy(ep)  # (n_trials, n_ch, n_times)
    n_trials, n_ch, _ = X.shape
    n_fb = len(fb_bands)

    # 1) covarianzas por banda
    covs_per_band = []
    for (fmin, fmax) in fb_bands:
        ep_b = ep.copy().filter(fmin, fmax, picks='eeg', verbose=False)
        Xb = ep_b.get_data()  # (n_trials, n_ch, n_times)
        if zscore_epoch:
            Xb = _riem_epochwise_zscore(Xb)
        Cb = Covariances(estimator=cov_estimator).fit_transform(Xb)  # (n_trials, n_ch, n_ch)
        Cb = _normalize_trace(Cb)
        covs_per_band.append(Cb)

    if model.lower() == 'fgmdm':
        # pila 4D: (n_trials, n_fb, n_ch, n_ch)
        C = np.stack(covs_per_band, axis=1)
    else:
        # bloque-diagonal (n_trials, n_fb*n_ch, n_fb*n_ch)
        C = np.zeros((n_trials, n_fb * n_ch, n_fb * n_ch), dtype=float)
        for b, Cb in enumerate(covs_per_band):
            i0 = b * n_ch
            i1 = (b + 1) * n_ch
            C[:, i0:i1, i0:i1] = Cb

    classes = np.unique(y).tolist()
    return C, y, classes

def _riem_fb_cov_train_test(ep_tr, ep_te,
                             fb_bands=DEFAULT_FB_BANDS,
                             motor_only=False,
                             zscore_epoch=False,
                             crop_window=None,
                             cov_estimator='oas',
                             model='mdm'):
    """
    Helper: saca cov multi-banda para TRAIN y TEST y alinea etiquetas con LabelEncoder.
    - Para 'mdm'   → Ctr/Cte 3D (bloque-diagonal).
    - Para 'fgmdm' → Ctr/Cte 4D (pila por banda).
    """
    Ctr, y_tr, _ = _riem_fb_covariances(ep_tr, fb_bands, motor_only, zscore_epoch, crop_window, cov_estimator, model)
    Cte, y_te, _ = _riem_fb_covariances(ep_te, fb_bands, motor_only, zscore_epoch, crop_window, cov_estimator, model)
    le = LabelEncoder().fit(np.concatenate([y_tr, y_te]))
    return Ctr, le.transform(y_tr), Cte, le.transform(y_te), list(le.classes_)

def _to_block_if_4d(X):
    """
    Si X viene 4D (n_trials, n_bands, n_ch, n_ch), lo convertimos
    a bloque-diagonal 3D (n_trials, n_bands*n_ch, n_bands*n_ch).
    Si ya es 3D, se devuelve tal cual.
    """
    X = np.asarray(X)
    if X.ndim == 4:
        n_trials, n_fb, n_ch, _ = X.shape
        Xb = np.zeros((n_trials, n_fb * n_ch, n_fb * n_ch), dtype=X.dtype)
        for b in range(n_fb):
            i0, i1 = b * n_ch, (b + 1) * n_ch
            Xb[:, i0:i1, i0:i1] = X[:, b, :, :]
        return Xb
    return X


### Bloque — Clasificadores riemannianos (MDM y FgMDM)

In [4]:
# %% [CLF Riemann — MDM/FgMDM]
def _riem_make_clf(model='mdm', metric='riemann'):
    """
    Crea el clasificador riemanniano:
      - 'mdm'   → Minimum Distance to Mean (metric='riemann' por defecto)
      - 'fgmdm' → Filter-Geodesic MDM (agrega por banda en la geometría)
    """
    model = (model or 'mdm').lower()
    if model == 'fgmdm':
        return FgMDM(metric=metric)
    return MDM(metric=metric)


### Bloque 4 — LOSO (todos) con MDM/FgMDM + calibración opcional

LOSO clásico (sin calibración) = generalización pura inter-sujeto.

Con calibración few-shot (calibrate_k_per_class > 0), se toman k épocas/clase del sujeto test para ajustar las medias riemannianas (MDM), lo que típicamente mejora ACC/F1 sin reentrenar toda la cadena (práctica común en el estado del arte).

In [5]:
# %% [INTER-SUBJECT Riemann desde JSON — validación por sujetos + calibración opcional]
import json
import numpy as np
import pandas as pd
import mne
import matplotlib.pyplot as plt
from math import ceil
from pathlib import Path
from datetime import datetime

from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import (
    accuracy_score, f1_score, confusion_matrix, ConfusionMatrixDisplay,
    classification_report
)

def run_inter_subject_riem_from_json(
    fif_dir=DATA_PROC,
    folds_json_path=None,
    crop_window=(0.5, 3.5),
    motor_only=True,
    zscore_epoch=False,
    fb_bands=DEFAULT_FB_BANDS,
    cov_estimator='oas',
    model='mdm',
    calibrate_n=None,                 # calibración opcional con sujetos de TEST (modo antiguo)
    calibrate_k_per_class=None,       # <<< NUEVO: calibración per-subject k épocas/clase
    val_ratio_subjects=0.16,          # % sujetos de TRAIN a VALID
    random_state=42,                  # reproducibilidad del split por sujeto
    max_subplots_per_fig=12,
    n_cols=4,
    save_csv_name=None,
    save_txt_name=None,
    print_fold_classification_table=True
):
    """
    Inter-subject CV Riemann con VALIDACIÓN INTERNA POR SUJETOS y calibración opcional.

    Modos de calibración:
      - Sin calibración: calibrate_n=None y calibrate_k_per_class=None o <=0.
      - Calibración por sujetos completos (modo antiguo): calibrate_n>0.
      - Calibración per-subject k-shots (RECOMENDADO): calibrate_k_per_class>0.
        Para cada sujeto de TEST: toma k épocas por clase para calibrar (junto con TRAIN)
        y evalúa en las épocas restantes de ese mismo sujeto.
    """
    ts = datetime.now().strftime("%Y%m%d-%H%M%S")
    run_tag = f"riem_inter_subject_{model}_{ts}"
    logger, log_path = _init_logger_riem(run_name=run_tag)

    # Normalizar flags de calibración
    k_cal = 0 if calibrate_k_per_class is None else int(calibrate_k_per_class)
    subj_cal = 0 if calibrate_n is None else int(calibrate_n)
    if k_cal > 0:
        subj_cal = 0  # prioridad al modo per-subject
        logger.info(f"[RUN {run_tag}] VALID por sujetos + CALIBRACIÓN PER-SUBJECT (k={k_cal} por clase)")
    elif subj_cal > 0:
        logger.info(f"[RUN {run_tag}] VALID por sujetos + CALIBRACIÓN por sujetos completos (n={subj_cal})")
    else:
        logger.info(f"[RUN {run_tag}] VALID por sujetos (SIN calibración)")

    logger.info(
        f"Perillas: crop={crop_window}, motor_only={motor_only}, zscore_epoch={zscore_epoch}, "
        f"fb_bands={fb_bands}, cov={cov_estimator}, val_ratio_subjects={val_ratio_subjects:.2f}"
    )

    # ---------- JSON de folds ----------
    if folds_json_path is None:
        folds_json_path = PROJ / 'models' / 'folds' / 'Kfold5.json'
    folds_json_path = Path(folds_json_path)
    if not folds_json_path.exists():
        raise FileNotFoundError(f"No se encontró folds JSON en {folds_json_path}")

    with open(folds_json_path, "r", encoding="utf-8") as f:
        payload = json.load(f)
    folds = payload.get("folds", [])
    subject_ids_json = payload.get("subject_ids", [])
    logger.info(f"Folds cargadas: {len(folds)} | sujetos en JSON: {len(subject_ids_json)}")

    # ---------- Cargar epochs por sujeto ----------
    ep_map = {}
    for sid in subject_ids_json:
        fif_path = Path(fif_dir) / f"{sid}_MI-epo.fif"
        if fif_path.exists():
            try:
                ep_map[sid] = mne.read_epochs(str(fif_path), preload=True, verbose=False)
            except Exception as e:
                logger.warning(f"Error leyendo {fif_path} para {sid}: {e}")
        else:
            logger.warning(f"Falta archivo FIF para {sid}: {fif_path}")

    # ---------- Acumuladores ----------
    rows = []
    cm_items = []
    cm_global = None
    classes_global = None
    per_fold_reports = []

    # ---------- Iterar folds ----------
    for f in folds:
        fold_i = int(f.get("fold"))
        train_sids = [sid for sid in f.get("train", []) if sid in ep_map]
        test_sids  = [sid for sid in f.get("test", [])  if sid in ep_map]

        logger.info(f"[Fold {fold_i}] train({len(train_sids)}): {train_sids}")
        logger.info(f"[Fold {fold_i}] test ({len(test_sids)}): {test_sids}")

        if len(train_sids) == 0 or len(test_sids) == 0:
            logger.warning(f"[Fold {fold_i}] faltan sujetos train/test — saltando fold.")
            continue

        # ---------- VALIDACIÓN INTERNA POR SUJETOS ----------
        rng = np.random.RandomState(random_state + fold_i)
        n_val_subj = max(1, int(round(len(train_sids) * float(val_ratio_subjects))))
        val_indices = rng.choice(len(train_sids), size=n_val_subj, replace=False)
        val_sids = sorted([train_sids[i] for i in val_indices])
        tr_sids  = sorted([sid for sid in train_sids if sid not in set(val_sids)])

        logger.info(f"[Fold {fold_i}] split interno → train_sids={len(tr_sids)}, val_sids={len(val_sids)}")

        # Concatenar epochs por split
        ep_tr  = mne.concatenate_epochs([ep_map[sid] for sid in tr_sids],  on_mismatch='ignore')
        ep_val = mne.concatenate_epochs([ep_map[sid] for sid in val_sids], on_mismatch='ignore')
        ep_te  = mne.concatenate_epochs([ep_map[sid] for sid in test_sids], on_mismatch='ignore')

        # Alinear canales con respecto a train
        try:
            ep_val = ep_val.copy().reorder_channels(ep_tr.ch_names)
            ep_te  = ep_te.copy().reorder_channels(ep_tr.ch_names)
        except Exception as e:
            logger.warning(f"[Fold {fold_i}] reorder_channels: {e}")

        # Etiquetas
        _, y_tr_str  = _riem_epochs_to_Xy(ep_tr)
        _, y_val_str = _riem_epochs_to_Xy(ep_val)
        _, y_te_str  = _riem_epochs_to_Xy(ep_te)

        le = LabelEncoder().fit(np.concatenate([y_tr_str, y_val_str, y_te_str]))
        y_tr  = le.transform(y_tr_str)
        y_val = le.transform(y_val_str)
        y_te  = le.transform(y_te_str)
        classes = list(le.classes_)

        if classes_global is None:
            classes_global = classes
            cm_global = np.zeros((len(classes), len(classes)), dtype=int)

        # ---------- FEATURES/ESPACIO (ajuste SOLO con TRAIN) ----------
        with mne.utils.use_log_level("ERROR"):
            # Fit contra ep_tr; transformar ep_val (y ep_te si no hay calibración) con el MISMO ajuste
            Ctr, y_tr_fit, Cval, y_val_fit, _ = _riem_fb_cov_train_test(
                ep_tr, ep_val,
                fb_bands=fb_bands,
                motor_only=motor_only,
                zscore_epoch=zscore_epoch,
                crop_window=crop_window,
                cov_estimator=cov_estimator,
                model=model
            )
            _Ctr_dummy, _ytr_dummy, Cte, y_te_fit, _ = _riem_fb_cov_train_test(
                ep_tr, ep_te,
                fb_bands=fb_bands,
                motor_only=motor_only,
                zscore_epoch=zscore_epoch,
                crop_window=crop_window,
                cov_estimator=cov_estimator,
                model=model
            )

        # Aplanar si viene 4D (banco) a 3D (apilado)
        Ctr  = _to_block_if_4d(Ctr)
        Cval = _to_block_if_4d(Cval)
        Cte  = _to_block_if_4d(Cte)

        # ---------- ENTRENAR CLASIFICADOR SOLO CON TRAIN ----------
        clf = _riem_make_clf(model=model)
        clf.fit(Ctr, y_tr_fit)

        # ---------- VALID ----------
        yhat_val = clf.predict(Cval)
        acc_val = accuracy_score(y_val_fit, yhat_val)
        f1m_val = f1_score(y_val_fit, yhat_val, average='macro')
        logger.info(f"[Fold {fold_i}] VAL   acc={acc_val:.4f} | f1m={f1m_val:.4f} | n_val={len(y_val_fit)}")

        # ---------- TEST: tres caminos ----------
        if k_cal > 0:
            # ===== Calibración per-subject k-shots =====
            y_te_all, yhat_all = [], []
            cm_fold = np.zeros((len(classes), len(classes)), dtype=int)

            for sid in test_sids:
                ep_te_subj = ep_map[sid].copy()
                try:
                    ep_te_subj = ep_te_subj.reorder_channels(ep_tr.ch_names)
                except Exception as e:
                    logger.warning(f"[Fold {fold_i}] reorder (test subject {sid}): {e}")

                # Partir en CALIB vs EVAL (k por clase)
                ep_calib, ep_eval = _split_calibration(ep_te_subj, k_per_class=k_cal)
                if (ep_calib is None) or (len(ep_calib) == 0) or (len(ep_eval) == 0):
                    logger.warning(f"[Fold {fold_i}] {sid}: calibración insuficiente (k={k_cal}) o sin eval; se omite este sujeto.")
                    continue

                # Etiquetas
                _, y_calib_str = _riem_epochs_to_Xy(ep_calib)
                _, y_eval_str  = _riem_epochs_to_Xy(ep_eval)
                y_calib = le.transform(y_calib_str)
                y_eval  = le.transform(y_eval_str)

                # Refit espacio con TRAIN + CALIB_del_sujeto; transformar EVAL
                with mne.utils.use_log_level("ERROR"):
                    ep_train_plus_calib = mne.concatenate_epochs([ep_tr, ep_calib], on_mismatch='ignore')
                    Ctr_comb, y_tr_comb_fit, Ceval, y_eval_fit, _ = _riem_fb_cov_train_test(
                        ep_train_plus_calib, ep_eval,
                        fb_bands=fb_bands,
                        motor_only=motor_only,
                        zscore_epoch=zscore_epoch,
                        crop_window=crop_window,
                        cov_estimator=cov_estimator,
                        model=model
                    )
                Ctr_comb = _to_block_if_4d(Ctr_comb)
                Ceval    = _to_block_if_4d(Ceval)

                clf_s = _riem_make_clf(model=model)
                clf_s.fit(Ctr_comb, y_tr_comb_fit)
                yhat_eval = clf_s.predict(Ceval)

                # Acumular por sujeto
                y_te_all.append(y_eval_fit)     # usar y_eval_fit alineado a Ceval
                yhat_all.append(yhat_eval)
                cm_s = confusion_matrix(y_eval_fit, yhat_eval, labels=np.arange(len(classes)))
                cm_fold += cm_s
                logger.info(f"[Fold {fold_i}] TEST (per-subject k-shots) {sid} → acc={accuracy_score(y_eval_fit, yhat_eval):.4f}, n={len(y_eval_fit)}")

            if len(y_te_all) == 0:
                logger.warning(f"[Fold {fold_i}] Sin sujetos válidos para k-shots (k={k_cal}). Se usa modelo sin calibrar.")
                y_te_cat = y_te_fit
                yhat_te  = clf.predict(Cte)
                cm = confusion_matrix(y_te_cat, yhat_te, labels=np.arange(len(classes)))
            else:
                y_te_cat = np.concatenate(y_te_all)
                yhat_te  = np.concatenate(yhat_all)
                cm = cm_fold

        elif subj_cal > 0:
            # ===== Calibración por sujetos completos (modo antiguo) =====
            n_subjs = min(int(subj_cal), len(test_sids))
            if n_subjs >= len(test_sids):
                logger.warning(f"[Fold {fold_i}] calibrate_n ({subj_cal}) >= nº test_sids ({len(test_sids)}). Se reducirá a {len(test_sids)-1}.")
                n_subjs = max(0, len(test_sids) - 1)

            calib_sids = test_sids[:n_subjs]
            rest_sids  = test_sids[n_subjs:]

            ep_calib   = mne.concatenate_epochs([ep_map[sid] for sid in calib_sids], on_mismatch='ignore') if calib_sids else None
            ep_te_rest = mne.concatenate_epochs([ep_map[sid] for sid in rest_sids],  on_mismatch='ignore') if rest_sids else None

            if (ep_calib is None) or (ep_te_rest is None):
                y_te_cat = y_te_fit
                yhat_te  = clf.predict(Cte)
                cm = confusion_matrix(y_te_cat, yhat_te, labels=np.arange(len(classes)))
            else:
                try:
                    ep_calib   = ep_calib.copy().reorder_channels(ep_tr.ch_names)
                    ep_te_rest = ep_te_rest.copy().reorder_channels(ep_tr.ch_names)
                except Exception as e:
                    logger.warning(f"[Fold {fold_i}] reorder (calib/test_rest): {e}")

                with mne.utils.use_log_level("ERROR"):
                    ep_train_plus_calib = mne.concatenate_epochs([ep_tr, ep_calib], on_mismatch='ignore')
                    Ctr_comb, y_tr_comb_fit, Cte_rest, y_te_rest_fit, _ = _riem_fb_cov_train_test(
                        ep_train_plus_calib, ep_te_rest,
                        fb_bands=fb_bands,
                        motor_only=motor_only,
                        zscore_epoch=zscore_epoch,
                        crop_window=crop_window,
                        cov_estimator=cov_estimator,
                        model=model
                    )
                Ctr_comb = _to_block_if_4d(Ctr_comb)
                Cte_rest = _to_block_if_4d(Cte_rest)

                clf_c = _riem_make_clf(model=model)
                clf_c.fit(Ctr_comb, y_tr_comb_fit)
                yhat_te  = clf_c.predict(Cte_rest)
                y_te_cat = y_te_rest_fit
                cm = confusion_matrix(y_te_cat, yhat_te, labels=np.arange(len(classes)))
        else:
            # ===== Sin calibración =====
            y_te_cat = y_te_fit
            yhat_te  = clf.predict(Cte)
            cm = confusion_matrix(y_te_cat, yhat_te, labels=np.arange(len(classes)))

        # ---------- MÉTRICAS TEST ----------
        acc = accuracy_score(y_te_cat, yhat_te)
        f1m = f1_score(y_te_cat, yhat_te, average='macro')
        cm_global += cm

        # Reporte por fold
        if print_fold_classification_table:
            try:
                rep = classification_report(y_te_cat, yhat_te, target_names=classes, digits=4)
                print(f"\n[Fold {fold_i}/{len(folds)}] Classification report (TEST)\n{rep}")
                logger.info(f"[Fold {fold_i}] Classification report (TEST):\n{rep}")
                per_fold_reports.append((fold_i, rep))
            except Exception as e:
                logger.warning(f"[Fold {fold_i}] classification_report error: {e}")

        rows.append(dict(
            fold=int(fold_i),
            train_subjects=",".join(tr_sids),
            val_subjects=",".join(val_sids),
            test_subjects=",".join(test_sids),
            val_acc=float(acc_val),
            val_f1_macro=float(f1m_val),
            acc=float(acc),
            f1_macro=float(f1m),
            n_val=int(len(y_val_fit)),
            n_test=int(len(y_te_cat)),
            calibrate_mode=("k-per-class" if k_cal > 0 else ("subjects" if subj_cal > 0 else "none")),
            calibrate_param=(k_cal if k_cal > 0 else (subj_cal if subj_cal > 0 else 0))
        ))
        cm_items.append((f"fold_{fold_i}", cm, classes))

    # ---------- Consolidados ----------
    df_rows = pd.DataFrame(rows).sort_values("fold") if rows else pd.DataFrame()
    acc_mu   = float(df_rows['acc'].mean()) if not df_rows.empty else 0.0
    f1_mu    = float(df_rows['f1_macro'].mean()) if not df_rows.empty else 0.0
    val_mu   = float(df_rows['val_acc'].mean()) if not df_rows.empty else 0.0
    valf1_mu = float(df_rows['val_f1_macro'].mean()) if not df_rows.empty else 0.0

    if not df_rows.empty:
        df_rows = pd.concat([df_rows, pd.DataFrame([{
            'fold': 0,
            'train_subjects': 'GLOBAL',
            'val_subjects': 'GLOBAL',
            'test_subjects': 'GLOBAL',
            'val_acc': val_mu,
            'val_f1_macro': valf1_mu,
            'acc': acc_mu,
            'f1_macro': f1_mu,
            'n_val': int(df_rows['n_val'].sum()),
            'n_test': int(df_rows['n_test'].sum()),
            'calibrate_mode': df_rows['calibrate_mode'].mode()[0] if 'calibrate_mode' in df_rows else 'none',
            'calibrate_param': int(df_rows['calibrate_param'].mean()) if 'calibrate_param' in df_rows else 0
        }])], ignore_index=True)

    # ---------- Guardar CSV/TXT ----------
    out_csv = (RIEM_TAB_DIR / f"{ts}_{save_csv_name}") if save_csv_name \
              else (RIEM_TAB_DIR / f"riem_inter_subject_{model}_{ts}.csv")
    df_rows.to_csv(out_csv, index=False)
    logger.info(f"CSV consolidado → {out_csv}")
    print("CSV consolidado →", out_csv)

    out_txt = (RIEM_LOG_DIR / f"{ts}_{save_txt_name}") if save_txt_name \
              else (RIEM_LOG_DIR / f"riem_inter_subject_{model}_{ts}.txt")
    with open(out_txt, "w", encoding="utf-8") as f:
        f.write(f"INTER-SUBJECT Riemann ({model}) — Con VALID interno por sujetos\n")
        f.write(f"Generado: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
        f.write(f"Total filas: {len(df_rows)}\n\n")
        header = df_rows.columns.tolist()
        f.write(" | ".join(header) + "\n")
        f.write("-" * 160 + "\n")
        for _, r in df_rows.iterrows():
            vals = []
            for kcol in header:
                v = r[kcol]
                if isinstance(v, float):
                    vals.append(f"{v:.4f}")
                elif isinstance(v, (np.integer,)):
                    vals.append(str(int(v)))
                else:
                    vals.append(str(v))
            f.write(" | ".join(vals) + "\n")
    logger.info(f"TXT consolidado → {out_txt}")
    print("TXT consolidado →", out_txt)

    # ---------- Mosaicos de confusión por fold ----------
    if cm_items:
        n = len(cm_items)
        per_fig = max(1, int(max_subplots_per_fig))
        n_figs = ceil(n / per_fig)

        def _n_rows_for_count(count):
            return ceil(count / n_cols)

        for fig_idx in range(n_figs):
            start = fig_idx * per_fig
            end   = min((fig_idx + 1) * per_fig, n)
            chunk = cm_items[start:end]
            count = len(chunk)
            n_rows = _n_rows_for_count(count)

            fig, axes = plt.subplots(n_rows, n_cols, figsize=(4.5*n_cols, 3.8*n_rows), dpi=140)
            axes = np.atleast_2d(axes).flatten()
            for ax_i, (label, cm_sum, classes) in enumerate(chunk):
                ax = axes[ax_i]
                disp = ConfusionMatrixDisplay(cm_sum, display_labels=classes)
                disp.plot(ax=ax, cmap="Blues", colorbar=False, values_format='d')
                ax.set_title(f"{label}")
                ax.set_xlabel(""); ax.set_ylabel("")
            for j in range(ax_i + 1, len(axes)):
                axes[j].axis("off")

            out_png = RIEM_FIG_DIR / f"riem_inter_subject_confusions_{model}_{ts}_p{fig_idx+1}.png"
            fig.suptitle(f"Inter-Subject Riemann ({model}) — Matrices de confusión (página {fig_idx+1}/{n_figs})",
                         y=0.995, fontsize=14)
            fig.tight_layout(rect=[0, 0, 1, 0.97])
            fig.savefig(out_png)
            plt.close(fig)
            logger.info(f"Figura consolidada → {out_png}")
            print("Figura consolidada →", out_png)

    # ---------- Matriz GLOBAL ----------
    if cm_global is not None and classes_global is not None:
        fig, ax = plt.subplots(figsize=(6.5, 5.2), dpi=140)
        disp = ConfusionMatrixDisplay(cm_global, display_labels=classes_global)
        disp.plot(ax=ax, cmap="Blues", colorbar=True, values_format='d')
        ax.set_title(f"Inter-Subject Riemann ({model}) — Matriz de confusión GLOBAL")
        fig.tight_layout()
        out_png_glob = RIEM_FIG_DIR / f"riem_inter_subject_global_confusion_{model}_{ts}.png"
        fig.savefig(out_png_glob)
        plt.close(fig)
        logger.info(f"Matriz GLOBAL → {out_png_glob}")
        print("Matriz GLOBAL →", out_png_glob)

    logger.info(f"[GLOBAL] VAL_acc={val_mu:.3f} | VAL_f1m={valf1_mu:.3f} | TEST_acc={acc_mu:.3f} | TEST_f1m={f1_mu:.3f}")
    print(f"[GLOBAL] VAL_acc={val_mu:.3f} | VAL_f1m={valf1_mu:.3f} | TEST_acc={acc_mu:.3f} | TEST_f1m={f1_mu:.3f}")
    logger.info(f"Log global → {log_path}")
    print(f"Log global → {log_path}")

    return df_rows.reset_index(drop=True)


### Bloque — Ejemplos

Corre INTRA y LOSO con MDM (o FgMDM).

In [None]:
# %% [Ejemplos — Riemann]
# INTRA — FgMDM (aprovecha mejor la estructura multi-banda), mismas bandas y setup
# df_intra_fgmdm = run_intra_all_riem(
#     fif_dir=DATA_PROC,
#     k=5,
#     random_state=42,
#     crop_window=(0.5, 4.5),
#     motor_only=True,
#     zscore_epoch=False,
#     fb_bands=DEFAULT_FB_BANDS,
#     cov_estimator='oas',
#     model='fgmdm',             # << FgMDM
#     save_csv_name="riem_intra_fgmdm_optim.csv",
#     save_txt_name="riem_intra_fgmdm_optim.txt"
# )


# LOSO — MDM, sin calibración
df_inter_fgmdm = run_inter_subject_riem_from_json(
    fif_dir=DATA_PROC,
    folds_json_path=PROJ / 'models' / 'folds' / 'Kfold5.json',  # path a tu JSON de folds
    crop_window=(0.5, 3.5),
    motor_only=True,
    zscore_epoch=False,
    fb_bands=DEFAULT_FB_BANDS,
    cov_estimator='oas',
    model='fgmdm',                # << FgMDM
    calibrate_k_per_class=5,                # calibración: 5 epochs por sujeto de test
    max_subplots_per_fig=12,
    n_cols=4,
    save_csv_name="riem_inter_fgmdm_calib.csv",
    save_txt_name="riem_inter_fgmdm_calib.txt"
)


[00:11:04] INFO: [RUN riem_inter_subject_fgmdm_20251030-001104] VALID por sujetos + CALIBRACIÓN PER-SUBJECT (k=5 por clase)


INFO:riem_inter_subject_fgmdm_20251030-001104:[RUN riem_inter_subject_fgmdm_20251030-001104] VALID por sujetos + CALIBRACIÓN PER-SUBJECT (k=5 por clase)


[00:11:04] INFO: Perillas: crop=(0.5, 3.5), motor_only=True, zscore_epoch=False, fb_bands=[(8, 10), (10, 12), (12, 14), (14, 16), (16, 18), (18, 20), (20, 22), (22, 24), (24, 26), (26, 28), (28, 30)], cov=oas, val_ratio_subjects=0.16


INFO:riem_inter_subject_fgmdm_20251030-001104:Perillas: crop=(0.5, 3.5), motor_only=True, zscore_epoch=False, fb_bands=[(8, 10), (10, 12), (12, 14), (14, 16), (16, 18), (18, 20), (20, 22), (22, 24), (24, 26), (26, 28), (28, 30)], cov=oas, val_ratio_subjects=0.16


[00:11:04] INFO: Folds cargadas: 5 | sujetos en JSON: 103


INFO:riem_inter_subject_fgmdm_20251030-001104:Folds cargadas: 5 | sujetos en JSON: 103


[00:11:04] INFO: [Fold 1] train(82): ['S001', 'S002', 'S004', 'S005', 'S006', 'S007', 'S009', 'S010', 'S011', 'S012', 'S014', 'S015', 'S016', 'S017', 'S019', 'S020', 'S021', 'S022', 'S024', 'S025', 'S026', 'S027', 'S029', 'S030', 'S031', 'S032', 'S034', 'S035', 'S036', 'S037', 'S040', 'S041', 'S042', 'S043', 'S045', 'S046', 'S047', 'S048', 'S050', 'S051', 'S052', 'S053', 'S055', 'S056', 'S057', 'S058', 'S060', 'S061', 'S062', 'S063', 'S065', 'S066', 'S067', 'S068', 'S070', 'S071', 'S072', 'S073', 'S075', 'S076', 'S077', 'S078', 'S080', 'S081', 'S082', 'S083', 'S085', 'S086', 'S087', 'S090', 'S093', 'S094', 'S095', 'S096', 'S098', 'S099', 'S101', 'S102', 'S105', 'S106', 'S107', 'S108']


INFO:riem_inter_subject_fgmdm_20251030-001104:[Fold 1] train(82): ['S001', 'S002', 'S004', 'S005', 'S006', 'S007', 'S009', 'S010', 'S011', 'S012', 'S014', 'S015', 'S016', 'S017', 'S019', 'S020', 'S021', 'S022', 'S024', 'S025', 'S026', 'S027', 'S029', 'S030', 'S031', 'S032', 'S034', 'S035', 'S036', 'S037', 'S040', 'S041', 'S042', 'S043', 'S045', 'S046', 'S047', 'S048', 'S050', 'S051', 'S052', 'S053', 'S055', 'S056', 'S057', 'S058', 'S060', 'S061', 'S062', 'S063', 'S065', 'S066', 'S067', 'S068', 'S070', 'S071', 'S072', 'S073', 'S075', 'S076', 'S077', 'S078', 'S080', 'S081', 'S082', 'S083', 'S085', 'S086', 'S087', 'S090', 'S093', 'S094', 'S095', 'S096', 'S098', 'S099', 'S101', 'S102', 'S105', 'S106', 'S107', 'S108']


[00:11:04] INFO: [Fold 1] test (21): ['S003', 'S008', 'S013', 'S018', 'S023', 'S028', 'S033', 'S039', 'S044', 'S049', 'S054', 'S059', 'S064', 'S069', 'S074', 'S079', 'S084', 'S091', 'S097', 'S103', 'S109']


INFO:riem_inter_subject_fgmdm_20251030-001104:[Fold 1] test (21): ['S003', 'S008', 'S013', 'S018', 'S023', 'S028', 'S033', 'S039', 'S044', 'S049', 'S054', 'S059', 'S064', 'S069', 'S074', 'S079', 'S084', 'S091', 'S097', 'S103', 'S109']


[00:11:04] INFO: [Fold 1] split interno → train_sids=69, val_sids=13


INFO:riem_inter_subject_fgmdm_20251030-001104:[Fold 1] split interno → train_sids=69, val_sids=13


[00:14:34] INFO: [Fold 1] VAL   acc=0.3839 | f1m=0.3799 | n_val=1120


INFO:riem_inter_subject_fgmdm_20251030-001104:[Fold 1] VAL   acc=0.3839 | f1m=0.3799 | n_val=1120


[00:16:39] INFO: [Fold 1] TEST (per-subject k-shots) S003 → acc=0.2857, n=70


INFO:riem_inter_subject_fgmdm_20251030-001104:[Fold 1] TEST (per-subject k-shots) S003 → acc=0.2857, n=70
