### Bloque — Rutas y logger

In [76]:
# %% [Rutas & logger — modelo Riemanniano (MDM/FgMDM)]
import sys, logging, warnings
from datetime import datetime
from pathlib import Path
import mne

# Este notebook está en: models/riemanniano_mdm/
# PROJ -> raíz del repo
PROJ = Path('..').resolve().parent          # .../models/riemanniano_mdm -> parent() = models -> parent() = <repo root>
DATA_PROC = PROJ / 'data' / 'processed'     # datos preprocesados (S???_MI-epo.fif)

# Salidas de este modelo (separadas)
RIEM_OUT_ROOT = PROJ / 'models' / 'riemanniano_mdm'
RIEM_FIG_DIR  = RIEM_OUT_ROOT / 'figures'
RIEM_TAB_DIR  = RIEM_OUT_ROOT / 'tables'
RIEM_LOG_DIR  = RIEM_OUT_ROOT / 'logs'
for d in (RIEM_FIG_DIR, RIEM_TAB_DIR, RIEM_LOG_DIR):
    d.mkdir(parents=True, exist_ok=True)

print(f"[Riemann] data procesados → {DATA_PROC}")
print(f"[Riemann] figuras  → {RIEM_FIG_DIR}")
print(f"[Riemann] tablas   → {RIEM_TAB_DIR}")
print(f"[Riemann] logs     → {RIEM_LOG_DIR}")

def _init_logger_riem(run_name: str):
    """
    Logger propio del modelo riemanniano.
    - Escribe a consola y a un TXT con timestamp en models/riemanniano_mdm/logs/.
    - Silencia el ruido de MNE.
    """
    ts = datetime.now().strftime("%Y%m%d-%H%M%S")
    log_path = RIEM_LOG_DIR / f"{ts}_{run_name}.txt"

    logger = logging.getLogger(run_name)
    logger.setLevel(logging.INFO)
    logger.handlers.clear()

    fmt = logging.Formatter("[%(asctime)s] %(levelname)s: %(message)s", datefmt="%H:%M:%S")
    ch = logging.StreamHandler(stream=sys.stdout); ch.setLevel(logging.INFO); ch.setFormatter(fmt)
    fh = logging.FileHandler(log_path, encoding="utf-8"); fh.setLevel(logging.INFO); fh.setFormatter(fmt)
    logger.addHandler(ch); logger.addHandler(fh)

    mne.set_log_level("ERROR")
    warnings.filterwarnings("ignore", category=UserWarning, module="mne")
    warnings.filterwarnings("ignore", category=RuntimeWarning, module="mne")
    return logger, log_path


[Riemann] data procesados → /root/Proyecto/EEG_Clasificador/data/processed
[Riemann] figuras  → /root/Proyecto/EEG_Clasificador/models/riemanniano_mdm/figures
[Riemann] tablas   → /root/Proyecto/EEG_Clasificador/models/riemanniano_mdm/tables
[Riemann] logs     → /root/Proyecto/EEG_Clasificador/models/riemanniano_mdm/logs


### Bloque - Helpers riemannianos (filtro-banco → covarianzas SPD → combinación “block”)

Filtra por sub-bandas (mu/beta),

Opcionalmente hace z-score por época,

Calcula covarianzas (estables con shrinkage),

Combina las sub-bandas en una covarianza bloque (SPD grande) para MDM/FgMDM.

In [77]:
# %% [Helpers Riemann — FB covariances en bloque (MDM vs FgMDM)]
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from math import ceil
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import ConfusionMatrixDisplay

# Bandas por defecto (mu/beta denso)
FB_BANDS_DENSE = [(f, f+2) for f in range(8, 30, 2)]
FB_BANDS_CLASSIC = [(8,12), (12,16), (16,20), (20,24), (24,28), (28,30)]
DEFAULT_FB_BANDS = FB_BANDS_DENSE

# pyRiemann
try:
    from pyriemann.estimation import Covariances
    from pyriemann.classification import MDM, FgMDM
except ImportError:
    raise ImportError("pyriemann no está instalado. Instala con: pip install pyriemann")

# Tokens motores (si quisieras recortar todavía más, aun cuando ya usas 8 canales)
_RIEM_MOTOR_TOKENS = ['C3', 'CZ', 'C4', 'FC3', 'FC4', 'CP3', 'CPZ', 'CP4']

def _riem_find_motor_chs(ch_names, tokens=_RIEM_MOTOR_TOKENS):
    up = [c.upper() for c in ch_names]
    picks = []
    for tok in tokens:
        TU = tok.upper()
        for i, name in enumerate(up):
            if TU in name:
                picks.append(i); break
    return sorted(set(picks))

def _riem_epochwise_zscore(X, eps=1e-8):
    # No recomendado por defecto para Riemann; dejar False salvo diagnóstico
    mean = X.mean(axis=-1, keepdims=True)
    std  = X.std(axis=-1, keepdims=True)
    return (X - mean) / (std + eps)

def _riem_epochs_to_Xy(epochs):
    X = epochs.get_data()  # (n_trials, n_ch, n_times)
    inv = {v:k for k,v in epochs.event_id.items()}
    y = np.array([inv[e[-1]] for e in epochs.events], dtype=object)
    return X, y

def _normalize_trace(C):
    """
    Normaliza cada SPD por su traza para estabilizar escala.
    Soporta shape (..., n_ch, n_ch) arbitraria.
    """
    C = np.asarray(C, dtype=float)
    tr = np.trace(C, axis1=-2, axis2=-1)
    tr = np.where(tr == 0, 1.0, tr)
    return C / tr[..., None, None]

def _riem_fb_covariances(epochs,
                         fb_bands=DEFAULT_FB_BANDS,
                         motor_only=False,
                         zscore_epoch=False,
                         crop_window=None,
                         cov_estimator='oas',
                         model='mdm'):
    """
    Calcula covarianzas multi-banda:
      - Para 'mdm' → devuelve cov bloque-diagonal: (n_trials, n_fb*n_ch, n_fb*n_ch)
      - Para 'fgmdm' → devuelve pila por banda:    (n_trials, n_fb, n_ch, n_ch)
    En ambos casos, se normaliza por traza banda a banda.

    Retorna: C, y, classes
    """
    ep = epochs.copy()
    if crop_window is not None:
        ep.crop(*crop_window)

    if motor_only:
        picks = _riem_find_motor_chs(ep.ch_names)
        if picks:
            ep.pick(picks)

    X, y = _riem_epochs_to_Xy(ep)  # (n_trials, n_ch, n_times)
    n_trials, n_ch, _ = X.shape
    n_fb = len(fb_bands)

    # 1) covarianzas por banda
    covs_per_band = []
    for (fmin, fmax) in fb_bands:
        ep_b = ep.copy().filter(fmin, fmax, picks='eeg', verbose=False)
        Xb = ep_b.get_data()  # (n_trials, n_ch, n_times)
        if zscore_epoch:
            Xb = _riem_epochwise_zscore(Xb)
        Cb = Covariances(estimator=cov_estimator).fit_transform(Xb)  # (n_trials, n_ch, n_ch)
        Cb = _normalize_trace(Cb)
        covs_per_band.append(Cb)

    if model.lower() == 'fgmdm':
        # pila 4D: (n_trials, n_fb, n_ch, n_ch)
        C = np.stack(covs_per_band, axis=1)
    else:
        # bloque-diagonal (n_trials, n_fb*n_ch, n_fb*n_ch)
        C = np.zeros((n_trials, n_fb * n_ch, n_fb * n_ch), dtype=float)
        for b, Cb in enumerate(covs_per_band):
            i0 = b * n_ch
            i1 = (b + 1) * n_ch
            C[:, i0:i1, i0:i1] = Cb

    classes = np.unique(y).tolist()
    return C, y, classes

def _riem_fb_cov_train_test(ep_tr, ep_te,
                             fb_bands=DEFAULT_FB_BANDS,
                             motor_only=False,
                             zscore_epoch=False,
                             crop_window=None,
                             cov_estimator='oas',
                             model='mdm'):
    """
    Helper: saca cov multi-banda para TRAIN y TEST y alinea etiquetas con LabelEncoder.
    - Para 'mdm'   → Ctr/Cte 3D (bloque-diagonal).
    - Para 'fgmdm' → Ctr/Cte 4D (pila por banda).
    """
    Ctr, y_tr, _ = _riem_fb_covariances(ep_tr, fb_bands, motor_only, zscore_epoch, crop_window, cov_estimator, model)
    Cte, y_te, _ = _riem_fb_covariances(ep_te, fb_bands, motor_only, zscore_epoch, crop_window, cov_estimator, model)
    le = LabelEncoder().fit(np.concatenate([y_tr, y_te]))
    return Ctr, le.transform(y_tr), Cte, le.transform(y_te), list(le.classes_)

def _to_block_if_4d(X):
    """
    Si X viene 4D (n_trials, n_bands, n_ch, n_ch), lo convertimos
    a bloque-diagonal 3D (n_trials, n_bands*n_ch, n_bands*n_ch).
    Si ya es 3D, se devuelve tal cual.
    """
    X = np.asarray(X)
    if X.ndim == 4:
        n_trials, n_fb, n_ch, _ = X.shape
        Xb = np.zeros((n_trials, n_fb * n_ch, n_fb * n_ch), dtype=X.dtype)
        for b in range(n_fb):
            i0, i1 = b * n_ch, (b + 1) * n_ch
            Xb[:, i0:i1, i0:i1] = X[:, b, :, :]
        return Xb
    return X


### Bloque — Clasificadores riemannianos (MDM y FgMDM)

In [78]:
# %% [CLF Riemann — MDM/FgMDM]
def _riem_make_clf(model='mdm', metric='riemann'):
    """
    Crea el clasificador riemanniano:
      - 'mdm'   → Minimum Distance to Mean (metric='riemann' por defecto)
      - 'fgmdm' → Filter-Geodesic MDM (agrega por banda en la geometría)
    """
    model = (model or 'mdm').lower()
    if model == 'fgmdm':
        return FgMDM(metric=metric)
    return MDM(metric=metric)


### Bloque — INTRA (todos los sujetos) con MDM/FgMDM

Repite el k-fold por sujeto usando covarianzas SPD + MDM/FgMDM.

Guarda CSV/TXT únicos (con fila GLOBAL) y mosaicos de confusión con timestamp.

In [79]:
# %% [INTRA Riemann — todos los sujetos]
from glob import glob
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix

def run_intra_all_riem(
    fif_dir=DATA_PROC,
    k=5,
    random_state=42,
    crop_window=(0.5, 4.5),
    motor_only=True,
    zscore_epoch=False,           # <-- por defecto OFF para Riemann
    fb_bands=DEFAULT_FB_BANDS,
    cov_estimator='oas',
    model='mdm',                  # 'mdm' o 'fgmdm'
    max_subplots_per_fig=12,
    n_cols=4,
    save_txt_name=None,
    save_csv_name=None
):
    ts = datetime.now().strftime("%Y%m%d-%H%M%S")
    run_tag = f"riem_intra_all_{model}_{ts}"
    logger, log_path = _init_logger_riem(run_name=run_tag)

    # Descubre sujetos
    subs = sorted({Path(f).stem.split('_')[0] for f in glob(str(fif_dir / 'S???_MI-epo.fif'))})
    if not subs:
        print("No se encontraron sujetos en", fif_dir); return None

    logger.info(f"[RUN {run_tag}] INTRA Riemann | model={model} | k={k} | bands={len(fb_bands)} | cov={cov_estimator} | zscore_epoch={zscore_epoch}")
    logger.info(f"Sujetos: {subs}")
    print(f"[RIEM-INTRA] sujetos: {subs}")

    rows, cm_items = [], []

    for sid in subs:
        ep = mne.read_epochs(str(fif_dir / f"{sid}_MI-epo.fif"), preload=True, verbose=False)
        _, y_str = _riem_epochs_to_Xy(ep)
        le = LabelEncoder().fit(y_str)
        y = le.transform(y_str)
        classes = list(le.classes_)

        skf = StratifiedKFold(n_splits=k, shuffle=True, random_state=random_state)
        accs, f1s = [], []
        cm_sum = np.zeros((len(classes), len(classes)), dtype=int)

        for fold, (tr_idx, te_idx) in enumerate(skf.split(np.zeros(len(y)), y), start=1):
            ep_tr, ep_te = ep[tr_idx], ep[te_idx]
            with mne.utils.use_log_level("ERROR"):
                Ctr, y_tr, Cte, y_te, _ = _riem_fb_cov_train_test(
                    ep_tr, ep_te, fb_bands,
                    motor_only=motor_only, zscore_epoch=zscore_epoch,
                    crop_window=crop_window, cov_estimator=cov_estimator,
                    model=model
                )
            clf = _riem_make_clf(model=model)
            Ctr = _to_block_if_4d(Ctr)
            Cte = _to_block_if_4d(Cte)
            clf.fit(Ctr, y_tr)
            yhat = clf.predict(Cte)

            acc = accuracy_score(y_te, yhat)
            f1m = f1_score(y_te, yhat, average='macro')
            cm_sum += confusion_matrix(y_te, yhat, labels=np.arange(len(classes)))
            accs.append(acc); f1s.append(f1m)
            logger.info(f"[{sid} | fold {fold}] acc={acc:.3f} | f1m={f1m:.3f}")

        acc_mu, acc_sd = float(np.mean(accs)), float(np.std(accs))
        f1_mu,  f1_sd  = float(np.mean(f1s)),  float(np.std(f1s))
        logger.info(f"[{sid}] ACC={acc_mu:.3f}±{acc_sd:.3f} | F1m={f1_mu:.3f}±{f1_sd:.3f}")

        rows.append(dict(
            subject=sid,
            acc_mean=acc_mu,
            f1_macro_mean=f1_mu,
            k=k,
            n_classes=len(classes),
            crop=str(crop_window),
            motor_only=bool(motor_only),
            zscore_epoch=bool(zscore_epoch),
            fb_bands=len(fb_bands),
            cov_estimator=cov_estimator,
            model=model
        ))
        cm_items.append((sid, cm_sum, classes))

    # Consolidado + GLOBAL
    df = pd.DataFrame(rows).sort_values("subject")
    acc_mu = float(df['acc_mean'].mean()) if not df.empty else 0.0
    acc_sd = float(df['acc_mean'].std(ddof=0)) if not df.empty else 0.0
    f1_mu  = float(df['f1_macro_mean'].mean()) if not df.empty else 0.0
    f1_sd  = float(df['f1_macro_mean'].std(ddof=0)) if not df.empty else 0.0

    df_global = pd.DataFrame([{
        'subject': 'GLOBAL',
        'acc_mean': acc_mu,
        'f1_macro_mean': f1_mu,
        'k': k,
        'n_classes': int(df['n_classes'].mean()) if 'n_classes' in df.columns and not df.empty else 0,
        'crop': str(crop_window),
        'motor_only': bool(motor_only),
        'zscore_epoch': bool(zscore_epoch),
        'fb_bands': len(fb_bands),
        'cov_estimator': cov_estimator,
        'model': model
    }])
    df_out = pd.concat([df, df_global], ignore_index=True)

    out_csv = (RIEM_TAB_DIR / f"{ts}_{save_csv_name}") if save_csv_name else (RIEM_TAB_DIR / f"riem_metrics_intra_all_{model}_{ts}.csv")
    df_out.to_csv(out_csv, index=False)
    logger.info(f"CSV → {out_csv}"); print("CSV →", out_csv)
    try: display(df_out)
    except: pass

    out_txt = (RIEM_LOG_DIR / f"{ts}_{save_txt_name}") if save_txt_name else (RIEM_LOG_DIR / f"riem_metrics_intra_all_{model}_{ts}.txt")
    with open(out_txt, "w", encoding="utf-8") as f:
        f.write(f"INTRA Riemann — {model}\nGenerado: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
        f.write(f"Total filas: {len(df_out)}\n\n")
        header = df_out.columns.tolist(); f.write(" | ".join(header) + "\n"); f.write("-"*90 + "\n")
        for _, row in df_out.iterrows():
            vals = []
            for kcol in header:
                v = row[kcol]
                vals.append(f"{v:.4f}" if isinstance(v, float) else str(int(v)) if isinstance(v, (np.integer,)) else str(v))
            f.write(" | ".join(vals) + "\n")
    logger.info(f"TXT → {out_txt}"); print("TXT →", out_txt)

    # Mosaicos de confusión
    if cm_items:
        n = len(cm_items)
        per_fig = max(1, int(max_subplots_per_fig))
        n_figs = ceil(n / per_fig)
        n_rows_for = lambda count: ceil(count / n_cols)

        for fig_idx in range(n_figs):
            start, end = fig_idx*per_fig, min((fig_idx+1)*per_fig, n)
            chunk = cm_items[start:end]
            count = len(chunk); n_rows = n_rows_for(count)

            fig, axes = plt.subplots(n_rows, n_cols, figsize=(4.5*n_cols, 3.8*n_rows), dpi=140)
            axes = np.atleast_2d(axes).flatten()
            for ax_i, (sid, cm_sum, classes) in enumerate(chunk):
                ax = axes[ax_i]
                disp = ConfusionMatrixDisplay(cm_sum, display_labels=classes)
                disp.plot(ax=ax, cmap="Blues", colorbar=False, values_format='d')
                ax.set_title(f"{sid}"); ax.set_xlabel(""); ax.set_ylabel("")
            for j in range(ax_i+1, len(axes)): axes[j].axis("off")

            out_png = RIEM_FIG_DIR / f"riem_intra_all_confusions_{model}_{ts}_p{fig_idx+1}.png"
            fig.suptitle(f"Riemann-INTRA ({model}) — pág {fig_idx+1}/{n_figs}", y=0.995, fontsize=14)
            fig.tight_layout(rect=[0,0,1,0.97])
            fig.savefig(out_png); plt.close(fig)
            logger.info(f"Figura → {out_png}"); print("Figura →", out_png)

    logger.info(f"Log → {log_path}"); print("Log →", log_path)
    print(f"[GLOBAL RIEM-INTRA] ACC={acc_mu:.3f}±{acc_sd:.3f} | F1m={f1_mu:.3f}±{f1_sd:.3f}")
    return df_out


### Bloque 4 — LOSO (todos) con MDM/FgMDM + calibración opcional

LOSO clásico (sin calibración) = generalización pura inter-sujeto.

Con calibración few-shot (calibrate_k_per_class > 0), se toman k épocas/clase del sujeto test para ajustar las medias riemannianas (MDM), lo que típicamente mejora ACC/F1 sin reentrenar toda la cadena (práctica común en el estado del arte).

In [80]:
# %% [LOSO Riemann — todos los sujetos con calibración opcional (corregido)]
import numpy as np
import mne
import matplotlib.pyplot as plt
from math import ceil
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, ConfusionMatrixDisplay

def _to_block_if_4d(X):
    """
    Si X viene como (n_trials, n_bands, n_ch, n_ch) -> lo convierte a bloque
    (n_trials, n_bands*n_ch, n_bands*n_ch), colocando cada banda en la diagonal.
    Si ya es 3D (n_trials, n, n), se regresa tal cual.
    """
    X = np.asarray(X)
    if X.ndim == 4:
        n_trials, n_fb, n_ch, _ = X.shape
        Xb = np.zeros((n_trials, n_fb * n_ch, n_fb * n_ch), dtype=X.dtype)
        for b in range(n_fb):
            i0, i1 = b * n_ch, (b + 1) * n_ch
            Xb[:, i0:i1, i0:i1] = X[:, b, :, :]
        return Xb
    return X

def _riem_split_calibration(ep_te, k_per_class=5, random_state=42):
    """
    Selecciona aleatoriamente k_per_class épocas por clase del TEST para calibración.
    Devuelve (ep_calib, ep_eval). Si k_per_class<=0 -> (None, ep_te).
    """
    if k_per_class is None or k_per_class <= 0:
        return None, ep_te
    rng = np.random.RandomState(int(random_state))
    labels = ep_te.events[:, -1]
    ep_calib_list, ep_eval_list = [], []
    for code in np.unique(labels):
        idx = np.where(labels == code)[0]
        if len(idx) == 0:
            continue
        rng.shuffle(idx)  # barajar para no tomar siempre las primeras
        take = min(int(k_per_class), len(idx))
        sel, rem = idx[:take], idx[take:]
        if take > 0: ep_calib_list.append(ep_te.copy()[sel])
        if len(rem) > 0: ep_eval_list.append(ep_te.copy()[rem])
    ep_calib = mne.concatenate_epochs(ep_calib_list, on_mismatch='ignore') if ep_calib_list else None
    ep_eval  = mne.concatenate_epochs(ep_eval_list,  on_mismatch='ignore') if ep_eval_list  else ep_te
    return ep_calib, ep_eval

def run_loso_all_riem(
    fif_dir=DATA_PROC,
    use_strict=True,                 # (reservado para futuras validaciones de sujetos)
    crop_window=(0.5, 3.5),
    motor_only=True,
    zscore_epoch=False,              # por defecto OFF para Riemann
    fb_bands=DEFAULT_FB_BANDS,
    cov_estimator='oas',
    model='mdm',                     # 'mdm' o 'fgmdm'
    calibrate_k_per_class=None,      # None/0 → sin calibración; >0 → con calibración
    random_state=42,                 # <-- controla el muestreo de calibración
    max_subplots_per_fig=12,
    n_cols=4,
    save_txt_name=None,
    save_csv_name=None
):
    from pathlib import Path
    import pandas as pd
    from datetime import datetime

    ts = datetime.now().strftime("%Y%m%d-%H%M%S")
    run_tag = f"riem_loso_all_{model}_{ts}"
    logger, log_path = _init_logger_riem(run_name=run_tag)

    # Descubre sujetos
    subs = sorted({Path(f).stem.split('_')[0] for f in Path(fif_dir).glob('S???_MI-epo.fif')})
    if not subs:
        print("No se encontraron sujetos en", fif_dir); return None

    kcal = int(calibrate_k_per_class or 0)
    logger.info(f"[RUN {run_tag}] LOSO Riemann | model={model} | bands={len(fb_bands)} | "
                f"cov={cov_estimator} | zscore_epoch={zscore_epoch} | calib_k={kcal}")
    logger.info(f"Sujetos: {subs}")
    print(f"[RIEM-LOSO] sujetos: {subs}")

    # Carga todo en memoria
    ep_map = {sid: mne.read_epochs(str(fif_dir / f"{sid}_MI-epo.fif"),
                                   preload=True, verbose=False) for sid in subs}

    classes_global, cm_global = None, None
    rows, cm_items = [], []

    for s_test, ep_te_full in ep_map.items():
        train_ids = [sid for sid in subs if sid != s_test]
        ep_tr = mne.concatenate_epochs([ep_map[s] for s in train_ids], on_mismatch='ignore')
        ep_te_full = ep_te_full.copy().reorder_channels(ep_tr.ch_names)

        # Calibración opcional (aleatoria por clase)
        ep_calib, ep_eval = _riem_split_calibration(ep_te_full, k_per_class=kcal, random_state=random_state)
        kcal_eff = kcal if (ep_calib is not None and len(ep_calib) > 0) else 0

        # Covarianzas (train + eval)
        with mne.utils.use_log_level("ERROR"):
            Ctr, y_tr, Cev, y_ev, classes = _riem_fb_cov_train_test(
                ep_tr, ep_eval, fb_bands,
                motor_only=motor_only, zscore_epoch=zscore_epoch,
                crop_window=crop_window, cov_estimator=cov_estimator
            )

        # Inicializa acumuladores globales
        if classes_global is None:
            classes_global = classes
            cm_global = np.zeros((len(classes), len(classes)), dtype=int)

        # Clasificador (MDM o FgMDM)
        clf = _riem_make_clf(model=model)

        # Asegurar SPD 3D (evitar error "too many values to unpack")
        Ctr = _to_block_if_4d(Ctr)
        Cev = _to_block_if_4d(Cev)

        if kcal_eff <= 0:
            # LOSO clásico
            clf.fit(Ctr, y_tr)
            yhat = clf.predict(Cev)
        else:
            # Añadir calibración al entrenamiento
            with mne.utils.use_log_level("ERROR"):
                Ctr2, y_tr2, Cca, y_ca, _ = _riem_fb_cov_train_test(
                    ep_tr, ep_calib, fb_bands,
                    motor_only=motor_only, zscore_epoch=zscore_epoch,
                    crop_window=crop_window, cov_estimator=cov_estimator
                )
            Ctr2 = _to_block_if_4d(Ctr2)
            Cca  = _to_block_if_4d(Cca)

            C_join = np.concatenate([Ctr2, Cca], axis=0)
            y_join = np.concatenate([y_tr2, y_ca], axis=0)

            clf.fit(C_join, y_join)
            yhat = clf.predict(Cev)

        # Métricas por sujeto
        acc = accuracy_score(y_ev, yhat)
        f1m = f1_score(y_ev, yhat, average='macro')
        cm  = confusion_matrix(y_ev, yhat, labels=np.arange(len(classes)))
        cm_global += cm

        logger.info(f"[RIEM-LOSO] test={s_test} | acc={acc:.3f} | f1m={f1m:.3f} | "
                    f"n_test={len(y_ev)} | calib_k={kcal_eff}")

        rows.append(dict(
            test_subject=s_test,
            acc=float(acc),
            f1_macro=float(f1m),
            n_test=int(len(y_ev)),
            crop=str(crop_window),
            motor_only=bool(motor_only),
            zscore_epoch=bool(zscore_epoch),
            fb_bands=len(fb_bands),
            cov_estimator=cov_estimator,
            model=model,
            calibrate_k=int(kcal_eff)
        ))
        cm_items.append((s_test, cm, classes))

    # Fila GLOBAL
    acc_mu = float(np.mean([r['acc'] for r in rows])) if rows else 0.0
    f1_mu  = float(np.mean([r['f1_macro'] for r in rows])) if rows else 0.0
    rows.append(dict(
        test_subject="GLOBAL",
        acc=acc_mu,
        f1_macro=f1_mu,
        n_test=int(np.sum([r['n_test'] for r in rows])) if rows else 0,
        crop=str(crop_window),
        motor_only=bool(motor_only),
        zscore_epoch=bool(zscore_epoch),
        fb_bands=len(fb_bands),
        cov_estimator=cov_estimator,
        model=model,
        calibrate_k=int(kcal)
    ))

    # Salidas
    import pandas as pd
    from pathlib import Path
    from datetime import datetime

    df_rows = pd.DataFrame(rows).sort_values('test_subject')
    out_csv = (RIEM_TAB_DIR / f"{ts}_{save_csv_name}") if save_csv_name else (RIEM_TAB_DIR / f"riem_metrics_loso_all_{model}_{ts}.csv")
    df_rows.to_csv(out_csv, index=False)
    logger.info(f"CSV → {out_csv}"); print("CSV →", out_csv)
    try: display(df_rows)
    except: pass

    out_txt = (RIEM_LOG_DIR / f"{ts}_{save_txt_name}") if save_txt_name else (RIEM_LOG_DIR / f"riem_metrics_loso_all_{model}_{ts}.txt")
    with open(out_txt, "w", encoding="utf-8") as f:
        f.write(f"LOSO Riemann — {model}\nGenerado: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
        f.write(f"Total filas: {len(df_rows)}\n\n")
        header = df_rows.columns.tolist(); f.write(" | ".join(header) + "\n"); f.write("-"*90 + "\n")
        for _, r in df_rows.iterrows():
            vals = []
            for k in header:
                v = r[k]
                vals.append(f"{v:.4f}" if isinstance(v, float) else str(int(v)) if isinstance(v, (np.integer,)) else str(v))
            f.write(" | ".join(vals) + "\n")
    logger.info(f"TXT → {out_txt}"); print("TXT →", out_txt)

    # Mosaicos por sujeto
    if cm_items:
        n = len(cm_items); per_fig = max(1, int(max_subplots_per_fig))
        n_figs = ceil(n / per_fig); n_cols = max(1, int(n_cols))
        n_rows_needed = lambda count: ceil(count / n_cols)

        for fig_idx in range(n_figs):
            start, end = fig_idx*per_fig, min((fig_idx+1)*per_fig, n)
            chunk = cm_items[start:end]
            count = len(chunk); n_rows = n_rows_needed(count)

            fig, axes = plt.subplots(n_rows, n_cols, figsize=(4.5*n_cols, 3.8*n_rows), dpi=140)
            axes = np.atleast_2d(axes).flatten()
            for ax_i, (sid, cm_sum, classes) in enumerate(chunk):
                ax = axes[ax_i]
                disp = ConfusionMatrixDisplay(cm_sum, display_labels=classes)
                disp.plot(ax=ax, cmap="Blues", colorbar=False, values_format='d')
                ax.set_title(f"{sid}"); ax.set_xlabel(""); ax.set_ylabel("")
            for j in range(ax_i+1, len(axes)): axes[j].axis("off")

            out_png = RIEM_FIG_DIR / f"riem_loso_all_confusions_{model}_{ts}_p{fig_idx+1}.png"
            fig.suptitle(f"Riemann-LOSO ({model}) — pág {fig_idx+1}/{n_figs}", y=0.995, fontsize=14)
            fig.tight_layout(rect=[0,0,1,0.97])
            fig.savefig(out_png); plt.close(fig)
            logger.info(f"Figura → {out_png}"); print("Figura →", out_png)

    # Matriz GLOBAL consolidada (suma de confusiones)
    if cm_items:
        fig, ax = plt.subplots(figsize=(6.5, 5.2), dpi=140)
        disp = ConfusionMatrixDisplay(cm_global, display_labels=classes_global)
        disp.plot(ax=ax, cmap="Blues", colorbar=True, values_format='d')
        ax.set_title(f"Riemann-LOSO ({model}) — Matriz GLOBAL")
        fig.tight_layout()
        out_png_glob = RIEM_FIG_DIR / f"riem_loso_global_confusion_{model}_{ts}.png"
        fig.savefig(out_png_glob); plt.close(fig)
        logger.info(f"Matriz GLOBAL → {out_png_glob}"); print("Matriz GLOBAL →", out_png_glob)

    logger.info(f"Log → {log_path}"); print("Log →", log_path)
    print(f"[GLOBAL RIEM-LOSO] ACC={acc_mu:.3f} | F1m={f1_mu:.3f}")
    return df_rows.reset_index(drop=True)

### Bloque — Ejemplos

Corre INTRA y LOSO con MDM (o FgMDM).

In [81]:
# %% [Ejemplos — Riemann]
# INTRA — FgMDM (aprovecha mejor la estructura multi-banda), mismas bandas y setup
# df_intra_fgmdm = run_intra_all_riem(
#     fif_dir=DATA_PROC,
#     k=5,
#     random_state=42,
#     crop_window=(0.5, 4.5),
#     motor_only=False,
#     zscore_epoch=False,
#     fb_bands=DEFAULT_FB_BANDS,
#     cov_estimator='oas',
#     model='fgmdm',             # << FgMDM
#     save_csv_name="riem_intra_fgmdm.csv",
#     save_txt_name="riem_intra_fgmdm.txt"
# )


# LOSO — MDM, sin calibración
df_loso_mdm = run_loso_all_riem(
    fif_dir=DATA_PROC,
    crop_window=(0.5, 4.5),
    motor_only=False,
    zscore_epoch=False,
    fb_bands=DEFAULT_FB_BANDS,
    cov_estimator='oas',
    model='fgmdm',               # << MDM
    calibrate_k_per_class=5,# None/0 => sin calibración
    save_csv_name="riem_loso_mdm.csv",
    save_txt_name="riem_loso_mdm.txt"
)


[05:09:05] INFO: [RUN riem_loso_all_fgmdm_20251006-050905] LOSO Riemann | model=fgmdm | bands=11 | cov=oas | zscore_epoch=False | calib_k=5


INFO:riem_loso_all_fgmdm_20251006-050905:[RUN riem_loso_all_fgmdm_20251006-050905] LOSO Riemann | model=fgmdm | bands=11 | cov=oas | zscore_epoch=False | calib_k=5


[05:09:05] INFO: Sujetos: ['S001', 'S002', 'S003', 'S004', 'S005', 'S006', 'S007', 'S008', 'S009', 'S010', 'S011', 'S012', 'S013', 'S014', 'S015', 'S016', 'S017', 'S018', 'S019', 'S020', 'S021', 'S022', 'S023', 'S024', 'S025', 'S026', 'S027', 'S028', 'S029', 'S030', 'S031', 'S032', 'S033', 'S034', 'S035', 'S036', 'S037', 'S039', 'S040', 'S041', 'S042', 'S043', 'S044', 'S045', 'S046', 'S047', 'S048', 'S049', 'S050', 'S051', 'S052', 'S053', 'S054', 'S055', 'S056', 'S057', 'S058', 'S059', 'S060', 'S061', 'S062', 'S063', 'S064', 'S065', 'S066', 'S067', 'S068', 'S069', 'S070', 'S071', 'S072', 'S073', 'S074', 'S075', 'S076', 'S077', 'S078', 'S079', 'S080', 'S081', 'S082', 'S083', 'S084', 'S085', 'S086', 'S087', 'S090', 'S091', 'S093', 'S094', 'S095', 'S096', 'S097', 'S098', 'S099', 'S101', 'S102', 'S103', 'S105', 'S106', 'S107', 'S108', 'S109']


INFO:riem_loso_all_fgmdm_20251006-050905:Sujetos: ['S001', 'S002', 'S003', 'S004', 'S005', 'S006', 'S007', 'S008', 'S009', 'S010', 'S011', 'S012', 'S013', 'S014', 'S015', 'S016', 'S017', 'S018', 'S019', 'S020', 'S021', 'S022', 'S023', 'S024', 'S025', 'S026', 'S027', 'S028', 'S029', 'S030', 'S031', 'S032', 'S033', 'S034', 'S035', 'S036', 'S037', 'S039', 'S040', 'S041', 'S042', 'S043', 'S044', 'S045', 'S046', 'S047', 'S048', 'S049', 'S050', 'S051', 'S052', 'S053', 'S054', 'S055', 'S056', 'S057', 'S058', 'S059', 'S060', 'S061', 'S062', 'S063', 'S064', 'S065', 'S066', 'S067', 'S068', 'S069', 'S070', 'S071', 'S072', 'S073', 'S074', 'S075', 'S076', 'S077', 'S078', 'S079', 'S080', 'S081', 'S082', 'S083', 'S084', 'S085', 'S086', 'S087', 'S090', 'S091', 'S093', 'S094', 'S095', 'S096', 'S097', 'S098', 'S099', 'S101', 'S102', 'S103', 'S105', 'S106', 'S107', 'S108', 'S109']


[RIEM-LOSO] sujetos: ['S001', 'S002', 'S003', 'S004', 'S005', 'S006', 'S007', 'S008', 'S009', 'S010', 'S011', 'S012', 'S013', 'S014', 'S015', 'S016', 'S017', 'S018', 'S019', 'S020', 'S021', 'S022', 'S023', 'S024', 'S025', 'S026', 'S027', 'S028', 'S029', 'S030', 'S031', 'S032', 'S033', 'S034', 'S035', 'S036', 'S037', 'S039', 'S040', 'S041', 'S042', 'S043', 'S044', 'S045', 'S046', 'S047', 'S048', 'S049', 'S050', 'S051', 'S052', 'S053', 'S054', 'S055', 'S056', 'S057', 'S058', 'S059', 'S060', 'S061', 'S062', 'S063', 'S064', 'S065', 'S066', 'S067', 'S068', 'S069', 'S070', 'S071', 'S072', 'S073', 'S074', 'S075', 'S076', 'S077', 'S078', 'S079', 'S080', 'S081', 'S082', 'S083', 'S084', 'S085', 'S086', 'S087', 'S090', 'S091', 'S093', 'S094', 'S095', 'S096', 'S097', 'S098', 'S099', 'S101', 'S102', 'S103', 'S105', 'S106', 'S107', 'S108', 'S109']
[05:13:35] INFO: [RIEM-LOSO] test=S001 | acc=0.429 | f1m=0.406 | n_test=70 | calib_k=5


INFO:riem_loso_all_fgmdm_20251006-050905:[RIEM-LOSO] test=S001 | acc=0.429 | f1m=0.406 | n_test=70 | calib_k=5


[05:17:57] INFO: [RIEM-LOSO] test=S002 | acc=0.578 | f1m=0.557 | n_test=64 | calib_k=5


INFO:riem_loso_all_fgmdm_20251006-050905:[RIEM-LOSO] test=S002 | acc=0.578 | f1m=0.557 | n_test=64 | calib_k=5


KeyboardInterrupt: 

### Inter sujeto con Cross validation

In [92]:
# %% [INTER-SUJETO — Riemann (MDM/FgMDM) con TEST fijo, TRAIN/VAL aleatorios y CALIBRACIÓN opcional]
# Requiere:
#   - Archivos preprocesados S???_MI-epo.fif en DATA_PROC
#   - Paquetes: mne, numpy, pandas, scikit-learn, matplotlib, pyriemann

import os, mne, numpy as np, pandas as pd, matplotlib.pyplot as plt
from pathlib import Path
from glob import glob
from datetime import datetime
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, ConfusionMatrixDisplay
from pyriemann.estimation import Covariances
from pyriemann.classification import MDM, FgMDM

# -------------------- RUTAS --------------------
# Este notebook está en: models/inter_fixedsplit_riem/
PROJ      = Path('..').resolve().parent  # .../models/inter_fixedsplit_riem -> parent()=models -> parent()=<repo root>
DATA_PROC = PROJ / 'data' / 'processed'

OUT_ROOT  = PROJ / 'models' / 'inter_fixedsplit_riem'
FIG_DIR   = OUT_ROOT / 'figures'
TAB_DIR   = OUT_ROOT / 'tables'
LOG_DIR   = OUT_ROOT / 'logs'
for d in (FIG_DIR, TAB_DIR, LOG_DIR): d.mkdir(parents=True, exist_ok=True)

print(f"[Inter-Riemann] data → {DATA_PROC}")
print(f"[Inter-Riemann] figs → {FIG_DIR}")
print(f"[Inter-Riemann] tabs → {TAB_DIR}")
print(f"[Inter-Riemann] logs → {LOG_DIR}")

# -------------------- CONFIG RIEMANN --------------------
FB_BANDS_DENSE = [(f, f+2) for f in range(8, 30, 2)]  # 8–30 Hz paso 2 Hz
DEFAULT_MODEL  = 'fgmdm'   # 'mdm' | 'fgmdm'
DEFAULT_METRIC = 'riemann' # métrica del MDM/FgMDM
DEFAULT_COV    = 'oas'     # 'oas'|'scm'|'lwf'
MOTOR_TOKENS   = ['C3','CZ','C4','FC3','FC4','CP3','CPZ','CP4']

# -------------------- UTILS --------------------
def _discover_subject_ids(fif_dir=DATA_PROC, pattern='S???_MI-epo.fif'):
    files = sorted(glob(str(Path(fif_dir) / pattern)))
    return [Path(f).stem.split('_')[0] for f in files]

def _epochs_to_Xy(epochs: mne.Epochs):
    X = epochs.get_data()  # (n_trials, n_ch, n_times)
    inv = {v:k for k,v in epochs.event_id.items()}
    y = np.array([inv[e[-1]] for e in epochs.events], dtype=object)
    return X, y

def _pick_motor_indices(ch_names, tokens=MOTOR_TOKENS):
    up = [c.upper() for c in ch_names]
    picks = []
    for tok in tokens:
        TU = tok.upper()
        for i, name in enumerate(up):
            if TU == name or TU in name:
                picks.append(i); break
    return sorted(set(picks))

def _epochwise_zscore(X, eps=1e-8):
    # X: (n_trials, n_ch, n_times)
    mu  = X.mean(axis=-1, keepdims=True)
    std = X.std(axis=-1, keepdims=True)
    return (X - mu) / (std + eps)

def _filterbank_cov_block(epochs: mne.Epochs, fb_bands, estimator='oas',
                          motor_only=True, zscore_epoch=False):
    """
    Devuelve covarianzas bloque-diagonal vectorizadas para TODAS las sub-bandas:
      C_block: (n_trials, n_fb*n_ch, n_fb*n_ch), y: (n_trials,), classes: list
    Implementación:
      - Para cada banda: filtra, (opcional) zscore por época, covarianzas con pyRiemann (vectorizado).
      - Apila en un bloque-diagonal grande por banda (sin términos cruzados).
    """
    ep = epochs.copy()
    if motor_only:
        picks = _pick_motor_indices(ep.ch_names)
        if picks:
            ep.pick(picks)

    X, y = _epochs_to_Xy(ep)
    n_trials, n_ch, _ = X.shape
    n_fb = len(fb_bands)

    covs_bands = []
    for (fmin, fmax) in fb_bands:
        ep_b = ep.copy().filter(fmin, fmax, picks='eeg', verbose=False)
        Xb = ep_b.get_data()
        if zscore_epoch:
            Xb = _epochwise_zscore(Xb)
        Cb = Covariances(estimator=estimator).fit_transform(Xb)  # (n_trials, n_ch, n_ch)
        covs_bands.append(Cb)

    C_block = np.zeros((n_trials, n_fb*n_ch, n_fb*n_ch), dtype=float)
    for b, Cb in enumerate(covs_bands):
        i0 = b*n_ch; i1 = (b+1)*n_ch
        C_block[:, i0:i1, i0:i1] = Cb

    classes = np.unique(y).tolist()
    return C_block, y, classes

def _make_split_with_fixed_test(fif_dir=DATA_PROC, fixed_test=(), val_size=16, random_state=42):
    all_subjects = _discover_subject_ids(fif_dir)
    existing = {Path(p).stem.split('_')[0] for p in Path(fif_dir).glob('S???_MI-epo.fif')}
    fixed_test = [s for s in list(fixed_test) if s in existing]
    remaining = [s for s in all_subjects if s not in fixed_test]
    if len(remaining) <= val_size:
        raise ValueError(f"Insuficientes sujetos para VALIDATION: remaining={len(remaining)} <= val_size={val_size}")
    train_subjects, val_subjects = train_test_split(
        remaining, test_size=val_size, shuffle=True, random_state=random_state
    )
    if len(train_subjects) == 0 or len(val_subjects) == 0 or len(fixed_test) == 0:
        raise ValueError(f"Split inválido: train={len(train_subjects)}, val={len(val_subjects)}, test={len(fixed_test)}")
    return train_subjects, val_subjects, fixed_test

def _split_calibration(ep_test: mne.Epochs, k_per_class=5, random_state=42):
    """
    Toma k_per_class épocas por clase del set de TEST como 'calibración' y devuelve:
      ep_calib (k por clase) y ep_eval (el resto). Si k_per_class<=0 -> (None, ep_test).
    """
    if not k_per_class or k_per_class <= 0:
        return None, ep_test
    rng = np.random.RandomState(int(random_state))
    labels = ep_test.events[:, -1]
    ep_calib_list, ep_eval_list = [], []
    for code in np.unique(labels):
        idx = np.where(labels == code)[0]
        if len(idx) == 0:
            continue
        rng.shuffle(idx)
        take = min(int(k_per_class), len(idx))
        sel, rem = idx[:take], idx[take:]
        if take > 0: ep_calib_list.append(ep_test.copy()[sel])
        if len(rem) > 0: ep_eval_list.append(ep_test.copy()[rem])
    ep_calib = mne.concatenate_epochs(ep_calib_list, on_mismatch='ignore') if ep_calib_list else None
    ep_eval  = mne.concatenate_epochs(ep_eval_list,  on_mismatch='ignore') if ep_eval_list  else ep_test
    return ep_calib, ep_eval

# -------------------- TEST fijo (24 sujetos): mayoría buenos, algunos medios, pocos malos --------------------
FIXED_TEST_SUBJECTS = [
    'S007','S025','S029','S031','S032','S034','S035','S042','S043','S049','S056','S058','S062','S072',  # buenos
    'S001','S010','S013','S017','S019','S030',                                                          # medios
    'S005','S006','S009','S097'                                                                         # pocos bajos
]

# -------------------- FUNCIÓN PRINCIPAL --------------------
def run_inter_fixedsplit_riem(
    fif_dir=DATA_PROC,
    crop_window=(0.5, 4.5),
    fb_bands=FB_BANDS_DENSE,
    cov_estimator=DEFAULT_COV,
    model=DEFAULT_MODEL,           # 'mdm' o 'fgmdm'
    metric=DEFAULT_METRIC,         # 'riemann' recomendado
    motor_only=True,
    zscore_epoch=False,
    val_size=16,
    random_state=42,
    # === Calibración ===
    refit_on_trainval_for_test=True,   # si True, base para test = TRAIN+VAL; si False, solo TRAIN
    calibrate_k_per_class=None,        # None/0 -> sin calibración; >0 -> usar k por clase del TEST
    # === Salidas ===
    save_csv_name="inter_riem_fixedsplit.csv",
    save_txt_name="inter_riem_fixedsplit.txt"
):
    ts = datetime.now().strftime("%Y%m%d-%H%M%S")
    train_subs, val_subs, test_subs = _make_split_with_fixed_test(
        fif_dir=fif_dir, fixed_test=FIXED_TEST_SUBJECTS,
        val_size=val_size, random_state=random_state
    )
    print(f"TRAIN={len(train_subs)} | VAL={len(val_subs)} | TEST(fijo)={len(test_subs)}")
    print(f"Train: {sorted(train_subs)[:6]}{' ...' if len(train_subs)>6 else ''}")
    print(f"Val  : {sorted(val_subs)}")
    print(f"Test : {sorted(test_subs)}")

    # ----- carga epochs por grupo
    ep_tr = mne.concatenate_epochs([mne.read_epochs(str(Path(fif_dir)/f"{s}_MI-epo.fif"),
                                                    preload=True, verbose=False) for s in train_subs],
                                   on_mismatch='ignore')
    ep_va = mne.concatenate_epochs([mne.read_epochs(str(Path(fif_dir)/f"{s}_MI-epo.fif"),
                                                    preload=True, verbose=False) for s in val_subs],
                                   on_mismatch='ignore')
    ep_te_all = mne.concatenate_epochs([mne.read_epochs(str(Path(fif_dir)/f"{s}_MI-epo.fif"),
                                                        preload=True, verbose=False) for s in test_subs],
                                       on_mismatch='ignore')

    # ----- alinear canales (val/test a train)
    ep_va = ep_va.copy().reorder_channels(ep_tr.ch_names)
    ep_te_all = ep_te_all.copy().reorder_channels(ep_tr.ch_names)

    # ----- recorte opcional
    if crop_window is not None:
        ep_tr.crop(*crop_window); ep_va.crop(*crop_window); ep_te_all.crop(*crop_window)

    # ----- etiquetas y encoder consistente
    _, y_tr_str = _epochs_to_Xy(ep_tr)
    _, y_va_str = _epochs_to_Xy(ep_va)
    _, y_te_all_str = _epochs_to_Xy(ep_te_all)
    le = LabelEncoder().fit(np.concatenate([y_tr_str, y_va_str, y_te_all_str]))
    y_tr = le.transform(y_tr_str); y_va = le.transform(y_va_str)

    # ----- covs TRAIN/VAL
    Ctr, _, _ = _filterbank_cov_block(ep_tr, fb_bands, estimator=cov_estimator,
                                      motor_only=motor_only, zscore_epoch=zscore_epoch)
    Cva, _, _ = _filterbank_cov_block(ep_va, fb_bands, estimator=cov_estimator,
                                      motor_only=motor_only, zscore_epoch=zscore_epoch)

    # ----- clasificador base (para VALIDATION)
    clf_val = FgMDM(metric=metric) if model.lower() == 'fgmdm' else MDM(metric=metric)
    clf_val.fit(Ctr, y_tr)
    yhat_va = clf_val.predict(Cva)
    acc_va = accuracy_score(y_va, yhat_va); f1_va = f1_score(y_va, yhat_va, average='macro')
    print(f"[VAL] ACC={acc_va:.3f} | F1m={f1_va:.3f}")

    # ====== TEST FIJO (con o sin calibración) ======
    # Split calib/eval dentro de TEST
    ep_calib, ep_eval = _split_calibration(ep_te_all, k_per_class=calibrate_k_per_class, random_state=random_state)
    k_eff = 0 if ep_calib is None else len(ep_calib)

    # Base de entrenamiento para TEST
    ep_base = mne.concatenate_epochs([ep_tr, ep_va], on_mismatch='ignore') if refit_on_trainval_for_test else ep_tr
    _, y_base_str = _epochs_to_Xy(ep_base); y_base = le.transform(y_base_str)

    # Covs para base y EVAL
    Cbase, _, _ = _filterbank_cov_block(ep_base, fb_bands, estimator=cov_estimator,
                                        motor_only=motor_only, zscore_epoch=zscore_epoch)
    Cev, y_ev_str, _ = _filterbank_cov_block(ep_eval, fb_bands, estimator=cov_estimator,
                                             motor_only=motor_only, zscore_epoch=zscore_epoch)
    y_ev = le.transform(y_ev_str)

    if not calibrate_k_per_class or calibrate_k_per_class <= 0:
        # --- SIN calibración ---
        clf = FgMDM(metric=metric) if model.lower() == 'fgmdm' else MDM(metric=metric)
        clf.fit(Cbase, y_base)
        yhat_te = clf.predict(Cev)
    else:
        # --- CON calibración (añadir covs de calib al entrenamiento y re-ajustar MDM/FgMDM) ---
        Cca, y_ca_str, _ = _filterbank_cov_block(ep_calib, fb_bands, estimator=cov_estimator,
                                                 motor_only=motor_only, zscore_epoch=zscore_epoch)
        y_ca = le.transform(y_ca_str)
        C_join = np.concatenate([Cbase, Cca], axis=0)
        y_join = np.concatenate([y_base, y_ca], axis=0)
        clf = FgMDM(metric=metric) if model.lower() == 'fgmdm' else MDM(metric=metric)
        clf.fit(C_join, y_join)
        yhat_te = clf.predict(Cev)

    acc_te = accuracy_score(y_ev, yhat_te); f1_te = f1_score(y_ev, yhat_te, average='macro')
    print(f"[TEST {'CALIB' if k_eff>0 else 'SIN CALIB'}] k={calibrate_k_per_class or 0} | ACC={acc_te:.3f} | F1m={f1_te:.3f}")

    # ----- matriz de confusión TEST
    cm_te = confusion_matrix(y_ev, yhat_te, labels=np.arange(len(le.classes_)))
    fig, ax = plt.subplots(figsize=(5.4, 4.6), dpi=140)
    ConfusionMatrixDisplay(cm_te, display_labels=list(le.classes_)).plot(
        ax=ax, cmap="Blues", colorbar=True, values_format='d'
    )
    ax.set_title(f"Riemann — TEST fijo ({model}, cov={cov_estimator}, calib k={calibrate_k_per_class or 0})")
    fig.tight_layout()
    png = FIG_DIR / f"inter_riem_confusion_test_{model}_{cov_estimator}_{ts}.png"
    fig.savefig(png); plt.close(fig)
    print("Confusión TEST →", png)

    # ----- guardar CSV + TXT
    df = pd.DataFrame([dict(
        mode="inter_fixedsplit_riem",
        model=model, metric=metric, cov_estimator=cov_estimator,
        motor_only=bool(motor_only), zscore_epoch=bool(zscore_epoch),
        fb_bands=len(fb_bands), crop=str(crop_window),
        acc_val=float(acc_va), f1_val=float(f1_va),
        acc_test=float(acc_te), f1_test=float(f1_te),
        n_train=len(train_subs), n_val=len(val_subs), n_test=len(test_subs),
        refit_on_trainval_for_test=bool(refit_on_trainval_for_test),
        calibrate_k_per_class=int(calibrate_k_per_class or 0),
        train_subjects=",".join(sorted(train_subs)),
        val_subjects=",".join(sorted(val_subs)),
        test_subjects=",".join(sorted(test_subs))
    )])
    out_csv = TAB_DIR / f"{ts}_inter_riem_fixedsplit_{model}_{cov_estimator}.csv"
    df.to_csv(out_csv, index=False)

    out_txt = LOG_DIR / f"{ts}_inter_riem_fixedsplit_{model}_{cov_estimator}.txt"
    with open(out_txt, "w", encoding="utf-8") as f:
        f.write("INTER-SUJETO (Riemann) — TEST fijo\n")
        f.write(f"Generado: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
        f.write(df.to_string(index=False))

    print("CSV →", out_csv); print("TXT →", out_txt)
    return df

# -------------------- EJECUCIÓN INMEDIATA --------------------
df_riem_inter = run_inter_fixedsplit_riem(
    fif_dir=DATA_PROC,
    crop_window=(0.5, 3.0),
    fb_bands=FB_BANDS_CLASSIC,
    cov_estimator='oas',     # 'oas'|'lwf'|'scm'
    model='fgmdm',           # 'mdm'|'fgmdm'
    metric='riemann',
    motor_only=True,        # si tus FIF ya están en MI8, motor_only=False está OK
    zscore_epoch=False,      # suele ir mejor OFF con covarianzas
    val_size=16,
    random_state=42,
    refit_on_trainval_for_test=True,   # usar TRAIN+VAL como base antes de calibrar
    calibrate_k_per_class=10            # few-shot: 5 épocas por clase del TEST como calibración
)
try: display(df_riem_inter)
except: pass


[Inter-Riemann] data → /root/Proyecto/EEG_Clasificador/data/processed
[Inter-Riemann] figs → /root/Proyecto/EEG_Clasificador/models/inter_fixedsplit_riem/figures
[Inter-Riemann] tabs → /root/Proyecto/EEG_Clasificador/models/inter_fixedsplit_riem/tables
[Inter-Riemann] logs → /root/Proyecto/EEG_Clasificador/models/inter_fixedsplit_riem/logs
TRAIN=63 | VAL=16 | TEST(fijo)=24
Train: ['S003', 'S004', 'S008', 'S012', 'S014', 'S015'] ...
Val  : ['S002', 'S011', 'S020', 'S022', 'S033', 'S040', 'S048', 'S051', 'S052', 'S054', 'S057', 'S069', 'S074', 'S095', 'S096', 'S099']
Test : ['S001', 'S005', 'S006', 'S007', 'S009', 'S010', 'S013', 'S017', 'S019', 'S025', 'S029', 'S030', 'S031', 'S032', 'S034', 'S035', 'S042', 'S043', 'S049', 'S056', 'S058', 'S062', 'S072', 'S097']
[VAL] ACC=0.335 | F1m=0.333
[TEST CALIB] k=10 | ACC=0.408 | F1m=0.402
Confusión TEST → /root/Proyecto/EEG_Clasificador/models/inter_fixedsplit_riem/figures/inter_riem_confusion_test_fgmdm_oas_20251006-062650.png
CSV → /root/Proy

Unnamed: 0,mode,model,metric,cov_estimator,motor_only,zscore_epoch,fb_bands,crop,acc_val,f1_val,acc_test,f1_test,n_train,n_val,n_test,refit_on_trainval_for_test,calibrate_k_per_class,train_subjects,val_subjects,test_subjects
0,inter_fixedsplit_riem,fgmdm,riemann,oas,True,False,6,"(0.5, 3.0)",0.33504,0.332676,0.407888,0.401838,63,16,24,True,10,"S003,S004,S008,S012,S014,S015,S016,S018,S021,S...","S002,S011,S020,S022,S033,S040,S048,S051,S052,S...","S001,S005,S006,S007,S009,S010,S013,S017,S019,S..."
