# FBCSP + LDA (intra-sujeto y cross-sujeto)

### Bloque 1 — Rutas de salida + utilidades de logging y guardado

Qué hace: define las carpetas donde se guardarán figuras, tablas y logs bajo models/fbcsp_lda/. Incluye utilidades para inicializar un logger limpio, guardar matrices de confusión (PNG + CSV) y anexar métricas a CSVs acumulativos.

In [26]:
# %% [PATHS & LOGGING] — rutas de salida + helpers para logs/figuras/tablas
import sys, logging, warnings
from datetime import datetime
from pathlib import Path
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from sklearn.metrics import ConfusionMatrixDisplay
import mne

# Raíz del repo (este notebook está en models/fbcsp_lda/)
PROJ = Path('..').resolve().parent
DATA_PROC = PROJ / 'data' / 'processed'
OUT_ROOT = PROJ / 'models' / 'fbcsp_lda'
FIG_DIR  = OUT_ROOT / 'figures'
TAB_DIR  = OUT_ROOT / 'tables'
LOG_DIR  = OUT_ROOT / 'logs'
for d in (FIG_DIR, TAB_DIR, LOG_DIR):
    d.mkdir(parents=True, exist_ok=True)

print(f"Directorio de datos procesados: {DATA_PROC}")

def _init_logger(run_name: str):
    """
    Crea un logger que escribe a consola y a TXT en models/fbcsp_lda/logs/.
    Reduce la verbosidad de MNE para que los logs sean legibles.
    """
    ts = datetime.now().strftime("%Y%m%d-%H%M%S")
    log_path = LOG_DIR / f"{ts}_{run_name}.txt"

    logger = logging.getLogger(run_name)
    logger.setLevel(logging.INFO)
    logger.handlers.clear()

    fmt = logging.Formatter("[%(asctime)s] %(levelname)s: %(message)s", datefmt="%H:%M:%S")
    ch = logging.StreamHandler(stream=sys.stdout); ch.setLevel(logging.INFO); ch.setFormatter(fmt)
    fh = logging.FileHandler(log_path, encoding="utf-8"); fh.setLevel(logging.INFO); fh.setFormatter(fmt)
    logger.addHandler(ch); logger.addHandler(fh)

    # Silenciar ruido externo (sin afectar tus prints/logs)
    mne.set_log_level("ERROR")
    warnings.filterwarnings("ignore", category=UserWarning, module="mne")
    warnings.filterwarnings("ignore", category=RuntimeWarning, module="mne")
    return logger, log_path

def _save_confusion(cm: np.ndarray, class_names, title: str, stem: str):
    """
    Guarda matriz de confusión como PNG y CSV en FIG_DIR/TAB_DIR.
    - stem: nombre base de archivo (sin extensión).
    """
    # CSV
    df_cm = pd.DataFrame(cm, index=class_names, columns=class_names)
    csv_path = TAB_DIR / f"{stem}_confusion.csv"
    df_cm.to_csv(csv_path, index=True)

    # PNG
    fig, ax = plt.subplots(figsize=(5.4, 4.6), dpi=140)
    disp = ConfusionMatrixDisplay(cm, display_labels=class_names)
    disp.plot(ax=ax, cmap="Blues", colorbar=True, values_format='d')
    ax.set_title(title)
    fig.tight_layout()
    png_path = FIG_DIR / f"{stem}_confusion.png"
    fig.savefig(png_path)
    plt.close(fig)
    return csv_path, png_path

def _append_metrics(row: dict, table_name: str):
    """
    Anexa una fila 'row' (dict) a un CSV en TAB_DIR (lo crea si no existe).
    Devuelve la ruta del archivo actualizado.
    """
    path = TAB_DIR / table_name
    df = pd.DataFrame([row])
    if path.exists():
        df.to_csv(path, mode='a', header=False, index=False)
    else:
        df.to_csv(path, index=False)
    return path


Directorio de datos procesados: C:\Users\joelc\Desktop\eeg2\data\processed


### Bloque 2 — FBCSP Helpers

Qué hace: helpers de modelo. Extraen X/y desde Epochs, aplican Filter-Bank + CSP por bandas y entrenan/escala LDA. Mantiene tus comentarios y añade compatibilidad con versiones de CSP.

In [27]:
# %% [FBCSP Helpers — Mejorado]  — banco de filtros, picks motores, z-score por época, crop
import numpy as np
import pandas as pd
import mne
from pathlib import Path
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix
from mne.decoding import CSP

# --- Banco de filtros ---
# Denso (2 Hz) en 8–30: mejor cobertura de mu/beta para MI
FB_BANDS_DENSE = [(f, f+2) for f in range(8, 30, 2)]
# Clásico (más corto)
FB_BANDS_CLASSIC = [(8,12), (12,16), (16,20), (20,24), (24,28), (28,30)]
# Alias por defecto para evitar NameError en otros bloques
DEFAULT_FB_BANDS = FB_BANDS_DENSE

# Nº componentes CSP por sub-banda (6–8 suele ir bien con bancos densos)
DEFAULT_N_CSP = 6

# LDA con shrinkage automático robusto
LDA_PARAMS = dict(solver='lsqr', shrinkage='auto')

# Tokens para picks motores
MOTOR_TOKENS = ['C3', 'CZ', 'C4']

def _epochs_to_Xy(epochs: mne.Epochs):
    """Extrae X (numpy) y y (clases string) desde Epochs (respeta event_id)."""
    X = epochs.get_data()  # (n_epochs, n_channels, n_times)
    inv = {v:k for k,v in epochs.event_id.items()}  # int->clase
    y = np.array([inv[e[-1]] for e in epochs.events], dtype=object)
    return X, y

def _find_motor_chs(ch_names, tokens=MOTOR_TOKENS):
    """Devuelve índices de canales que contienen C3/Cz/C4 (case-insensitive)."""
    up = [c.upper() for c in ch_names]
    picks = []
    for tok in tokens:
        TU = tok.upper()
        for i, name in enumerate(up):
            if TU in name:
                picks.append(i); break
    return sorted(set(picks))

def _epochwise_zscore(X, eps=1e-8):
    """
    Z-score por época y canal: para cada (epoch, canal) normaliza en el eje tiempo.
    X: (n_epochs, n_channels, n_times)  -> retorna mismo shape.
    """
    mean = X.mean(axis=-1, keepdims=True)
    std  = X.std(axis=-1, keepdims=True)
    return (X - mean) / (std + eps)

def _fit_fb_csp_transform(train_ep: mne.Epochs,
                          test_ep: mne.Epochs,
                          fb_bands=DEFAULT_FB_BANDS,
                          n_csp=DEFAULT_N_CSP,
                          motor_only=False,
                          zscore_epoch=False,
                          crop_window=None):
    """
    Aplica FBCSP con opciones:
      - crop_window=(tmin,tmax) recorta épocas sin reexportar FIF
      - motor_only=True: usa solo C3/Cz/C4
      - zscore_epoch=True: z-score dentro de cada época y canal antes de CSP
      - fb_bands: lista de sub-bandas [(fmin,fmax), ...]
    Devuelve (Xtr_fb, Xte_fb) con features concatenadas.
    """
    tr = train_ep.copy()
    te = test_ep.copy()

    if crop_window is not None:
        tmin, tmax = crop_window
        tr.crop(tmin, tmax)
        te.crop(tmin, tmax)

    if motor_only:
        picks = _find_motor_chs(tr.ch_names)
        if picks:
            tr.pick(picks)
            te = te.copy().reorder_channels(tr.ch_names)

    Xtr_list, Xte_list = [], []
    y_tr = tr.events[:, -1]

    for (fmin, fmax) in fb_bands:
        tr_b = tr.copy().filter(fmin, fmax, picks='eeg', verbose=False)
        te_b = te.copy().filter(fmin, fmax, picks='eeg', verbose=False)

        Xtr = tr_b.get_data()
        Xte = te_b.get_data()

        if zscore_epoch:
            Xtr = _epochwise_zscore(Xtr)
            Xte = _epochwise_zscore(Xte)

        try:
            csp = CSP(n_components=n_csp, reg='ledoit_wolf', log=True, norm_trace=False)
        except TypeError:
            csp = CSP(n_components=n_csp, reg='ledoit_wolf', log=True)

        Xtr_c = csp.fit_transform(Xtr, y_tr)
        Xte_c = csp.transform(Xte)

        Xtr_list.append(Xtr_c)
        Xte_list.append(Xte_c)

    Xtr_fb = np.concatenate(Xtr_list, axis=1)
    Xte_fb = np.concatenate(Xte_list,  axis=1)
    return Xtr_fb, Xte_fb

def _fit_scale_lda(Xtr, ytr, Xte, lda_params=LDA_PARAMS):
    """Estandariza features (fit en train) + LDA. Devuelve (yhat, clf, scaler)."""
    scaler = StandardScaler()
    Xtr_s = scaler.fit_transform(Xtr)
    Xte_s = scaler.transform(Xte)

    clf = LDA(**lda_params)
    clf.fit(Xtr_s, ytr)
    yhat = clf.predict(Xte_s)
    return yhat, clf, scaler


## Bloque - Diagnostico

In [None]:
# # %% [Running classifier — Diagnóstico de ventana temporal]
# from sklearn.model_selection import StratifiedKFold
# from sklearn.pipeline import Pipeline
# from sklearn.preprocessing import LabelEncoder

# def run_running_classifier(epochs: mne.Epochs,
#                            crop_train=(1.0, 2.0),  # ventana fija para entrenar
#                            w_len=0.5,             # ancho de ventana deslizante (s)
#                            w_step=0.1,            # paso entre ventanas (s)
#                            n_splits=10,
#                            n_csp=6):
#     """
#     Entrena CSP+LDA en crop_train y evalúa en ventanas deslizantes a lo largo del epoch completo.
#     Devuelve (times, curve) donde curve es la accuracy promedio vs tiempo (centro de ventana).
#     Útil para elegir la mejor crop_window para tus experimentos.
#     """
#     # Datos completos (para test deslizante)
#     X_full = epochs.get_data()
#     inv = {v: k for k, v in epochs.event_id.items()}
#     y_str = np.array([inv[e[-1]] for e in epochs.events], dtype=object)
#     y = LabelEncoder().fit_transform(y_str)
#     sf = epochs.info['sfreq']

#     # Conjunto de entrenamiento limitado a la ventana crop_train
#     ep_train = epochs.copy().crop(*crop_train)
#     Xtr = ep_train.get_data()

#     # Ventanas deslizantes (sobre el epoch completo)
#     w_len_s = int(sf * w_len)
#     w_step_s = int(sf * w_step)
#     starts = np.arange(0, X_full.shape[2] - w_len_s, w_step_s)

#     skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
#     win_scores = []
#     for tr, te in skf.split(Xtr, y):
#         pipe = Pipeline([
#             ('csp', CSP(n_components=n_csp, reg='ledoit_wolf', log=True)),
#             ('lda', LDA(**LDA_PARAMS))
#         ])
#         pipe.fit(Xtr[tr], y[tr])

#         fold_scores = []
#         for n in starts:
#             Xwin = X_full[te][:, :, n:n+w_len_s]
#             fold_scores.append(pipe.score(Xwin, y[te]))
#         win_scores.append(fold_scores)

#     curve = np.mean(win_scores, axis=0)
#     times = (starts + w_len_s/2) / sf + epochs.tmin
#     return times, curve


### Bloque 3 — Inspección de datos

Qué hace: muestra rutas y un listado rápido del contenido de data/ y data/processed/ para verificar que los FIF están donde esperamos

In [29]:
# %% [Inspect data folders]
# Celda añadida automáticamente: muestra rutas y lista contenidos de data y data/processed
try:
    print(f"PROJ: {PROJ}")
    print(f"DATA: {PROJ / 'data'}")
    print(f"DATA_PROC: {DATA_PROC}")
    print('\nContenido de data (top-level):')
    data_dir = PROJ / 'data'
    if data_dir.exists():
        for p in sorted(data_dir.iterdir()):
            print(f" - {p.name}{'/' if p.is_dir() else ''}")
    else:
        print('  (no existe)')

    print('\nContenido de data/processed (muestras):')
    if DATA_PROC.exists():
        for p in sorted(DATA_PROC.glob('*'))[:50]:
            print(f" - {p.name}")
    else:
        print('  (no existe)')
except Exception as e:
    print('Error inspeccionando data:', e)


PROJ: C:\Users\joelc\Desktop\eeg2
DATA: C:\Users\joelc\Desktop\eeg2\data
DATA_PROC: C:\Users\joelc\Desktop\eeg2\data\processed

Contenido de data (top-level):
 - processed/
 - raw/

Contenido de data/processed (muestras):
 - all_subjects-epo.fif
 - S001_MI-epo.fif
 - S002_MI-epo.fif
 - S003_MI-epo.fif
 - S004_MI-epo.fif
 - S005_MI-epo.fif
 - S006_MI-epo.fif
 - S007_MI-epo.fif
 - S008_MI-epo.fif
 - S009_MI-epo.fif
 - S010_MI-epo.fif
 - S011_MI-epo.fif
 - S012_MI-epo.fif
 - S013_MI-epo.fif
 - S014_MI-epo.fif
 - S015_MI-epo.fif
 - S016_MI-epo.fif
 - S017_MI-epo.fif
 - S018_MI-epo.fif
 - S019_MI-epo.fif
 - S020_MI-epo.fif
 - S021_MI-epo.fif
 - S022_MI-epo.fif
 - S023_MI-epo.fif
 - S024_MI-epo.fif
 - S025_MI-epo.fif
 - S026_MI-epo.fif
 - S027_MI-epo.fif
 - S028_MI-epo.fif
 - S029_MI-epo.fif
 - S030_MI-epo.fif
 - S031_MI-epo.fif
 - S032_MI-epo.fif
 - S033_MI-epo.fif
 - S034_MI-epo.fif
 - S035_MI-epo.fif
 - S036_MI-epo.fif
 - S037_MI-epo.fif
 - S039_MI-epo.fif
 - S040_MI-epo.fif
 - S041_MI-ep

### Bloque 4 — Intra-sujeto (k-Fold CV) con logs + guardado

Qué hace: ejecuta CV por sujeto con FBCSP+LDA, imprime métricas limpias, guarda matriz de confusión (PNG/CSV), y escribe métricas en tables/metrics_intra.csv. Además genera un TXT en logs/ con el detalle de la corrida.

Añade argumentos crop_window, motor_only, zscore_epoch, fb_bands, n_csp. Guarda lo de siempre (matriz, métricas, log) pero ahora puedes probar rápidamente ventanas/picks

In [30]:
# %% [Intra-sujeto — Mejorado] CV por sujeto con crop/picks/zscore/banco denso
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import LabelEncoder

def run_intra_subject(subject_id='S001',
                      fif_dir=DATA_PROC,
                      k=5,
                      random_state=42,
                      crop_window=None,          # p.ej. (0.5, 3.5)
                      motor_only=False,          # True -> C3/Cz/C4
                      zscore_epoch=False,        # True -> z-score por época
                      fb_bands=FB_BANDS_DENSE,   # o FB_BANDS_CLASSIC
                      n_csp=DEFAULT_N_CSP):
    """
    K-fold CV (FBCSP+LDA) con opciones de ventana/picks/zscore/fb.
    Guarda matriz agregada y métricas en models/fbcsp_lda/{figures,tables,logs}.
    """
    logger, log_path = _init_logger(run_name=f"intra_{subject_id}_099")
    fif_path = fif_dir / f'{subject_id}_MI-epo.fif'
    epochs = mne.read_epochs(fif_path, preload=True, verbose=False)

    X, y_str = _epochs_to_Xy(epochs)
    le = LabelEncoder(); y = le.fit_transform(y_str)
    classes = list(le.classes_)

    knobs = dict(crop_window=crop_window, motor_only=motor_only,
                 zscore_epoch=zscore_epoch, fb_bands=fb_bands, n_csp=n_csp)
    logger.info(f"INTRA {subject_id} | k={k} | n_epochs={len(y)} | clases={classes} | sfreq={epochs.info['sfreq']}")
    logger.info(f"Perillas: {knobs}")

    skf = StratifiedKFold(n_splits=k, shuffle=True, random_state=random_state)
    accs, f1s = [], []
    cm_sum = np.zeros((len(classes), len(classes)), dtype=int)

    for fold, (tr_idx, te_idx) in enumerate(skf.split(np.zeros(len(y)), y), start=1):
        ep_tr = epochs[tr_idx]; ep_te = epochs[te_idx]
        with mne.utils.use_log_level("ERROR"):
            Xtr_fb, Xte_fb = _fit_fb_csp_transform(ep_tr, ep_te,
                                                   fb_bands=fb_bands,
                                                   n_csp=n_csp,
                                                   motor_only=motor_only,
                                                   zscore_epoch=zscore_epoch,
                                                   crop_window=crop_window)
        yhat, clf, scaler = _fit_scale_lda(Xtr_fb, y[tr_idx], Xte_fb)

        acc = accuracy_score(y[te_idx], yhat)
        f1m = f1_score(y[te_idx], yhat, average='macro')
        cm_sum += confusion_matrix(y[te_idx], yhat, labels=np.arange(len(classes)))
        accs.append(acc); f1s.append(f1m)
        logger.info(f"[fold {fold}] acc={acc:.3f} | f1m={f1m:.3f} | n_test={len(te_idx)}")

    acc_mu, acc_sd = float(np.mean(accs)), float(np.std(accs))
    f1_mu,  f1_sd  = float(np.mean(f1s)),  float(np.std(f1s))
    logger.info(f"[INTRA {subject_id}] ACC={acc_mu:.3f}±{acc_sd:.3f} | F1m={f1_mu:.3f}±{f1_sd:.3f}")
    logger.info("Matriz agregada:\n" + pd.DataFrame(cm_sum, index=classes, columns=classes).to_string())

    stem = f"intra_{subject_id}"
    _save_confusion(cm_sum, classes, title=f"Intra {subject_id} — FBCSP+LDA", stem=stem)

    row = dict(mode="intra", subject=subject_id,
               acc_mean=round(acc_mu,4), acc_std=round(acc_sd,4),
               f1_macro_mean=round(f1_mu,4), f1_macro_std=round(f1_sd,4),
               n_epochs=int(len(y)), n_classes=int(len(classes)),
               crop=str(crop_window), motor_only=bool(motor_only),
               zscore_epoch=bool(zscore_epoch), n_csp=int(n_csp),
               fb_bands=len(fb_bands))
    _append_metrics(row, table_name="metrics_intra.csv")
    logger.info(f"Log → {log_path}")

    print(f"[Intra {subject_id}] k={k} | ACC={acc_mu:.3f}±{acc_sd:.3f} | F1m={f1_mu:.3f}±{f1_sd:.3f}")
    print("Clases:", classes)
    print("Matriz (suma folds):\n", cm_sum)
    return dict(accs=accs, f1s=f1s, cm=cm_sum, classes=classes)


### Bloque 5 — Cross-sujeto (LOSO) con logs + guardado

Qué hace: para cada sujeto como test, entrena en el resto, calcula métricas y guarda una matriz de confusión por sujeto y una global. También guarda métricas por sujeto en tables/metrics_loso_per_subject.csv y un resumen global en tables/metrics_loso.csv, además de un TXT en logs/

- run_loso(..., use_strict=True) hace LOSO clásico sobre el conjunto resuelto (con Strict si está ON).

- run_loso_single(test_subject, ...) entrena con todos los demás y prueba sólo en ese sujeto (Strict opcional).

- Incluye utilidades de selección Strict y reemplazo para subject_list.

In [31]:
# %% [Cross-sujeto LOSO — Mejorado]  — guarda “perillas” en logs y CSVs (clásico, subset y single)
from glob import glob
import re
import numpy as np
import pandas as pd
import mne
from pathlib import Path
from sklearn.preprocessing import LabelEncoder
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix
from sklearn.preprocessing import StandardScaler

STRICT_DROP_TXT = PROJ / 'reports' / 'tables' / '02_prepro' / 'subjects_strict_DROP.txt'
print(f"Archivo DROP-only: {STRICT_DROP_TXT} (si no existe, no hay exclusiones)")
_re_sid = re.compile(r'^S\d{3}$')

def _list_subject_fifs(fif_dir=DATA_PROC, pattern='S???_MI-epo.fif'):
    return sorted(glob(str(fif_dir / pattern)))

def _list_available_subjects(fif_dir=DATA_PROC):
    files = _list_subject_fifs(fif_dir)
    return sorted({Path(f).stem.split('_')[0] for f in files})

def _read_drop_file(path: Path):
    if not path.exists():
        return set()
    s = set()
    with open(path, 'r', encoding='utf-8', errors='ignore') as f:
        for ln in f:
            sid = ln.strip().upper()
            if _re_sid.match(sid):
                s.add(sid)
    return s

def _strict_valid_from_drop(avail_ids):
    drop = _read_drop_file(STRICT_DROP_TXT)
    avail = set(avail_ids)
    valid = sorted(avail - drop)
    info = f"DROP-only: {len(drop)} en DROP; válidos={len(valid)}/{len(avail)}"
    if not STRICT_DROP_TXT.exists():
        info += " (archivo DROP no encontrado → sin exclusiones)"
    return valid, info

def _knobs_dict(crop_window, motor_only, zscore_epoch, fb_bands, n_csp):
    return dict(
        crop_window=crop_window if crop_window is not None else None,
        motor_only=bool(motor_only),
        zscore_epoch=bool(zscore_epoch),
        fb_bands=str(fb_bands),
        n_csp=int(n_csp)
    )

def run_loso(fif_dir=DATA_PROC,
             subject_list=None,
             use_strict=True,
             crop_window=None,
             motor_only=False,
             zscore_epoch=False,
             fb_bands=DEFAULT_FB_BANDS,
             n_csp=DEFAULT_N_CSP):
    """LOSO clásico (DROP-only si use_strict=True) con registro de “perillas”."""
    logger, log_path = _init_logger(run_name="loso")

    knobs = _knobs_dict(crop_window, motor_only, zscore_epoch, fb_bands, n_csp)
    logger.info(f"Perillas: {knobs}")

    avail = _list_available_subjects(fif_dir)
    if not avail:
        logger.error(f"No hay FIF S???_MI-epo.fif en {fif_dir}")
        return None

    if use_strict:
        valid, info = _strict_valid_from_drop(avail)
        if subject_list:
            sids = [s for s in (subject_list or []) if s in valid]
            logger.info(f"(DROP-only) tests solicitados → {sids} | {info}")
        else:
            sids = valid
            logger.info(f"Usando todos los válidos ({len(sids)}). {info}")
    else:
        sids = subject_list or avail
        logger.info(f"(strict=OFF) tests → {sids[:10]}{' ...' if len(sids)>10 else ''}")

    if not sids:
        logger.warning("No hay sujetos para LOSO.")
        return None

    ep_map = {sid: mne.read_epochs(str(fif_dir / f"{sid}_MI-epo.fif"),
                                   preload=True, verbose=False) for sid in sids}

    classes_global, cm_global, rows = None, None, []
    for s_test, ep_te in ep_map.items():
        train_ids = [sid for sid in sids if sid != s_test]
        if not train_ids:
            logger.warning(f"Sin train para {s_test}.")
            continue

        ep_tr = mne.concatenate_epochs([ep_map[s] for s in train_ids], on_mismatch='ignore')
        ep_te = ep_te.copy().reorder_channels(ep_tr.ch_names)

        _, y_tr_str = _epochs_to_Xy(ep_tr)
        _, y_te_str = _epochs_to_Xy(ep_te)
        le = LabelEncoder().fit(np.concatenate([y_tr_str, y_te_str]))
        y_tr = le.transform(y_tr_str)
        y_te = le.transform(y_te_str)
        classes = list(le.classes_)

        if classes_global is None:
            classes_global = classes
            cm_global = np.zeros((len(classes), len(classes)), dtype=int)

        with mne.utils.use_log_level("ERROR"):
            try:
                Xtr_fb, Xte_fb = _fit_fb_csp_transform(ep_tr, ep_te,
                                                       fb_bands=fb_bands,
                                                       n_csp=n_csp,
                                                       motor_only=motor_only,
                                                       zscore_epoch=zscore_epoch,
                                                       crop_window=crop_window)
            except TypeError:
                Xtr_fb, Xte_fb = _fit_fb_csp_transform(ep_tr, ep_te,
                                                       fb_bands=fb_bands,
                                                       n_csp=n_csp)

        yhat, clf, scaler = _fit_scale_lda(Xtr_fb, y_tr, Xte_fb)

        acc = accuracy_score(y_te, yhat)
        f1m = f1_score(y_te, yhat, average='macro')
        cm = confusion_matrix(y_te, yhat, labels=np.arange(len(classes)))
        cm_global += cm

        logger.info(f"[LOSO] test={s_test} | acc={acc:.3f} | f1m={f1m:.3f} | n_test={len(y_te)}")
        _save_confusion(cm, classes, title=f"LOSO test {s_test} — FBCSP+LDA",
                        stem=f"loso_test-{s_test}")

        rows.append(dict(mode="loso",
                         test_subject=s_test,
                         acc=round(acc,4),
                         f1_macro=round(f1m,4),
                         n_test=int(len(y_te)),
                         **knobs))

    per_subj_path = TAB_DIR / "metrics_loso_per_subject.csv"
    pd.DataFrame(rows).sort_values('test_subject').to_csv(per_subj_path, index=False)
    logger.info(f"Métricas por sujeto → {per_subj_path}")

    _save_confusion(cm_global, classes_global, title="LOSO Global — FBCSP+LDA", stem="loso_global")

    acc_mu = float(np.mean([r['acc'] for r in rows])) if rows else 0.0
    f1_mu  = float(np.mean([r['f1_macro'] for r in rows])) if rows else 0.0
    row_g = dict(mode="loso_global",
                 acc_mean=round(acc_mu,4),
                 f1_macro_mean=round(f1_mu,4),
                 n_subjects=len(rows),
                 **knobs)
    metrics_path = _append_metrics(row_g, table_name="metrics_loso.csv")
    logger.info(f"Resumen global LOSO → {metrics_path}")
    logger.info(f"Log guardado en → {log_path}")

    print("\nResumen LOSO (promedios): ACC={:.3f} | F1m={:.3f}".format(acc_mu, f1_mu))
    return pd.DataFrame(rows).sort_values('test_subject').reset_index(drop=True)

def _split_calibration(ep_te, k_per_class=5):
    """Toma k épocas por clase como calibración y el resto como test."""
    if k_per_class <= 0:
        return None, ep_te
    labels = ep_te.events[:, -1]
    ep_calib_list, ep_eval_list = [], []
    for code in np.unique(labels):
        idx = np.where(labels == code)[0]
        if len(idx) == 0:
            continue
        take = min(k_per_class, len(idx))
        sel = idx[:take]
        rem = idx[take:]
        if take > 0:
            ep_calib_list.append(ep_te.copy()[sel])
        if len(rem) > 0:
            ep_eval_list.append(ep_te.copy()[rem])
    ep_calib = mne.concatenate_epochs(ep_calib_list, on_mismatch='ignore') if ep_calib_list else None
    ep_eval  = mne.concatenate_epochs(ep_eval_list,  on_mismatch='ignore') if ep_eval_list  else ep_te
    return ep_calib, ep_eval

def run_loso_single(test_subject: str,
                    fif_dir=DATA_PROC,
                    use_strict=True,
                    crop_window=None,
                    motor_only=False,
                    zscore_epoch=False,
                    fb_bands=DEFAULT_FB_BANDS,
                    n_csp=DEFAULT_N_CSP,
                    calibrate_k_per_class=0):
    """
    LOSO single con calibración opcional (reajusta SOLO scaler+LDA con k/cls del test).
    """
    logger, log_path = _init_logger(run_name=f"loso_single_{test_subject}")
    test_subject = test_subject.upper()
    test_path = fif_dir / f"{test_subject}_MI-epo.fif"
    if not test_path.exists():
        logger.error(f"No existe: {test_path}")
        return None
    ep_te_full = mne.read_epochs(str(test_path), preload=True, verbose=False)

    knobs = _knobs_dict(crop_window, motor_only, zscore_epoch, fb_bands, n_csp)
    knobs['calib_k'] = int(calibrate_k_per_class)
    logger.info(f"Perillas: {knobs}")

    avail = _list_available_subjects(fif_dir)
    train_ids = [s for s in avail if s != test_subject]
    if use_strict:
        valid, info = _strict_valid_from_drop(avail)
        train_ids = [s for s in train_ids if s in valid]
        logger.info(f"(DROP-only) {info} | train_ids={train_ids[:10]}{' ...' if len(train_ids)>10 else ''}")
    else:
        logger.info(f"(strict=OFF) train_ids={train_ids[:10]}{' ...' if len(train_ids)>10 else ''}")

    if not train_ids:
        logger.error("Sin sujetos de entrenamiento disponibles.")
        return None

    ep_tr = mne.concatenate_epochs(
        [mne.read_epochs(str(fif_dir / f"{s}_MI-epo.fif"), preload=True, verbose=False) for s in train_ids],
        on_mismatch='ignore'
    )
    ep_te_full = ep_te_full.copy().reorder_channels(ep_tr.ch_names)

    ep_calib, ep_eval = _split_calibration(ep_te_full, k_per_class=calibrate_k_per_class)

    _, y_tr_str = _epochs_to_Xy(ep_tr)
    _, y_ev_str = _epochs_to_Xy(ep_eval)
    le = LabelEncoder().fit(np.concatenate([y_tr_str, y_ev_str]))
    y_tr = le.transform(y_tr_str)
    y_ev = le.transform(y_ev_str)
    classes = list(le.classes_)

    with mne.utils.use_log_level("ERROR"):
        try:
            Xtr_fb, Xev_fb = _fit_fb_csp_transform(ep_tr, ep_eval,
                                                   fb_bands=fb_bands,
                                                   n_csp=n_csp,
                                                   motor_only=motor_only,
                                                   zscore_epoch=zscore_epoch,
                                                   crop_window=crop_window)
        except TypeError:
            Xtr_fb, Xev_fb = _fit_fb_csp_transform(ep_tr, ep_eval,
                                                   fb_bands=fb_bands,
                                                   n_csp=n_csp)

    yhat, clf, scaler = _fit_scale_lda(Xtr_fb, y_tr, Xev_fb)

    if ep_calib is not None and len(ep_calib) > 0:
        _, y_ca_str = _epochs_to_Xy(ep_calib)
        y_ca = le.transform(y_ca_str)
        with mne.utils.use_log_level("ERROR"):
            try:
                Xtr_fb2, Xca_fb = _fit_fb_csp_transform(ep_tr, ep_calib,
                                                        fb_bands=fb_bands,
                                                        n_csp=n_csp,
                                                        motor_only=motor_only,
                                                        zscore_epoch=zscore_epoch,
                                                        crop_window=crop_window)
            except TypeError:
                Xtr_fb2, Xca_fb = _fit_fb_csp_transform(ep_tr, ep_calib,
                                                        fb_bands=fb_bands,
                                                        n_csp=n_csp)
        scaler2 = StandardScaler()
        X_join = np.vstack([Xtr_fb2, Xca_fb])
        y_join = np.concatenate([y_tr, y_ca])
        X_join_s = scaler2.fit_transform(X_join)
        Xev_s = scaler2.transform(Xev_fb)

        clf2 = LDA(**LDA_PARAMS)
        clf2.fit(X_join_s, y_join)
        yhat = clf2.predict(Xev_s)

    acc = accuracy_score(y_ev, yhat)
    f1m = f1_score(y_ev, yhat, average='macro')
    cm = confusion_matrix(y_ev, yhat, labels=np.arange(len(classes)))
    logger.info(f"[LOSO SINGLE] test={test_subject} | acc={acc:.3f} | f1m={f1m:.3f} | n_test={len(y_ev)}")

    _save_confusion(cm, classes, title=f"LOSO single {test_subject} — FBCSP+LDA",
                    stem=f"loso_single_test-{test_subject}")

    row = dict(mode="loso_single",
               test_subject=test_subject,
               acc=round(acc,4),
               f1_macro=round(f1m,4),
               n_test=int(len(y_ev)),
               **knobs)
    _append_metrics(row, table_name="metrics_loso_single.csv")

    print(f"[LOSO single] {test_subject} | ACC={acc:.3f} | F1m={f1m:.3f}")
    return row

def run_loso_subset(subject_ids=None, fif_dir=DATA_PROC, use_strict=True,
                    crop_window=None, motor_only=False, zscore_epoch=False,
                    fb_bands=DEFAULT_FB_BANDS, n_csp=DEFAULT_N_CSP, n_default_test=2):
    """Ejecuta LOSO para un subset de tests y reusa run_loso()."""
    avail = _list_available_subjects(fif_dir)
    if not avail:
        print("No hay sujetos en data/processed.")
        return None
    if use_strict:
        valid, info = _strict_valid_from_drop(avail)
        if subject_ids is None:
            subs = valid[:n_default_test]
            print(f"[LOSO SUBSET] tests: {subs} | {info}")
        else:
            subs = [s for s in subject_ids if s in valid][:n_default_test]
            print(f"[LOSO SUBSET] tests: {subs} | {info}")
    else:
        if subject_ids is None:
            subs = avail[:n_default_test]
            print(f"(strict=OFF) tests: {subs}")
        else:
            subs = [s for s in subject_ids if s in avail][:n_default_test]
            print(f"(strict=OFF) tests: {subs}")

    if not subs:
        print("No se encontraron sujetos válidos para loso-subset.")
        return None

    return run_loso(fif_dir=fif_dir,
                    subject_list=subs,
                    use_strict=use_strict,
                    crop_window=crop_window,
                    motor_only=motor_only,
                    zscore_epoch=zscore_epoch,
                    fb_bands=fb_bands,
                    n_csp=n_csp)


Archivo DROP-only: C:\Users\joelc\Desktop\eeg2\reports\tables\02_prepro\subjects_strict_DROP.txt (si no existe, no hay exclusiones)


### Bloque 6 — Lanzadores “batch”: intra para N sujetos y LOSO para M sujetos

Qué hace: agrega funciones y un ejemplo de uso para:

- run_intra_batch(..., use_strict=True) ejecuta intra para un conjunto de sujetos resueltos (con reemplazo si Strict).

- run_loso_subset(..., use_strict=True) corre LOSO sólo en un subset de test subjects resuelto (con reemplazo si Strict).

- Ambos ofrecen defaults (4 intra, 2 loso) si no pasas lista

In [None]:
# %% [Batch runners] intra/loso con perillas y defaults fuertes
from glob import glob
from pathlib import Path

def _discover_subject_ids(fif_dir=DATA_PROC, pattern='S???_MI-epo.fif'):
    files = sorted(glob(str(fif_dir / pattern)))
    return [Path(f).stem.split('_')[0] for f in files]

def run_intra_batch(subject_ids=None, k=10, fif_dir=DATA_PROC,
                    crop_window=(0.5, 3.5),
                    motor_only=True,
                    zscore_epoch=True,
                    fb_bands=DEFAULT_FB_BANDS,
                    n_csp=DEFAULT_N_CSP):
    """
    Ejecuta intra para una lista (o detecta los 4 primeros si None).
    Usa por defecto: ventana 0.5–3.5, picks motores, z-score, banco denso, n_csp=6.
    """
    subs_all = _discover_subject_ids(fif_dir)
    if not subs_all:
        print("No se encontraron sujetos en", fif_dir)
        return None

    if subject_ids is None:
        subject_ids = subs_all[:4]
    else:
        subject_ids = [s for s in subject_ids if s in subs_all]
        if not subject_ids:
            print("Ningún ID válido en subject_ids.")
            return None

    rows = []
    print(f"[INTRA BATCH] sujetos: {subject_ids}")
    for sid in subject_ids:
        print(f"\n== INTRA {sid} ==")
        res = run_intra_subject(sid, fif_dir=fif_dir, k=k,
                                crop_window=crop_window,
                                motor_only=motor_only,
                                zscore_epoch=zscore_epoch,
                                fb_bands=fb_bands,
                                n_csp=n_csp)
        acc_mu = float(np.mean(res['accs'])) if res['accs'] else 0.0
        f1_mu  = float(np.mean(res['f1s']))  if res['f1s']  else 0.0
        rows.append(dict(subject=sid, acc_mean=acc_mu, f1_macro_mean=f1_mu,
                         k=k, n_classes=len(res['classes']),
                         crop=str(crop_window), motor_only=bool(motor_only),
                         zscore_epoch=bool(zscore_epoch), n_csp=int(n_csp),
                         fb_bands=len(fb_bands)))
    if rows:
        df = pd.DataFrame(rows).sort_values('subject')
        out_path = TAB_DIR / "metrics_intra_batch.csv"
        df.to_csv(out_path, index=False)
        print("Resumen intra batch →", out_path)
        display(df)
        return df
    return None

def run_loso_subset_quick(subject_ids=None, fif_dir=DATA_PROC,
                          crop_window=(0.5, 3.5),
                          motor_only=True,
                          zscore_epoch=True,
                          fb_bands=DEFAULT_FB_BANDS,
                          n_csp=DEFAULT_N_CSP,
                          n_default_test=2):
    """
    Azúcar: delega en run_loso_subset (definido en el bloque LOSO).
    """
    return run_loso_subset(subject_ids=subject_ids, fif_dir=fif_dir, use_strict=True,
                           crop_window=crop_window, motor_only=motor_only,
                           zscore_epoch=zscore_epoch, fb_bands=fb_bands, n_csp=n_csp,
                           n_default_test=n_default_test)


### Bloque 7 — Ejemplos de ejecución

Qué hace: muestra cómo lanzar los “batch” por defecto (4 intra, 2 LOSO) y cómo pasar propia lista de sujetos.

In [33]:
# %% [Examples — Cómo lanzar con buenas prácticas]
# 1) Intra de los 4 primeros sujetos con perillas "fuertes" para MI
_ = run_intra_batch(
    k=10,
    crop_window=(0.5, 4.5),
    motor_only=False,
    zscore_epoch=False,
    fb_bands=FB_BANDS_DENSE,
    n_csp=6
)

# 2) LOSO de los 2 primeros sujetos como TEST (rápido para validar)
# df_loso_subset = run_loso_subset(
#     crop_window=(0.5, 3.5),
#     motor_only=True,
#     zscore_epoch=True,
#     fb_bands=FB_BANDS_DENSE,
#     n_csp=6
# )

# 3) LOSO clásico sobre TODOS (cuidado: puede tardar)
# df_loso_all = run_loso(
#     crop_window=(0.5, 3.5),
#     motor_only=True,
#     zscore_epoch=True,
#     fb_bands=FB_BANDS_DENSE,
#     n_csp=6
# )

# 4) LOSO single con calibrado mínimo (k=5 épocas/clase del sujeto test)
row_single = run_loso_single(
    'S001',
    crop_window=(0.5, 4.5),
    motor_only=False,
    zscore_epoch=False,
    fb_bands=FB_BANDS_DENSE,
    n_csp=6,
    calibrate_k_per_class=5
)


[INTRA BATCH] sujetos: ['S001', 'S002', 'S003', 'S004']

== INTRA S001 ==
[21:38:36] INFO: INTRA S001 | k=10 | n_epochs=77 | clases=['Both Feet', 'Both Fists', 'Left', 'Right'] | sfreq=160.0
[21:38:36] INFO: Perillas: {'crop_window': (0.5, 4.5), 'motor_only': False, 'zscore_epoch': False, 'fb_bands': [(8, 10), (10, 12), (12, 14), (14, 16), (16, 18), (18, 20), (20, 22), (22, 24), (24, 26), (26, 28), (28, 30)], 'n_csp': 6}
[21:39:03] INFO: [fold 1] acc=0.750 | f1m=0.667 | n_test=8
[21:39:29] INFO: [fold 2] acc=0.875 | f1m=0.867 | n_test=8
[21:39:55] INFO: [fold 3] acc=1.000 | f1m=1.000 | n_test=8
[21:40:21] INFO: [fold 4] acc=1.000 | f1m=1.000 | n_test=8
[21:40:47] INFO: [fold 5] acc=0.875 | f1m=0.867 | n_test=8
[21:41:13] INFO: [fold 6] acc=0.875 | f1m=0.867 | n_test=8
[21:41:39] INFO: [fold 7] acc=0.875 | f1m=0.867 | n_test=8
[21:42:05] INFO: [fold 8] acc=1.000 | f1m=1.000 | n_test=7
[21:42:31] INFO: [fold 9] acc=0.857 | f1m=0.867 | n_test=7
[21:42:57] INFO: [fold 10] acc=1.000 | f1m=1

Unnamed: 0,subject,acc_mean,f1_macro_mean,k,n_classes,crop,motor_only,zscore_epoch,n_csp,fb_bands
0,S001,0.910714,0.9,10,4,"(0.5, 4.5)",False,False,6,11
1,S002,0.858929,0.84,10,4,"(0.5, 4.5)",False,False,6,11
2,S003,0.446667,0.395,10,4,"(0.5, 4.5)",False,False,6,11
3,S004,0.7,0.666667,10,4,"(0.5, 4.5)",False,False,6,11


[21:55:23] INFO: Perillas: {'crop_window': (0.5, 4.5), 'motor_only': False, 'zscore_epoch': False, 'fb_bands': '[(8, 10), (10, 12), (12, 14), (14, 16), (16, 18), (18, 20), (20, 22), (22, 24), (24, 26), (26, 28), (28, 30)]', 'n_csp': 6, 'calib_k': 5}
[21:55:23] INFO: (DROP-only) DROP-only: 33 en DROP; válidos=70/103 | train_ids=['S002', 'S003', 'S004', 'S005', 'S006', 'S007', 'S008', 'S010', 'S011', 'S012'] ...


KeyboardInterrupt: 

In [None]:
# %% [LOSO single — Barrido de perillas y ranking de resultados]
import itertools
import pandas as pd
import numpy as np

# ---------- Presets de bandas (por si no existen en tu sesión actual) ----------
def _bands_dense():
    # 2 Hz entre 8 y 30 (11 bandas)
    return [(f, f+2) for f in range(8, 30, 2)]

def _bands_classic():
    # Las clásicas que venías usando (6 bandas)
    return [(8,12), (12,16), (16,20), (20,24), (24,28), (28,30)]

def _bands_wide():
    # Un preset más ancho (3 bandas)
    return [(8,14), (14,22), (22,30)]

# Si ya tienes FB_BANDS_DENSE/DEFAULT_N_CSP definidos, los usamos; si no, por defecto:
try:
    FB_BANDS_DENSE
except NameError:
    FB_BANDS_DENSE = _bands_dense()
try:
    DEFAULT_N_CSP
except NameError:
    DEFAULT_N_CSP = 6

# ---------- Helper: obtener preset de bandas por nombre ----------
def _get_fb_bands(preset_name: str):
    name = preset_name.lower()
    if name in ("dense", "2hz", "narrow"):
        return _bands_dense()
    if name in ("classic", "std", "6bands"):
        return _bands_classic()
    if name in ("wide", "3bands"):
        return _bands_wide()
    raise ValueError(f"Preset de bandas desconocido: {preset_name}")

# ---------- Barrido ----------
def run_loso_single_sweep(
    test_subject: str = "S001",
    crop_windows = ((0.5, 4.5), (1.0, 4.0)),     # prueba 2 ventanas
    motor_only_opts = (False, True),             # todos canales vs solo motores
    zscore_epoch_opts = (False, True),           # sin / con z-score por época
    fb_band_presets = ("dense", "classic"),      # presets de bancos
    n_csp_opts = (4, 6, 8),                      # nº comp. CSP por banda
    calibrate_k_opts = (0, 5),                   # sin / con calibración ligera
    fif_dir = DATA_PROC,
    use_strict = True,
    top_k = 5,                                   # cuántos mostrar en ranking
    out_csv_stem = "loso_single_sweep"
):
    """
    Recorre una grilla de combinaciones para run_loso_single() y guarda todo a CSV.
    Devuelve el DataFrame completo. Imprime ranking por ACC y por F1.
    """
    rows = []
    total = (len(crop_windows) * len(motor_only_opts) * len(zscore_epoch_opts) *
             len(fb_band_presets) * len(n_csp_opts) * len(calibrate_k_opts))
    print(f"[Sweep] Combinaciones a ejecutar: {total}")

    comb_iter = itertools.product(
        crop_windows, motor_only_opts, zscore_epoch_opts, fb_band_presets, n_csp_opts, calibrate_k_opts
    )
    for i, (cw, mo, zep, fb_name, nc, calib_k) in enumerate(comb_iter, start=1):
        fb = _get_fb_bands(fb_name)
        print(f"\n[{i}/{total}] test={test_subject} | crop={cw} | motor_only={mo} | zscore_epoch={zep} "
              f"| fb='{fb_name}'({len(fb)}) | n_csp={nc} | calib_k={calib_k}")

        res = run_loso_single(
            test_subject,
            fif_dir=fif_dir,
            use_strict=use_strict,
            crop_window=cw,
            motor_only=mo,
            zscore_epoch=zep,
            fb_bands=fb,
            n_csp=nc,
            calibrate_k_per_class=calib_k
        )
        # 'res' ya incluye acc, f1, n_test y las perillas (por nuestro run_loso_single mejorado)
        if res is None:
            # anota fallo para no perder la fila
            rows.append(dict(
                test_subject=test_subject, acc=np.nan, f1_macro=np.nan, n_test=0,
                crop=str(cw), motor_only=bool(mo), zscore_epoch=bool(zep),
                n_csp=int(nc), fb_preset=fb_name, fb_bands=len(fb),
                calib_k=calib_k, status="error"
            ))
        else:
            res_row = dict(res)  # copia
            res_row["fb_preset"] = fb_name   # además del conteo, guarda el nombre del preset
            res_row["status"] = "ok"
            rows.append(res_row)

    # Armar DataFrame y guardar
    df = pd.DataFrame(rows)
    df = df[[
        "test_subject", "status", "acc", "f1_macro", "n_test",
        "crop", "motor_only", "zscore_epoch", "n_csp", "fb_preset", "fb_bands", "calib_k"
    ]]

    out_csv = TAB_DIR / f"{out_csv_stem}_{test_subject}.csv"
    df.to_csv(out_csv, index=False)
    print(f"\n[Salida] Resultados guardados en: {out_csv}")

    # Rankings
    df_ok = df[df["status"] == "ok"].copy()
    if len(df_ok):
        top_acc = df_ok.sort_values("acc", ascending=False).head(top_k)
        top_f1  = df_ok.sort_values("f1_macro", ascending=False).head(top_k)

        print("\nTop por ACC:")
        print(top_acc[["acc","f1_macro","n_test","crop","motor_only","zscore_epoch","n_csp","fb_preset","calib_k"]]
              .to_string(index=False))

        print("\nTop por F1_macro:")
        print(top_f1[["acc","f1_macro","n_test","crop","motor_only","zscore_epoch","n_csp","fb_preset","calib_k"]]
              .to_string(index=False))
    else:
        print("No hubo corridas exitosas para rankear.")
    return df

# --------------------------------------------------------------------------------
# EJEMPLO DE USO (puedes ajustar la grilla arriba si quieres algo más pequeño/grande)
# --------------------------------------------------------------------------------
df_sweep = run_loso_single_sweep(
    test_subject="S001",
    crop_windows=((0.5,4.5),(1.0,4.0)),   # puedes añadir (0.0,4.0) o (1.0,3.5)
    motor_only_opts=(False, True),
    zscore_epoch_opts=(False, True),
    fb_band_presets=("dense","classic"),
    n_csp_opts=(4,6,8),
    calibrate_k_opts=(0,5),
    top_k=5
)


[Sweep] Combinaciones a ejecutar: 96

[1/96] test=S001 | crop=(0.5, 4.5) | motor_only=False | zscore_epoch=False | fb='dense'(11) | n_csp=4 | calib_k=0
[23:15:08] INFO: Perillas: {'crop_window': (0.5, 4.5), 'motor_only': False, 'zscore_epoch': False, 'fb_bands': '[(8, 10), (10, 12), (12, 14), (14, 16), (16, 18), (18, 20), (20, 22), (22, 24), (24, 26), (26, 28), (28, 30)]', 'n_csp': 4, 'calib_k': 0}
[23:15:08] INFO: (DROP-only) DROP-only: 13 en DROP; válidos=90/103 | train_ids=['S002', 'S003', 'S004', 'S005', 'S006', 'S007', 'S008', 'S010', 'S011', 'S012'] ...


KeyboardInterrupt: 