In [2]:
# -*- coding: utf-8 -*-
# ===========================================
# MI-EEG (PhysioNet eegmmidb) — CNN+Transformer (Balanced 21/cls/subj) + Sliding Window
# Patch: usar Kfold5.json SOLO para sujetos por fold (ignorar tr_idx/te_idx)
# Split por sujeto/ensayo -> luego windowing (sin fuga). Métrica por ENSAYO.
# FT por sujeto (L2-SP) agrupando por ensayo (sin fuga).
# ===========================================

import os, re, math, json, copy, random, itertools, collections
from pathlib import Path
from datetime import datetime

import numpy as np
import pandas as pd
import mne

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler, Subset

from sklearn.model_selection import GroupKFold, GroupShuffleSplit, StratifiedKFold
from sklearn.metrics import confusion_matrix, classification_report, f1_score, cohen_kappa_score
from sklearn.utils import check_random_state
from tqdm import tqdm
import matplotlib.pyplot as plt

# =========================
# CONFIG
# =========================
PROJ = Path('..').resolve().parent
DATA_RAW = PROJ / 'data' / 'raw'           # .../S###/S###R##.edf
FOLDS_JSON = PROJ / 'models' / 'folds' / 'Kfold5.json'
CACHE_DIR = PROJ / 'data' / 'cache'
CACHE_DIR.mkdir(parents=True, exist_ok=True)

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
RANDOM_STATE = 42
torch.manual_seed(RANDOM_STATE); np.random.seed(RANDOM_STATE); random.seed(RANDOM_STATE)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
print(f"🚀 Dispositivo: {DEVICE}")

FS = 160.0

# Escenario y rango base por ensayo
CLASS_SCENARIO = '4c'            # '2c','3c','4c'
WINDOW_GLOBAL = (0.0, 4.0)       # segundos rel. T1/T2 (ensayo base)

# Sliding window (mismo nº por ensayo)
WIN_LEN_SEC = 2.0
WIN_STRIDE_SEC = 0.5
N_SUBWINS = int(math.floor((WINDOW_GLOBAL[1]-WINDOW_GLOBAL[0]-WIN_LEN_SEC)/WIN_STRIDE_SEC) + 1)

# Balance 21 por clase/sujeto
N_PER_CLASS_PER_SUBJECT = 21

# CV
N_FOLDS = 5
GLOBAL_VAL_SPLIT = 0.15
GLOBAL_PATIENCE  = 10

# Entrenamiento
BATCH_SIZE = 64
EPOCHS_GLOBAL = 80
LR_INIT = 1e-3

# FT por sujeto
CALIB_CV_FOLDS = 4
FT_EPOCHS = 30
FT_BASE_LR = 5e-5
FT_HEAD_LR = 1e-3
FT_L2SP = 1e-4
FT_PATIENCE = 5
FT_VAL_RATIO = 0.2

# Band-pass
USE_BANDPASS = False
BP_LO, BP_HI = 8.0, 30.0
BP_METHOD, BP_PHASE = 'fir', 'zero'

# Dataset
EXCLUDE_SUBJECTS = {38,88,89,92,100,104}
MI_RUNS_LR = [4,8,12]            # Left/Right
MI_RUNS_OF = [6,10,14]           # Both Fists / Both Feet
EXPECTED_8 = ['C3','C4','Cz','CP3','CP4','FC3','FC4','FCz']

CLASS_NAMES = ['Left', 'Right', 'Both Fists', 'Both Feet']

# =========================
# UTIL: canales/edf
# =========================
def normalize_label(s: str) -> str:
    if s is None: return s
    s = s.strip()
    s = re.sub(r'[^A-Za-z0-9]', '', s)
    s = re.sub(r'([A-Za-z])0([0-9])', r'\1\2', s)
    s = re.sub(r'([A-Za-z])Z$', r'\1z', s)
    s = s.replace('fp', 'Fp').replace('FP', 'Fp')
    s = ''.join(ch.upper() if ch != 'z' else 'z' for ch in s)
    return s

def rename_channels_1010(raw: mne.io.BaseRaw):
    mapping = {}
    for ch in raw.ch_names:
        lab = normalize_label(ch)
        lab = lab[:-1] + 'z' if lab.endswith('Z') else lab
        lab = re.sub(r'([A-Z])Z$', r'\1z', lab)
        mapping[ch] = lab
    mne.rename_channels(raw.info, mapping)

def ensure_channels_order(raw: mne.io.BaseRaw, desired_channels=EXPECTED_8):
    missing = [ch for ch in desired_channels if ch not in raw.ch_names]
    if missing:
        print(f"[WARN] faltan canales {missing} en {getattr(raw,'filenames',[''])[0]}")
        return None
    raw.reorder_channels([ch for ch in raw.ch_names if ch in desired_channels] +
                         [ch for ch in raw.ch_names if ch not in desired_channels])
    raw.pick_channels(desired_channels, ordered=True)
    return raw

_re_file = re.compile(r'[Ss](\d{3}).*?[Rr](\d{2})')
def parse_subject_run(path: Path):
    m = _re_file.search(str(path))
    if not m: return None, None
    return int(m.group(1)), int(m.group(2))

def run_kind(run_id:int):
    if run_id in MI_RUNS_LR: return 'LR'
    if run_id in MI_RUNS_OF: return 'OF'
    return None

_SNR_TABLE = None
def _load_snr_table():
    global _SNR_TABLE
    if _SNR_TABLE is not None: return _SNR_TABLE
    csv_path = PROJ / 'reports' / 'psd_mains' / 'psd_mains_summary.csv'
    if csv_path.exists():
        try: _SNR_TABLE = pd.read_csv(csv_path)
        except Exception as e:
            print(f"[SNR] No se pudo leer {csv_path}: {e}")
            _SNR_TABLE = None
    return _SNR_TABLE

def _decide_notch(subject, run, th_db=10.0):
    df = _load_snr_table()
    if df is None:  return 60.0
    row = df[(df['subject']==subject) & (df['run']==run)]
    if row.empty:   return 60.0
    snr50 = float(row['snr50_db'].iloc[0]); snr60 = float(row['snr60_db'].iloc[0])
    if snr60 >= th_db and snr60 >= snr50: return 60.0
    if snr50 >= th_db and snr50 >  snr60: return 50.0
    return None

def read_raw_edf(path: Path):
    raw = mne.io.read_raw_edf(path, preload=True, verbose=False)
    raw.pick(mne.pick_types(raw.info, eeg=True))
    rename_channels_1010(raw)
    try:
        mont = mne.channels.make_standard_montage('standard_1020')
        raw.set_montage(mont, on_missing='ignore')
    except Exception:
        pass
    if abs(raw.info['sfreq'] - FS) > 1e-6:
        raw.resample(FS, npad="auto")
    raw = ensure_channels_order(raw, EXPECTED_8)
    if raw is None: return None

    sid, rid = parse_subject_run(path)
    notch = _decide_notch(sid, rid)
    if notch is not None:
        raw.notch_filter(freqs=[float(notch)], picks='eeg', method='spectrum_fit', phase='zero')

    if USE_BANDPASS and (BP_LO is not None) and (BP_HI is not None):
        raw.filter(l_freq=float(BP_LO), h_freq=float(BP_HI),
                   picks='eeg', method=BP_METHOD, phase=BP_PHASE, verbose=False)
    return raw

def collect_events_T1T2(raw: mne.io.BaseRaw):
    out = []
    if raw.annotations is None or len(raw.annotations) == 0: return out
    def _norm(s): return str(s).strip().upper().replace(' ','')
    for onset, desc in zip(raw.annotations.onset, raw.annotations.description):
        tag = _norm(desc)
        if tag in ('T1','T2'): out.append((float(onset), tag))
    out.sort()
    dedup = []; last_t1 = last_t2 = -1e9
    for t, tag in out:
        if tag == 'T1':
            if (t - last_t1) >= 0.5: dedup.append((t, tag)); last_t1 = t
        else:
            if (t - last_t2) >= 0.5: dedup.append((t, tag)); last_t2 = t
    return dedup

# =========================
# Dataset helpers
# =========================
def subjects_available():
    subs = []
    for sdir in sorted(DATA_RAW.glob('S*')):
        if not sdir.is_dir(): continue
        try: sid = int(sdir.name[1:])
        except: continue
        if sid in EXCLUDE_SUBJECTS: continue
        any_mi = any((sdir / f"S{sid:03d}R{r:02d}.edf").exists() for r in (MI_RUNS_LR+MI_RUNS_OF))
        if any_mi: subs.append(sid)
    return subs

def extract_base_trials_from_run(edf_path: Path, scenario: str, window_global=(0.0,4.0)):
    subj, run = parse_subject_run(edf_path)
    kind = run_kind(run)
    if kind not in ('LR','OF'): return []

    raw = read_raw_edf(edf_path)
    if raw is None: return []
    data = raw.get_data()
    fs = raw.info['sfreq']
    out = []
    events = collect_events_T1T2(raw)
    rel_start, rel_end = window_global
    ev_idx = 0
    for onset_sec, tag in events:
        ev_idx += 1
        if kind == 'LR':
            label = 0 if tag=='T1' else 1
        else:
            label = 2 if tag=='T1' else 3
        if   scenario == '2c' and label not in (0,1): continue
        elif scenario == '3c' and label not in (0,1,2): continue
        elif scenario == '4c' and label not in (0,1,2,3): continue

        s0 = int(round((raw.first_time + onset_sec + rel_start) * fs))
        e0 = int(round((raw.first_time + onset_sec + rel_end)   * fs))
        if s0 < 0 or e0 > data.shape[1] or e0 <= s0: continue

        seg = data[:, s0:e0].T.astype(np.float32)  # (T_base, C)
        seg = (seg - seg.mean(axis=0, keepdims=True)) / (seg.std(axis=0, keepdims=True) + 1e-6)
        trial_key = f"S{subj:03d}_R{run:02d}_E{ev_idx:03d}"
        out.append((seg, label, subj, trial_key))
    return out

def build_balanced_trials(subjects, scenario='4c', window_global=(0.0,4.0), n_per_class=N_PER_CLASS_PER_SUBJECT):
    trials_balanced = []
    for s in tqdm(subjects, desc="Recolectando y balanceando ensayos base por sujeto"):
        sdir = DATA_RAW / f"S{s:03d}"
        if not sdir.exists(): continue
        bins = {0:[],1:[],2:[],3:[]}
        for r in (MI_RUNS_LR + MI_RUNS_OF):
            p = sdir / f"S{s:03d}R{r:02d}.edf"
            if not p.exists(): continue
            for t in extract_base_trials_from_run(p, scenario, window_global):
                bins[t[1]].append(t)
        # descarta sujeto si no tiene todas las clases
        if any(len(bins[c]) == 0 for c in (0,1,2,3)): continue
        rng = check_random_state(RANDOM_STATE + s)
        for c in (0,1,2,3):
            pool = bins[c]
            if len(pool) >= n_per_class:
                rng.shuffle(pool); take = pool[:n_per_class]
            else:
                idx = rng.choice(len(pool), size=n_per_class, replace=True)
                take = [pool[i] for i in idx]
            trials_balanced.extend(take)
    print(f"Ensayos base balanceados: {len(trials_balanced)}  (~ #sujetos_utiles * {n_per_class*4})")
    return trials_balanced

def window_trials(trials, win_len_sec=2.0, win_stride_sec=0.5, fs=160.0):
    """Windowing para una lista de ensayos (de un split). Retorna X,y,subj,trial_ids."""
    X, y, g, tkeys = [], [], [], []
    win_len = int(round(win_len_sec * fs))
    stride  = int(round(win_stride_sec * fs))
    for seg, lab, subj, tkey in trials:
        T = seg.shape[0]
        starts = [i for i in range(0, T - win_len + 1, stride)]
        # asegura N_SUBWINS por ensayo
        if len(starts) > N_SUBWINS: starts = starts[:N_SUBWINS]
        while len(starts) < N_SUBWINS and (T - win_len) >= 0:
            starts.append(T - win_len)
        for s in starts:
            e = s + win_len
            sub = seg[s:e, :].astype(np.float32)
            sub = (sub - sub.mean(axis=0, keepdims=True)) / (sub.std(axis=0, keepdims=True) + 1e-6)
            X.append(sub); y.append(lab); g.append(subj); tkeys.append(tkey)
    uniq = {k:i for i,k in enumerate(sorted(set(tkeys)))}
    trial_ids = np.asarray([uniq[k] for k in tkeys], dtype=np.int64)
    X = np.stack(X, axis=0).astype(np.float32)
    y = np.asarray(y, dtype=np.int64)
    g = np.asarray(g, dtype=np.int64)
    return X, y, g, trial_ids, uniq

# =========================
# Dataset torch
# =========================
class EEGTrials(Dataset):
    def __init__(self, X, y, groups, trial_ids):
        self.X = X.astype(np.float32)
        self.y = y.astype(np.int64)
        self.g = groups.astype(np.int64)
        self.t = trial_ids.astype(np.int64)
    def __len__(self): return self.X.shape[0]
    def __getitem__(self, idx):
        x = self.X[idx]; x = np.expand_dims(x, 0)  # (1,Tw,C)
        return torch.from_numpy(x), torch.tensor(self.y[idx]), torch.tensor(self.g[idx]), torch.tensor(self.t[idx])

def build_weighted_sampler(y, groups):
    y = np.asarray(y); groups = np.asarray(groups)
    class_counts = np.bincount(y, minlength=len(np.unique(y))).astype(float)
    class_w = 1.0 / class_counts[y]
    subj_vals, subj_counts = np.unique(groups, return_counts=True)
    subj_map = {s:c for s,c in zip(subj_vals, subj_counts)}
    subj_w = np.array([1.0/subj_map[g] for g in groups], dtype=float)
    w = class_w * subj_w; w = w / w.mean()
    return WeightedRandomSampler(weights=torch.from_numpy(w).float(), num_samples=len(w), replacement=True)

# =========================
# Modelo: CNN + Transformer
# =========================
class ChannelDropout(nn.Module):
    def __init__(self, p=0.1): super().__init__(); self.p = p
    def forward(self, x):
        if not self.training or self.p<=0: return x
        B,_,T,C = x.shape
        mask = (torch.rand(B,1,1,C, device=x.device) > self.p).float()
        return x * mask

@torch.no_grad()
def apply_max_norm(model, max_value=2.0, p=2.0):
    for m in model.modules():
        if hasattr(m,'weight') and isinstance(m,(nn.Conv2d, nn.Linear)):
            w = m.weight.data
            norms = w.view(w.size(0), -1).norm(p=p, dim=1, keepdim=True)
            desired = torch.clamp(norms, max=max_value)
            w.view(w.size(0), -1).mul_(desired / (1e-8 + norms))

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 4000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0)/d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer('pe', pe.unsqueeze(0), persistent=False)
    def forward(self, x):
        T = x.size(1); return x + self.pe[:, :T, :]

class CNNTokenizer(nn.Module):
    def __init__(self, n_ch: int, d_feat: int = 128, k_temporal: int = 128, stride_t: int = 4, chdrop_p: float = 0.1, drop_p: float = 0.2):
        super().__init__()
        self.chdrop = ChannelDropout(chdrop_p)
        self.conv_t = nn.Conv2d(1, d_feat, kernel_size=(k_temporal,1), stride=(stride_t,1), padding=(k_temporal//2,0), bias=False)
        self.bn_t   = nn.BatchNorm2d(d_feat, momentum=0.99, eps=1e-3)
        self.act    = nn.ELU()
        self.conv_sp_dw = nn.Conv2d(d_feat, d_feat, kernel_size=(1,n_ch), groups=d_feat, bias=False)
        self.bn_sp      = nn.BatchNorm2d(d_feat, momentum=0.99, eps=1e-3)
        self.conv_pw = nn.Conv2d(d_feat, d_feat, kernel_size=(1,1), bias=False)
        self.bn_pw   = nn.BatchNorm2d(d_feat, momentum=0.99, eps=1e-3)
        self.drop    = nn.Dropout(drop_p)
    def forward(self, x):
        z = self.chdrop(x)
        z = self.conv_t(z); z = self.bn_t(z); z = self.act(z)         # (B,D,T',C)
        z = self.conv_sp_dw(z); z = self.bn_sp(z); z = self.act(z)    # (B,D,T',1)
        z = self.conv_pw(z); z = self.bn_pw(z); z = self.act(z)
        z = self.drop(z)
        z = z.squeeze(-1).transpose(1,2)  # (B,T',D)
        return z

def _enc_layer(d_model=128, nhead=4, dim_ff=256, drop=0.2):
    return nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_ff,
                                      dropout=drop, activation='gelu', batch_first=True, norm_first=True)

class TemporalTransformer(nn.Module):
    def __init__(self, d_model=128, nhead=4, dim_ff=256, depth=3, drop=0.2):
        super().__init__()
        self.pe = PositionalEncoding(d_model)
        layer = _enc_layer(d_model, nhead, dim_ff, drop)
        self.enc = nn.TransformerEncoder(layer, num_layers=depth)
        self.norm = nn.LayerNorm(d_model)
    def forward(self, x):
        x = self.pe(x)
        x = self.enc(x)
        return self.norm(x)

class HybridCNNTransformer(nn.Module):
    def __init__(self, n_ch: int, n_classes: int, seq_len: int,
                 d_model: int = 128, nhead: int = 4, depth: int = 3,
                 dim_ff: int = 256, k_temporal: int = 128, stride_t: int = 4,
                 drop: float = 0.2, chdrop_p: float = 0.1, use_cls_token: bool = False):
        super().__init__()
        self.use_cls = use_cls_token
        self.tokenizer = CNNTokenizer(n_ch, d_model, k_temporal, stride_t, chdrop_p, drop)
        self.cls = nn.Parameter(torch.zeros(1,1,d_model)) if use_cls_token else None
        self.transformer = TemporalTransformer(d_model, nhead, dim_ff, depth, drop)
        self.head = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, d_model),
            nn.GELU(),
            nn.Dropout(drop),
            nn.Linear(d_model, n_classes)
        )
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
            elif isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None: nn.init.zeros_(m.bias)

    def forward(self, x):
        tok = self.tokenizer(x)                   # (B,T',D)
        if self.use_cls:
            B = tok.size(0); cls_tok = self.cls.expand(B,-1,-1)
            tok = torch.cat([cls_tok, tok], dim=1)
        z = self.transformer(tok)
        feat = z[:,0,:] if self.use_cls else z.mean(dim=1)
        return self.head(feat)

# =========================
# Entrenamiento / Eval
# =========================
class SoftCE(nn.Module):
    def __init__(self, n_classes: int, label_smoothing: float = 0.05, class_weights: torch.Tensor | None = None):
        super().__init__()
        self.n_classes = n_classes
        self.ls = float(label_smoothing)
        self.register_buffer('w', class_weights if class_weights is not None else None)
    def forward(self, logits, targets_idx):
        B, C = logits.size()
        yt = F.one_hot(targets_idx, num_classes=C).float()
        if self.ls > 0: yt = (1 - self.ls) * yt + self.ls * (1.0 / C)
        logp = F.log_softmax(logits, dim=1)
        loss = -(yt * logp)
        if self.w is not None: loss = loss * self.w.unsqueeze(0)
        return loss.sum(dim=1).mean()

@torch.no_grad()
def _predict_tta(model, xb, n=5, fs=160.0):
    def time_jitter(x, max_ms=25):
        max_shift = int(round(max_ms/1000.0 * fs))
        if max_shift <= 0: return x
        B,_,T,C = x.shape
        shifts = torch.randint(low=-max_shift, high=max_shift+1, size=(B,), device=x.device)
        out = torch.zeros_like(x)
        for i,s in enumerate(shifts):
            if s >= 0: out[i,:,s:,:] = x[i,:,:T-s,:]
            else: s=-s; out[i,:,:T-s,:] = x[i,:,s:,:]
        return out
    outs = []
    for _ in range(n):
        outs.append(model(time_jitter(xb,25)))
    return torch.stack(outs, dim=0).mean(dim=0)

@torch.no_grad()
def evaluate_per_trial(model, loader, use_tta=True, tta_n=5, n_classes=4):
    model.eval()
    sums, counts, labels = {}, {}, {}
    for xb, yb, _, tb in loader:
        xb = xb.to(DEVICE)
        logits = _predict_tta(model, xb, n=tta_n) if use_tta else model(xb)
        logits = logits.detach().cpu().numpy()
        yb = yb.numpy(); tb = tb.numpy()
        for i in range(xb.size(0)):
            t = int(tb[i])
            if t not in sums:
                sums[t] = np.zeros((n_classes,), dtype=np.float64)
                counts[t] = 0
                labels[t] = int(yb[i])
            sums[t] += logits[i]; counts[t] += 1
    ids = sorted(sums.keys())
    y_true_trial = np.asarray([labels[t] for t in ids], dtype=int)
    y_pred_trial = np.asarray([int((sums[t]/max(1,counts[t])).argmax()) for t in ids], dtype=int)
    acc = (y_true_trial == y_pred_trial).mean()
    return y_true_trial, y_pred_trial, float(acc)

def plot_confusion(y_true, y_pred, classes, title, fname):
    cm = confusion_matrix(y_true, y_pred, labels=list(range(len(classes))))
    with np.errstate(invalid='ignore'):
        cm_norm = cm.astype('float') / cm.sum(axis=1, keepdims=True)
    cm_norm = np.nan_to_num(cm_norm)
    plt.figure(figsize=(6,5))
    plt.imshow(cm_norm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title(title); plt.colorbar(fraction=0.046, pad=0.04)
    ticks = np.arange(len(classes))
    plt.xticks(ticks, classes, rotation=45, ha='right'); plt.yticks(ticks, classes)
    fmt = '.2f'; thresh = cm_norm.max()/2.
    for i,j in itertools.product(range(cm_norm.shape[0]), range(cm_norm.shape[1])):
        plt.text(j,i,format(cm_norm[i,j],fmt),ha="center", color="white" if cm_norm[i,j]>thresh else "black")
    plt.ylabel('True'); plt.xlabel('Pred'); plt.tight_layout(); plt.savefig(fname, dpi=150); plt.close()

def print_report(y_true, y_pred, classes, tag=""):
    print(f"\n=== Reporte {tag} ===")
    print(classification_report(y_true, y_pred, target_names=classes, digits=4))
    print(f"F1-macro: {f1_score(y_true, y_pred, average='macro'):.4f} | Kappa: {cohen_kappa_score(y_true,y_pred):.4f}")

# =========================
# FT (L2-SP), SIN fuga (agrupa por ensayo)
# =========================
def _freeze_for_mode(model: HybridCNNTransformer, mode: str):
    for p in model.parameters(): p.requires_grad = False
    if   mode=='out':
        for p in model.head[-1].parameters(): p.requires_grad = True
    elif mode=='head':
        for p in model.head.parameters(): p.requires_grad = True
    elif mode=='cnn+trans+head':
        for p in model.tokenizer.parameters():   p.requires_grad = True
        for p in model.transformer.parameters(): p.requires_grad = True
        for p in model.head.parameters():        p.requires_grad = True
    else:
        raise ValueError(mode)

def _param_groups(model: HybridCNNTransformer, mode: str):
    if   mode=='out': return list(model.head[-1].parameters())
    elif mode=='head': return list(model.head.parameters())
    elif mode=='cnn+trans+head': 
        return list(model.tokenizer.parameters()) + list(model.transformer.parameters()) + list(model.head.parameters())

def finetune_l2sp(model_global: HybridCNNTransformer, Xcal, ycal, trial_ids_cal,
                  n_classes: int, mode='head', epochs=FT_EPOCHS, batch_size=32,
                  base_lr=FT_BASE_LR, head_lr=FT_HEAD_LR, l2sp_lambda=FT_L2SP,
                  val_ratio=FT_VAL_RATIO, seed=RANDOM_STATE, device=DEVICE):
    # split interno sin fuga: por ensayo (GroupShuffleSplit)
    gss = GroupShuffleSplit(n_splits=1, test_size=val_ratio, random_state=seed)
    (tr_idx, va_idx), = gss.split(Xcal, ycal, groups=trial_ids_cal)
    Xtr, ytr = Xcal[tr_idx], ycal[tr_idx]
    Xva, yva = Xcal[va_idx], ycal[va_idx]

    ds_tr = torch.utils.data.TensorDataset(torch.from_numpy(Xtr).float().unsqueeze(1),
                                           torch.from_numpy(ytr).long())
    ds_va = torch.utils.data.TensorDataset(torch.from_numpy(Xva).float().unsqueeze(1),
                                           torch.from_numpy(yva).long())
    dl_tr = DataLoader(ds_tr, batch_size=batch_size, shuffle=True, drop_last=False)
    dl_va = DataLoader(ds_va, batch_size=batch_size, shuffle=False, drop_last=False)

    model = copy.deepcopy(model_global).to(device)
    _freeze_for_mode(model, mode)
    params = _param_groups(model, mode)

    if mode=='cnn+trans+head':
        tok_params = list(model.tokenizer.parameters()) + list(model.transformer.parameters())
        head_params = list(model.head.parameters())
        opt = torch.optim.Adam([
            {'params': tok_params,  'lr': base_lr},
            {'params': head_params, 'lr': head_lr},
        ])
        train_params = tok_params + head_params
    else:
        opt = torch.optim.Adam(params, lr=head_lr)
        train_params = params

    ref = [p.detach().clone().to(device) for p in train_params]   # L2-SP ref
    criterion = nn.CrossEntropyLoss()

    best_state, best_val, bad = copy.deepcopy(model.state_dict()), float('inf'), 0
    for _ in range(epochs):
        model.train()
        for xb, yb in dl_tr:
            xb, yb = xb.to(device), yb.to(device)
            opt.zero_grad(set_to_none=True)
            logits = model(xb)
            loss = criterion(logits, yb)
            reg = 0.0
            for p_cur, p_ref in zip(train_params, ref):
                reg = reg + torch.sum((p_cur - p_ref)**2)
            (loss + l2sp_lambda*reg).backward(); opt.step()
            apply_max_norm(model, 2.0, p=2.0)

        model.eval()
        val_loss, nval = 0.0, 0
        with torch.no_grad():
            for xb, yb in dl_va:
                xb, yb = xb.to(device), yb.to(device)
                val_loss += F.cross_entropy(model(xb), yb).item() * xb.size(0)
                nval += xb.size(0)
        val_loss /= max(1,nval)
        if val_loss + 1e-7 < best_val: best_val, bad, best_state = val_loss, 0, copy.deepcopy(model.state_dict())
        else:
            bad += 1
            if bad >= FT_PATIENCE: break

    model.load_state_dict(best_state)
    return model

# =========================
# Folds JSON: SOLO sujetos (ignora índices)
# =========================
def read_folds_subjects_only(json_path, current_subjects):
    """
    Lee Kfold5.json y devuelve lista de folds con train_subjects/test_subjects (enteros).
    Ignora por completo tr_idx/te_idx si existen.
    Valida que los sujetos existan en el dataset actual.
    """
    p = Path(json_path)
    if not p.exists():
        raise FileNotFoundError(f"No existe {p}. Debes proporcionar un JSON con sujetos por fold.")
    with open(p, 'r', encoding='utf-8') as f:
        payload = json.load(f)
    folds_raw = payload.get('folds', [])
    if not folds_raw:
        raise ValueError("JSON no contiene 'folds'.")

    cur_set = set(current_subjects)
    folds = []
    for fr in folds_raw:
        train_s = fr.get('train_subjects', fr.get('train', []))
        test_s  = fr.get('test_subjects',  fr.get('test',  []))
        if not train_s or not test_s:
            raise ValueError("Cada fold debe tener 'train_subjects' y 'test_subjects' (o 'train'/'test').")

        def to_int_list(lst):
            out = []
            for s in lst:
                if isinstance(s, int): out.append(s)
                elif isinstance(s, str) and s.upper().startswith('S') and s[1:].isdigit():
                    out.append(int(s[1:]))
                elif isinstance(s, str) and s.isdigit():
                    out.append(int(s))
                else:
                    raise ValueError(f"Sujeto inválido en JSON: {s}")
            return out

        tr_int = to_int_list(train_s)
        te_int = to_int_list(test_s)

        # validación: sujetos del JSON deben existir
        miss_tr = [s for s in tr_int if s not in cur_set]
        miss_te = [s for s in te_int if s not in cur_set]
        if miss_tr or miss_te:
            raise ValueError(f"Fold con sujetos no presentes en dataset actual. "
                             f"Faltan en TRAIN: {miss_tr}, Faltan en TEST: {miss_te}")

        folds.append({'fold': fr.get('fold', len(folds)+1),
                      'train_subjects': tr_int, 'test_subjects': te_int})
    return folds

# =========================
# Experimento
# =========================
def plot_training_curves(history, fname):
    plt.figure(figsize=(6,4))
    plt.plot(history['train_acc'], label='train_acc')
    plt.plot(history['val_acc'], label='val_acc')
    plt.xlabel('Epoch'); plt.ylabel('Accuracy'); plt.title('Training curve'); plt.legend()
    plt.tight_layout(); plt.savefig(fname, dpi=150); plt.close()

def run_experiment():
    mne.set_log_level('WARNING')
    subs = subjects_available()
    print(f"Sujetos elegibles: {len(subs)} → {subs[:10]}{'...' if len(subs)>10 else ''}")

    # 1) Ensayos balanceados (21/clase/sujeto) sobre TODOS los sujetos elegibles
    trials_bal = build_balanced_trials(subs, scenario=CLASS_SCENARIO, window_global=WINDOW_GLOBAL,
                                       n_per_class=N_PER_CLASS_PER_SUBJECT)

    # Sujetos presentes tras balanceo
    subj_in_trials = sorted({t[2] for t in trials_bal})
    print(f"Sujetos con ensayos balanceados: {len(subj_in_trials)}")

    # 2) Leer folds SOLO por sujetos
    folds = read_folds_subjects_only(FOLDS_JSON, current_subjects=subj_in_trials)

    # 3) Loop de folds
    global_folds_trial, all_true_trial, all_pred_trial, ft_prog_folds = [], [], [], []

    for f in folds:
        fold = f['fold']
        train_subs = set(f['train_subjects'])
        test_subs  = set(f['test_subjects'])

        # Split por sujeto (antes de windowing): seleccionamos ensayos (trial-level)
        tr_trials = [t for t in trials_bal if t[2] in train_subs]
        te_trials = [t for t in trials_bal if t[2] in test_subs]

        # Sanity: cada clase presente en TEST por ensayo (antes de windowing)
        lbl_te = [t[1] for t in te_trials]
        present = {0,1,2,3}.intersection(set(lbl_te))
        if present != {0,1,2,3}:
            print(f"[WARN] Fold {fold}: clases en TEST (trial-level) = {sorted(present)}; se salta este fold.")
            continue

        # 3a) Split interno de VALID por sujeto dentro de TRAIN (sin fuga de sujetos)
        train_subs_list = sorted(list(train_subs))
        rng = np.random.RandomState(RANDOM_STATE + fold)
        rng.shuffle(train_subs_list)
        n_va = max(1, int(round(GLOBAL_VAL_SPLIT * len(train_subs_list))))
        val_subs = set(train_subs_list[:n_va])
        real_train_subs = set(train_subs_list[n_va:])

        tr_trials_main = [t for t in tr_trials if t[2] in real_train_subs]
        va_trials      = [t for t in tr_trials if t[2] in val_subs]

        # 4) Windowing por split (sin fuga entre ensayos)
        Xtr, ytr, gtr, ttr, _ = window_trials(trials=tr_trials_main, win_len_sec=WIN_LEN_SEC, win_stride_sec=WIN_STRIDE_SEC, fs=FS)
        Xva, yva, gva, tva, _ = window_trials(trials=va_trials,      win_len_sec=WIN_LEN_SEC, win_stride_sec=WIN_STRIDE_SEC, fs=FS)
        Xte, yte, gte, tte, _ = window_trials(trials=te_trials,      win_len_sec=WIN_LEN_SEC, win_stride_sec=WIN_STRIDE_SEC, fs=FS)

        # DataLoaders
        ds_tr = EEGTrials(Xtr, ytr, gtr, ttr)
        ds_va = EEGTrials(Xva, yva, gva, tva)
        ds_te = EEGTrials(Xte, yte, gte, tte)

        tr_loader = DataLoader(ds_tr, batch_size=BATCH_SIZE,
                               sampler=build_weighted_sampler(ytr, gtr), drop_last=False)
        va_loader = DataLoader(ds_va, batch_size=BATCH_SIZE, shuffle=False, drop_last=False)
        te_loader = DataLoader(ds_te, batch_size=BATCH_SIZE, shuffle=False, drop_last=False)

        # 5) Modelo
        Tw, C = Xtr.shape[1], Xtr.shape[2]
        n_classes = len(set(ytr.tolist()) | set(yva.tolist()) | set(yte.tolist()))
        model = HybridCNNTransformer(
            n_ch=C, n_classes=n_classes, seq_len=Tw,
            d_model=128, nhead=4, depth=3, dim_ff=256,
            k_temporal=128, stride_t=4, drop=0.2, chdrop_p=0.10, use_cls_token=False
        ).to(DEVICE)

        opt = torch.optim.Adam(model.parameters(), lr=LR_INIT)
        criterion = SoftCE(n_classes=n_classes, label_smoothing=0.05, class_weights=None)

        def _acc_trial(loader): return evaluate_per_trial(model, loader, use_tta=False, n_classes=n_classes)[2]

        # 6) Entrenamiento (early stopping por ENSAYO)
        history = {'train_acc': [], 'val_acc': []}
        print(f"\n[Fold {fold}/{len(folds)}] Entrenando global (VAL por ensayo)...")
        best_state = copy.deepcopy(model.state_dict()); best_val = -1.0; bad = 0

        for epoch in range(1, EPOCHS_GLOBAL+1):
            model.train()
            for xb, yb, _, _ in tr_loader:
                xb, yb = xb.to(DEVICE), yb.to(DEVICE)
                loss = criterion(model(xb), yb)
                opt.zero_grad(set_to_none=True); loss.backward(); opt.step()
                apply_max_norm(model, 2.0, p=2.0)

            tr_acc = _acc_trial(tr_loader)
            va_acc = _acc_trial(va_loader)
            history['train_acc'].append(tr_acc); history['val_acc'].append(va_acc)

            if epoch % 5 == 0 or epoch in (1,10,20,50,80):
                print(f"  Época {epoch:3d} | train(ensayo)={tr_acc:.4f} | val(ensayo)={va_acc:.4f}")

            if va_acc > best_val + 1e-4:
                best_val, bad = va_acc, 0
                best_state = copy.deepcopy(model.state_dict())
            else:
                bad += 1
                if bad >= GLOBAL_PATIENCE:
                    print(f"  Early stopping @ epoch {epoch} (best val ensayo={best_val:.4f})")
                    break

        plot_training_curves(history, f"training_curve_cnnTrans_subjFolds_fold{fold}.png")
        print(f"↳ Curva: training_curve_cnnTrans_subjFolds_fold{fold}.png")
        model.load_state_dict(best_state)

        # 7) Test por ENSAYO (TTA)
        y_true_tr, y_pred_tr, acc_tr = evaluate_per_trial(model, te_loader, use_tta=True, tta_n=5, n_classes=n_classes)
        global_folds_trial.append(acc_tr)
        all_true_trial.append(y_true_tr); all_pred_trial.append(y_pred_tr)
        print(f"[Fold {fold}] Ensayo acc={acc_tr:.4f}")
        print_report(y_true_tr, y_pred_tr, CLASS_NAMES, tag="por ENSAYO (agregado)")

        # 8) FT por sujeto (sin fuga, agrupando por ensayo)
        #    Trabajamos dentro del set de TEST (por cada sujeto en test_subs)
        y_true_ft_all, y_pred_ft_all, used_subjects = [], [], 0

        # Para comodidad, indices por sujeto dentro de Xte
        trial_to_label = {}
        for idx in range(len(Xte)):
            trial_to_label[int(tte[idx])] = int(yte[idx])

        for sid in sorted(test_subs):
            mask_sid = (gte == sid)
            if mask_sid.sum() == 0: continue
            Xs, ys, ts = Xte[mask_sid], yte[mask_sid], tte[mask_sid]

            # GroupKFold por ensayo (sin fuga)
            if len(np.unique(ts)) < CALIB_CV_FOLDS:
                continue
            gkf = GroupKFold(n_splits=CALIB_CV_FOLDS)
            for tr_i, ho_i in gkf.split(Xs, ys, groups=ts):
                Xcal, ycal, tcal = Xs[tr_i], ys[tr_i], ts[tr_i]
                Xho,  yho,  tho  = Xs[ho_i], ys[ho_i], ts[ho_i]

                m_head = finetune_l2sp(model, Xcal, ycal, tcal,
                                       n_classes=n_classes, mode='head',
                                       epochs=FT_EPOCHS, base_lr=FT_BASE_LR, head_lr=FT_HEAD_LR,
                                       l2sp_lambda=FT_L2SP, val_ratio=FT_VAL_RATIO, device=DEVICE)

                with torch.no_grad():
                    xb = torch.from_numpy(Xho).float().unsqueeze(1).to(DEVICE)
                    logits = m_head(xb).cpu().numpy()

                sums = collections.defaultdict(lambda: np.zeros((n_classes,), dtype=np.float64))
                counts = collections.defaultdict(int)
                labels = {}
                for i in range(len(Xho)):
                    t = int(tho[i]); sums[t] += logits[i]; counts[t] += 1; labels[t] = int(yho[i])
                for t in sums.keys():
                    y_pred_ft_all.append(int((sums[t]/max(1,counts[t])).argmax()))
                    y_true_ft_all.append(labels[t])
                used_subjects += 1

        if len(y_true_ft_all) > 0:
            y_true_ft_all = np.asarray(y_true_ft_all, dtype=int)
            y_pred_ft_all = np.asarray(y_pred_ft_all, dtype=int)
            acc_ft = (y_true_ft_all == y_pred_ft_all).mean()
            print(f"  FT progresivo por ENSAYO acc={acc_ft:.4f} | sujetos={used_subjects}")
            print(f"  Δ(FT-Ensayo - Global-Ensayo) = {acc_ft - acc_tr:+.4f}")
        else:
            acc_ft = np.nan
            print("  FT no ejecutado (fold sin sujetos aptos).")
        ft_prog_folds.append(acc_ft)

    # ----- resultados globales -----
    if len(all_true_trial)>0: all_true_trial = np.concatenate(all_true_trial); all_pred_trial = np.concatenate(all_pred_trial)
    else: all_true_trial = np.array([],dtype=int); all_pred_trial = np.array([],dtype=int)

    print("\n" + "="*60)
    print("RESULTADOS FINALES — CNN+Transformer (21/cls/subj) + Sliding Window + Kfold SUBJECTS-only")
    print("="*60)
    print("Ensayo folds:", [f"{a:.4f}" for a in global_folds_trial])
    if len(global_folds_trial) > 0:
        print(f"Ensayo mean: {np.mean(global_folds_trial):.4f}")
    print("FT prog (ensayo) folds:", [("nan" if (a is None or np.isnan(a)) else f"{a:.4f}") for a in ft_prog_folds])
    if len(ft_prog_folds) > 0:
        print(f"FT (ensayo) mean: {np.nanmean(ft_prog_folds):.4f}")
        print(f"Δ(FT-Ensayo - Global-Ensayo) mean: {np.nanmean(ft_prog_folds) - np.mean(global_folds_trial):+.4f}")

    if all_true_trial.size > 0:
        plot_confusion(all_true_trial, all_pred_trial, CLASS_NAMES,
                       title="Confusion Matrix - Global Model (Per Trial, All Folds)",
                       fname="confusion_cnnTrans_subjectFolds_trial_allfolds.png")
        print("↳ Matriz de confusión (ensayo): confusion_cnnTrans_subjectFolds_trial_allfolds.png")

# ---------- MAIN ----------
if __name__ == "__main__":
    bp_status = f"ON [{BP_LO}-{BP_HI} Hz]" if USE_BANDPASS and (BP_LO is not None) and (BP_HI is not None) else "OFF"
    print("🧠 MI-EEG — CNN+Transformer (21/cls/subj) + Sliding Window + FT progresivo")
    print(f"🔧 Escenario: {CLASS_SCENARIO} | rango ensayo={WINDOW_GLOBAL} s | subventana={WIN_LEN_SEC}s | stride={WIN_STRIDE_SEC}s | subwins/ensayo={N_SUBWINS} | band-pass={bp_status}")
    print(f"⚙️  Balance: 21 ensayos por clase y por sujeto (con reemplazo si falta)")
    print(f"⚙️  FT (sin fuga): GroupKFold por ensayo + GroupShuffleSplit en validación interna")
    run_experiment()


🚀 Dispositivo: cuda
🧠 MI-EEG — CNN+Transformer (21/cls/subj) + Sliding Window + FT progresivo
🔧 Escenario: 4c | rango ensayo=(0.0, 4.0) s | subventana=2.0s | stride=0.5s | subwins/ensayo=5 | band-pass=OFF
⚙️  Balance: 21 ensayos por clase y por sujeto (con reemplazo si falta)
⚙️  FT (sin fuga): GroupKFold por ensayo + GroupShuffleSplit en validación interna
Sujetos elegibles: 103 → [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]...


Recolectando y balanceando ensayos base por sujeto: 100%|██████████| 103/103 [00:36<00:00,  2.79it/s]


Ensayos base balanceados: 8652  (~ #sujetos_utiles * 84)
Sujetos con ensayos balanceados: 103





[Fold 1/5] Entrenando global (VAL por ensayo)...
  Época   1 | train(ensayo)=0.3913 | val(ensayo)=0.3601
  Época   5 | train(ensayo)=0.4621 | val(ensayo)=0.4256
  Época  10 | train(ensayo)=0.5616 | val(ensayo)=0.4216
  Época  15 | train(ensayo)=0.6284 | val(ensayo)=0.3998
  Early stopping @ epoch 19 (best val ensayo=0.4415)
↳ Curva: training_curve_cnnTrans_subjFolds_fold1.png
[Fold 1] Ensayo acc=0.3532

=== Reporte por ENSAYO (agregado) ===
              precision    recall  f1-score   support

        Left     0.4006    0.2880    0.3351       441
       Right     0.4130    0.2154    0.2832       441
  Both Fists     0.3033    0.5488    0.3906       441
   Both Feet     0.3795    0.3605    0.3698       441

    accuracy                         0.3532      1764
   macro avg     0.3741    0.3532    0.3447      1764
weighted avg     0.3741    0.3532    0.3447      1764

F1-macro: 0.3447 | Kappa: 0.1376
  FT progresivo por ENSAYO acc=0.3203 | sujetos=84
  Δ(FT-Ensayo - Global-Ensayo) = -0




[Fold 2/5] Entrenando global (VAL por ensayo)...
  Época   1 | train(ensayo)=0.3315 | val(ensayo)=0.3304
  Época   5 | train(ensayo)=0.4603 | val(ensayo)=0.3929
  Época  10 | train(ensayo)=0.5660 | val(ensayo)=0.3780
  Época  15 | train(ensayo)=0.6189 | val(ensayo)=0.3810
  Early stopping @ epoch 17 (best val ensayo=0.4117)
↳ Curva: training_curve_cnnTrans_subjFolds_fold2.png
[Fold 2] Ensayo acc=0.4008

=== Reporte por ENSAYO (agregado) ===
              precision    recall  f1-score   support

        Left     0.4634    0.4739    0.4686       441
       Right     0.4496    0.3537    0.3959       441
  Both Fists     0.3135    0.3946    0.3494       441
   Both Feet     0.4088    0.3810    0.3944       441

    accuracy                         0.4008      1764
   macro avg     0.4088    0.4008    0.4021      1764
weighted avg     0.4088    0.4008    0.4021      1764

F1-macro: 0.4021 | Kappa: 0.2011
  FT progresivo por ENSAYO acc=0.3781 | sujetos=84
  Δ(FT-Ensayo - Global-Ensayo) = -0




[Fold 3/5] Entrenando global (VAL por ensayo)...
  Época   1 | train(ensayo)=0.3411 | val(ensayo)=0.3433
  Época   5 | train(ensayo)=0.4635 | val(ensayo)=0.3542
  Época  10 | train(ensayo)=0.5798 | val(ensayo)=0.3790
  Época  15 | train(ensayo)=0.5889 | val(ensayo)=0.3403
  Early stopping @ epoch 19 (best val ensayo=0.3790)
↳ Curva: training_curve_cnnTrans_subjFolds_fold3.png
[Fold 3] Ensayo acc=0.3730

=== Reporte por ENSAYO (agregado) ===
              precision    recall  f1-score   support

        Left     0.4642    0.2789    0.3484       441
       Right     0.3594    0.6349    0.4590       441
  Both Fists     0.3444    0.1882    0.2434       441
   Both Feet     0.3591    0.3900    0.3739       441

    accuracy                         0.3730      1764
   macro avg     0.3818    0.3730    0.3562      1764
weighted avg     0.3818    0.3730    0.3562      1764

F1-macro: 0.3562 | Kappa: 0.1640
  FT progresivo por ENSAYO acc=0.3617 | sujetos=84
  Δ(FT-Ensayo - Global-Ensayo) = -0




[Fold 4/5] Entrenando global (VAL por ensayo)...
  Época   1 | train(ensayo)=0.3170 | val(ensayo)=0.2798
  Época   5 | train(ensayo)=0.4740 | val(ensayo)=0.3413
  Época  10 | train(ensayo)=0.5784 | val(ensayo)=0.3591
  Época  15 | train(ensayo)=0.6337 | val(ensayo)=0.3294
  Early stopping @ epoch 17 (best val ensayo=0.3611)
↳ Curva: training_curve_cnnTrans_subjFolds_fold4.png
[Fold 4] Ensayo acc=0.4101

=== Reporte por ENSAYO (agregado) ===
              precision    recall  f1-score   support

        Left     0.3899    0.6071    0.4749       420
       Right     0.4573    0.3190    0.3759       420
  Both Fists     0.3475    0.2333    0.2792       420
   Both Feet     0.4479    0.4810    0.4638       420

    accuracy                         0.4101      1680
   macro avg     0.4107    0.4101    0.3984      1680
weighted avg     0.4107    0.4101    0.3984      1680

F1-macro: 0.3984 | Kappa: 0.2135
  FT progresivo por ENSAYO acc=0.3935 | sujetos=80
  Δ(FT-Ensayo - Global-Ensayo) = -0




[Fold 5/5] Entrenando global (VAL por ensayo)...
  Época   1 | train(ensayo)=0.3077 | val(ensayo)=0.2966
  Época   5 | train(ensayo)=0.4561 | val(ensayo)=0.3532
  Época  10 | train(ensayo)=0.5395 | val(ensayo)=0.3393
  Época  15 | train(ensayo)=0.6155 | val(ensayo)=0.3393
  Early stopping @ epoch 15 (best val ensayo=0.3532)
↳ Curva: training_curve_cnnTrans_subjFolds_fold5.png
[Fold 5] Ensayo acc=0.4417

=== Reporte por ENSAYO (agregado) ===
              precision    recall  f1-score   support

        Left     0.4970    0.3929    0.4388       420
       Right     0.4608    0.5881    0.5167       420
  Both Fists     0.4035    0.3286    0.3622       420
   Both Feet     0.4085    0.4571    0.4315       420

    accuracy                         0.4417      1680
   macro avg     0.4425    0.4417    0.4373      1680
weighted avg     0.4425    0.4417    0.4373      1680

F1-macro: 0.4373 | Kappa: 0.2556
  FT progresivo por ENSAYO acc=0.4149 | sujetos=80
  Δ(FT-Ensayo - Global-Ensayo) = -0

In [3]:
# -*- coding: utf-8 -*-
# ===========================================
# MI-EEG (PhysioNet eegmmidb) — CNN+Transformer (Balanced 21/cls/subj) + Sliding Window
# Patch: usar Kfold5.json SOLO para sujetos por fold (ignora tr_idx/te_idx)
# FT opcional:
#   - "none_bn_protos": BN-Adapt + prototipos (sin entrenamiento)
#   - "head_l2sp": fine-tuning de la cabeza con L2-SP
# Cosine head con temperatura aprendible
# Tokenización densa (k=64, stride=2) + CLS token
# ===========================================

import os, re, math, json, copy, random, itertools, collections
from pathlib import Path
from datetime import datetime

import numpy as np
import pandas as pd
import mne

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler, Subset

from sklearn.model_selection import GroupKFold, GroupShuffleSplit
from sklearn.metrics import confusion_matrix, classification_report, f1_score, cohen_kappa_score
from sklearn.utils import check_random_state
from tqdm import tqdm
import matplotlib.pyplot as plt

# =========================
# CONFIG
# =========================
PROJ = Path('..').resolve().parent
DATA_RAW = PROJ / 'data' / 'raw'           # .../S###/S###R##.edf
FOLDS_JSON = PROJ / 'models' / 'folds' / 'Kfold5.json'
CACHE_DIR = PROJ / 'data' / 'cache'
CACHE_DIR.mkdir(parents=True, exist_ok=True)

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
RANDOM_STATE = 42
torch.manual_seed(RANDOM_STATE); np.random.seed(RANDOM_STATE); random.seed(RANDOM_STATE)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
print(f"🚀 Dispositivo: {DEVICE}")

FS = 160.0

# Escenario y rango base por ensayo
CLASS_SCENARIO = '4c'            # '2c','3c','4c'
WINDOW_GLOBAL = (0.0, 4.0)       # segundos rel. T1/T2 (ensayo base)

# Sliding window (mismo nº por ensayo)
WIN_LEN_SEC = 2.5                # recomendado 2.0–2.5; aquí 2.5 para +contexto
WIN_STRIDE_SEC = 0.5
N_SUBWINS = int(math.floor((WINDOW_GLOBAL[1]-WINDOW_GLOBAL[0]-WIN_LEN_SEC)/WIN_STRIDE_SEC) + 1)

# Balance 21 por clase/sujeto
N_PER_CLASS_PER_SUBJECT = 21

# CV
GLOBAL_VAL_SPLIT = 0.15
GLOBAL_PATIENCE  = 10

# Entrenamiento global
BATCH_SIZE = 64
EPOCHS_GLOBAL = 80
LR_INIT = 1e-3

# FT por sujeto — estrategia
# "none_bn_protos"  -> BN-Adapt + prototipos (sin entrenamiento)
# "head_l2sp"       -> FT de la cabeza con L2-SP (supervisado)
FT_STRATEGY = "none_bn_protos"

# Hiper FT (solo si FT_STRATEGY == "head_l2sp")
FT_EPOCHS = 12
FT_BASE_LR = 5e-5        # no usado en "head" puro
FT_HEAD_LR = 3e-4
FT_L2SP = 5e-4
FT_PATIENCE = 3
FT_VAL_RATIO = 0.2
CALIB_CV_FOLDS = 4

# Band-pass
USE_BANDPASS = False
BP_LO, BP_HI = 8.0, 30.0
BP_METHOD, BP_PHASE = 'fir', 'zero'

# Dataset
EXCLUDE_SUBJECTS = {38,88,89,92,100,104}
MI_RUNS_LR = [4,8,12]            # Left/Right
MI_RUNS_OF = [6,10,14]           # Both Fists / Both Feet
EXPECTED_8 = ['C3','C4','Cz','CP3','CP4','FC3','FC4','FCz']

CLASS_NAMES = ['Left', 'Right', 'Both Fists', 'Both Feet']

# =========================
# UTIL: canales/edf
# =========================
def normalize_label(s: str) -> str:
    if s is None: return s
    s = s.strip()
    s = re.sub(r'[^A-Za-z0-9]', '', s)
    s = re.sub(r'([A-Za-z])0([0-9])', r'\1\2', s)
    s = re.sub(r'([A-Za-z])Z$', r'\1z', s)
    s = s.replace('fp', 'Fp').replace('FP', 'Fp')
    s = ''.join(ch.upper() if ch != 'z' else 'z' for ch in s)
    return s

def rename_channels_1010(raw: mne.io.BaseRaw):
    mapping = {}
    for ch in raw.ch_names:
        lab = normalize_label(ch)
        lab = lab[:-1] + 'z' if lab.endswith('Z') else lab
        lab = re.sub(r'([A-Z])Z$', r'\1z', lab)
        mapping[ch] = lab
    mne.rename_channels(raw.info, mapping)

def ensure_channels_order(raw: mne.io.BaseRaw, desired_channels=EXPECTED_8):
    missing = [ch for ch in desired_channels if ch not in raw.ch_names]
    if missing:
        print(f"[WARN] faltan canales {missing} en {getattr(raw,'filenames',[''])[0]}")
        return None
    raw.reorder_channels([ch for ch in raw.ch_names if ch in desired_channels] +
                         [ch for ch in raw.ch_names if ch not in desired_channels])
    raw.pick_channels(desired_channels, ordered=True)
    return raw

_re_file = re.compile(r'[Ss](\d{3}).*?[Rr](\d{2})')
def parse_subject_run(path: Path):
    m = _re_file.search(str(path))
    if not m: return None, None
    return int(m.group(1)), int(m.group(2))

def run_kind(run_id:int):
    if run_id in MI_RUNS_LR: return 'LR'
    if run_id in MI_RUNS_OF: return 'OF'
    return None

_SNR_TABLE = None
def _load_snr_table():
    global _SNR_TABLE
    if _SNR_TABLE is not None: return _SNR_TABLE
    csv_path = PROJ / 'reports' / 'psd_mains' / 'psd_mains_summary.csv'
    if csv_path.exists():
        try: _SNR_TABLE = pd.read_csv(csv_path)
        except Exception as e:
            print(f"[SNR] No se pudo leer {csv_path}: {e}")
            _SNR_TABLE = None
    return _SNR_TABLE

def _decide_notch(subject, run, th_db=10.0):
    df = _load_snr_table()
    if df is None:  return 60.0
    row = df[(df['subject']==subject) & (df['run']==run)]
    if row.empty:   return 60.0
    snr50 = float(row['snr50_db'].iloc[0]); snr60 = float(row['snr60_db'].iloc[0])
    if snr60 >= th_db and snr60 >= snr50: return 60.0
    if snr50 >= th_db and snr50 >  snr60: return 50.0
    return None

def read_raw_edf(path: Path):
    raw = mne.io.read_raw_edf(path, preload=True, verbose=False)
    raw.pick(mne.pick_types(raw.info, eeg=True))
    rename_channels_1010(raw)
    try:
        mont = mne.channels.make_standard_montage('standard_1020')
        raw.set_montage(mont, on_missing='ignore')
    except Exception:
        pass
    if abs(raw.info['sfreq'] - FS) > 1e-6:
        raw.resample(FS, npad="auto")
    raw = ensure_channels_order(raw, EXPECTED_8)
    if raw is None: return None

    sid, rid = parse_subject_run(path)
    notch = _decide_notch(sid, rid)
    if notch is not None:
        raw.notch_filter(freqs=[float(notch)], picks='eeg', method='spectrum_fit', phase='zero')

    if USE_BANDPASS and (BP_LO is not None) and (BP_HI is not None):
        raw.filter(l_freq=float(BP_LO), h_freq=float(BP_HI),
                   picks='eeg', method=BP_METHOD, phase=BP_PHASE, verbose=False)
    return raw

def collect_events_T1T2(raw: mne.io.BaseRaw):
    out = []
    if raw.annotations is None or len(raw.annotations) == 0: return out
    def _norm(s): return str(s).strip().upper().replace(' ','')
    for onset, desc in zip(raw.annotations.onset, raw.annotations.description):
        tag = _norm(desc)
        if tag in ('T1','T2'): out.append((float(onset), tag))
    out.sort()
    dedup = []; last_t1 = last_t2 = -1e9
    for t, tag in out:
        if tag == 'T1':
            if (t - last_t1) >= 0.5: dedup.append((t, tag)); last_t1 = t
        else:
            if (t - last_t2) >= 0.5: dedup.append((t, tag)); last_t2 = t
    return dedup

# =========================
# Dataset helpers
# =========================
def subjects_available():
    subs = []
    for sdir in sorted(DATA_RAW.glob('S*')):
        if not sdir.is_dir(): continue
        try: sid = int(sdir.name[1:])
        except: continue
        if sid in EXCLUDE_SUBJECTS: continue
        any_mi = any((sdir / f"S{sid:03d}R{r:02d}.edf").exists() for r in (MI_RUNS_LR+MI_RUNS_OF))
        if any_mi: subs.append(sid)
    return subs

def extract_base_trials_from_run(edf_path: Path, scenario: str, window_global=(0.0,4.0)):
    subj, run = parse_subject_run(edf_path)
    kind = run_kind(run)
    if kind not in ('LR','OF'): return []

    raw = read_raw_edf(edf_path)
    if raw is None: return []
    data = raw.get_data()
    fs = raw.info['sfreq']
    out = []
    events = collect_events_T1T2(raw)
    rel_start, rel_end = window_global
    ev_idx = 0
    for onset_sec, tag in events:
        ev_idx += 1
        if kind == 'LR':
            label = 0 if tag=='T1' else 1
        else:
            label = 2 if tag=='T1' else 3
        if   scenario == '2c' and label not in (0,1): continue
        elif scenario == '3c' and label not in (0,1,2): continue
        elif scenario == '4c' and label not in (0,1,2,3): continue

        s0 = int(round((raw.first_time + onset_sec + rel_start) * fs))
        e0 = int(round((raw.first_time + onset_sec + rel_end)   * fs))
        if s0 < 0 or e0 > data.shape[1] or e0 <= s0: continue

        seg = data[:, s0:e0].T.astype(np.float32)  # (T_base, C)
        seg = (seg - seg.mean(axis=0, keepdims=True)) / (seg.std(axis=0, keepdims=True) + 1e-6)
        trial_key = f"S{subj:03d}_R{run:02d}_E{ev_idx:03d}"
        out.append((seg, label, subj, trial_key))
    return out

def build_balanced_trials(subjects, scenario='4c', window_global=(0.0,4.0), n_per_class=N_PER_CLASS_PER_SUBJECT):
    trials_balanced = []
    for s in tqdm(subjects, desc="Recolectando y balanceando ensayos base por sujeto"):
        sdir = DATA_RAW / f"S{s:03d}"
        if not sdir.exists(): continue
        bins = {0:[],1:[],2:[],3:[]}
        for r in (MI_RUNS_LR + MI_RUNS_OF):
            p = sdir / f"S{s:03d}R{r:02d}.edf"
            if not p.exists(): continue
            for t in extract_base_trials_from_run(p, scenario, window_global):
                bins[t[1]].append(t)
        # descarta sujeto si no tiene todas las clases
        if any(len(bins[c]) == 0 for c in (0,1,2,3)): continue
        rng = check_random_state(RANDOM_STATE + s)
        for c in (0,1,2,3):
            pool = bins[c]
            if len(pool) >= n_per_class:
                rng.shuffle(pool); take = pool[:n_per_class]
            else:
                idx = rng.choice(len(pool), size=n_per_class, replace=True)
                take = [pool[i] for i in idx]
            trials_balanced.extend(take)
    print(f"Ensayos base balanceados: {len(trials_balanced)}  (~ #sujetos_utiles * {n_per_class*4})")
    return trials_balanced

def window_trials(trials, win_len_sec=2.0, win_stride_sec=0.5, fs=160.0):
    """Windowing para una lista de ensayos (de un split). Retorna X,y,subj,trial_ids."""
    X, y, g, tkeys = [], [], [], []
    win_len = int(round(win_len_sec * fs))
    stride  = int(round(win_stride_sec * fs))
    for seg, lab, subj, tkey in trials:
        T = seg.shape[0]
        starts = [i for i in range(0, T - win_len + 1, stride)]
        # asegura N_SUBWINS por ensayo
        if len(starts) > N_SUBWINS: starts = starts[:N_SUBWINS]
        while len(starts) < N_SUBWINS and (T - win_len) >= 0:
            starts.append(T - win_len)
        for s in starts:
            e = s + win_len
            sub = seg[s:e, :].astype(np.float32)
            sub = (sub - sub.mean(axis=0, keepdims=True)) / (sub.std(axis=0, keepdims=True) + 1e-6)
            X.append(sub); y.append(lab); g.append(subj); tkeys.append(tkey)
    uniq = {k:i for i,k in enumerate(sorted(set(tkeys)))}
    trial_ids = np.asarray([uniq[k] for k in tkeys], dtype=np.int64)
    X = np.stack(X, axis=0).astype(np.float32)
    y = np.asarray(y, dtype=np.int64)
    g = np.asarray(g, dtype=np.int64)
    return X, y, g, trial_ids, uniq

# =========================
# Dataset torch
# =========================
class EEGTrials(Dataset):
    def __init__(self, X, y, groups, trial_ids):
        self.X = X.astype(np.float32)
        self.y = y.astype(np.int64)
        self.g = groups.astype(np.int64)
        self.t = trial_ids.astype(np.int64)
    def __len__(self): return self.X.shape[0]
    def __getitem__(self, idx):
        x = self.X[idx]; x = np.expand_dims(x, 0)  # (1,Tw,C)
        return torch.from_numpy(x), torch.tensor(self.y[idx]), torch.tensor(self.g[idx]), torch.tensor(self.t[idx])

def build_weighted_sampler(y, groups):
    y = np.asarray(y); groups = np.asarray(groups)
    class_counts = np.bincount(y, minlength=len(np.unique(y))).astype(float)
    class_w = 1.0 / class_counts[y]
    subj_vals, subj_counts = np.unique(groups, return_counts=True)
    subj_map = {s:c for s,c in zip(subj_vals, subj_counts)}
    subj_w = np.array([1.0/subj_map[g] for g in groups], dtype=float)
    w = class_w * subj_w; w = w / w.mean()
    return WeightedRandomSampler(weights=torch.from_numpy(w).float(), num_samples=len(w), replacement=True)

# =========================
# Modelo: CNN + Transformer + Cosine head
# =========================
class ChannelDropout(nn.Module):
    def __init__(self, p=0.1): super().__init__(); self.p = p
    def forward(self, x):
        if not self.training or self.p<=0: return x
        B,_,T,C = x.shape
        mask = (torch.rand(B,1,1,C, device=x.device) > self.p).float()
        return x * mask

@torch.no_grad()
def apply_max_norm(model, max_value=2.0, p=2.0):
    for m in model.modules():
        if hasattr(m,'weight') and isinstance(m,(nn.Conv2d, nn.Linear)):
            w = m.weight.data
            norms = w.view(w.size(0), -1).norm(p=p, dim=1, keepdim=True)
            desired = torch.clamp(norms, max=max_value)
            w.view(w.size(0), -1).mul_(desired / (1e-8 + norms))

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 4000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0)/d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer('pe', pe.unsqueeze(0), persistent=False)
    def forward(self, x):
        T = x.size(1); return x + self.pe[:, :T, :]

class CNNTokenizer(nn.Module):
    def __init__(self, n_ch: int, d_feat: int = 128, k_temporal: int = 64, stride_t: int = 2, chdrop_p: float = 0.1, drop_p: float = 0.2):
        super().__init__()
        self.chdrop = ChannelDropout(chdrop_p)
        self.conv_t = nn.Conv2d(1, d_feat, kernel_size=(k_temporal,1), stride=(stride_t,1), padding=(k_temporal//2,0), bias=False)
        self.bn_t   = nn.BatchNorm2d(d_feat, momentum=0.99, eps=1e-3)
        self.act    = nn.ELU()
        self.conv_sp_dw = nn.Conv2d(d_feat, d_feat, kernel_size=(1,n_ch), groups=d_feat, bias=False)
        self.bn_sp      = nn.BatchNorm2d(d_feat, momentum=0.99, eps=1e-3)
        self.conv_pw = nn.Conv2d(d_feat, d_feat, kernel_size=(1,1), bias=False)
        self.bn_pw   = nn.BatchNorm2d(d_feat, momentum=0.99, eps=1e-3)
        self.drop    = nn.Dropout(drop_p)
    def forward(self, x):
        z = self.chdrop(x)
        z = self.conv_t(z); z = self.bn_t(z); z = self.act(z)         # (B,D,T',C)
        z = self.conv_sp_dw(z); z = self.bn_sp(z); z = self.act(z)    # (B,D,T',1)
        z = self.conv_pw(z); z = self.bn_pw(z); z = self.act(z)
        z = self.drop(z)
        z = z.squeeze(-1).transpose(1,2)  # (B,T',D)
        return z

def _enc_layer(d_model=128, nhead=4, dim_ff=256, drop=0.2):
    return nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_ff,
                                      dropout=drop, activation='gelu', batch_first=True, norm_first=True)

class TemporalTransformer(nn.Module):
    def __init__(self, d_model=128, nhead=4, dim_ff=256, depth=3, drop=0.2):
        super().__init__()
        self.pe = PositionalEncoding(d_model)
        layer = _enc_layer(d_model, nhead, dim_ff, drop)
        self.enc = nn.TransformerEncoder(layer, num_layers=depth)
        self.norm = nn.LayerNorm(d_model)
    def forward(self, x):
        x = self.pe(x)
        x = self.enc(x)
        return self.norm(x)

class CosineClassifier(nn.Module):
    def __init__(self, d, n_classes, tau=10.0):
        super().__init__()
        self.W = nn.Parameter(torch.randn(n_classes, d))
        nn.init.xavier_uniform_(self.W)
        self.log_tau = nn.Parameter(torch.log(torch.tensor(tau, dtype=torch.float32)))
    def forward(self, f):  # f: (B,d)
        f_n = F.normalize(f, dim=1)                  # (B,d)
        W_n = F.normalize(self.W, dim=1)             # (C,d)
        cos = torch.mm(f_n, W_n.t())                 # (B,C)
        return torch.exp(self.log_tau) * cos

class HybridCNNTransformer(nn.Module):
    def __init__(self, n_ch: int, n_classes: int, seq_len: int,
                 d_model: int = 128, nhead: int = 4, depth: int = 3,
                 dim_ff: int = 256, k_temporal: int = 64, stride_t: int = 2,
                 drop: float = 0.2, chdrop_p: float = 0.1, use_cls_token: bool = True):
        super().__init__()
        self.use_cls = use_cls_token
        self.tokenizer = CNNTokenizer(n_ch, d_model, k_temporal, stride_t, chdrop_p, drop)
        self.cls = nn.Parameter(torch.zeros(1,1,d_model)) if use_cls_token else None
        self.transformer = TemporalTransformer(d_model, nhead, dim_ff, depth, drop)
        # Cosine head
        self.norm_head = nn.LayerNorm(d_model)
        self.head = CosineClassifier(d_model, n_classes, tau=10.0)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
            elif isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None: nn.init.zeros_(m.bias)

    def forward(self, x):
        tok = self.tokenizer(x)                   # (B,T',D)
        if self.use_cls:
            B = tok.size(0); cls_tok = self.cls.expand(B,-1,-1)
            tok = torch.cat([cls_tok, tok], dim=1)
        z = self.transformer(tok)
        feat = z[:,0,:] if self.use_cls else z.mean(dim=1)
        feat = self.norm_head(feat)
        return self.head(feat)                    # (B,C)

    @torch.no_grad()
    def extract_feats(self, x):
        tok = self.tokenizer(x)
        if self.use_cls:
            B = tok.size(0); cls_tok = self.cls.expand(B,-1,-1)
            tok = torch.cat([cls_tok, tok], dim=1)
        z = self.transformer(tok)
        feat = z[:,0,:] if self.use_cls else z.mean(dim=1)
        return self.norm_head(feat)

# =========================
# Entrenamiento / Eval
# =========================
class SoftCE(nn.Module):
    def __init__(self, n_classes: int, label_smoothing: float = 0.05, class_weights: torch.Tensor | None = None):
        super().__init__()
        self.n_classes = n_classes
        self.ls = float(label_smoothing)
        self.register_buffer('w', class_weights if class_weights is not None else None)
    def forward(self, logits, targets_idx):
        B, C = logits.size()
        yt = F.one_hot(targets_idx, num_classes=C).float()
        if self.ls > 0: yt = (1 - self.ls) * yt + self.ls * (1.0 / C)
        logp = F.log_softmax(logits, dim=1)
        loss = -(yt * logp)
        if self.w is not None: loss = loss * self.w.unsqueeze(0)
        return loss.sum(dim=1).mean()

@torch.no_grad()
def _predict_tta(model, xb, n=5, fs=160.0):
    def time_jitter(x, max_ms=25):
        max_shift = int(round(max_ms/1000.0 * fs))
        if max_shift <= 0: return x
        B,_,T,C = x.shape
        shifts = torch.randint(low=-max_shift, high=max_shift+1, size=(B,), device=x.device)
        out = torch.zeros_like(x)
        for i,s in enumerate(shifts):
            if s >= 0: out[i,:,s:,:] = x[i,:,:T-s,:]
            else: s=-s; out[i,:,:T-s,:] = x[i,:,s:,:]
        return out
    outs = []
    for _ in range(n):
        outs.append(model(time_jitter(xb,25)))
    return torch.stack(outs, dim=0).mean(dim=0)

@torch.no_grad()
def evaluate_per_trial(model, loader, use_tta=True, tta_n=5, n_classes=4):
    model.eval()
    sums, counts, labels = {}, {}, {}
    for xb, yb, _, tb in loader:
        xb = xb.to(DEVICE)
        logits = _predict_tta(model, xb, n=tta_n) if use_tta else model(xb)
        logits = logits.detach().cpu().numpy()
        yb = yb.numpy(); tb = tb.numpy()
        for i in range(xb.size(0)):
            t = int(tb[i])
            if t not in sums:
                sums[t] = np.zeros((n_classes,), dtype=np.float64)
                counts[t] = 0
                labels[t] = int(yb[i])
            sums[t] += logits[i]; counts[t] += 1
    ids = sorted(sums.keys())
    y_true_trial = np.asarray([labels[t] for t in ids], dtype=int)
    y_pred_trial = np.asarray([int((sums[t]/max(1,counts[t])).argmax()) for t in ids], dtype=int)
    acc = (y_true_trial == y_pred_trial).mean()
    return y_true_trial, y_pred_trial, float(acc)

def plot_confusion(y_true, y_pred, classes, title, fname):
    cm = confusion_matrix(y_true, y_pred, labels=list(range(len(classes))))
    with np.errstate(invalid='ignore'):
        cm_norm = cm.astype('float') / cm.sum(axis=1, keepdims=True)
    cm_norm = np.nan_to_num(cm_norm)
    plt.figure(figsize=(6,5))
    plt.imshow(cm_norm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title(title); plt.colorbar(fraction=0.046, pad=0.04)
    ticks = np.arange(len(classes))
    plt.xticks(ticks, classes, rotation=45, ha='right'); plt.yticks(ticks, classes)
    fmt = '.2f'; thresh = cm_norm.max()/2.
    for i,j in itertools.product(range(cm_norm.shape[0]), range(cm_norm.shape[1])):
        plt.text(j,i,format(cm_norm[i,j],fmt),ha="center", color="white" if cm_norm[i,j]>thresh else "black")
    plt.ylabel('True'); plt.xlabel('Pred'); plt.tight_layout(); plt.savefig(fname, dpi=150); plt.close()

def print_report(y_true, y_pred, classes, tag=""):
    print(f"\n=== Reporte {tag} ===")
    print(classification_report(y_true, y_pred, target_names=classes, digits=4))
    print(f"F1-macro: {f1_score(y_true, y_pred, average='macro'):.4f} | Kappa: {cohen_kappa_score(y_true,y_pred):.4f}")

# =========================
# BN-Adapt + Prototipos (FT sin entrenamiento)
# =========================
@torch.no_grad()
def bn_adapt(model, X_cal, batch=128):
    """Actualiza running stats de BatchNorm SOLO con datos del sujeto (sin gradientes)."""
    was_training = model.training
    model.train()
    for i in range(0, len(X_cal), batch):
        xb = torch.from_numpy(X_cal[i:i+batch]).float().unsqueeze(1).to(DEVICE)
        _ = model(xb)
    model.eval()
    if was_training: model.train()

@torch.no_grad()
def set_prototypes(model: HybridCNNTransformer, X_cal, y_cal, n_classes, batch=128):
    """Calcula prototipos por clase (features medios) y actualiza pesos del Cosine head."""
    feats = []
    labs = []
    model.eval()
    for i in range(0, len(X_cal), batch):
        xb = torch.from_numpy(X_cal[i:i+batch]).float().unsqueeze(1).to(DEVICE)
        f = model.extract_feats(xb).cpu()
        feats.append(f); labs.append(torch.from_numpy(y_cal[i:i+batch]))
    F_all = torch.cat(feats,0); Y_all = torch.cat(labs,0)
    protos = []
    for c in range(n_classes):
        m = (Y_all==c)
        if m.any():
            proto = F.normalize(F_all[m].mean(0, keepdim=True), dim=1)
        else:
            proto = F.normalize(F_all.mean(0, keepdim=True), dim=1)
        protos.append(proto)
    W = torch.cat(protos,0)  # (C,d)
    model.head.W.data.copy_(W)

# =========================
# FT con L2-SP (solo cabeza)
# =========================
def _freeze_for_mode(model: HybridCNNTransformer, mode: str):
    for p in model.parameters(): p.requires_grad = False
    if   mode=='head':
        for p in model.norm_head.parameters(): p.requires_grad = True
        for p in model.head.parameters():      p.requires_grad = True
    else:
        raise ValueError(mode)

def _param_groups(model: HybridCNNTransformer, mode: str):
    if   mode=='head': 
        return list(model.norm_head.parameters()) + list(model.head.parameters())

def finetune_head_l2sp(model_global: HybridCNNTransformer, Xcal, ycal, trial_ids_cal,
                       n_classes: int, epochs=FT_EPOCHS, batch_size=32,
                       head_lr=FT_HEAD_LR, l2sp_lambda=FT_L2SP,
                       val_ratio=FT_VAL_RATIO, seed=RANDOM_STATE, device=DEVICE):
    gss = GroupShuffleSplit(n_splits=1, test_size=val_ratio, random_state=seed)
    (tr_idx, va_idx), = gss.split(Xcal, ycal, groups=trial_ids_cal)
    Xtr, ytr = Xcal[tr_idx], ycal[tr_idx]
    Xva, yva = Xcal[va_idx], ycal[va_idx]

    ds_tr = torch.utils.data.TensorDataset(torch.from_numpy(Xtr).float().unsqueeze(1),
                                           torch.from_numpy(ytr).long())
    ds_va = torch.utils.data.TensorDataset(torch.from_numpy(Xva).float().unsqueeze(1),
                                           torch.from_numpy(yva).long())
    dl_tr = DataLoader(ds_tr, batch_size=batch_size, shuffle=True, drop_last=False)
    dl_va = DataLoader(ds_va, batch_size=batch_size, shuffle=False, drop_last=False)

    model = copy.deepcopy(model_global).to(device)
    _freeze_for_mode(model, 'head')
    params = _param_groups(model, 'head')
    opt = torch.optim.Adam(params, lr=head_lr)
    ref = [p.detach().clone().to(device) for p in params]
    criterion = nn.CrossEntropyLoss()

    best_state, best_val, bad = copy.deepcopy(model.state_dict()), float('inf'), 0
    for _ in range(epochs):
        model.train()
        for xb, yb in dl_tr:
            xb, yb = xb.to(device), yb.to(device)
            opt.zero_grad(set_to_none=True)
            logits = model(xb)
            loss = criterion(logits, yb)
            reg = 0.0
            for p_cur, p_ref in zip(params, ref):
                reg = reg + torch.sum((p_cur - p_ref)**2)
            (loss + l2sp_lambda*reg).backward(); opt.step()
            apply_max_norm(model, 2.0, p=2.0)

        # valid
        model.eval()
        val_loss, nval = 0.0, 0
        with torch.no_grad():
            for xb, yb in dl_va:
                xb, yb = xb.to(device), yb.to(device)
                val_loss += F.cross_entropy(model(xb), yb).item() * xb.size(0)
                nval += xb.size(0)
        val_loss /= max(1,nval)
        if val_loss + 1e-7 < best_val: best_val, bad, best_state = val_loss, 0, copy.deepcopy(model.state_dict())
        else:
            bad += 1
            if bad >= FT_PATIENCE: break

    model.load_state_dict(best_state)
    return model

# =========================
# Folds JSON: SOLO sujetos (ignora índices)
# =========================
def read_folds_subjects_only(json_path, current_subjects):
    """
    Lee Kfold5.json y devuelve lista de folds con train_subjects/test_subjects (enteros).
    Ignora por completo tr_idx/te_idx si existen.
    Valida que los sujetos existan en el dataset actual.
    """
    p = Path(json_path)
    if not p.exists():
        raise FileNotFoundError(f"No existe {p}. Debes proporcionar un JSON con sujetos por fold.")
    with open(p, 'r', encoding='utf-8') as f:
        payload = json.load(f)
    folds_raw = payload.get('folds', [])
    if not folds_raw:
        raise ValueError("JSON no contiene 'folds'.")

    cur_set = set(current_subjects)
    folds = []
    for fr in folds_raw:
        train_s = fr.get('train_subjects', fr.get('train', []))
        test_s  = fr.get('test_subjects',  fr.get('test',  []))
        if not train_s or not test_s:
            raise ValueError("Cada fold debe tener 'train_subjects' y 'test_subjects' (o 'train'/'test').")

        def to_int_list(lst):
            out = []
            for s in lst:
                if isinstance(s, int): out.append(s)
                elif isinstance(s, str) and s.upper().startswith('S') and s[1:].isdigit():
                    out.append(int(s[1:]))
                elif isinstance(s, str) and s.isdigit():
                    out.append(int(s))
                else:
                    raise ValueError(f"Sujeto inválido en JSON: {s}")
            return out

        tr_int = to_int_list(train_s)
        te_int = to_int_list(test_s)

        miss_tr = [s for s in tr_int if s not in cur_set]
        miss_te = [s for s in te_int if s not in cur_set]
        if miss_tr or miss_te:
            raise ValueError(f"Fold con sujetos no presentes en dataset actual. "
                             f"Faltan en TRAIN: {miss_tr}, Faltan en TEST: {miss_te}")

        folds.append({'fold': fr.get('fold', len(folds)+1),
                      'train_subjects': tr_int, 'test_subjects': te_int})
    return folds

# =========================
# Experimento
# =========================
def plot_training_curves(history, fname):
    plt.figure(figsize=(6,4))
    plt.plot(history['train_acc'], label='train_acc')
    plt.plot(history['val_acc'], label='val_acc')
    plt.xlabel('Epoch'); plt.ylabel('Accuracy'); plt.title('Training curve'); plt.legend()
    plt.tight_layout(); plt.savefig(fname, dpi=150); plt.close()

def run_experiment():
    mne.set_log_level('WARNING')
    subs = subjects_available()
    print(f"Sujetos elegibles: {len(subs)} → {subs[:10]}{'...' if len(subs)>10 else ''}")

    # Ensayos balanceados (21/clase/sujeto)
    trials_bal = build_balanced_trials(subs, scenario=CLASS_SCENARIO, window_global=WINDOW_GLOBAL,
                                       n_per_class=N_PER_CLASS_PER_SUBJECT)

    # Sujetos presentes tras balanceo
    subj_in_trials = sorted({t[2] for t in trials_bal})
    print(f"Sujetos con ensayos balanceados: {len(subj_in_trials)}")

    # Leer folds SOLO por sujetos
    folds = read_folds_subjects_only(FOLDS_JSON, current_subjects=subj_in_trials)

    global_folds_trial, all_true_trial, all_pred_trial, ft_prog_folds = [], [], [], []

    for f in folds:
        fold = f['fold']
        train_subs = set(f['train_subjects'])
        test_subs  = set(f['test_subjects'])

        # Split por sujeto (antes de windowing)
        tr_trials = [t for t in trials_bal if t[2] in train_subs]
        te_trials = [t for t in trials_bal if t[2] in test_subs]

        # Sanity: todas las clases presentes en TEST por ensayo
        lbl_te = [t[1] for t in te_trials]
        present = {0,1,2,3}.intersection(set(lbl_te))
        if present != {0,1,2,3}:
            print(f"[WARN] Fold {fold}: clases en TEST (trial-level) = {sorted(present)}; se salta este fold.")
            continue

        # VALID por sujeto dentro de TRAIN
        train_subs_list = sorted(list(train_subs))
        rng = np.random.RandomState(RANDOM_STATE + fold)
        rng.shuffle(train_subs_list)
        n_va = max(1, int(round(GLOBAL_VAL_SPLIT * len(train_subs_list))))
        val_subs = set(train_subs_list[:n_va])
        real_train_subs = set(train_subs_list[n_va:])

        tr_trials_main = [t for t in tr_trials if t[2] in real_train_subs]
        va_trials      = [t for t in tr_trials if t[2] in val_subs]

        # Windowing por split (sin fuga)
        Xtr, ytr, gtr, ttr, _ = window_trials(trials=tr_trials_main, win_len_sec=WIN_LEN_SEC, win_stride_sec=WIN_STRIDE_SEC, fs=FS)
        Xva, yva, gva, tva, _ = window_trials(trials=va_trials,      win_len_sec=WIN_LEN_SEC, win_stride_sec=WIN_STRIDE_SEC, fs=FS)
        Xte, yte, gte, tte, _ = window_trials(trials=te_trials,      win_len_sec=WIN_LEN_SEC, win_stride_sec=WIN_STRIDE_SEC, fs=FS)

        # DataLoaders
        ds_tr = EEGTrials(Xtr, ytr, gtr, ttr)
        ds_va = EEGTrials(Xva, yva, gva, tva)
        ds_te = EEGTrials(Xte, yte, gte, tte)

        tr_loader = DataLoader(ds_tr, batch_size=BATCH_SIZE,
                               sampler=build_weighted_sampler(ytr, gtr), drop_last=False)
        va_loader = DataLoader(ds_va, batch_size=BATCH_SIZE, shuffle=False, drop_last=False)
        te_loader = DataLoader(ds_te, batch_size=BATCH_SIZE, shuffle=False, drop_last=False)

        # Modelo
        Tw, C = Xtr.shape[1], Xtr.shape[2]
        n_classes = len(set(ytr.tolist()) | set(yva.tolist()) | set(yte.tolist()))
        model = HybridCNNTransformer(
            n_ch=C, n_classes=n_classes, seq_len=Tw,
            d_model=128, nhead=4, depth=3, dim_ff=256,
            k_temporal=64, stride_t=2, drop=0.2, chdrop_p=0.10, use_cls_token=True
        ).to(DEVICE)

        opt = torch.optim.Adam(model.parameters(), lr=LR_INIT)
        criterion = SoftCE(n_classes=n_classes, label_smoothing=0.05, class_weights=None)

        def _acc_trial(loader): return evaluate_per_trial(model, loader, use_tta=False, n_classes=n_classes)[2]

        # Entrenamiento global (early stopping por ENSAYO)
        history = {'train_acc': [], 'val_acc': []}
        print(f"\n[Fold {fold}/{len(folds)}] Entrenando global (VAL por ensayo)...")
        best_state = copy.deepcopy(model.state_dict()); best_val = -1.0; bad = 0

        for epoch in range(1, EPOCHS_GLOBAL+1):
            model.train()
            for xb, yb, _, _ in tr_loader:
                xb, yb = xb.to(DEVICE), yb.to(DEVICE)
                loss = criterion(model(xb), yb)
                opt.zero_grad(set_to_none=True); loss.backward(); opt.step()
                apply_max_norm(model, 2.0, p=2.0)

            tr_acc = _acc_trial(tr_loader)
            va_acc = _acc_trial(va_loader)
            history['train_acc'].append(tr_acc); history['val_acc'].append(va_acc)

            if epoch % 5 == 0 or epoch in (1,10,20,50,80):
                print(f"  Época {epoch:3d} | train(ensayo)={tr_acc:.4f} | val(ensayo)={va_acc:.4f}")

            if va_acc > best_val + 1e-4:
                best_val, bad = va_acc, 0
                best_state = copy.deepcopy(model.state_dict())
            else:
                bad += 1
                if bad >= GLOBAL_PATIENCE:
                    print(f"  Early stopping @ epoch {epoch} (best val ensayo={best_val:.4f})")
                    break

        plot_training_curves(history, f"training_curve_cnnTrans_cosine_fold{fold}.png")
        print(f"↳ Curva: training_curve_cnnTrans_cosine_fold{fold}.png")
        model.load_state_dict(best_state)

        # Test por ENSAYO (TTA)
        y_true_tr, y_pred_tr, acc_tr = evaluate_per_trial(model, te_loader, use_tta=True, tta_n=5, n_classes=n_classes)
        print(f"[Fold {fold}] Ensayo acc={acc_tr:.4f}")
        print_report(y_true_tr, y_pred_tr, CLASS_NAMES, tag="por ENSAYO (agregado)")
        global_folds_trial.append(acc_tr)
        all_true_trial.append(y_true_tr); all_pred_trial.append(y_pred_tr)

        # ====== FT por sujeto ======
        y_true_ft_all, y_pred_ft_all, used_subjects = [], [], 0

        # Índices por sujeto en Xte
        for sid in sorted(test_subs):
            mask_sid = (gte == sid)
            if mask_sid.sum() == 0: continue
            Xs, ys, ts = Xte[mask_sid], yte[mask_sid], tte[mask_sid]

            # GroupKFold por ensayo (sin fuga)
            if len(np.unique(ts)) < CALIB_CV_FOLDS:
                continue
            gkf = GroupKFold(n_splits=CALIB_CV_FOLDS)
            for tr_i, ho_i in gkf.split(Xs, ys, groups=ts):
                Xcal, ycal, tcal = Xs[tr_i], ys[tr_i], ts[tr_i]
                Xho,  yho,  tho  = Xs[ho_i], ys[ho_i], ts[ho_i]

                if FT_STRATEGY == "none_bn_protos":
                    # 1) BN-Adapt (sin etiquetas)
                    bn_adapt(model, Xcal)
                    # 2) Prototipos (opcional; usa etiquetas de calibración)
                    set_prototypes(model, Xcal, ycal, n_classes)
                    m_ft = model  # usamos el propio modelo adaptado (sin copiar)

                elif FT_STRATEGY == "head_l2sp":
                    m_ft = finetune_head_l2sp(model, Xcal, ycal, tcal,
                                              n_classes=n_classes, epochs=FT_EPOCHS,
                                              head_lr=FT_HEAD_LR, l2sp_lambda=FT_L2SP,
                                              val_ratio=FT_VAL_RATIO, device=DEVICE)
                else:
                    raise ValueError(f"Estrategia FT desconocida: {FT_STRATEGY}")

                with torch.no_grad():
                    xb = torch.from_numpy(Xho).float().unsqueeze(1).to(DEVICE)
                    logits = m_ft(xb).cpu().numpy()

                # Agregar por ensayo
                sums = collections.defaultdict(lambda: np.zeros((n_classes,), dtype=np.float64))
                counts = collections.defaultdict(int)
                labels = {}
                for i in range(len(Xho)):
                    t = int(tho[i]); sums[t] += logits[i]; counts[t] += 1; labels[t] = int(yho[i])
                for t in sums.keys():
                    y_pred_ft_all.append(int((sums[t]/max(1,counts[t])).argmax()))
                    y_true_ft_all.append(labels[t])
                used_subjects += 1

        if len(y_true_ft_all) > 0:
            y_true_ft_all = np.asarray(y_true_ft_all, dtype=int)
            y_pred_ft_all = np.asarray(y_pred_ft_all, dtype=int)
            acc_ft = (y_true_ft_all == y_pred_ft_all).mean()
            print(f"  FT ({FT_STRATEGY}) por ENSAYO acc={acc_ft:.4f} | sujetos={used_subjects}")
            print(f"  Δ(FT - Global) = {acc_ft - acc_tr:+.4f}")
        else:
            acc_ft = np.nan
            print("  FT no ejecutado (fold sin sujetos aptos).")

        ft_prog_folds.append(acc_ft)

    # ----- resultados globales -----
    if len(all_true_trial)>0: all_true_trial = np.concatenate(all_true_trial); all_pred_trial = np.concatenate(all_pred_trial)
    else: all_true_trial = np.array([],dtype=int); all_pred_trial = np.array([],dtype=int)

    print("\n" + "="*60)
    print(f"RESULTADOS FINALES — CNN+Transformer + Cosine head + ({FT_STRATEGY})")
    print("="*60)
    print("Ensayo folds:", [f"{a:.4f}" for a in global_folds_trial])
    if len(global_folds_trial) > 0:
        print(f"Ensayo mean: {np.mean(global_folds_trial):.4f}")
    print("FT prog (ensayo) folds:", [("nan" if (a is None or np.isnan(a)) else f"{a:.4f}") for a in ft_prog_folds])
    if len(ft_prog_folds) > 0 and len(global_folds_trial) > 0:
        print(f"Δ(FT - Global) mean: {np.nanmean(ft_prog_folds) - np.mean(global_folds_trial):+.4f}")

    if all_true_trial.size > 0:
        plot_confusion(all_true_trial, all_pred_trial, CLASS_NAMES,
                       title="Confusion Matrix - Global Model (Per Trial, All Folds)",
                       fname="confusion_cnnTrans_cosine_trial_allfolds.png")
        print("↳ Matriz de confusión (ensayo): confusion_cnnTrans_cosine_trial_allfolds.png")

# ---------- MAIN ----------
if __name__ == "__main__":
    bp_status = f"ON [{BP_LO}-{BP_HI} Hz]" if USE_BANDPASS and (BP_LO is not None) and (BP_HI is not None) else "OFF"
    print("🧠 MI-EEG — CNN+Transformer (21/cls/subj) + Sliding Window + Cosine head + FT opcional")
    print(f"🔧 Escenario: {CLASS_SCENARIO} | rango ensayo={WINDOW_GLOBAL} s | subventana={WIN_LEN_SEC}s | stride={WIN_STRIDE_SEC}s | subwins/ensayo={N_SUBWINS} | band-pass={bp_status}")
    print(f"⚙️  Balance: 21 ensayos por clase y por sujeto (con reemplazo si falta)")
    print(f"⚙️  FT strategy: {FT_STRATEGY}")
    run_experiment()


🚀 Dispositivo: cuda
🧠 MI-EEG — CNN+Transformer (21/cls/subj) + Sliding Window + Cosine head + FT opcional
🔧 Escenario: 4c | rango ensayo=(0.0, 4.0) s | subventana=2.5s | stride=0.5s | subwins/ensayo=4 | band-pass=OFF
⚙️  Balance: 21 ensayos por clase y por sujeto (con reemplazo si falta)
⚙️  FT strategy: none_bn_protos
Sujetos elegibles: 103 → [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]...


Recolectando y balanceando ensayos base por sujeto: 100%|██████████| 103/103 [00:37<00:00,  2.78it/s]


Ensayos base balanceados: 8652  (~ #sujetos_utiles * 84)
Sujetos con ensayos balanceados: 103





[Fold 1/5] Entrenando global (VAL por ensayo)...
  Época   1 | train(ensayo)=0.3952 | val(ensayo)=0.4147
  Época   5 | train(ensayo)=0.4409 | val(ensayo)=0.4534
  Época  10 | train(ensayo)=0.5419 | val(ensayo)=0.4425
  Época  15 | train(ensayo)=0.6056 | val(ensayo)=0.4435
  Early stopping @ epoch 19 (best val ensayo=0.4583)
↳ Curva: training_curve_cnnTrans_cosine_fold1.png
[Fold 1] Ensayo acc=0.3475

=== Reporte por ENSAYO (agregado) ===
              precision    recall  f1-score   support

        Left     0.4111    0.2676    0.3242       441
       Right     0.4150    0.2766    0.3320       441
  Both Fists     0.2922    0.6440    0.4020       441
   Both Feet     0.4218    0.2018    0.2730       441

    accuracy                         0.3475      1764
   macro avg     0.3850    0.3475    0.3328      1764
weighted avg     0.3850    0.3475    0.3328      1764

F1-macro: 0.3328 | Kappa: 0.1300
  FT (none_bn_protos) por ENSAYO acc=0.3832 | sujetos=84
  Δ(FT - Global) = +0.0357





[Fold 2/5] Entrenando global (VAL por ensayo)...
  Época   1 | train(ensayo)=0.3844 | val(ensayo)=0.3690
  Época   5 | train(ensayo)=0.4420 | val(ensayo)=0.3998
  Época  10 | train(ensayo)=0.5009 | val(ensayo)=0.4315
  Época  15 | train(ensayo)=0.5718 | val(ensayo)=0.4206
  Early stopping @ epoch 16 (best val ensayo=0.4335)
↳ Curva: training_curve_cnnTrans_cosine_fold2.png
[Fold 2] Ensayo acc=0.4359

=== Reporte por ENSAYO (agregado) ===
              precision    recall  f1-score   support

        Left     0.4617    0.3832    0.4188       441
       Right     0.4394    0.6327    0.5186       441
  Both Fists     0.4519    0.2766    0.3432       441
   Both Feet     0.4037    0.4512    0.4261       441

    accuracy                         0.4359      1764
   macro avg     0.4392    0.4359    0.4267      1764
weighted avg     0.4392    0.4359    0.4267      1764

F1-macro: 0.4267 | Kappa: 0.2479
  FT (none_bn_protos) por ENSAYO acc=0.4348 | sujetos=84
  Δ(FT - Global) = -0.0011





[Fold 3/5] Entrenando global (VAL por ensayo)...
  Época   1 | train(ensayo)=0.3501 | val(ensayo)=0.3581
  Época   5 | train(ensayo)=0.4347 | val(ensayo)=0.3800
  Época  10 | train(ensayo)=0.5297 | val(ensayo)=0.4216
  Época  15 | train(ensayo)=0.6162 | val(ensayo)=0.3998
  Early stopping @ epoch 16 (best val ensayo=0.4306)
↳ Curva: training_curve_cnnTrans_cosine_fold3.png
[Fold 3] Ensayo acc=0.4019

=== Reporte por ENSAYO (agregado) ===
              precision    recall  f1-score   support

        Left     0.4363    0.4739    0.4543       441
       Right     0.5306    0.2948    0.3790       441
  Both Fists     0.3069    0.3855    0.3417       441
   Both Feet     0.4115    0.4535    0.4315       441

    accuracy                         0.4019      1764
   macro avg     0.4213    0.4019    0.4016      1764
weighted avg     0.4213    0.4019    0.4016      1764

F1-macro: 0.4016 | Kappa: 0.2026
  FT (none_bn_protos) por ENSAYO acc=0.4150 | sujetos=84
  Δ(FT - Global) = +0.0130





[Fold 4/5] Entrenando global (VAL por ensayo)...
  Época   1 | train(ensayo)=0.3516 | val(ensayo)=0.2877
  Época   5 | train(ensayo)=0.4616 | val(ensayo)=0.3938
  Época  10 | train(ensayo)=0.5371 | val(ensayo)=0.3730
  Época  15 | train(ensayo)=0.5627 | val(ensayo)=0.3522
  Early stopping @ epoch 15 (best val ensayo=0.3938)
↳ Curva: training_curve_cnnTrans_cosine_fold4.png
[Fold 4] Ensayo acc=0.4286

=== Reporte por ENSAYO (agregado) ===
              precision    recall  f1-score   support

        Left     0.4252    0.6095    0.5010       420
       Right     0.3945    0.5833    0.4707       420
  Both Fists     0.3925    0.1000    0.1594       420
   Both Feet     0.5057    0.4214    0.4597       420

    accuracy                         0.4286      1680
   macro avg     0.4295    0.4286    0.3977      1680
weighted avg     0.4295    0.4286    0.3977      1680

F1-macro: 0.3977 | Kappa: 0.2381
  FT (none_bn_protos) por ENSAYO acc=0.4375 | sujetos=80
  Δ(FT - Global) = +0.0089





[Fold 5/5] Entrenando global (VAL por ensayo)...
  Época   1 | train(ensayo)=0.3258 | val(ensayo)=0.2946
  Época   5 | train(ensayo)=0.4639 | val(ensayo)=0.3879
  Época  10 | train(ensayo)=0.5121 | val(ensayo)=0.3700
  Early stopping @ epoch 13 (best val ensayo=0.4008)
↳ Curva: training_curve_cnnTrans_cosine_fold5.png
[Fold 5] Ensayo acc=0.4506

=== Reporte por ENSAYO (agregado) ===
              precision    recall  f1-score   support

        Left     0.4852    0.4690    0.4770       420
       Right     0.4472    0.5548    0.4952       420
  Both Fists     0.4080    0.3167    0.3566       420
   Both Feet     0.4543    0.4619    0.4581       420

    accuracy                         0.4506      1680
   macro avg     0.4487    0.4506    0.4467      1680
weighted avg     0.4487    0.4506    0.4467      1680

F1-macro: 0.4467 | Kappa: 0.2675
  FT (none_bn_protos) por ENSAYO acc=0.4804 | sujetos=80
  Δ(FT - Global) = +0.0298

RESULTADOS FINALES — CNN+Transformer + Cosine head + (none_b

In [None]:
# -*- coding: utf-8 -*-
# ===========================================
# MI-EEG (PhysioNet eegmmidb) — CNN+Transformer
# Balanced 21/cls/subj + Sliding Window + Cosine head
# + Euclidean Alignment (EA) + (MMD) + TTA-BN opcional
# ✅ Fix: ClassBalancedFocal ahora sincroniza el tensor de pesos (cb_w) con el device de los logits.
# ===========================================

import os, re, math, json, copy, random, itertools
from pathlib import Path
from datetime import datetime

import numpy as np
import pandas as pd
import mne

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torch.optim.swa_utils import AveragedModel, SWALR

from sklearn.metrics import confusion_matrix, classification_report, f1_score, cohen_kappa_score
from sklearn.utils import check_random_state
from tqdm import tqdm
import matplotlib.pyplot as plt

# =========================
# CONFIG
# =========================
PROJ = Path('..').resolve().parent
DATA_RAW = PROJ / 'data' / 'raw'           # .../S###/S###R##.edf
FOLDS_JSON = PROJ / 'models' / 'folds' / 'Kfold5.json'
CACHE_DIR = PROJ / 'data' / 'cache'
CACHE_DIR.mkdir(parents=True, exist_ok=True)

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
RANDOM_STATE = 42
torch.manual_seed(RANDOM_STATE); np.random.seed(RANDOM_STATE); random.seed(RANDOM_STATE)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
print(f"🚀 Dispositivo: {DEVICE}")

FS = 160.0

CLASS_SCENARIO = '4c'
WINDOW_GLOBAL = (0.0, 4.0)

WIN_LEN_SEC = 2.5
WIN_STRIDE_SEC = 0.5
def _num_subwins(win_len_s, stride_s, rng):
    return int(math.floor((rng[1]-rng[0]-win_len_s)/stride_s) + 1)
N_SUBWINS = _num_subwins(WIN_LEN_SEC, WIN_STRIDE_SEC, WINDOW_GLOBAL)

N_PER_CLASS_PER_SUBJECT = 21

BATCH_SIZE = 64
EPOCHS_GLOBAL = 80
LR_INIT = 1e-3
GLOBAL_VAL_SPLIT_SUBJ = 0.15
GLOBAL_PATIENCE  = 10

USE_SWA = True
SWA_START_FRAC = 0.6
SWA_LR_FACTOR = 0.5

TTA_BN_ON_TEST = True

MMD_LAMBDA = 0.05
MMD_MAX_SAMPLES_PER_GROUP = 64

USE_EA = True
EA_EPS = 1e-6

USE_BANDPASS = False
BP_LO, BP_HI = 8.0, 30.0
BP_METHOD, BP_PHASE = 'fir', 'zero'

EXCLUDE_SUBJECTS = {38,88,89,92,100,104}
MI_RUNS_LR = [4,8,12]
MI_RUNS_OF = [6,10,14]
EXPECTED_8 = ['C3','C4','Cz','CP3','CP4','FC3','FC4','FCz']
CLASS_NAMES = ['Left', 'Right', 'Both Fists', 'Both Feet']

PER_CLASS_TAU = np.array([1.00, 1.00, 1.15, 1.10], dtype=np.float32)

# =========================
# UTIL: canales/edf
# =========================
def normalize_label(s: str) -> str:
    if s is None: return s
    s = s.strip()
    s = re.sub(r'[^A-Za-z0-9]', '', s)
    s = s.replace('fp', 'Fp').replace('FP', 'Fp')
    s = ''.join(ch.upper() if ch != 'z' else 'z' for ch in s)
    return s

def rename_channels_1010(raw: mne.io.BaseRaw):
    mapping = {}
    for ch in raw.ch_names:
        lab = normalize_label(ch)
        lab = lab[:-1] + 'z' if lab.endswith('Z') else lab
        lab = re.sub(r'([A-Z])Z$', r'\1z', lab)
        mapping[ch] = lab
    mne.rename_channels(raw.info, mapping)

def ensure_channels_order(raw: mne.io.BaseRaw, desired_channels=EXPECTED_8):
    missing = [ch for ch in desired_channels if ch not in raw.ch_names]
    if missing:
        print(f"[WARN] faltan canales {missing} en {getattr(raw,'filenames',[''])[0]}")
        return None
    raw.reorder_channels([ch for ch in raw.ch_names if ch in desired_channels] +
                         [ch for ch in raw.ch_names if ch not in desired_channels])
    raw.pick_channels(desired_channels, ordered=True)
    return raw

_re_file = re.compile(r'[Ss](\d{3}).*?[Rr](\d{2})')
def parse_subject_run(path: Path):
    m = _re_file.search(str(path))
    if not m: return None, None
    return int(m.group(1)), int(m.group(2))

def run_kind(run_id:int):
    if run_id in MI_RUNS_LR: return 'LR'
    if run_id in MI_RUNS_OF: return 'OF'
    return None

_SNR_TABLE = None
def _load_snr_table():
    global _SNR_TABLE
    if _SNR_TABLE is not None: return _SNR_TABLE
    csv_path = PROJ / 'reports' / 'psd_mains' / 'psd_mains_summary.csv'
    if csv_path.exists():
        try: _SNR_TABLE = pd.read_csv(csv_path)
        except Exception:
            _SNR_TABLE = None
    return _SNR_TABLE

def _decide_notch(subject, run, th_db=10.0):
    df = _load_snr_table()
    if df is None:  return 60.0
    row = df[(df['subject']==subject) & (df['run']==run)]
    if row.empty:   return 60.0
    snr50 = float(row['snr50_db'].iloc[0]); snr60 = float(row['snr60_db'].iloc[0])
    if snr60 >= th_db and snr60 >= snr50: return 60.0
    if snr50 >= th_db and snr50 >  snr60: return 50.0
    return None

def read_raw_edf(path: Path):
    raw = mne.io.read_raw_edf(path, preload=True, verbose=False)
    raw.pick(mne.pick_types(raw.info, eeg=True))
    rename_channels_1010(raw)
    try:
        mont = mne.channels.make_standard_montage('standard_1020')
        raw.set_montage(mont, on_missing='ignore')
    except Exception:
        pass
    if abs(raw.info['sfreq'] - FS) > 1e-6:
        raw.resample(FS, npad="auto")
    raw = ensure_channels_order(raw, EXPECTED_8)
    if raw is None: return None

    sid, rid = parse_subject_run(path)
    notch = _decide_notch(sid, rid)
    if notch is not None:
        raw.notch_filter(freqs=[float(notch)], picks='eeg', method='spectrum_fit', phase='zero')

    if USE_BANDPASS and (BP_LO is not None) and (BP_HI is not None):
        raw.filter(l_freq=float(BP_LO), h_freq=float(BP_HI),
                   picks='eeg', method=BP_METHOD, phase=BP_PHASE, verbose=False)
    return raw

def collect_events_T1T2(raw: mne.io.BaseRaw):
    out = []
    if raw.annotations is None or len(raw.annotations) == 0: return out
    def _norm(s): return str(s).strip().upper().replace(' ','')
    for onset, desc in zip(raw.annotations.onset, raw.annotations.description):
        tag = _norm(desc)
        if tag in ('T1','T2'): out.append((float(onset), tag))
    out.sort()
    dedup = []; last_t1 = last_t2 = -1e9
    for t, tag in out:
        if tag == 'T1':
            if (t - last_t1) >= 0.5: dedup.append((t, tag)); last_t1 = t
        else:
            if (t - last_t2) >= 0.5: dedup.append((t, tag)); last_t2 = t
    return dedup

# =========================
# Dataset helpers
# =========================
def subjects_available():
    subs = []
    for sdir in sorted(DATA_RAW.glob('S*')):
        if not sdir.is_dir(): continue
        try: sid = int(sdir.name[1:])
        except: continue
        if sid in EXCLUDE_SUBJECTS: continue
        any_mi = any((sdir / f"S{sid:03d}R{r:02d}.edf").exists() for r in (MI_RUNS_LR+MI_RUNS_OF))
        if any_mi: subs.append(sid)
    return subs

def extract_base_trials_from_run(edf_path: Path, scenario: str, window_global=(0.0,4.0)):
    subj, run = parse_subject_run(edf_path)
    kind = run_kind(run)
    if kind not in ('LR','OF'): return []

    raw = read_raw_edf(edf_path)
    if raw is None: return []
    data = raw.get_data()
    fs = raw.info['sfreq']
    out = []
    events = collect_events_T1T2(raw)
    rel_start, rel_end = window_global
    ev_idx = 0
    for onset_sec, tag in events:
        ev_idx += 1
        if kind == 'LR':
            label = 0 if tag=='T1' else 1
        else:
            label = 2 if tag=='T1' else 3
        if   scenario == '2c' and label not in (0,1): continue
        elif scenario == '3c' and label not in (0,1,2): continue
        elif scenario == '4c' and label not in (0,1,2,3): continue

        s0 = int(round((raw.first_time + onset_sec + rel_start) * fs))
        e0 = int(round((raw.first_time + onset_sec + rel_end)   * fs))
        if s0 < 0 or e0 > data.shape[1] or e0 <= s0: continue

        seg = data[:, s0:e0].T.astype(np.float32)  # (T_base, C)
        seg = (seg - seg.mean(axis=0, keepdims=True)) / (seg.std(axis=0, keepdims=True) + 1e-6)
        trial_key = f"S{subj:03d}_R{run:02d}_E{ev_idx:03d}"
        out.append((seg, label, subj, trial_key))
    return out

def build_balanced_trials(subjects, scenario='4c', window_global=(0.0,4.0), n_per_class=N_PER_CLASS_PER_SUBJECT):
    trials_balanced = []
    for s in tqdm(subjects, desc="Recolectando y balanceando ensayos base por sujeto"):
        sdir = DATA_RAW / f"S{s:03d}"
        if not sdir.exists(): continue
        bins = {0:[],1:[],2:[],3:[]}
        for r in (MI_RUNS_LR + MI_RUNS_OF):
            p = sdir / f"S{s:03d}R{r:02d}.edf"
            if not p.exists(): continue
            for t in extract_base_trials_from_run(p, scenario, window_global):
                bins[t[1]].append(t)
        if any(len(bins[c]) == 0 for c in (0,1,2,3)): continue
        rng = check_random_state(RANDOM_STATE + s)
        for c in (0,1,2,3):
            pool = bins[c]
            if len(pool) >= n_per_class:
                rng.shuffle(pool); take = pool[:n_per_class]
            else:
                idx = rng.choice(len(pool), size=n_per_class, replace=True)
                take = [pool[i] for i in idx]
            trials_balanced.extend(take)
    print(f"Ensayos base balanceados: {len(trials_balanced)}  (~ #sujetos_utiles * {n_per_class*4})")
    return trials_balanced

# ---------- EA ----------
def _sym_invsqrt(mat, eps=1e-6):
    w, v = np.linalg.eigh(mat)
    w = np.clip(w, eps, None)
    invsqrt = (v * (w**-0.5))[..., :] @ v.T
    return invsqrt

def compute_ea_whiteners_per_subject(trials_balanced, eps=EA_EPS):
    by_subj = {}
    for seg, _, subj, _ in trials_balanced:
        by_subj.setdefault(subj, []).append(seg)
    whiten = {}
    for subj, segs in by_subj.items():
        covs = []
        for X in segs:
            C = (X.T @ X) / float(max(1, X.shape[0]))
            covs.append(C)
        Cmean = np.mean(covs, axis=0)
        W = _sym_invsqrt(Cmean, eps=eps)  # (C,C)
        whiten[subj] = W.astype(np.float32)
    return whiten

def apply_ea_to_trials(trials_balanced, whiteners):
    out = []
    for seg, lab, subj, tkey in trials_balanced:
        W = whiteners.get(subj, None)
        if W is not None:
            seg = seg @ W.T
        out.append((seg, lab, subj, tkey))
    return out

# ---------- Windowing ----------
def window_trials(trials, win_len_sec=2.5, win_stride_sec=0.5, fs=160.0, n_subwins=N_SUBWINS):
    X, y, g, tkeys = [], [], [], []
    win_len = int(round(win_len_sec * fs))
    stride  = int(round(win_stride_sec * fs))
    for seg, lab, subj, tkey in trials:
        T = seg.shape[0]
        starts = [i for i in range(0, T - win_len + 1, stride)]
        if len(starts) > n_subwins: starts = starts[:n_subwins]
        while len(starts) < n_subwins and (T - win_len) >= 0:
            starts.append(T - win_len)
        for s in starts:
            e = s + win_len
            sub = seg[s:e, :].astype(np.float32)
            sub = (sub - sub.mean(axis=0, keepdims=True)) / (sub.std(axis=0, keepdims=True) + 1e-6)
            X.append(sub); y.append(lab); g.append(subj); tkeys.append(tkey)
    uniq = {k:i for i,k in enumerate(sorted(set(tkeys)))}
    trial_ids = np.asarray([uniq[k] for k in tkeys], dtype=np.int64)
    X = np.stack(X, axis=0).astype(np.float32)
    y = np.asarray(y, dtype=np.int64)
    g = np.asarray(g, dtype=np.int64)
    return X, y, g, trial_ids, uniq

# =========================
# Dataset torch
# =========================
class EEGTrials(Dataset):
    def __init__(self, X, y, groups, trial_ids, subjs_for_mmd=None):
        self.X = X.astype(np.float32)
        self.y = y.astype(np.int64)
        self.g = groups.astype(np.int64)
        self.t = trial_ids.astype(np.int64)
        self.s = self.g if subjs_for_mmd is None else subjs_for_mmd.astype(np.int64)
    def __len__(self): return self.X.shape[0]
    def __getitem__(self, idx):
        x = self.X[idx]; x = np.expand_dims(x, 0)  # (1,Tw,C)
        return torch.from_numpy(x), torch.tensor(self.y[idx]), torch.tensor(self.g[idx]), torch.tensor(self.t[idx]), torch.tensor(self.s[idx])

def build_weighted_sampler(y, groups):
    y = np.asarray(y); groups = np.asarray(groups)
    class_counts = np.bincount(y, minlength=len(np.unique(y))).astype(float)
    class_w = 1.0 / class_counts[y]
    subj_vals, subj_counts = np.unique(groups, return_counts=True)
    subj_map = {s:c for s,c in zip(subj_vals, subj_counts)}
    subj_w = np.array([1.0/subj_map[g] for g in groups], dtype=float)
    w = class_w * subj_w; w = w / w.mean()
    return WeightedRandomSampler(weights=torch.from_numpy(w).float(), num_samples=len(w), replacement=True)

# =========================
# Modelo: CNN + Transformer + Cosine head
# =========================
class ChannelDropout(nn.Module):
    def __init__(self, p=0.1): super().__init__(); self.p = p
    def forward(self, x):
        if not self.training or self.p<=0: return x
        B,_,T,C = x.shape
        mask = (torch.rand(B,1,1,C, device=x.device) > self.p).float()
        return x * mask

@torch.no_grad()
def apply_max_norm(model, max_value=2.0, p=2.0):
    for m in model.modules():
        if hasattr(m,'weight') and isinstance(m,(nn.Conv2d, nn.Linear)):
            w = m.weight.data
            norms = w.view(w.size(0), -1).norm(p=p, dim=1, keepdim=True)
            desired = torch.clamp(norms, max=max_value)
            w.view(w.size(0), -1).mul_(desired / (1e-8 + norms))

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 4000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0)/d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer('pe', pe.unsqueeze(0), persistent=False)
    def forward(self, x):
        T = x.size(1); return x + self.pe[:, :T, :]

class CNNTokenizer(nn.Module):
    def __init__(self, n_ch: int, d_feat: int = 128, k_temporal: int = 64, stride_t: int = 2, chdrop_p: float = 0.1, drop_p: float = 0.2):
        super().__init__()
        self.chdrop = ChannelDropout(chdrop_p)
        self.conv_t = nn.Conv2d(1, d_feat, kernel_size=(k_temporal,1), stride=(stride_t,1), padding=(k_temporal//2,0), bias=False)
        self.bn_t   = nn.BatchNorm2d(d_feat, momentum=0.99, eps=1e-3)
        self.act    = nn.ELU()
        self.conv_sp_dw = nn.Conv2d(d_feat, d_feat, kernel_size=(1,n_ch), groups=d_feat, bias=False)
        self.bn_sp      = nn.BatchNorm2d(d_feat, momentum=0.99, eps=1e-3)
        self.conv_pw = nn.Conv2d(d_feat, d_feat, kernel_size=(1,1), bias=False)
        self.bn_pw   = nn.BatchNorm2d(d_feat, momentum=0.99, eps=1e-3)
        self.drop    = nn.Dropout(drop_p)
    def forward(self, x):
        z = self.chdrop(x)
        z = self.conv_t(z); z = self.bn_t(z); z = self.act(z)
        z = self.conv_sp_dw(z); z = self.bn_sp(z); z = self.act(z)
        z = self.conv_pw(z); z = self.bn_pw(z); z = self.act(z)
        z = self.drop(z)
        z = z.squeeze(-1).transpose(1,2)  # (B,T',D)
        return z

def _enc_layer(d_model=128, nhead=4, dim_ff=256, drop=0.2):
    return nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_ff,
                                      dropout=drop, activation='gelu', batch_first=True, norm_first=True)

class TemporalTransformer(nn.Module):
    def __init__(self, d_model=128, nhead=4, dim_ff=256, depth=3, drop=0.2):
        super().__init__()
        self.pe = PositionalEncoding(d_model)
        layer = _enc_layer(d_model, nhead, dim_ff, drop)
        self.enc = nn.TransformerEncoder(layer, num_layers=depth)
        self.norm = nn.LayerNorm(d_model)
    def forward(self, x):
        x = self.pe(x)
        x = self.enc(x)
        return self.norm(x)

class CosineClassifier(nn.Module):
    def __init__(self, d, n_classes, tau=10.0):
        super().__init__()
        self.W = nn.Parameter(torch.randn(n_classes, d))
        nn.init.xavier_uniform_(self.W)
        self.log_tau = nn.Parameter(torch.log(torch.tensor(tau, dtype=torch.float32)))
    def forward(self, f):  # f: (B,d)
        f_n = F.normalize(f, dim=1)
        W_n = F.normalize(self.W, dim=1)
        cos = torch.mm(f_n, W_n.t())
        return torch.exp(self.log_tau) * cos

class HybridCNNTransformer(nn.Module):
    def __init__(self, n_ch: int, n_classes: int, seq_len: int,
                 d_model: int = 128, nhead: int = 4, depth: int = 3,
                 dim_ff: int = 256, k_temporal: int = 64, stride_t: int = 2,
                 drop: float = 0.2, chdrop_p: float = 0.1, use_cls_token: bool = True):
        super().__init__()
        self.use_cls = use_cls_token
        self.tokenizer = CNNTokenizer(n_ch, d_model, k_temporal, stride_t, chdrop_p, drop)
        self.cls = nn.Parameter(torch.zeros(1,1,d_model)) if use_cls_token else None
        self.transformer = TemporalTransformer(d_model, nhead, dim_ff, depth, drop)
        self.norm_head = nn.LayerNorm(d_model)
        self.head = CosineClassifier(d_model, n_classes, tau=10.0)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
            elif isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None: nn.init.zeros_(m.bias)

    def tokenize(self, x):
        tok = self.tokenizer(x)                 # (B,T',D)
        if self.use_cls:
            B = tok.size(0); cls_tok = self.cls.expand(B,-1,-1)
            tok = torch.cat([cls_tok, tok], dim=1)
        return tok

    def forward_from_tokens(self, tok):
        z = self.transformer(tok)
        feat = z[:,0,:] if self.use_cls else z.mean(dim=1)
        feat = self.norm_head(feat)
        return self.head(feat)

    def forward(self, x):
        tok = self.tokenize(x)
        return self.forward_from_tokens(tok)

    @torch.no_grad()
    def extract_feats(self, x):
        tok = self.tokenize(x)
        z = self.transformer(tok)
        feat = z[:,0,:] if self.use_cls else z.mean(dim=1)
        return self.norm_head(feat)

# =========================
# Losses / Métricas
# =========================
class ClassBalancedFocal(nn.Module):
    """Class-Balanced Focal Loss con label smoothing.
       ✅ FIX DEVICE: cb_w siempre se usa en el mismo device que los logits."""
    def __init__(self, n_classes: int, beta: float = 0.999, gamma: float = 2.0, label_smoothing: float = 0.05):
        super().__init__()
        self.n = n_classes
        self.beta = beta
        self.gamma = gamma
        self.ls = label_smoothing
        self.register_buffer('cb_w', torch.ones(n_classes))  # buffer en CPU por defecto

    @torch.no_grad()
    def update_weights(self, counts: np.ndarray):
        eff_num = (1.0 - np.power(self.beta, counts)) / (1.0 - self.beta + 1e-8)
        w = (eff_num.sum() / (eff_num + 1e-8))
        w = w / w.mean()
        # mantener device del buffer
        self.cb_w = torch.tensor(w, dtype=torch.float32, device=self.cb_w.device)

    def forward(self, logits, y):
        B, C = logits.size()
        device = logits.device
        with torch.no_grad():
            y_one = F.one_hot(y, C).float()
            if self.ls > 0:
                y_one = (1 - self.ls) * y_one + self.ls * (1.0 / C)
        logp = F.log_softmax(logits, dim=1)
        p = logp.exp()
        focal = ((1 - p) ** self.gamma)
        loss = -(y_one * focal * logp)
        # mover cb_w al device de logits si hace falta
        cw = self.cb_w if self.cb_w.device == device else self.cb_w.to(device)
        loss = loss * cw.unsqueeze(0)
        return loss.sum(dim=1).mean()

def print_report(y_true, y_pred, classes, tag=""):
    print(f"\n=== Reporte {tag} ===")
    print(classification_report(y_true, y_pred, target_names=classes, digits=4))
    print(f"F1-macro: {f1_score(y_true, y_pred, average='macro'):.4f} | Kappa: {cohen_kappa_score(y_true,y_pred):.4f}")

# ---------- MMD ----------
def _rbf_mmd2(x, y, gamma=None):
    if gamma is None:
        with torch.no_grad():
            z = torch.cat([x, y], dim=0)
            d2 = torch.cdist(z, z, p=2.0).pow(2)
            med2 = torch.median(d2[d2>0]).clamp(min=1e-8)
            gamma = 1.0 / (2.0 * med2)
    Kxx = torch.exp(-gamma * torch.cdist(x, x, p=2.0).pow(2))
    Kyy = torch.exp(-gamma * torch.cdist(y, y, p=2.0).pow(2))
    Kxy = torch.exp(-gamma * torch.cdist(x, y, p=2.0).pow(2))
    n = x.size(0); m = y.size(0)
    mmd2 = (Kxx.sum() - torch.diag(Kxx).sum()) / (n*(n-1) + 1e-8) \
         + (Kyy.sum() - torch.diag(Kyy).sum()) / (m*(m-1) + 1e-8) \
         - 2.0 * Kxy.mean()
    return mmd2

def mmd_between_subject_groups(model: HybridCNNTransformer, xb, subj_ids, max_samples=64):
    subj_ids = subj_ids.detach().cpu().numpy().tolist()
    uniq = sorted(set(subj_ids))
    if len(uniq) < 2: return None
    counts = {s: subj_ids.count(s) for s in uniq}
    top = sorted(counts.items(), key=lambda p: p[1], reverse=True)[:2]
    sA, sB = top[0][0], top[1][0]
    maskA = torch.tensor([si==sA for si in subj_ids], device=xb.device)
    maskB = torch.tensor([si==sB for si in subj_ids], device=xb.device)
    if maskA.sum()<2 or maskB.sum()<2: return None
    idxA = torch.where(maskA)[0]
    idxB = torch.where(maskB)[0]
    nA = min(max_samples, idxA.numel())
    nB = min(max_samples, idxB.numel())
    idxA = idxA[torch.randperm(idxA.numel(), device=xb.device)[:nA]]
    idxB = idxB[torch.randperm(idxB.numel(), device=xb.device)[:nB]]
    fA = model.extract_feats(xb[idxA]).detach()
    fB = model.extract_feats(xb[idxB]).detach()
    mmd2 = _rbf_mmd2(fA, fB)
    return mmd2

# ---------- Feature Mixup en tokens ----------
def feature_mixup(tok, alpha=0.2):
    if alpha <= 0:
        return tok, None
    lam = np.random.beta(alpha, alpha)
    perm = torch.randperm(tok.size(0), device=tok.device)
    tok_mix = lam * tok + (1 - lam) * tok[perm]
    return tok_mix, perm

# =========================
# Inferencia por ensayo (τ per-class)
# =========================
@torch.no_grad()
def evaluate_per_trial(model, loader, use_tta=True, tta_n=5, n_classes=4, per_class_tau=PER_CLASS_TAU):
    model.eval()
    sums, counts, labels = {}, {}, {}
    for xb, yb, _, tb, _ in loader:
        xb = xb.to(DEVICE)
        if use_tta:
            def time_jitter(x, max_ms=25, fs=160.0):
                max_shift = int(round(max_ms/1000.0 * fs))
                if max_shift <= 0: return x
                B,_,T,C = x.shape
                shifts = torch.randint(low=-max_shift, high=max_shift+1, size=(B,), device=x.device)
                out = torch.zeros_like(x)
                for i,s in enumerate(shifts):
                    if s >= 0: out[i,:,s:,:] = x[i,:,:T-s,:]
                    else: s=-s; out[i,:,:T-s,:] = x[i,:,s:,:]
                return out
            outs = []
            for _ in range(tta_n):
                outs.append(model(time_jitter(xb,25)))
            logits = torch.stack(outs, dim=0).mean(dim=0)
        else:
            logits = model(xb)
        logits = logits.detach().cpu().numpy()
        logits = logits / per_class_tau.reshape(1, -1)  # calibración por-clase

        yb = yb.numpy(); tb = tb.numpy()
        for i in range(xb.size(0)):
            t = int(tb[i])
            if t not in sums:
                sums[t] = np.zeros((n_classes,), dtype=np.float64)
                counts[t] = 0
                labels[t] = int(yb[i])
            sums[t] += logits[i]; counts[t] += 1
    ids = sorted(sums.keys())
    y_true_trial = np.asarray([labels[t] for t in ids], dtype=int)
    y_pred_trial = np.asarray([int((sums[t]/max(1,counts[t])).argmax()) for t in ids], dtype=int)
    acc = (y_true_trial == y_pred_trial).mean()
    return y_true_trial, y_pred_trial, float(acc)

def plot_confusion(y_true, y_pred, classes, title, fname):
    cm = confusion_matrix(y_true, y_pred, labels=list(range(len(classes))))
    with np.errstate(invalid='ignore'):
        cm_norm = cm.astype('float') / cm.sum(axis=1, keepdims=True)
    cm_norm = np.nan_to_num(cm_norm)
    plt.figure(figsize=(6,5))
    plt.imshow(cm_norm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title(title); plt.colorbar(fraction=0.046, pad=0.04)
    ticks = np.arange(len(classes))
    plt.xticks(ticks, classes, rotation=45, ha='right'); plt.yticks(ticks, classes)
    fmt = '.2f'; thresh = cm_norm.max()/2.
    for i,j in itertools.product(range(cm_norm.shape[0]), range(cm_norm.shape[1])):
        plt.text(j,i,format(cm_norm[i,j],fmt),ha="center", color="white" if cm_norm[i,j]>thresh else "black")
    plt.ylabel('True'); plt.xlabel('Pred'); plt.tight_layout(); plt.savefig(fname, dpi=150); plt.close()

# =========================
# TTA BN-only por sujeto (smoothing)
# =========================
@torch.no_grad()
def bn_adapt_subject_smooth(model, X_subj, batch=256, momentum=0.2):
    saved = []
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
            saved.append((m, m.running_mean.clone(), m.running_var.clone()))
    was_training = model.training
    model.train()
    for i in range(0, len(X_subj), batch):
        xb = torch.from_numpy(X_subj[i:i+batch]).float().unsqueeze(1).to(DEVICE)
        _ = model(xb)
    model.eval()
    for m, old_mean, old_var in saved:
        m.running_mean = (1 - momentum) * old_mean + momentum * m.running_mean
        m.running_var  = (1 - momentum) * old_var  + momentum * m.running_var
    if was_training: model.train()

# =========================
# Folds (subjects-only)
# =========================
def read_folds_subjects_only(json_path, current_subjects):
    p = Path(json_path)
    if not p.exists():
        raise FileNotFoundError(f"No existe {p}. Provee Kfold5.json con sujetos por fold.")
    with open(p, 'r', encoding='utf-8') as f:
        payload = json.load(f)
    folds_raw = payload.get('folds', [])
    if not folds_raw:
        raise ValueError("JSON no contiene 'folds'.")

    cur_set = set(current_subjects)
    folds = []
    for fr in folds_raw:
        train_s = fr.get('train_subjects', fr.get('train', []))
        test_s  = fr.get('test_subjects',  fr.get('test',  []))
        if not train_s or not test_s:
            raise ValueError("Cada fold debe tener 'train_subjects' y 'test_subjects' (o 'train'/'test').")

        def to_int_list(lst):
            out = []
            for s in lst:
                if isinstance(s, int): out.append(s)
                elif isinstance(s, str) and s.upper().startswith('S') and s[1:].isdigit():
                    out.append(int(s[1:]))
                elif isinstance(s, str) and s.isdigit():
                    out.append(int(s))
                else:
                    raise ValueError(f"Sujeto inválido en JSON: {s}")
            return out

        tr_int = to_int_list(train_s)
        te_int = to_int_list(test_s)

        miss_tr = [s for s in tr_int if s not in cur_set]
        miss_te = [s for s in te_int if s not in cur_set]
        if miss_tr or miss_te:
            raise ValueError(f"Fold con sujetos fuera del dataset. Falta TRAIN: {miss_tr}, TEST: {miss_te}")

        folds.append({'fold': fr.get('fold', len(folds)+1),
                      'train_subjects': tr_int, 'test_subjects': te_int})
    return folds

# =========================
# Experimento
# =========================
def plot_training_curves(history, fname):
    plt.figure(figsize=(6,4))
    plt.plot(history['train_acc'], label='train_acc')
    plt.plot(history['val_acc'], label='val_acc')
    plt.xlabel('Epoch'); plt.ylabel('Accuracy'); plt.title('Training curve'); plt.legend()
    plt.tight_layout(); plt.savefig(fname, dpi=150); plt.close()

def run_experiment():
    mne.set_log_level('WARNING')
    subs = subjects_available()
    print(f"Sujetos elegibles: {len(subs)} → {subs[:10]}{'...' if len(subs)>10 else ''}")

    trials_bal = build_balanced_trials(subs, scenario=CLASS_SCENARIO, window_global=WINDOW_GLOBAL,
                                       n_per_class=N_PER_CLASS_PER_SUBJECT)
    subj_in_trials = sorted({t[2] for t in trials_bal})
    print(f"Sujetos con ensayos balanceados: {len(subj_in_trials)}")

    if USE_EA:
        print("🔧 EA: calculando whiteners por sujeto...")
        Wsubj = compute_ea_whiteners_per_subject(trials_bal)
        trials_bal = apply_ea_to_trials(trials_bal, Wsubj)
        print(f"EA aplicado a {len(Wsubj)} sujetos.")

    folds = read_folds_subjects_only(FOLDS_JSON, current_subjects=subj_in_trials)

    global_folds_trial, all_true_trial, all_pred_trial = [], [], []

    for f in folds:
        fold = f['fold']
        train_subs = set(f['train_subjects'])
        test_subs  = set(f['test_subjects'])

        tr_trials = [t for t in trials_bal if t[2] in train_subs]
        te_trials = [t for t in trials_bal if t[2] in test_subs]

        train_subs_list = sorted(list(train_subs))
        rng = np.random.RandomState(RANDOM_STATE + fold)
        rng.shuffle(train_subs_list)
        n_va = max(1, int(round(GLOBAL_VAL_SPLIT_SUBJ * len(train_subs_list))))
        val_subs = set(train_subs_list[:n_va])
        real_train_subs = set(train_subs_list[n_va:])

        tr_trials_main = [t for t in tr_trials if t[2] in real_train_subs]
        va_trials      = [t for t in tr_trials if t[2] in val_subs]

        Xtr, ytr, gtr, ttr, _ = window_trials(trials=tr_trials_main, win_len_sec=WIN_LEN_SEC, win_stride_sec=WIN_STRIDE_SEC, fs=FS, n_subwins=N_SUBWINS)
        Xva, yva, gva, tva, _ = window_trials(trials=va_trials,      win_len_sec=WIN_LEN_SEC, win_stride_sec=WIN_STRIDE_SEC, fs=FS, n_subwins=N_SUBWINS)
        Xte, yte, gte, tte, _ = window_trials(trials=te_trials,      win_len_sec=WIN_LEN_SEC, win_stride_sec=WIN_STRIDE_SEC, fs=FS, n_subwins=N_SUBWINS)

        ds_tr = EEGTrials(Xtr, ytr, gtr, ttr)
        ds_va = EEGTrials(Xva, yva, gva, tva)
        ds_te = EEGTrials(Xte, yte, gte, tte)

        tr_loader = DataLoader(ds_tr, batch_size=BATCH_SIZE,
                               sampler=build_weighted_sampler(ytr, gtr), drop_last=False)
        va_loader = DataLoader(ds_va, batch_size=BATCH_SIZE, shuffle=False, drop_last=False)
        te_loader = DataLoader(ds_te, batch_size=BATCH_SIZE, shuffle=False, drop_last=False)

        Tw, C = Xtr.shape[1], Xtr.shape[2]
        n_classes = 4
        model = HybridCNNTransformer(
            n_ch=C, n_classes=n_classes, seq_len=Tw,
            d_model=128, nhead=4, depth=3, dim_ff=256,
            k_temporal=64, stride_t=2, drop=0.2, chdrop_p=0.10, use_cls_token=True
        ).to(DEVICE)

        opt = torch.optim.Adam(model.parameters(), lr=LR_INIT)

        criterion = ClassBalancedFocal(n_classes=n_classes, beta=0.999, gamma=2.0, label_smoothing=0.05)
        class_counts = np.bincount(ytr, minlength=n_classes)
        criterion.update_weights(class_counts)

        if USE_SWA:
            swa_start_epoch = max(5, int(SWA_START_FRAC * EPOCHS_GLOBAL))
            swa_model = AveragedModel(model)
            swa_scheduler = SWALR(opt, anneal_strategy="cos", anneal_epochs=5, swa_lr=LR_INIT * SWA_LR_FACTOR)

        def _acc_trial(loader):
            y_true_tr, y_pred_tr, acc_tr = evaluate_per_trial(model, loader, use_tta=False, n_classes=n_classes, per_class_tau=PER_CLASS_TAU)
            return acc_tr

        history = {'train_acc': [], 'val_acc': []}
        print(f"\n[Fold {fold}/{len(folds)}] Entrenando global (VAL por ensayo)...")
        best_state = copy.deepcopy(model.state_dict()); best_val = -1.0; bad = 0

        for epoch in range(1, EPOCHS_GLOBAL+1):
            model.train()
            for xb, yb, gb, tb, sb in tr_loader:
                xb, yb, sb = xb.to(DEVICE), yb.to(DEVICE), sb.to(DEVICE)

                tok = model.tokenize(xb)
                tok, _ = feature_mixup(tok, alpha=0.2)

                logits = model.forward_from_tokens(tok)
                loss = criterion(logits, yb)

                if MMD_LAMBDA > 0.0:
                    mmd2 = mmd_between_subject_groups(model, xb, sb, max_samples=MMD_MAX_SAMPLES_PER_GROUP)
                    if mmd2 is not None:
                        loss = loss + MMD_LAMBDA * mmd2

                opt.zero_grad(set_to_none=True); loss.backward(); opt.step()
                apply_max_norm(model, 2.0, p=2.0)

            if USE_SWA and epoch >= swa_start_epoch:
                swa_model.update_parameters(model)
                swa_scheduler.step()

            tr_acc = _acc_trial(tr_loader)
            va_acc = _acc_trial(va_loader)
            history['train_acc'].append(tr_acc); history['val_acc'].append(va_acc)

            if epoch % 5 == 0 or epoch in (1,10,20,50,80):
                print(f"  Época {epoch:3d} | train(ensayo)={tr_acc:.4f} | val(ensayo)={va_acc:.4f}")

            if va_acc > best_val + 1e-4:
                best_val, bad = va_acc, 0
                best_state = copy.deepcopy(model.state_dict())
            else:
                bad += 1
                if bad >= GLOBAL_PATIENCE:
                    print(f"  Early stopping @ epoch {epoch} (best val ensayo={best_val:.4f})")
                    break

        if USE_SWA:
            torch.optim.swa_utils.update_bn(tr_loader, swa_model, device=DEVICE)
            model.load_state_dict(swa_model.module.state_dict())
        else:
            model.load_state_dict(best_state)

        plot_training_curves(history, f"training_curve_cnnTrans_ea_mmd_fold{fold}.png")
        print(f"↳ Curva: training_curve_cnnTrans_ea_mmd_fold{fold}.png")

        if TTA_BN_ON_TEST:
            print("🔧 TTA-BN: adaptando BN por sujeto en TEST...")
            subj2idx = {}
            for i, sid in enumerate(gte):
                subj2idx.setdefault(int(sid), []).append(i)

            sums_all, counts_all, labels_all = {}, {}, {}
            model.eval()

            for sid, idxs in subj2idx.items():
                Xs = Xte[idxs]
                bn_adapt_subject_smooth(model, Xs, batch=256, momentum=0.2)

                for i0 in range(0, len(idxs), BATCH_SIZE):
                    sel = idxs[i0:i0+BATCH_SIZE]
                    xb = torch.from_numpy(Xte[sel]).float().unsqueeze(1).to(DEVICE)

                    outs = []
                    with torch.no_grad():
                        for _ in range(5):
                            def time_jitter(x, max_ms=25, fs=160.0):
                                max_shift = int(round(max_ms/1000.0 * fs))
                                if max_shift <= 0: return x
                                B,_,T,C = x.shape
                                shifts = torch.randint(low=-max_shift, high=max_shift+1, size=(B,), device=x.device)
                                out = torch.zeros_like(x)
                                for ii,s in enumerate(shifts):
                                    if s >= 0: out[ii,:,s:,:] = x[ii,:,:T-s,:]
                                    else: s = -s; out[ii,:,:T-s,:] = x[ii,:,s:,:]
                                return out
                            outs.append(model(time_jitter(xb,25)))
                    logits = torch.stack(outs, 0).mean(0).cpu().numpy()
                    logits = logits / PER_CLASS_TAU.reshape(1, -1)

                    for k, j in enumerate(sel):
                        trial_id = int(tte[j])
                        ytrue    = int(yte[j])
                        if trial_id not in sums_all:
                            sums_all[trial_id] = np.zeros((n_classes,), dtype=np.float64)
                            counts_all[trial_id] = 0
                            labels_all[trial_id] = ytrue
                        sums_all[trial_id] += logits[k]
                        counts_all[trial_id] += 1

            ids = sorted(sums_all.keys())
            y_true_tr = np.asarray([labels_all[t] for t in ids], dtype=int)
            y_pred_tr = np.asarray([int((sums_all[t]/max(1,counts_all[t])).argmax()) for t in ids], dtype=int)
            acc_tr = (y_true_tr == y_pred_tr).mean()
        else:
            y_true_tr, y_pred_tr, acc_tr = evaluate_per_trial(model, te_loader, use_tta=True, n_classes=n_classes, per_class_tau=PER_CLASS_TAU)

        print(f"[Fold {fold}] Ensayo acc={acc_tr:.4f}")
        print("Distribución real por clase en test (por ensayo):", np.bincount(y_true_tr, minlength=4))
        print_report(y_true_tr, y_pred_tr, CLASS_NAMES, tag="por ENSAYO (agregado)")
        global_folds_trial.append(acc_tr)
        all_true_trial.append(y_true_tr); all_pred_trial.append(y_pred_tr)

    if len(all_true_trial)>0: all_true_trial = np.concatenate(all_true_trial); all_pred_trial = np.concatenate(all_pred_trial)
    else: all_true_trial = np.array([],dtype=int); all_pred_trial = np.array([],dtype=int)

    print("\n" + "="*60)
    print(f"RESULTADOS FINALES — CNN+Transformer + Cosine head + EA={USE_EA} + MMD={MMD_LAMBDA} + TTA-BN={TTA_BN_ON_TEST}")
    print("="*60)
    print("Ensayo folds:", [f"{a:.4f}" for a in global_folds_trial])
    if len(global_folds_trial) > 0:
        print(f"Ensayo mean: {np.mean(global_folds_trial):.4f}")

    if all_true_trial.size > 0:
        plot_confusion(all_true_trial, all_pred_trial, CLASS_NAMES,
                       title="Confusion Matrix - Global Model (Per Trial, All Folds)",
                       fname="confusion_cnnTrans_ea_mmd_ttabn_trial_allfolds.png")
        print("↳ Matriz de confusión (ensayo): confusion_cnnTrans_ea_mmd_ttabn_trial_allfolds.png")

# ---------- MAIN ----------
if __name__ == "__main__":
    bp_status = f"ON [{BP_LO}-{BP_HI} Hz]" if USE_BANDPASS and (BP_LO is not None) and (BP_HI is not None) else "OFF"
    print("🧠 MI-EEG — CNN+Transformer (21/cls/subj) + Sliding Window + Cosine head")
    print(f"🔧 Escenario: {CLASS_SCENARIO} | rango ensayo={WINDOW_GLOBAL} s | subventana={WIN_LEN_SEC}s | stride={WIN_STRIDE_SEC}s | subwins/ensayo={N_SUBWINS} | band-pass={bp_status}")
    print(f"⚙️  EA: {USE_EA} | MMD_LAMBDA={MMD_LAMBDA} | TTA-BN test: {TTA_BN_ON_TEST}")
    run_experiment()


🚀 Dispositivo: cuda
🧠 MI-EEG — CNN+Transformer (21/cls/subj) + Sliding Window + Cosine head
🔧 Escenario: 4c | rango ensayo=(0.0, 4.0) s | subventana=2.5s | stride=0.5s | subwins/ensayo=4 | band-pass=OFF
⚙️  EA: True | MMD_LAMBDA=0.05 | TTA-BN test: True
Sujetos elegibles: 103 → [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]...


Recolectando y balanceando ensayos base por sujeto:   0%|          | 0/103 [00:00<?, ?it/s]

Recolectando y balanceando ensayos base por sujeto: 100%|██████████| 103/103 [00:37<00:00,  2.75it/s]


Ensayos base balanceados: 8652  (~ #sujetos_utiles * 84)
Sujetos con ensayos balanceados: 103
🔧 EA: calculando whiteners por sujeto...
EA aplicado a 103 sujetos.





[Fold 1/5] Entrenando global (VAL por ensayo)...
  Época   1 | train(ensayo)=0.3594 | val(ensayo)=0.3919
  Época   5 | train(ensayo)=0.4083 | val(ensayo)=0.3899
  Época  10 | train(ensayo)=0.4701 | val(ensayo)=0.5020
  Época  15 | train(ensayo)=0.4811 | val(ensayo)=0.4474
  Época  20 | train(ensayo)=0.4964 | val(ensayo)=0.4345
  Época  25 | train(ensayo)=0.5145 | val(ensayo)=0.4544
  Early stopping @ epoch 28 (best val ensayo=0.5119)
↳ Curva: training_curve_cnnTrans_ea_mmd_fold1.png
🔧 TTA-BN: adaptando BN por sujeto en TEST...
[Fold 1] Ensayo acc=0.2500
Distribución real por clase en test (por ensayo): [441 441 441 441]

=== Reporte por ENSAYO (agregado) ===
              precision    recall  f1-score   support

        Left     0.2500    1.0000    0.4000       441
       Right     0.0000    0.0000    0.0000       441
  Both Fists     0.0000    0.0000    0.0000       441
   Both Feet     0.0000    0.0000    0.0000       441

    accuracy                         0.2500      1764
   mac

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



[Fold 2/5] Entrenando global (VAL por ensayo)...
  Época   1 | train(ensayo)=0.3441 | val(ensayo)=0.3552
  Época   5 | train(ensayo)=0.3860 | val(ensayo)=0.3889
  Época  10 | train(ensayo)=0.4271 | val(ensayo)=0.4187
  Época  15 | train(ensayo)=0.4353 | val(ensayo)=0.4286
  Época  20 | train(ensayo)=0.4554 | val(ensayo)=0.4127
  Época  25 | train(ensayo)=0.4230 | val(ensayo)=0.3810
  Época  30 | train(ensayo)=0.4722 | val(ensayo)=0.3968
  Early stopping @ epoch 31 (best val ensayo=0.4554)
↳ Curva: training_curve_cnnTrans_ea_mmd_fold2.png
🔧 TTA-BN: adaptando BN por sujeto en TEST...
[Fold 2] Ensayo acc=0.2426
Distribución real por clase en test (por ensayo): [441 441 441 441]

=== Reporte por ENSAYO (agregado) ===
              precision    recall  f1-score   support

        Left     0.0000    0.0000    0.0000       441
       Right     0.1993    0.1224    0.1517       441
  Both Fists     0.0000    0.0000    0.0000       441
   Both Feet     0.2505    0.8481    0.3868       441

    

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



[Fold 3/5] Entrenando global (VAL por ensayo)...


KeyboardInterrupt: 