# FBCSP + LDA (intra-sujeto y cross-sujeto)

### Bloque 1 — Rutas de salida + utilidades de logging y guardado

Qué hace: define las carpetas donde se guardarán figuras, tablas y logs bajo models/fbcsp_lda/. Incluye utilidades para inicializar un logger limpio, guardar matrices de confusión (PNG + CSV) y anexar métricas a CSVs acumulativos.

In [44]:
# %% [PATHS & LOGGING] — rutas de salida + helpers para logs/figuras/tablas
import sys, logging, warnings
from datetime import datetime
from pathlib import Path
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from sklearn.metrics import ConfusionMatrixDisplay
import mne

# Raíz del repo (este notebook está en models/fbcsp_lda/)
PROJ = Path('..').resolve().parent
DATA_PROC = PROJ / 'data' / 'processed'
OUT_ROOT = PROJ / 'models' / 'fbcsp_lda'
FIG_DIR  = OUT_ROOT / 'figures'
TAB_DIR  = OUT_ROOT / 'tables'
LOG_DIR  = OUT_ROOT / 'logs'
for d in (FIG_DIR, TAB_DIR, LOG_DIR):
    d.mkdir(parents=True, exist_ok=True)

print(f"Directorio de datos procesados: {DATA_PROC}")


Directorio de datos procesados: /root/Proyecto/EEG_Clasificador/data/processed


### Bloque 2 — FBCSP Helpers

Qué hace: helpers de modelo. Extraen X/y desde Epochs, aplican Filter-Bank + CSP por bandas y entrenan/escala LDA. Mantiene tus comentarios y añade compatibilidad con versiones de CSP.

In [45]:
# %% [HELPERS — comunes FBCSP/LOSO/Calibración/Logging]
# Reúne en un solo bloque:
#  - Descubrimiento de sujetos + DROP-only
#  - Perillas (_knobs_dict)
#  - FBCSP helpers (banco de filtros, picks motores, z-score por época, FBCSP transform)
#  - Clasificador (scaler + LDA)
#  - Split de calibración
#  - Logger

# ====== IMPORTS ======
import sys, logging, warnings, re
from glob import glob
from datetime import datetime
from pathlib import Path

import numpy as np
import pandas as pd
import mne

from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix
from mne.decoding import CSP

# ====== RUTAS y ARCHIVOS AUXILIARES ======
# NOTA: este bloque asume que en un bloque anterior ya definiste:
#   PROJ, DATA_PROC, LOG_DIR  (del bloque PATHS & LOGGING original)
# Archivo opcional con sujetos a excluir (DROP-only)
STRICT_DROP_TXT = PROJ / 'reports' / 'tables' / '02_prepro' / 'subjects_strict_DROP.txt'
_re_sid = re.compile(r'^S\d{3}$')
# ====== FBCSP HELPERS ======
# Bancos de filtros
FB_BANDS_DENSE = [(f, f+2) for f in range(8, 30, 2)]                  # denso 8–30 por pasos de 2 Hz
FB_BANDS_CLASSIC = [(8,12), (12,16), (16,20), (20,24), (24,28), (28,30)]
DEFAULT_FB_BANDS = FB_BANDS_DENSE

# Nº de componentes CSP por sub-banda (típico 6–8 en bancos densos)
DEFAULT_N_CSP = 6

# LDA con shrinkage automático robusto
LDA_PARAMS = dict(solver='lsqr', shrinkage='auto')

# Picks motores (si deseas limitar; tus datos ya tienen 8 canales motores)
MOTOR_TOKENS = ['C3', 'CZ', 'C4', 'FC3', 'FC4', 'CP3', 'CPZ', 'CP4']

def _list_subject_fifs(fif_dir=DATA_PROC, pattern='S???_MI-epo.fif'):
    """Devuelve lista ordenada de rutas a los FIF de sujetos disponibles."""
    return sorted(glob(str(fif_dir / pattern)))

def _list_available_subjects(fif_dir=DATA_PROC):
    """IDs únicos SXXX disponibles en el directorio de FIF procesados."""
    files = _list_subject_fifs(fif_dir)
    return sorted({Path(f).stem.split('_')[0] for f in files})

def _read_drop_file(path: Path):
    """Lee archivo con IDs de sujetos a excluir (uno por línea, formato SXXX)."""
    if not path.exists():
        return set()
    s = set()
    with open(path, 'r', encoding='utf-8', errors='ignore') as f:
        for ln in f:
            sid = ln.strip().upper()
            if _re_sid.match(sid):
                s.add(sid)
    return s

def _strict_valid_from_drop(avail_ids):
    """
    Aplica DROP-only:
      - Lee STRICT_DROP_TXT si existe
      - Devuelve (lista válidos, info string)
    """
    drop = _read_drop_file(STRICT_DROP_TXT)
    avail = set(avail_ids)
    valid = sorted(avail - drop)
    info = f"DROP-only: {len(drop)} en DROP; válidos={len(valid)}/{len(avail)}"
    if not STRICT_DROP_TXT.exists():
        info += " (archivo DROP no encontrado → sin exclusiones)"
    return valid, info

# Helper para registrar “perillas” (config) en logs/CSV
def _knobs_dict(crop_window, motor_only, zscore_epoch, fb_bands, n_csp):
    return dict(
        crop_window=crop_window if crop_window is not None else None,
        motor_only=bool(motor_only),
        zscore_epoch=bool(zscore_epoch),
        fb_bands=str(fb_bands),
        n_csp=int(n_csp)
    )

def _epochs_to_Xy(epochs: mne.Epochs):
    """
    Extrae X e y (clases string) desde Epochs respetando event_id.
    X: (n_epochs, n_channels, n_times)
    y: (n_epochs,) etiquetas string según mapping event_id.
    """
    X = epochs.get_data()
    inv = {v: k for k, v in epochs.event_id.items()}  # int->clase
    y = np.array([inv[e[-1]] for e in epochs.events], dtype=object)
    return X, y

def _find_motor_chs(ch_names, tokens=MOTOR_TOKENS):
    """Devuelve índices de canales que contienen tokens motores (case-insensitive)."""
    up = [c.upper() for c in ch_names]
    picks = []
    for tok in tokens:
        TU = tok.upper()
        for i, name in enumerate(up):
            if TU in name:
                picks.append(i); break
    return sorted(set(picks))

def _epochwise_zscore(X, eps=1e-8):
    """
    Z-score por época y canal (normaliza a lo largo del tiempo).
    X: (n_epochs, n_channels, n_times) → mismo shape.
    """
    mean = X.mean(axis=-1, keepdims=True)
    std  = X.std(axis=-1, keepdims=True)
    return (X - mean) / (std + eps)

def _fit_fb_csp_transform(train_ep: mne.Epochs,
                          test_ep: mne.Epochs,
                          fb_bands=DEFAULT_FB_BANDS,
                          n_csp=DEFAULT_N_CSP,
                          motor_only=False,
                          zscore_epoch=False,
                          crop_window=None):
    """
    Aplica FBCSP con opciones:
      - crop_window=(tmin,tmax): recorta épocas
      - motor_only=True: usa solo canales motores comunes
      - zscore_epoch=True: z-score por época/canal antes de CSP
      - fb_bands: lista de sub-bandas [(fmin,fmax), ...]
    Devuelve (Xtr_fb, Xte_fb) con features concatenadas por sub-banda.
    """
    tr = train_ep.copy()
    te = test_ep.copy()

    if crop_window is not None:
        tmin, tmax = crop_window
        tr.crop(tmin, tmax)
        te.crop(tmin, tmax)

    if motor_only:
        picks = _find_motor_chs(tr.ch_names)
        if picks:
            tr.pick(picks)
            # Alinear canales del test con los del train
            te = te.copy().reorder_channels(tr.ch_names)

    Xtr_list, Xte_list = [], []
    y_tr = tr.events[:, -1]

    for (fmin, fmax) in fb_bands:
        tr_b = tr.copy().filter(fmin, fmax, picks='eeg', verbose=False)
        te_b = te.copy().filter(fmin, fmax, picks='eeg', verbose=False)

        Xtr = tr_b.get_data()
        Xte = te_b.get_data()

        if zscore_epoch:
            Xtr = _epochwise_zscore(Xtr)
            Xte = _epochwise_zscore(Xte)

        try:
            csp = CSP(n_components=n_csp, reg='ledoit_wolf', log=True, norm_trace=False)
        except TypeError:
            # Compatibilidad con MNE antiguos sin 'norm_trace'
            csp = CSP(n_components=n_csp, reg='ledoit_wolf', log=True)

        Xtr_c = csp.fit_transform(Xtr, y_tr)
        Xte_c = csp.transform(Xte)

        Xtr_list.append(Xtr_c)
        Xte_list.append(Xte_c)

    Xtr_fb = np.concatenate(Xtr_list, axis=1)
    Xte_fb = np.concatenate(Xte_list,  axis=1)
    return Xtr_fb, Xte_fb

def _fit_scale_lda(Xtr, ytr, Xte, lda_params=LDA_PARAMS):
    """
    Estandariza features (fit solo en train) y entrena LDA.
    Devuelve: (yhat, clf, scaler)
    """
    scaler = StandardScaler()
    Xtr_s = scaler.fit_transform(Xtr)
    Xte_s = scaler.transform(Xte)

    clf = LDA(**lda_params)
    clf.fit(Xtr_s, ytr)
    yhat = clf.predict(Xte_s)
    return yhat, clf, scaler

# ====== CALIBRACIÓN (para LOSO calibrado) ======
def _split_calibration(ep_te, k_per_class=5):
    """
    Toma k épocas por clase como calibración y el resto como evaluación.
    Si k_per_class <= 0 → (None, ep_te).
    """
    if k_per_class <= 0:
        return None, ep_te
    labels = ep_te.events[:, -1]
    ep_calib_list, ep_eval_list = [], []
    for code in np.unique(labels):
        idx = np.where(labels == code)[0]
        if len(idx) == 0:
            continue
        take = min(k_per_class, len(idx))
        sel = idx[:take]
        rem = idx[take:]
        if take > 0:
            ep_calib_list.append(ep_te.copy()[sel])
        if len(rem) > 0:
            ep_eval_list.append(ep_te.copy()[rem])
    ep_calib = mne.concatenate_epochs(ep_calib_list, on_mismatch='ignore') if ep_calib_list else None
    ep_eval  = mne.concatenate_epochs(ep_eval_list,  on_mismatch='ignore') if ep_eval_list  else ep_te
    return ep_calib, ep_eval

# ====== LOGGER ======
def _init_logger(run_name: str):
    """
    Crea un logger que escribe a consola y a TXT en models/fbcsp_lda/logs/.
    Reduce la verbosidad de MNE para que los logs sean legibles.
    """
    ts = datetime.now().strftime("%Y%m%d-%H%M%S")
    log_path = LOG_DIR / f"{ts}_{run_name}.txt"

    logger = logging.getLogger(run_name)
    logger.setLevel(logging.INFO)
    logger.handlers.clear()

    fmt = logging.Formatter("[%(asctime)s] %(levelname)s: %(message)s", datefmt="%H:%M:%S")
    ch = logging.StreamHandler(stream=sys.stdout); ch.setLevel(logging.INFO); ch.setFormatter(fmt)
    fh = logging.FileHandler(log_path, encoding="utf-8"); fh.setLevel(logging.INFO); fh.setFormatter(fmt)
    logger.addHandler(ch); logger.addHandler(fh)

    # Silenciar ruido externo (sin afectar tus prints/logs)
    mne.set_log_level("ERROR")
    warnings.filterwarnings("ignore", category=UserWarning, module="mne")
    warnings.filterwarnings("ignore", category=RuntimeWarning, module="mne")
    return logger, log_path


### Bloque 3 — Inspección de datos

Qué hace: muestra rutas y un listado rápido del contenido de data/ y data/processed/ para verificar que los FIF están donde esperamos

In [46]:
# %% [Inspect data folders]
# Celda añadida automáticamente: muestra rutas y lista contenidos de data y data/processed
try:
    print(f"PROJ: {PROJ}")
    print(f"DATA: {PROJ / 'data'}")
    print(f"DATA_PROC: {DATA_PROC}")
    print('\nContenido de data (top-level):')
    data_dir = PROJ / 'data'
    if data_dir.exists():
        for p in sorted(data_dir.iterdir()):
            print(f" - {p.name}{'/' if p.is_dir() else ''}")
    else:
        print('  (no existe)')

    print('\nContenido de data/processed (muestras):')
    if DATA_PROC.exists():
        for p in sorted(DATA_PROC.glob('*'))[:50]:
            print(f" - {p.name}")
    else:
        print('  (no existe)')
except Exception as e:
    print('Error inspeccionando data:', e)


PROJ: /root/Proyecto/EEG_Clasificador
DATA: /root/Proyecto/EEG_Clasificador/data
DATA_PROC: /root/Proyecto/EEG_Clasificador/data/processed

Contenido de data (top-level):
 - processed/
 - raw/

Contenido de data/processed (muestras):
 - S001_MI-epo.fif
 - S002_MI-epo.fif
 - S003_MI-epo.fif
 - S004_MI-epo.fif
 - S005_MI-epo.fif
 - S006_MI-epo.fif
 - S007_MI-epo.fif
 - S008_MI-epo.fif
 - S009_MI-epo.fif
 - S010_MI-epo.fif
 - S011_MI-epo.fif
 - S012_MI-epo.fif
 - S013_MI-epo.fif
 - S014_MI-epo.fif
 - S015_MI-epo.fif
 - S016_MI-epo.fif
 - S017_MI-epo.fif
 - S018_MI-epo.fif
 - S019_MI-epo.fif
 - S020_MI-epo.fif
 - S021_MI-epo.fif
 - S022_MI-epo.fif
 - S023_MI-epo.fif
 - S024_MI-epo.fif
 - S025_MI-epo.fif
 - S026_MI-epo.fif
 - S027_MI-epo.fif
 - S028_MI-epo.fif
 - S029_MI-epo.fif
 - S030_MI-epo.fif
 - S031_MI-epo.fif
 - S032_MI-epo.fif
 - S033_MI-epo.fif
 - S034_MI-epo.fif
 - S035_MI-epo.fif
 - S036_MI-epo.fif
 - S037_MI-epo.fif
 - S039_MI-epo.fif
 - S040_MI-epo.fif
 - S041_MI-epo.fif
 - S04

### Bloque 4 — Intra-sujeto (k-Fold CV) con logs + guardado

Qué hace: ejecuta CV por sujeto con FBCSP+LDA, imprime métricas limpias, guarda matriz de confusión (PNG/CSV), y escribe métricas en tables/metrics_intra.csv. Además genera un TXT en logs/ con el detalle de la corrida.

Añade argumentos crop_window, motor_only, zscore_epoch, fb_bands, n_csp. Guarda lo de siempre (matriz, métricas, log) pero ahora puedes probar rápidamente ventanas/picks

In [47]:
# %% [INTRA — todos los sujetos, con timestamp y artefactos consolidados + fila GLOBAL]
from glob import glob
from pathlib import Path
from math import ceil
from datetime import datetime
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import mne

def _discover_subject_ids(fif_dir=DATA_PROC, pattern='S???_MI-epo.fif'):
    files = sorted(glob(str(fif_dir / pattern)))
    return [Path(f).stem.split('_')[0] for f in files]


def run_intra_all(
    fif_dir=DATA_PROC,
    k=5,
    random_state=42,
    crop_window=(0.5, 3.5),
    motor_only=True,
    zscore_epoch=True,
    fb_bands=DEFAULT_FB_BANDS,
    n_csp=DEFAULT_N_CSP,
    max_subplots_per_fig=12,
    n_cols=4,
    save_txt_name=None,            # opcional: nombre base TXT (se antepone timestamp)
    save_csv_name=None             # opcional: nombre base CSV (se antepone timestamp)
):
    """
    Ejecuta INTRA (k-fold) en TODOS los sujetos.
    Guarda: 1 log con timestamp, 1 CSV y 1 TXT (ambos con fila GLOBAL),
    y mosaicos de matrices de confusión por sujeto (sin archivos por sujeto).
    """
    ts = datetime.now().strftime("%Y%m%d-%H%M%S")
    run_tag = f"intra_all_{ts}"

    # Log global
    logger, log_path = _init_logger(run_name=run_tag)
    logger.info(f"[RUN {run_tag}] Inicio de ejecución INTRA")
    logger.info(f"Parámetros: k={k}, crop_window={crop_window}, motor_only={motor_only}, "
                f"zscore_epoch={zscore_epoch}, n_csp={n_csp}, fb_bands={len(fb_bands)}")

    subject_ids = _discover_subject_ids(fif_dir)
    if not subject_ids:
        print("No se encontraron sujetos en", fif_dir)
        return None

    logger.info(f"Sujetos detectados: {subject_ids}")
    print(f"[INTRA ALL] sujetos detectados: {subject_ids}")

    rows_summary = []
    cm_items = []

    for subject_id in subject_ids:
        fif_path = fif_dir / f"{subject_id}_MI-epo.fif"
        epochs = mne.read_epochs(fif_path, preload=True, verbose=False)

        _, y_str = _epochs_to_Xy(epochs)
        le = LabelEncoder(); y = le.fit_transform(y_str)
        classes = list(le.classes_)

        logger.info(f"== {subject_id} | n_epochs={len(y)} | clases={classes} | sfreq={epochs.info['sfreq']}")
        skf = StratifiedKFold(n_splits=k, shuffle=True, random_state=random_state)
        accs, f1s = [], []
        cm_sum = np.zeros((len(classes), len(classes)), dtype=int)

        for fold, (tr_idx, te_idx) in enumerate(skf.split(np.zeros(len(y)), y), start=1):
            ep_tr = epochs[tr_idx]; ep_te = epochs[te_idx]
            with mne.utils.use_log_level("ERROR"):
                Xtr_fb, Xte_fb = _fit_fb_csp_transform(
                    ep_tr, ep_te,
                    fb_bands=fb_bands,
                    n_csp=n_csp,
                    motor_only=motor_only,
                    zscore_epoch=zscore_epoch,
                    crop_window=crop_window
                )
            yhat, clf, scaler = _fit_scale_lda(Xtr_fb, y[tr_idx], Xte_fb)

            acc = accuracy_score(y[te_idx], yhat)
            f1m = f1_score(y[te_idx], yhat, average='macro')
            cm_sum += confusion_matrix(y[te_idx], yhat, labels=np.arange(len(classes)))
            accs.append(acc); f1s.append(f1m)
            logger.info(f"[{subject_id} | fold {fold}] acc={acc:.3f} | f1m={f1m:.3f}")

        acc_mu, acc_sd = float(np.mean(accs)), float(np.std(accs))
        f1_mu,  f1_sd  = float(np.mean(f1s)),  float(np.std(f1s))

        logger.info(f"[{subject_id}] ACC={acc_mu:.3f}±{acc_sd:.3f} | F1m={f1_mu:.3f}±{f1_sd:.3f}")
        rows_summary.append(dict(
            subject=subject_id,
            acc_mean=acc_mu,
            f1_macro_mean=f1_mu,
            k=k,
            n_classes=len(classes),
            crop=str(crop_window),
            motor_only=bool(motor_only),
            zscore_epoch=bool(zscore_epoch),
            n_csp=int(n_csp),
            fb_bands=len(fb_bands)
        ))
        cm_items.append((subject_id, cm_sum, classes))

    # --- CSV/TXT (con fila GLOBAL)
    df = pd.DataFrame(rows_summary).sort_values("subject")

    acc_mu = float(df['acc_mean'].mean()) if not df.empty else 0.0
    acc_sd = float(df['acc_mean'].std(ddof=0)) if not df.empty else 0.0
    f1_mu  = float(df['f1_macro_mean'].mean()) if not df.empty else 0.0
    f1_sd  = float(df['f1_macro_mean'].std(ddof=0)) if not df.empty else 0.0

    df_global = pd.DataFrame([{
        'subject': 'GLOBAL',
        'acc_mean': acc_mu,
        'f1_macro_mean': f1_mu,
        'k': k,
        'n_classes': int(df['n_classes'].mean()) if 'n_classes' in df.columns and not df.empty else 0,
        'crop': str(crop_window),
        'motor_only': bool(motor_only),
        'zscore_epoch': bool(zscore_epoch),
        'n_csp': int(n_csp),
        'fb_bands': len(fb_bands)
    }])
    df_out = pd.concat([df, df_global], ignore_index=True)

    out_csv = (TAB_DIR / f"{ts}_{save_csv_name}") if save_csv_name else (TAB_DIR / f"metrics_intra_all_{ts}.csv")
    df_out.to_csv(out_csv, index=False)
    logger.info(f"Resumen CSV guardado → {out_csv}")
    print("Resumen INTRA (todos) →", out_csv)

    logger.info(f"[GLOBAL INTRA] ACC={acc_mu:.3f}±{acc_sd:.3f} | F1m={f1_mu:.3f}±{f1_sd:.3f}")
    print(f"[GLOBAL INTRA] ACC={acc_mu:.3f}±{acc_sd:.3f} | F1m={f1_mu:.3f}±{f1_sd:.3f}")

    try:
        display(df_out)
    except Exception:
        pass

    out_txt = (LOG_DIR / f"{ts}_{save_txt_name}") if save_txt_name else (LOG_DIR / f"metrics_intra_all_{ts}.txt")
    with open(out_txt, "w", encoding="utf-8") as f:
        f.write(f"INTRA-SUJETO (k-fold) — Métricas por sujeto (incluye GLOBAL)\n")
        f.write(f"Generado: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
        f.write(f"Total filas: {len(df_out)}\n\n")
        header = df_out.columns.tolist()
        f.write(" | ".join(header) + "\n")
        f.write("-" * 90 + "\n")
        for _, row in df_out.iterrows():
            vals = []
            for kcol in header:
                v = row[kcol]
                vals.append(f"{v:.4f}" if isinstance(v, float) else str(int(v)) if isinstance(v, (np.integer,)) else str(v))
            f.write(" | ".join(vals) + "\n")
    logger.info(f"Métricas TXT guardadas → {out_txt}")
    print("Métricas TXT guardadas →", out_txt)

    # --- Mosaicos de matrices de confusión
    if cm_items:
        n = len(cm_items)
        per_fig = max(1, int(max_subplots_per_fig))
        n_figs = ceil(n / per_fig)
        n_rows_per_fig = lambda count: ceil(count / n_cols)

        for fig_idx in range(n_figs):
            start = fig_idx * per_fig
            end   = min((fig_idx + 1) * per_fig, n)
            chunk = cm_items[start:end]
            count = len(chunk)
            n_rows = n_rows_per_fig(count)

            fig, axes = plt.subplots(n_rows, n_cols, figsize=(4.5*n_cols, 3.8*n_rows), dpi=140)
            axes = np.atleast_2d(axes).flatten()

            for ax_i, (sid, cm_sum, classes) in enumerate(chunk):
                ax = axes[ax_i]
                disp = ConfusionMatrixDisplay(cm_sum, display_labels=classes)
                disp.plot(ax=ax, cmap="Blues", colorbar=False, values_format='d')
                ax.set_title(f"{sid}")
                ax.set_xlabel(""); ax.set_ylabel("")
            for j in range(ax_i + 1, len(axes)):
                axes[j].axis("off")

            out_png = FIG_DIR / f"intra_all_confusions_{ts}_p{fig_idx+1}.png"
            fig.suptitle(f"Intra — Matrices de confusión (página {fig_idx+1}/{n_figs})", y=0.995, fontsize=14)
            fig.tight_layout(rect=[0, 0, 1, 0.97])
            fig.savefig(out_png)
            plt.close(fig)
            logger.info(f"Figura consolidada → {out_png}")
            print("Figura consolidada →", out_png)

    logger.info(f"Log global de esta corrida → {log_path}")
    print(f"Log global → {log_path}")

    return df_out


### Bloque 5 — Cross-sujeto (LOSO) con logs + guardado

Qué hace: para cada sujeto como test, entrena en el resto, calcula métricas y guarda una matriz de confusión por sujeto y una global. También guarda métricas por sujeto en tables/metrics_loso_per_subject.csv y un resumen global en tables/metrics_loso.csv, además de un TXT en logs/

- run_loso(..., use_strict=True) hace LOSO clásico sobre el conjunto resuelto (con Strict si está ON).

- run_loso_single(test_subject, ...) entrena con todos los demás y prueba sólo en ese sujeto (Strict opcional).

- Incluye utilidades de selección Strict y reemplazo para subject_list.

In [48]:
# %% [LOSO — todos los sujetos, con timestamp, fila GLOBAL y calibración opcional]
import numpy as np
import pandas as pd
import mne
import matplotlib.pyplot as plt
from math import ceil
from datetime import datetime
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, ConfusionMatrixDisplay

def run_loso_all(
    fif_dir=DATA_PROC,
    use_strict=True,                # True → aplica DROP-only si existe el txt
    crop_window=(0.5, 3.5),
    motor_only=True,
    zscore_epoch=True,
    fb_bands=DEFAULT_FB_BANDS,
    n_csp=DEFAULT_N_CSP,
    max_subplots_per_fig=12,        # sujetos por página en los mosaicos
    n_cols=4,                       # columnas por página
    save_txt_name=None,             # opcional: nombre base TXT (se antepone timestamp)
    save_csv_name=None,             # opcional: nombre base CSV (se antepone timestamp)
    calibrate_k_per_class=None      # NUEVO: None/<=0 → sin calibración; >0 → calibración con k por clase
):
    """
    LOSO clásico en TODOS los sujetos, con opción de calibración ligera.

    - Sin calibración (por defecto): entrena en sujetos train y evalúa todo el sujeto test.
    - Con calibración (calibrate_k_per_class > 0): toma k épocas por clase del sujeto test para
      recalibrar solo StandardScaler + LDA y evalúa en el resto del sujeto test.

    Artefactos:
      - 1 log (timestamp), 1 CSV (con fila GLOBAL), 1 TXT (con fila GLOBAL)
      - Mosaicos de matrices por sujeto y matriz GLOBAL.
    """
    ts = datetime.now().strftime("%Y%m%d-%H%M%S")
    run_tag = f"loso_all_{ts}"

    logger, log_path = _init_logger(run_name=run_tag)
    knobs = _knobs_dict(crop_window, motor_only, zscore_epoch, fb_bands, n_csp)
    knobs['calibrate_k'] = int(calibrate_k_per_class or 0)
    logger.info(f"[RUN {run_tag}] Inicio de ejecución LOSO")
    logger.info(f"Perillas: {knobs}")

    avail = _list_available_subjects(fif_dir)
    if not avail:
        logger.error(f"No hay FIF S???_MI-epo.fif en {fif_dir}")
        return None

    if use_strict:
        sids, info = _strict_valid_from_drop(avail)
        logger.info(f"Usando todos los válidos ({len(sids)}). {info}")
    else:
        sids = avail
        logger.info(f"(strict=OFF) tests → {sids[:10]}{' ...' if len(sids)>10 else ''}")

    if not sids:
        logger.warning("No hay sujetos para LOSO.")
        return None

    ep_map = {sid: mne.read_epochs(str(fif_dir / f"{sid}_MI-epo.fif"),
                                   preload=True, verbose=False) for sid in sids}

    classes_global, cm_global = None, None
    rows, cm_items = [], []

    # -------------------------------------------------------------
    # Bucle LOSO
    for s_test, ep_te_full in ep_map.items():
        train_ids = [sid for sid in sids if sid != s_test]
        if not train_ids:
            logger.warning(f"Sin train para {s_test}.")
            continue

        ep_tr = mne.concatenate_epochs([ep_map[s] for s in train_ids], on_mismatch='ignore')
        ep_te_full = ep_te_full.copy().reorder_channels(ep_tr.ch_names)

        # Etiquetas para train + (test/calib/eval)
        _, y_tr_str = _epochs_to_Xy(ep_tr)

        # ¿Calibración?
        kcal = int(calibrate_k_per_class or 0)
        if kcal > 0:
            # Divide test en calibración y evaluación
            ep_calib, ep_eval = _split_calibration(ep_te_full, k_per_class=kcal)
            if (ep_calib is None) or (len(ep_calib) == 0):
                logger.info(f"[{s_test}] Calibración solicitada (k={kcal}), pero no hay muestras válidas → sin calibración.")
                ep_eval = ep_te_full
                kcal = 0  # desactiva calibración
        else:
            ep_calib, ep_eval = None, ep_te_full

        # Ajusta codificador con TRAIN + EVAL (y CALIB si existe, aunque no es imprescindible)
        _, y_ev_str = _epochs_to_Xy(ep_eval)
        if ep_calib is not None and len(ep_calib) > 0:
            _, y_ca_str = _epochs_to_Xy(ep_calib)
            le = LabelEncoder().fit(np.concatenate([y_tr_str, y_ev_str, y_ca_str]))
        else:
            le = LabelEncoder().fit(np.concatenate([y_tr_str, y_ev_str]))

        y_tr = le.transform(y_tr_str)
        y_ev = le.transform(y_ev_str)
        classes = list(le.classes_)

        if classes_global is None:
            classes_global = classes
            cm_global = np.zeros((len(classes), len(classes)), dtype=int)

        # --- Features para TRAIN + EVAL
        with mne.utils.use_log_level("ERROR"):
            Xtr_fb, Xev_fb = _fit_fb_csp_transform(
                ep_tr, ep_eval,
                fb_bands=fb_bands,
                n_csp=n_csp,
                motor_only=motor_only,
                zscore_epoch=zscore_epoch,
                crop_window=crop_window
            )

        if kcal <= 0:
            # LOSO clásico
            yhat, clf, scaler = _fit_scale_lda(Xtr_fb, y_tr, Xev_fb)
        else:
            # --- Features para CALIBRACIÓN (mismos filtros CSP)
            with mne.utils.use_log_level("ERROR"):
                Xtr_fb2, Xca_fb = _fit_fb_csp_transform(
                    ep_tr, ep_calib,
                    fb_bands=fb_bands,
                    n_csp=n_csp,
                    motor_only=motor_only,
                    zscore_epoch=zscore_epoch,
                    crop_window=crop_window
                )

            _, y_ca_str = _epochs_to_Xy(ep_calib)
            y_ca = le.transform(y_ca_str)

            # Re-ajuste SOLO de scaler + LDA con TRAIN + CALIB
            scaler2 = StandardScaler()
            X_join = np.vstack([Xtr_fb2, Xca_fb])
            y_join = np.concatenate([y_tr, y_ca])
            X_join_s = scaler2.fit_transform(X_join)
            Xev_s = scaler2.transform(Xev_fb)

            clf2 = LDA(**LDA_PARAMS)
            clf2.fit(X_join_s, y_join)
            yhat = clf2.predict(Xev_s)

        # Métricas por sujeto (en EVAL)
        acc = accuracy_score(y_ev, yhat)
        f1m = f1_score(y_ev, yhat, average='macro')
        cm  = confusion_matrix(y_ev, yhat, labels=np.arange(len(classes)))
        cm_global += cm

        logger.info(f"[LOSO] test={s_test} | acc={acc:.3f} | f1m={f1m:.3f} | "
                    f"n_test={len(y_ev)} | calib_k={kcal}")

        rows.append(dict(
            test_subject=s_test,
            acc=float(acc),
            f1_macro=float(f1m),
            n_test=int(len(y_ev)),
            crop=str(crop_window),
            motor_only=bool(motor_only),
            zscore_epoch=bool(zscore_epoch),
            n_csp=int(n_csp),
            fb_bands=len(fb_bands),
            calibrate_k=int(kcal)
        ))
        cm_items.append((s_test, cm, classes))

    # -------------------------------------------------------------
    # Fila GLOBAL
    acc_mu = float(np.mean([r['acc'] for r in rows])) if rows else 0.0
    f1_mu  = float(np.mean([r['f1_macro'] for r in rows])) if rows else 0.0
    rows.append(dict(
        test_subject="GLOBAL",
        acc=acc_mu,
        f1_macro=f1_mu,
        n_test=int(np.sum([r['n_test'] for r in rows])) if rows else 0,
        crop=str(crop_window),
        motor_only=bool(motor_only),
        zscore_epoch=bool(zscore_epoch),
        n_csp=int(n_csp),
        fb_bands=len(fb_bands),
        calibrate_k=int(calibrate_k_per_class or 0)
    ))
    logger.info(f"[GLOBAL LOSO] ACC={acc_mu:.3f} | F1m={f1_mu:.3f}")

    # -------------------------------------------------------------
    # CSV único
    df_rows = pd.DataFrame(rows).sort_values('test_subject')
    out_csv = (TAB_DIR / f"{ts}_{save_csv_name}") if save_csv_name else (TAB_DIR / f"metrics_loso_all_{ts}.csv")
    df_rows.to_csv(out_csv, index=False)
    logger.info(f"CSV consolidado → {out_csv}")
    print("CSV consolidado →", out_csv)
    try:
        display(df_rows)
    except Exception:
        pass

    # TXT único
    out_txt = (LOG_DIR / f"{ts}_{save_txt_name}") if save_txt_name else (LOG_DIR / f"metrics_loso_all_{ts}.txt")
    with open(out_txt, "w", encoding="utf-8") as f:
        f.write("LOSO — Métricas por sujeto (incluye GLOBAL)\n")
        f.write(f"Generado: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
        f.write(f"Total filas: {len(df_rows)}\n\n")
        header = df_rows.columns.tolist()
        f.write(" | ".join(header) + "\n")
        f.write("-" * 90 + "\n")
        for _, r in df_rows.iterrows():
            vals = []
            for k in header:
                v = r[k]
                vals.append(f"{v:.4f}" if isinstance(v, float) else str(int(v)) if isinstance(v, (np.integer,)) else str(v))
            f.write(" | ".join(vals) + "\n")
    logger.info(f"TXT consolidado → {out_txt}")
    print("TXT consolidado →", out_txt)

    # -------------------------------------------------------------
    # Mosaicos por sujeto
    if cm_items:
        n = len(cm_items)
        per_fig = max(1, int(max_subplots_per_fig))
        n_figs = ceil(n / per_fig)
        n_cols = max(1, int(n_cols))
        n_rows_needed = lambda count: ceil(count / n_cols)

        for fig_idx in range(n_figs):
            start = fig_idx * per_fig
            end   = min((fig_idx + 1) * per_fig, n)
            chunk = cm_items[start:end]
            count = len(chunk)
            n_rows = n_rows_needed(count)

            fig, axes = plt.subplots(n_rows, n_cols, figsize=(4.5*n_cols, 3.8*n_rows), dpi=140)
            axes = np.atleast_2d(axes).flatten()

            for ax_i, (sid, cm_sum, classes) in enumerate(chunk):
                ax = axes[ax_i]
                disp = ConfusionMatrixDisplay(cm_sum, display_labels=classes)
                disp.plot(ax=ax, cmap="Blues", colorbar=False, values_format='d')
                ax.set_title(f"{sid}")
                ax.set_xlabel(""); ax.set_ylabel("")
            for j in range(ax_i + 1, len(axes)):
                axes[j].axis("off")

            out_png = FIG_DIR / f"loso_all_confusions_{ts}_p{fig_idx+1}.png"
            fig.suptitle(f"LOSO — Matrices de confusión por sujeto (página {fig_idx+1}/{n_figs})", y=0.995, fontsize=14)
            fig.tight_layout(rect=[0, 0, 1, 0.97])
            fig.savefig(out_png)
            plt.close(fig)
            logger.info(f"Mosaico de confusiones → {out_png}")
            print("Mosaico de confusiones →", out_png)

    # Matriz GLOBAL
    if len(cm_items) > 0 and "GLOBAL" in df_rows["test_subject"].values:
        fig, ax = plt.subplots(figsize=(6.5, 5.2), dpi=140)
        disp = ConfusionMatrixDisplay(cm_global, display_labels=classes_global)
        disp.plot(ax=ax, cmap="Blues", colorbar=True, values_format='d')
        ax.set_title("LOSO — Matriz de confusión GLOBAL")
        fig.tight_layout()
        out_png_glob = FIG_DIR / f"loso_global_confusion_{ts}.png"
        fig.savefig(out_png_glob)
        plt.close(fig)
        logger.info(f"Matriz GLOBAL → {out_png_glob}")
        print("Matriz GLOBAL →", out_png_glob)

    # Print/Log global final
    logger.info(f"[GLOBAL LOSO] ACC={acc_mu:.3f} | F1m={f1_mu:.3f}")
    print(f"[GLOBAL LOSO] ACC={acc_mu:.3f} | F1m={f1_mu:.3f}")
    logger.info(f"Log global de esta corrida → {log_path}")
    print(f"Log global → {log_path}")

    return df_rows.reset_index(drop=True)


### Bloque 7 — Ejemplos de ejecución

Qué hace: muestra cómo lanzar los “batch” por defecto (4 intra, 2 LOSO) y cómo pasar propia lista de sujetos.

In [50]:
# INTRA en todos los sujetos
# df_intra = run_intra_all(
#     k=5,
#     random_state=42,
#     crop_window=(0.5, 4.5),
#     motor_only=False,
#     zscore_epoch=True,
#     fb_bands=FB_BANDS_DENSE,
#     n_csp=4,
#     save_txt_name="metrics_intra_all.txt"
# )

# LOSO clásico en todos los sujetos
# df_loso = run_loso_all(
#     crop_window=(0.5, 4.5),
#     motor_only=False,
#     zscore_epoch=False,
#     fb_bands=FB_BANDS_DENSE,
#     n_csp=4,
#     use_strict=False,
#     save_txt_name="metrics_loso_all.txt",
#     calibrate_k_per_class=5  # None o numero entero > 0 para calibración ligera
# )


### Inter sujeto con Cross validation

In [64]:
# %% [INTER-SUJETO — FBCSP+LDA con TEST fijo, TRAIN/VAL aleatorios y CALIBRACIÓN opcional (LDA-only)]
# Requiere:
#   - Archivos preprocesados S???_MI-epo.fif en DATA_PROC
#   - Paquetes: mne, numpy, pandas, scikit-learn, matplotlib

import numpy as np, pandas as pd, matplotlib.pyplot as plt
from pathlib import Path
from glob import glob
from datetime import datetime
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, ConfusionMatrixDisplay
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from mne.decoding import CSP
import mne, os, sys

# -------------------- RUTAS --------------------
PROJ      = Path('..').resolve().parent     # asumiendo notebook en models/...
DATA_PROC = PROJ / 'data' / 'processed'
OUT_ROOT  = PROJ / 'models' / 'inter_fixedsplit'
FIG_DIR   = OUT_ROOT / 'figures'
TAB_DIR   = OUT_ROOT / 'tables'
LOG_DIR   = OUT_ROOT / 'logs'
for d in (FIG_DIR, TAB_DIR, LOG_DIR): d.mkdir(parents=True, exist_ok=True)

# -------------------- CONFIG FBCSP --------------------
FB_BANDS_DENSE = [(f, f+2) for f in range(8, 30, 2)]  # 8–30 Hz, paso 2 Hz
DEFAULT_N_CSP  = 6
LDA_PARAMS     = dict(solver='lsqr', shrinkage='auto')  # shrinkage automático ayuda inter-sujeto

# -------------------- UTILS --------------------
def _discover_subject_ids(fif_dir=DATA_PROC, pattern='S???_MI-epo.fif'):
    files = sorted(glob(str(Path(fif_dir) / pattern)))
    return [Path(f).stem.split('_')[0] for f in files]

def _epochs_to_Xy(epochs: mne.Epochs):
    X = epochs.get_data()
    inv = {v:k for k,v in epochs.event_id.items()}
    y = np.array([inv[e[-1]] for e in epochs.events], dtype=object)
    return X, y

def _fit_fb_csp_transform(train_ep: mne.Epochs,
                          test_ep:  mne.Epochs,
                          fb_bands=FB_BANDS_DENSE,
                          n_csp=DEFAULT_N_CSP,
                          crop_window=None):
    """
    Ajusta CSP por sub-banda en TRAIN y transforma TRAIN/TEST. Devuelve features log-var concatenadas.
    """
    tr = train_ep.copy(); te = test_ep.copy()
    if crop_window is not None:
        tr.crop(*crop_window); te.crop(*crop_window)
    y_tr = tr.events[:, -1]
    Xtr_list, Xte_list = [], []
    for (fmin, fmax) in fb_bands:
        tr_b = tr.copy().filter(fmin, fmax, picks='eeg', verbose=False)
        te_b = te.copy().filter(fmin, fmax, picks='eeg', verbose=False)
        Xtr = tr_b.get_data(); Xte = te_b.get_data()
        try:
            csp = CSP(n_components=n_csp, reg='ledoit_wolf', log=True, norm_trace=False)
        except TypeError:
            csp = CSP(n_components=n_csp, reg='ledoit_wolf', log=True)
        Xtr_c = csp.fit_transform(Xtr, y_tr)
        Xte_c = csp.transform(Xte)
        Xtr_list.append(Xtr_c); Xte_list.append(Xte_c)
    return np.concatenate(Xtr_list, axis=1), np.concatenate(Xte_list, axis=1)

def _transform_with_trained_fb_csp(train_ep: mne.Epochs,
                                   target_ep: mne.Epochs,
                                   fb_bands=FB_BANDS_DENSE,
                                   n_csp=DEFAULT_N_CSP,
                                   crop_window=None):
    """
    Entrena CSP SOLO con 'train_ep' y transforma 'target_ep' con esos mismos filtros.
    Útil para calibración LDA-only.
    """
    tr = train_ep.copy(); tg = target_ep.copy()
    if crop_window is not None:
        tr.crop(*crop_window); tg.crop(*crop_window)
    y_tr = tr.events[:, -1]
    Xtr_list, Xtg_list = [], []
    for (fmin, fmax) in fb_bands:
        tr_b = tr.copy().filter(fmin, fmax, picks='eeg', verbose=False)
        tg_b = tg.copy().filter(fmin, fmax, picks='eeg', verbose=False)
        Xtr = tr_b.get_data(); Xtg = tg_b.get_data()
        try:
            csp = CSP(n_components=n_csp, reg='ledoit_wolf', log=True, norm_trace=False)
        except TypeError:
            csp = CSP(n_components=n_csp, reg='ledoit_wolf', log=True)
        csp.fit(Xtr, y_tr)
        Xtg_c = csp.transform(Xtg)
        Xtr_list.append(csp.transform(Xtr))  # para tener base coherente
        Xtg_list.append(Xtg_c)
    return np.concatenate(Xtr_list, axis=1), np.concatenate(Xtg_list, axis=1)

def _fit_scale_lda(Xtr, ytr, Xte):
    scaler = StandardScaler()
    Xtr_s = scaler.fit_transform(Xtr); Xte_s = scaler.transform(Xte)
    clf = LDA(**LDA_PARAMS).fit(Xtr_s, ytr)
    return clf.predict(Xte_s), clf, scaler

def _save_confusion(cm: np.ndarray, class_names, title: str, out_png: Path):
    fig, ax = plt.subplots(figsize=(5.4, 4.6), dpi=140)
    disp = ConfusionMatrixDisplay(cm, display_labels=class_names)
    disp.plot(ax=ax, cmap="Blues", colorbar=True, values_format='d')
    ax.set_title(title); fig.tight_layout(); fig.savefig(out_png); plt.close(fig)

def _split_calibration(ep_test: mne.Epochs, k_per_class=5, random_state=42):
    """
    Toma k_per_class épocas por clase del set de TEST como 'calibración' y devuelve:
      ep_calib (k por clase) y ep_eval (el resto).
    Si k_per_class es 0/None -> (None, ep_test).
    """
    if not k_per_class or k_per_class <= 0:
        return None, ep_test
    rng = np.random.RandomState(int(random_state))
    labels = ep_test.events[:, -1]
    ep_calib_list, ep_eval_list = [], []
    for code in np.unique(labels):
        idx = np.where(labels == code)[0]
        if len(idx) == 0:
            continue
        rng.shuffle(idx)
        take = min(int(k_per_class), len(idx))
        sel, rem = idx[:take], idx[take:]
        if take > 0: ep_calib_list.append(ep_test.copy()[sel])
        if len(rem) > 0: ep_eval_list.append(ep_test.copy()[rem])
    ep_calib = mne.concatenate_epochs(ep_calib_list, on_mismatch='ignore') if ep_calib_list else None
    ep_eval  = mne.concatenate_epochs(ep_eval_list,  on_mismatch='ignore') if ep_eval_list  else ep_test
    return ep_calib, ep_eval

# -------------------- TEST fijo (24 sujetos) --------------------
FIXED_TEST_SUBJECTS = [
    # Buenos (>~0.70)
    'S007','S025','S029','S031','S032','S034','S035','S042','S043','S049','S056','S058','S062','S072',
    # Medios (~0.50–0.69)
    'S001','S010','S013','S017','S019','S030',
    # Malos (<~0.45)
    'S005','S006','S009','S097'
]

def _make_split_with_fixed_test(fif_dir=DATA_PROC, fixed_test=(), val_size=16, random_state=42):
    all_subjects = _discover_subject_ids(fif_dir)
    existing = {p.stem.split('_')[0] for p in Path(fif_dir).glob('S???_MI-epo.fif')}
    fixed_test = [s for s in list(fixed_test) if s in existing]
    remaining = [s for s in all_subjects if s not in fixed_test]
    if len(remaining) <= val_size:
        raise ValueError(f"Insuficientes sujetos para VALIDATION: remaining={len(remaining)} <= val_size={val_size}")
    train_subjects, val_subjects = train_test_split(remaining, test_size=val_size,
                                                    shuffle=True, random_state=random_state)
    return train_subjects, val_subjects, fixed_test

# -------------------- FUNCIÓN PRINCIPAL --------------------
def run_inter_fixedsplit_fbcsp(
    fif_dir=DATA_PROC,
    crop_window=(0.5, 4.5),
    fb_bands=FB_BANDS_DENSE,
    n_csp=DEFAULT_N_CSP,
    val_size=16,
    random_state=42,
    refit_on_trainval_for_test=True,
    # === Calibración (LDA-only) ===
    calibrate_k_per_class=None,         # None/0 -> sin calibración; >0 -> usar k por clase del TEST
    # === Salidas ===
    save_csv_name="inter_fbcsp_fixedsplit.csv",
    save_txt_name="inter_fbcsp_fixedsplit.txt"
):
    ts = datetime.now().strftime("%Y%m%d-%H%M%S")
    train_subs, val_subs, test_subs = _make_split_with_fixed_test(
        fif_dir=fif_dir, fixed_test=FIXED_TEST_SUBJECTS,
        val_size=val_size, random_state=random_state
    )
    print(f"TRAIN={len(train_subs)} | VAL={len(val_subs)} | TEST(fijo)={len(test_subs)}")

    # ----- carga epochs por grupo
    ep_tr = mne.concatenate_epochs([mne.read_epochs(str(Path(fif_dir)/f"{s}_MI-epo.fif"),
                                                    preload=True, verbose=False) for s in train_subs],
                                   on_mismatch='ignore')
    ep_va = mne.concatenate_epochs([mne.read_epochs(str(Path(fif_dir)/f"{s}_MI-epo.fif"),
                                                    preload=True, verbose=False) for s in val_subs],
                                   on_mismatch='ignore')
    ep_te_all = mne.concatenate_epochs([mne.read_epochs(str(Path(fif_dir)/f"{s}_MI-epo.fif"),
                                                        preload=True, verbose=False) for s in test_subs],
                                       on_mismatch='ignore')

    # ----- alinear canales
    ep_va = ep_va.copy().reorder_channels(ep_tr.ch_names)
    ep_te_all = ep_te_all.copy().reorder_channels(ep_tr.ch_names)

    # ----- labels globales (para encoder consistente)
    _, y_tr_str = _epochs_to_Xy(ep_tr)
    _, y_va_str = _epochs_to_Xy(ep_va)
    _, y_te_str_all = _epochs_to_Xy(ep_te_all)
    le = LabelEncoder().fit(np.concatenate([y_tr_str, y_va_str, y_te_str_all]))
    y_tr = le.transform(y_tr_str); y_va = le.transform(y_va_str)
    classes = list(le.classes_)

    # ----- features FBCSP para VALIDATION (sin calibración)
    with mne.utils.use_log_level("ERROR"):
        Xtr, Xva = _fit_fb_csp_transform(ep_tr, ep_va, fb_bands=fb_bands, n_csp=n_csp, crop_window=crop_window)
        yhat_va, clf, scaler = _fit_scale_lda(Xtr, y_tr, Xva)
    y_va_enc = le.transform(y_va_str)
    acc_va = accuracy_score(y_va_enc, yhat_va); f1_va = f1_score(y_va_enc, yhat_va, average='macro')
    print(f"[VAL] ACC={acc_va:.3f} | F1m={f1_va:.3f}")

    # ====== TEST FIJO (con o sin calibración LDA-only) ======
    ep_calib, ep_eval = _split_calibration(ep_te_all, k_per_class=calibrate_k_per_class, random_state=random_state)
    k_eff = 0 if ep_calib is None else len(ep_calib)

    # Base para TEST: si quieres, re-ajusta base con TRAIN+VAL antes de calibrar
    if refit_on_trainval_for_test:
        ep_base = mne.concatenate_epochs([ep_tr, ep_va], on_mismatch='ignore')
        _, y_base_str = _epochs_to_Xy(ep_base)
        y_base = le.transform(y_base_str)
    else:
        ep_base = ep_tr
        y_base = y_tr

    if not calibrate_k_per_class or calibrate_k_per_class <= 0:
        # --- SIN calibración ---
        with mne.utils.use_log_level("ERROR"):
            Xbase, Xte = _fit_fb_csp_transform(ep_base, ep_eval, fb_bands=fb_bands, n_csp=n_csp, crop_window=crop_window)
        yhat_te, clf2, scaler2 = _fit_scale_lda(Xbase, y_base, Xte)
        _, y_eval_str = _epochs_to_Xy(ep_eval)
        y_eval = le.transform(y_eval_str)
    else:
        # --- CON calibración (LDA-only) ---
        with mne.utils.use_log_level("ERROR"):
            # 1) Entrena CSP con ep_base, transforma base/calib/eval con esos mismos filtros
            Xbase_from_base, Xcal_from_base = _transform_with_trained_fb_csp(
                ep_base, ep_calib, fb_bands=fb_bands, n_csp=n_csp, crop_window=crop_window
            )
            _, Xeval_from_base = _transform_with_trained_fb_csp(
                ep_base, ep_eval, fb_bands=fb_bands, n_csp=n_csp, crop_window=crop_window
            )
        # 2) Reentrena SOLO LDA con (base + calib) y evalúa en eval
        _, y_calib_str = _epochs_to_Xy(ep_calib)
        y_calib = le.transform(y_calib_str)
        X_join = np.vstack([Xbase_from_base, Xcal_from_base])
        y_join = np.concatenate([y_base, y_calib])
        yhat_te, clf2, scaler2 = _fit_scale_lda(X_join, y_join, Xeval_from_base)
        _, y_eval_str = _epochs_to_Xy(ep_eval)
        y_eval = le.transform(y_eval_str)

    # Métricas TEST
    acc_te = accuracy_score(y_eval, yhat_te); f1_te = f1_score(y_eval, yhat_te, average='macro')
    print(f"[TEST {'CALIB-LDA' if k_eff>0 else 'SIN CALIB'}] k={calibrate_k_per_class} | ACC={acc_te:.3f} | F1m={f1_te:.3f}")

    # Matriz de confusión (TEST eval)
    cm_te = confusion_matrix(y_eval, yhat_te, labels=np.arange(len(classes)))
    cm_png = FIG_DIR / f"inter_fbcsp_confusion_test_{ts}.png"
    _save_confusion(cm_te, classes, f"FBCSP — TEST fijo (calib LDA k={calibrate_k_per_class})", cm_png)
    print("Confusión TEST →", cm_png)

    # Guardar CSV + TXT
    df = pd.DataFrame([dict(
        mode="inter_fixedsplit_fbcsp",
        acc_val=float(acc_va), f1_val=float(f1_va),
        acc_test=float(acc_te), f1_test=float(f1_te),
        n_train=len(train_subs), n_val=len(val_subs), n_test=len(test_subs),
        crop=str(crop_window), n_csp=int(n_csp), fb_bands=len(fb_bands),
        refit_on_trainval_for_test=bool(refit_on_trainval_for_test),
        calibrate_k_per_class=int(calibrate_k_per_class or 0)
    )])
    out_csv = TAB_DIR / f"{ts}_inter_fbcsp_fixedsplit.csv"; df.to_csv(out_csv, index=False)
    out_txt = LOG_DIR / f"{ts}_inter_fbcsp_fixedsplit.txt"
    with open(out_txt, "w", encoding="utf-8") as f:
        f.write("INTER-SUJETO (FBCSP) — TEST fijo\n")
        f.write(f"Generado: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
        f.write(df.to_string(index=False))
    print("CSV →", out_csv); print("TXT →", out_txt)
    return df

# -------------------- EJECUCIÓN INMEDIATA --------------------
df_fbcsp_inter = run_inter_fixedsplit_fbcsp(
    fif_dir=DATA_PROC,
    crop_window=(0.5, 3.0),
    fb_bands=FB_BANDS_CLASSIC,
    n_csp=4,
    val_size=16,
    random_state=42,
    refit_on_trainval_for_test=True,   # re-ajusta con TRAIN+VAL antes del test si True
    calibrate_k_per_class=10             # <= pon 0/None si no quieres calibración
)
try: display(df_fbcsp_inter)
except: pass


TRAIN=63 | VAL=16 | TEST(fijo)=24
[VAL] ACC=0.340 | F1m=0.340
[TEST CALIB-LDA] k=10 | ACC=0.410 | F1m=0.407
Confusión TEST → /root/Proyecto/EEG_Clasificador/models/inter_fixedsplit/figures/inter_fbcsp_confusion_test_20251006-062822.png
CSV → /root/Proyecto/EEG_Clasificador/models/inter_fixedsplit/tables/20251006-062822_inter_fbcsp_fixedsplit.csv
TXT → /root/Proyecto/EEG_Clasificador/models/inter_fixedsplit/logs/20251006-062822_inter_fbcsp_fixedsplit.txt


Unnamed: 0,mode,acc_val,f1_val,acc_test,f1_test,n_train,n_val,n_test,crop,n_csp,fb_bands,refit_on_trainval_for_test,calibrate_k_per_class
0,inter_fixedsplit_fbcsp,0.340161,0.340162,0.409885,0.406748,63,16,24,"(0.5, 3.0)",4,6,True,10
