In [4]:
# -*- coding: utf-8 -*-
# Transformer-only para MI (4 clases, 8 canales) - VERSIÓN FINAL CON WINDOW SLIDING

import os
os.environ['PYTHONHASHSEED'] = '42'
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'

import re, json, random
from pathlib import Path
from glob import glob
import numpy as np
import mne
from tqdm import tqdm
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, classification_report

# =========================
# REPRODUCIBILIDAD
# =========================
RANDOM_STATE = 42
def seed_everything(seed: int = 42):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    try:
        torch.backends.cuda.matmul.allow_tf32 = False
        torch.backends.cudnn.allow_tf32 = False
    except Exception:
        pass
    try:
        torch.use_deterministic_algorithms(True)
    except Exception:
        pass

def seed_worker(worker_id: int):
    worker_seed = RANDOM_STATE + worker_id
    np.random.seed(worker_seed)
    random.seed(worker_seed)

seed_everything(RANDOM_STATE)

# =========================
# CONFIG - CON WINDOW SLIDING ACTIVADO
# =========================
PROJ = Path('..').resolve().parent
DATA_RAW = PROJ / 'data' / 'raw'
FOLDS_JSON = PROJ / 'models' / 'folds' / 'Kfold5.json'

EPOCHS = 60
BATCH_SIZE = 64
BASE_LR = 1e-3
WARMUP_EPOCHS = 8
PATIENCE = 15

# Split de validación por fold (por sujetos)
VAL_SUBJECT_FRAC = 0.18

# Prepro
RESAMPLE_HZ = None
DO_NOTCH = True
DO_BANDPASS = False
BP_LO, BP_HI = 4.0, 38.0
DO_CAR = False
ZSCORE_PER_EPOCH = False

# Modelo (transformer-only)
D_MODEL = 128
N_HEADS = 4
N_LAYERS = 4
P_DROP = 0.2

# Ventana temporal (epoch)
TMIN, TMAX = -1.0, 5.0

# TTA (sólo en TEST) - ACTIVADO CON CONFIGURACIÓN MEJORADA
SW_MODE = 'tta'   # 'none'|'subwin'|'tta'
SW_ENABLE = True     # Activado
TTA_SHIFTS_S = [0.0]  # Solo shifts pequeños para empezar
SW_LEN, SW_STRIDE = 2.5, 0.3  # Configuración más conservadora

EXCLUDE_SUBJECTS = {38, 88, 89, 92, 100, 104}
EXPECTED_8 = ['C3','C4','Cz','CP3','CP4','FC3','FC4','FCz']
CLASS_NAMES = ['left', 'right', 'both_fists', 'both_feet']

IMAGERY_RUNS_LR = {4, 8, 12}
IMAGERY_RUNS_BF = {6, 10, 14}

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🚀 Usando dispositivo: {DEVICE}")

# =========================
# UTILIDADES I/O (igual que antes)
# =========================
def normalize_ch_name(name: str) -> str:
    s = re.sub(r'[^A-Za-z0-9]', '', name)
    return s.upper()

NORMALIZED_TARGETS = [normalize_ch_name(c) for c in EXPECTED_8]

def pick_8_channels(raw: mne.io.BaseRaw) -> mne.io.BaseRaw:
    chs = raw.info['ch_names']
    norm_map = {normalize_ch_name(ch): ch for ch in chs}
    picked = []
    for target_norm, target_orig in zip(NORMALIZED_TARGETS, EXPECTED_8):
        if target_norm in norm_map:
            picked.append(norm_map[target_norm])
        else:
            raise RuntimeError(f"Canal requerido '{target_orig}' no encontrado. Disponibles: {chs}")
    return raw.pick(picks=picked)

def list_subject_imagery_edfs(subject_id: str) -> list:
    subj_dir = DATA_RAW / subject_id
    edfs = []
    for r in [4, 6, 8, 10, 12, 14]:
        edfs.extend(glob(str(subj_dir / f"{subject_id}R{r:02d}.edf")))
    return sorted(edfs)

def subject_id_to_int(s: str) -> int:
    m = re.match(r'[Ss](\d+)', s)
    return int(m.group(1)) if m else -1

def load_subject_epochs(subject_id: str, resample_hz: int, do_notch: bool, do_bandpass: bool,
                        do_car: bool, bp_lo: float, bp_hi: float):
    edfs = list_subject_imagery_edfs(subject_id)
    if len(edfs) == 0:
        return np.empty((0,8,1), dtype=np.float32), np.empty((0,), dtype=int), None

    X_list, y_list, sfreq_list = [], [], []

    for edf_path in edfs:
        m = re.search(r"R(\d{2})", Path(edf_path).name)
        run = int(m.group(1)) if m else -1

        raw = mne.io.read_raw_edf(edf_path, preload=True, verbose='ERROR')
        raw = pick_8_channels(raw)

        if do_notch:
            raw.notch_filter(freqs=[60.0], picks='all', verbose='ERROR')
        if do_bandpass:
            raw.filter(l_freq=bp_lo, h_freq=bp_hi, picks='all', verbose='ERROR')
        if do_car:
            raw.set_eeg_reference('average', projection=False, verbose='ERROR')

        if resample_hz is not None and resample_hz > 0:
            raw.resample(resample_hz)
        sfreq = raw.info['sfreq']

        events, event_id = mne.events_from_annotations(raw, verbose='ERROR')
        keep = {k: v for k, v in event_id.items() if k in {'T1', 'T2'}}
        if len(keep) == 0:
            continue

        epochs = mne.Epochs(raw, events=events, event_id=keep, tmin=TMIN, tmax=TMAX,
                            baseline=None, preload=True, verbose='ERROR')
        X = epochs.get_data()

        if ZSCORE_PER_EPOCH:
            X = X.astype(np.float32)
            eps = 1e-6
            mu = X.mean(axis=2, keepdims=True)
            sd = X.std(axis=2, keepdims=True) + eps
            X = (X - mu) / sd

        ev_codes = epochs.events[:, 2]
        inv = {v: k for k, v in keep.items()}
        y_run = []
        for code in ev_codes:
            lab = inv[code]
            if run in IMAGERY_RUNS_LR:
                y_run.append(0 if lab == 'T1' else 1)
            elif run in IMAGERY_RUNS_BF:
                y_run.append(2 if lab == 'T1' else 3)
            else:
                y_run.append(-1)
        y_run = np.array(y_run, dtype=int)
        mask = y_run >= 0
        if mask.sum() == 0:
            continue

        X_list.append(X[mask])
        y_list.append(y_run[mask])
        sfreq_list.append(sfreq)

    if len(X_list) == 0:
        return np.empty((0,8,1), dtype=np.float32), np.empty((0,), dtype=int), None

    X_all = np.concatenate(X_list, axis=0).astype(np.float32)
    y_all = np.concatenate(y_list, axis=0).astype(int)

    if len(set([int(round(s)) for s in sfreq_list])) != 1:
        raise RuntimeError(f"Sampling rates inconsistentes: {sfreq_list}")

    return X_all, y_all, sfreq_list[0]

def load_fold_subjects(folds_json: Path, fold: int):
    with open(folds_json, 'r') as f:
        data = json.load(f)
    for item in data.get('folds', []):
        if int(item.get('fold', -1)) == int(fold):
            return list(item.get('train', [])), list(item.get('test', []))
    raise ValueError(f"Fold {fold} not found in {folds_json}")

def standardize_per_channel_consistent(train_X, other_X):
    """Versión mejorada que asegura consistencia entre train y test"""
    C = train_X.shape[1]
    train_X = train_X.astype(np.float32)
    other_X = other_X.astype(np.float32)
    
    # Calcular estadísticas globales del train
    channel_means = train_X.mean(axis=(0, 2), keepdims=True)
    channel_stds = train_X.std(axis=(0, 2), keepdims=True) + 1e-6
    
    # Aplicar a todos los conjuntos
    train_X_std = (train_X - channel_means) / channel_stds
    other_X_std = (other_X - channel_means) / channel_stds
    
    return train_X_std, other_X_std

# =========================
# MODELO (igual que antes)
# =========================
class PatchEmbedding(nn.Module):
    def __init__(self, n_ch, patch_size_samples, patch_stride_samples, d_model):
        super().__init__()
        assert patch_size_samples >= 1 and patch_stride_samples >= 1
        self.n_ch = n_ch
        self.p_size = int(patch_size_samples)
        self.p_stride = int(patch_stride_samples)
        self.proj = nn.Linear(n_ch * self.p_size, d_model, bias=False)

    def forward(self, x):
        B, C, T = x.shape
        if T < self.p_size:
            pad = self.p_size - T
            x = F.pad(x, (0, pad), mode='replicate')
            T = x.shape[-1]
        starts = list(range(0, T - self.p_size + 1, self.p_stride))
        patches = []
        for s in starts:
            seg = x[:, :, s:s + self.p_size]
            seg = seg.reshape(B, C * self.p_size)
            patches.append(seg)
        if len(patches) == 0:
            seg = x[:, :, :self.p_size].reshape(B, C * self.p_size)
            patches = [seg]
        patches = torch.stack(patches, dim=1)
        out = self.proj(patches)
        return out

class EEGTransformer(nn.Module):
    def __init__(self, n_ch=8, n_cls=4, d_model=128, n_heads=4, n_layers=4, p_drop=0.2,
                 patch_size_samples=16, patch_stride_samples=None):
        super().__init__()
        if patch_stride_samples is None:
            patch_stride_samples = patch_size_samples
        self.embed = PatchEmbedding(n_ch, patch_size_samples, patch_stride_samples, d_model)
        self.dropout = nn.Dropout(p_drop)
        enc = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=n_heads, dim_feedforward=4*d_model,
            batch_first=True, activation='gelu', dropout=p_drop, norm_first=True
        )
        self.encoder = nn.TransformerEncoder(enc, num_layers=n_layers)
        self.cls = nn.Parameter(torch.zeros(1, 1, d_model))
        nn.init.normal_(self.cls, std=0.02)
        self.pos_encoding = None
        self.head = nn.Sequential(nn.LayerNorm(d_model), nn.Linear(d_model, n_cls))

    def _positional_encoding(self, L, d):
        pos = torch.arange(0, L, dtype=torch.float32).unsqueeze(1)
        i   = torch.arange(0, d, dtype=torch.float32).unsqueeze(0)
        angle = pos / torch.pow(10000, (2 * (i//2)) / d)
        pe = torch.zeros(L, d, dtype=torch.float32)
        pe[:, 0::2] = torch.sin(angle[:, 0::2])
        pe[:, 1::2] = torch.cos(angle[:, 1::2])
        return pe

    def forward(self, x):
        z = self.embed(x)
        B, L, D = z.shape
        if (self.pos_encoding is None) or (self.pos_encoding.shape[0] != L) or (self.pos_encoding.shape[1] != D):
            self.pos_encoding = self._positional_encoding(L, D).to(z.device)
        z = z + self.pos_encoding[None, :, :]
        cls_tok = self.cls.expand(B, -1, -1)
        z = torch.cat([cls_tok, z], dim=1)
        z = self.encoder(z)
        cls = z[:, 0, :]
        out = self.head(cls)
        return out

# =========================
# FOCAL LOSS CORREGIDO
# =========================
class FocalLoss(nn.Module):
    def __init__(self, alpha: torch.Tensor = None, gamma: float = 1.5, reduction: str = 'mean'):
        super().__init__()
        if alpha is not None:
            self.alpha = alpha / alpha.sum()
        else:
            self.alpha = None
        self.gamma = gamma
        self.reduction = reduction
        
    def forward(self, logits, target):
        logp = nn.functional.log_softmax(logits, dim=-1)
        p = logp.exp()
        idx = torch.arange(target.shape[0], device=logits.device)
        pt = p[idx, target]
        logpt = logp[idx, target]
        
        if self.alpha is not None:
            at = self.alpha[target]
        else:
            at = 1.0
            
        loss = - at * ((1 - pt) ** self.gamma) * logpt
        
        if self.reduction == 'mean': 
            return loss.mean()
        elif self.reduction == 'sum':  
            return loss.sum()
        return loss

# =========================
# AUGMENTS (igual que antes)
# =========================
def augment_batch(
    xb,
    p_jitter=0.35, p_noise=0.35, p_chdrop=0.15,
    max_jitter_frac=0.03, noise_std=0.03, max_chdrop=1
):
    B, C, T = xb.shape
    if np.random.rand() < p_jitter:
        max_shift = int(max(1, T*max_jitter_frac))
        shifts = torch.randint(low=-max_shift, high=max_shift+1, size=(B,), device=xb.device)
        for i in range(B):
            xb[i] = torch.roll(xb[i], shifts=int(shifts[i].item()), dims=-1)
    if np.random.rand() < p_noise:
        xb = xb + noise_std*torch.randn_like(xb)
    if np.random.rand() < p_chdrop and max_chdrop > 0:
        k = min(max_chdrop, C)
        for i in range(B):
            idx = torch.randperm(C, device=xb.device)[:k]
            xb[i, idx, :] = 0.0
    return xb

# =========================
# INFERENCIA CON WINDOW SLIDING CORREGIDA
# =========================
def subwindow_logits(model, X, sfreq, sw_len, sw_stride, device):
    """Window sliding con preprocesamiento consistente"""
    model.eval()
    wl = int(round(sw_len * sfreq))
    st = int(round(sw_stride * sfreq))
    wl = max(1, min(wl, X.shape[-1]))
    st = max(1, st)
    
    print(f"🔧 Window Sliding: len={wl}samples ({sw_len}s), stride={st}samples ({sw_stride}s), total_windows={max(1, (X.shape[-1]-wl+st)//st)}")
    
    out = []
    with torch.no_grad():
        for i in tqdm(range(X.shape[0]), desc="Processing windows"):
            x = X[i]
            acc = []
            
            # Calcular estadísticas de normalización para esta muestra
            if not ZSCORE_PER_EPOCH:
                x_mean = x.mean(axis=1, keepdims=True)
                x_std = x.std(axis=1, keepdims=True) + 1e-6
            
            for s in range(0, max(1, X.shape[-1]-wl+1), st):
                seg = x[:, s:s+wl]
                
                # Aplicar padding si es necesario
                if seg.shape[-1] < wl:
                    pad = wl - seg.shape[-1]
                    seg = np.pad(seg, ((0,0),(0,pad)), mode='constant', constant_values=0)
                
                # ✅ NORMALIZACIÓN CONSISTENTE para cada ventana
                if not ZSCORE_PER_EPOCH:
                    seg = (seg - x_mean) / x_std
                
                xb = torch.tensor(seg[None, ...], dtype=torch.float32, device=device)
                logit = model(xb).detach().cpu().numpy()[0]
                acc.append(logit)
            
            # Promedio de logits de todas las ventanas
            if len(acc) > 0:
                avg_logits = np.mean(np.stack(acc, axis=0), axis=0)
            else:
                avg_logits = np.zeros(4, dtype=np.float32)
            out.append(avg_logits)
    
    return np.stack(out, axis=0)

def time_shift_tta_logits(model, X, sfreq, shifts_s, device):
    """TTA con preprocesamiento consistente"""
    model.eval()
    T = X.shape[-1]
    out = []
    with torch.no_grad():
        for i in tqdm(range(X.shape[0]), desc="Processing TTA"):
            x0 = X[i]
            acc = []
            
            # Calcular estadísticas de normalización para esta muestra
            if not ZSCORE_PER_EPOCH:
                x0_mean = x0.mean(axis=1, keepdims=True)
                x0_std = x0.std(axis=1, keepdims=True) + 1e-6
            
            for sh in shifts_s:
                shift = int(round(sh * sfreq))
                if shift == 0:
                    x = x0
                elif shift > 0:
                    x = np.pad(x0[:, shift:], ((0,0),(0,shift)), mode='constant', constant_values=0)[:, :T]
                else:
                    shift = -shift
                    x = np.pad(x0[:, :-shift], ((0,0),(shift,0)), mode='constant', constant_values=0)[:, :T]
                
                # ✅ NORMALIZACIÓN CONSISTENTE
                if not ZSCORE_PER_EPOCH:
                    x = (x - x0_mean) / x0_std
                
                xb = torch.tensor(x[None, ...], dtype=torch.float32, device=device)
                logit = model(xb).detach().cpu().numpy()[0]
                acc.append(logit)
            
            out.append(np.mean(np.stack(acc, axis=0), axis=0))
    return np.stack(out, axis=0)

def standard_inference(model, X, device):
    """Inference estándar sin complicaciones"""
    model.eval()
    X_tensor = torch.tensor(X, dtype=torch.float32, device=device)
    with torch.no_grad():
        logits = model(X_tensor).cpu().numpy()
    return logits

def unified_inference(model, X, sfreq, device, mode='standard'):
    """Función unificada para inference con diferentes modos"""
    model.eval()
    
    if mode == 'standard':
        return standard_inference(model, X, device)
    
    elif mode == 'subwindow':
        return subwindow_logits(model, X, sfreq, SW_LEN, SW_STRIDE, device)
    
    elif mode == 'tta':
        return time_shift_tta_logits(model, X, sfreq, TTA_SHIFTS_S, device)
    
    else:
        raise ValueError(f"Modo desconocido: {mode}")

# =========================
# TRAIN/EVAL CON WINDOW SLIDING
# =========================
def train_one_fold(fold:int, device):
    train_sub, test_sub = load_fold_subjects(FOLDS_JSON, fold)
    train_sub = [s for s in train_sub if subject_id_to_int(s) not in EXCLUDE_SUBJECTS]
    test_sub  = [s for s in test_sub  if subject_id_to_int(s) not in EXCLUDE_SUBJECTS]

    # Split de validación
    rng = random.Random(RANDOM_STATE + fold)
    tr_subjects = train_sub.copy()
    rng.shuffle(tr_subjects)
    n_val_subj = max(1, int(round(len(tr_subjects) * VAL_SUBJECT_FRAC)))
    val_subjects = sorted(tr_subjects[:n_val_subj])
    train_subjects = sorted(tr_subjects[n_val_subj:])

    # Carga de datos
    X_tr_list, y_tr_list, sub_tr_list = [], [], []
    X_val_list, y_val_list, sub_val_list = [], [], []
    X_te_list, y_te_list, sub_te_list = [], [], []
    sfreq = None

    for sid in tqdm(train_subjects, desc=f"Cargando train fold{fold}"):
        Xs, ys, sf = load_subject_epochs(sid, RESAMPLE_HZ, DO_NOTCH, DO_BANDPASS, DO_CAR, BP_LO, BP_HI)
        if len(ys) == 0: continue
        X_tr_list.append(Xs); y_tr_list.append(ys)
        sub_tr_list.append(np.full_like(ys, fill_value=subject_id_to_int(sid)))
        sfreq = sf if sfreq is None else sfreq

    for sid in tqdm(val_subjects, desc=f"Cargando val fold{fold}"):
        Xs, ys, sf = load_subject_epochs(sid, RESAMPLE_HZ, DO_NOTCH, DO_BANDPASS, DO_CAR, BP_LO, BP_HI)
        if len(ys) == 0: continue
        X_val_list.append(Xs); y_val_list.append(ys)
        sub_val_list.append(np.full_like(ys, fill_value=subject_id_to_int(sid)))
        sfreq = sf if sfreq is None else sfreq

    for sid in tqdm(test_sub, desc=f"Cargando test fold{fold}"):
        Xs, ys, sf = load_subject_epochs(sid, RESAMPLE_HZ, DO_NOTCH, DO_BANDPASS, DO_CAR, BP_LO, BP_HI)
        if len(ys) == 0: continue
        X_te_list.append(Xs); y_te_list.append(ys)
        sub_te_list.append(np.full_like(ys, fill_value=subject_id_to_int(sid)))
        sfreq = sf if sfreq is None else sfreq

    # Concatenar
    X_tr = np.concatenate(X_tr_list, axis=0); y_tr = np.concatenate(y_tr_list, axis=0)
    X_val = np.concatenate(X_val_list, axis=0); y_val = np.concatenate(y_val_list, axis=0)
    X_te = np.concatenate(X_te_list, axis=0); y_te = np.concatenate(y_te_list, axis=0)

    print(f"[Fold {fold}/5] n_train={len(y_tr)} | n_val={len(y_val)} | n_test={len(y_te)}")
    print(f"📊 Distribución de clases - Train: {np.bincount(y_tr)}")

    if sfreq is None:
        raise RuntimeError("No se pudo inferir sfreq.")

    # Normalización
    X_tr_std, X_val_std = standardize_per_channel_consistent(X_tr, X_val)
    _, X_te_std = standardize_per_channel_consistent(X_tr, X_te)

    # Datasets
    tr_ds  = TensorDataset(torch.tensor(X_tr_std), torch.tensor(y_tr).long())
    val_ds = TensorDataset(torch.tensor(X_val_std), torch.tensor(y_val).long())
    te_ds  = TensorDataset(torch.tensor(X_te_std), torch.tensor(y_te).long())

    tr_ld  = DataLoader(tr_ds, batch_size=BATCH_SIZE, shuffle=True, drop_last=False, worker_init_fn=seed_worker)
    val_ld = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, worker_init_fn=seed_worker)
    te_ld  = DataLoader(te_ds, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, worker_init_fn=seed_worker)

    # Configuración de patches
    patch_duration_s = 0.100
    patch_stride_fraction = 0.5

    patch_size_samples = max(1, int(round(patch_duration_s * sfreq)))
    patch_stride_samples = max(1, int(round(patch_size_samples * patch_stride_fraction)))

    print(f"📊 sfreq={sfreq:.1f} Hz -> patch_size={patch_size_samples} | stride={patch_stride_samples}")

    # Modelo
    model = EEGTransformer(n_ch=8, n_cls=4, d_model=D_MODEL, n_heads=N_HEADS,
                           n_layers=N_LAYERS, p_drop=P_DROP,
                           patch_size_samples=patch_size_samples,
                           patch_stride_samples=patch_stride_samples).to(device)

    # Optimizador
    opt = torch.optim.AdamW(model.parameters(), lr=BASE_LR, weight_decay=1e-2)

    # Focal Loss balanceado
    class_counts = np.bincount(y_tr, minlength=4).astype(np.float32)
    total = class_counts.sum()
    alpha = torch.tensor([total / (4.0 * count) for count in class_counts], 
                         dtype=torch.float32, device=device)
    alpha = alpha / alpha.sum()
    
    print(f"📊 Pesos Focal Loss: {alpha.cpu().numpy()}")
    crit = FocalLoss(alpha=alpha, gamma=1.0, reduction='mean')

    # Scheduler
    from torch.optim.lr_scheduler import LambdaLR
    total_epochs = EPOCHS
    warmup_epochs = max(1, int(WARMUP_EPOCHS))
    min_factor = 0.1
    
    def lr_lambda(current_epoch):
        if current_epoch < warmup_epochs:
            return (current_epoch + 1) / warmup_epochs
        progress = (current_epoch - warmup_epochs) / max(1, (total_epochs - warmup_epochs))
        progress = min(1.0, max(0.0, progress))
        return min_factor + 0.5 * (1.0 - min_factor) * (1.0 + np.cos(np.pi * progress))
    
    scheduler = LambdaLR(opt, lr_lambda=lr_lambda)

    # Entrenamiento
    best_f1, best_state, wait = 0.0, None, 0
    hist = {"ep": [], "tr_loss": [], "tr_acc": [], "val_acc": [], "val_f1m": [], "lr": []}

    for ep in range(1, EPOCHS+1):
        # Train
        model.train()
        tr_loss, n_seen, tr_correct = 0.0, 0, 0
        for xb, yb in tr_ld:
            xb = xb.to(device); yb = yb.to(device)
            xb = augment_batch(xb)
            opt.zero_grad()
            logits = model(xb)
            loss = crit(logits, yb)
            loss.backward()
            opt.step()
            tr_loss += loss.item() * len(yb)
            n_seen += len(yb)
            tr_correct += (logits.argmax(1) == yb).sum().item()
        
        tr_loss /= max(1, n_seen)
        tr_acc = tr_correct / max(1, n_seen)

        # Validation
        model.eval()
        preds, gts = [], []
        with torch.no_grad():
            for xb, yb in val_ld:
                xb = xb.to(device)
                p = model(xb).argmax(dim=1).cpu().numpy()
                preds.append(p); gts.append(yb.numpy())
        
        preds = np.concatenate(preds); gts = np.concatenate(gts)
        acc = accuracy_score(gts, preds)
        f1m = f1_score(gts, preds, average='macro')

        hist["ep"].append(ep)
        hist["tr_loss"].append(tr_loss)
        hist["tr_acc"].append(tr_acc)
        hist["val_acc"].append(acc)
        hist["val_f1m"].append(f1m)
        hist["lr"].append(scheduler.get_last_lr()[0])

        print(f"  Época {ep:3d} | train_loss={tr_loss:.4f} | train_acc={tr_acc:.4f} | val_acc={acc:.4f} | val_f1m={f1m:.4f}")

        improved = f1m > best_f1 + 1e-4
        if improved:
            best_f1 = f1m
            best_state = {k: v.detach().cpu() for k, v in model.state_dict().items()}
            wait = 0
        else:
            wait += 1

        scheduler.step()
        
        if wait >= PATIENCE:
            print(f"  Early stopping en época {ep} (mejor val_f1m={best_f1:.4f})")
            break

    # Cargar mejor modelo
    if best_state is not None:
        model.load_state_dict(best_state)

    # Evaluación en TEST CON WINDOW SLIDING
    print(f"🎯 Evaluando TEST con {SW_MODE}...")
    model.eval()
    
    # Usar inference unificada
    logits = unified_inference(model, X_te_std, sfreq, device, SW_MODE)
    preds = logits.argmax(axis=1)
    gts = y_te

    acc = accuracy_score(gts, preds)
    f1m = f1_score(gts, preds, average='macro')
    
    print(f"[Fold {fold}/5] TEST - acc={acc:.4f}, F1-macro={f1m:.4f}")
    print(classification_report(gts, preds, target_names=[c.replace('_',' ') for c in CLASS_NAMES], digits=4))
    print("Confusion matrix:")
    print(confusion_matrix(gts, preds, labels=[0,1,2,3]))

    return acc, f1m

# =========================
# LOOP PRINCIPAL
# =========================
if __name__ == "__main__":
    print("🚀 INICIANDO ENTRENAMIENTO CON WINDOW SLIDING ACTIVADO")
    print(f"🔧 Configuración Window Sliding: SW_LEN={SW_LEN}s, SW_STRIDE={SW_STRIDE}s")
    print(f"🎯 Modo: {SW_MODE}")
    
    acc_folds, f1_folds = [], []
    for fold in range(1, 6):
        print(f"\n{'='*60}")
        print(f"PROCESANDO FOLD {fold}/5")
        print(f"{'='*60}")
        acc, f1m = train_one_fold(fold, DEVICE)
        acc_folds.append(f"{acc:.4f}")
        f1_folds.append(f"{f1m:.4f}")

    acc_mean = float(np.mean([float(a) for a in acc_folds]))
    f1_mean  = float(np.mean([float(f) for f in f1_folds]))

    print("\n" + "="*60)
    print("🎉 RESULTADOS FINALES CON WINDOW SLIDING")
    print("="*60)
    print(f"Global folds (ACC): {acc_folds}")
    print(f"Global mean ACC: {acc_mean:.4f}")
    print(f"F1 folds (MACRO): {f1_folds}")
    print(f"F1 mean (MACRO): {f1_mean:.4f}")
    print(f"🔧 Configuración usada: SW_MODE={SW_MODE}, SW_LEN={SW_LEN}, SW_STRIDE={SW_STRIDE}")

🚀 Usando dispositivo: cuda
🚀 INICIANDO ENTRENAMIENTO CON WINDOW SLIDING ACTIVADO
🔧 Configuración Window Sliding: SW_LEN=2.5s, SW_STRIDE=0.3s
🎯 Modo: tta

PROCESANDO FOLD 1/5


Cargando train fold1:   0%|          | 0/67 [00:00<?, ?it/s]

Cargando train fold1: 100%|██████████| 67/67 [00:08<00:00,  7.96it/s]
Cargando val fold1: 100%|██████████| 15/15 [00:01<00:00,  7.72it/s]
Cargando test fold1: 100%|██████████| 21/21 [00:02<00:00,  8.13it/s]


[Fold 1/5] n_train=5628 | n_val=1260 | n_test=1764
📊 Distribución de clases - Train: [1407 1407 1407 1407]
📊 sfreq=160.0 Hz -> patch_size=16 | stride=8
📊 Pesos Focal Loss: [0.25 0.25 0.25 0.25]




  Época   1 | train_loss=0.2663 | train_acc=0.2562 | val_acc=0.2897 | val_f1m=0.2138
  Época   2 | train_loss=0.2554 | train_acc=0.2999 | val_acc=0.3595 | val_f1m=0.2781
  Época   3 | train_loss=0.2382 | train_acc=0.3849 | val_acc=0.4683 | val_f1m=0.4573
  Época   4 | train_loss=0.2207 | train_acc=0.4563 | val_acc=0.5270 | val_f1m=0.5256
  Época   5 | train_loss=0.2127 | train_acc=0.4796 | val_acc=0.4825 | val_f1m=0.4785
  Época   6 | train_loss=0.2063 | train_acc=0.5016 | val_acc=0.5405 | val_f1m=0.5365
  Época   7 | train_loss=0.2074 | train_acc=0.4915 | val_acc=0.5167 | val_f1m=0.5197
  Época   8 | train_loss=0.2001 | train_acc=0.5135 | val_acc=0.5071 | val_f1m=0.4783
  Época   9 | train_loss=0.1991 | train_acc=0.5176 | val_acc=0.5183 | val_f1m=0.5018
  Época  10 | train_loss=0.1997 | train_acc=0.5091 | val_acc=0.5262 | val_f1m=0.5219
  Época  11 | train_loss=0.1964 | train_acc=0.5229 | val_acc=0.5071 | val_f1m=0.5030
  Época  12 | train_loss=0.1945 | train_acc=0.5355 | val_acc=0.51

Processing TTA: 100%|██████████| 1764/1764 [00:06<00:00, 253.57it/s]


[Fold 1/5] TEST - acc=0.4796, F1-macro=0.4811
              precision    recall  f1-score   support

        left     0.5597    0.5850    0.5721       441
       right     0.5746    0.4104    0.4788       441
  both fists     0.3694    0.4490    0.4053       441
   both feet     0.4624    0.4739    0.4681       441

    accuracy                         0.4796      1764
   macro avg     0.4915    0.4796    0.4811      1764
weighted avg     0.4915    0.4796    0.4811      1764

Confusion matrix:
[[258  32  83  68]
 [ 48 181 129  83]
 [ 94  57 198  92]
 [ 61  45 126 209]]

PROCESANDO FOLD 2/5


Cargando train fold2: 100%|██████████| 67/67 [00:08<00:00,  8.18it/s]
Cargando val fold2: 100%|██████████| 15/15 [00:01<00:00,  8.40it/s]
Cargando test fold2: 100%|██████████| 21/21 [00:02<00:00,  8.37it/s]


[Fold 2/5] n_train=5628 | n_val=1260 | n_test=1764
📊 Distribución de clases - Train: [1407 1407 1407 1407]
📊 sfreq=160.0 Hz -> patch_size=16 | stride=8
📊 Pesos Focal Loss: [0.25 0.25 0.25 0.25]




  Época   1 | train_loss=0.2652 | train_acc=0.2596 | val_acc=0.2746 | val_f1m=0.1936
  Época   2 | train_loss=0.2556 | train_acc=0.3070 | val_acc=0.3857 | val_f1m=0.3809
  Época   3 | train_loss=0.2313 | train_acc=0.4128 | val_acc=0.4016 | val_f1m=0.3556
  Época   4 | train_loss=0.2241 | train_acc=0.4387 | val_acc=0.4754 | val_f1m=0.4741
  Época   5 | train_loss=0.2137 | train_acc=0.4753 | val_acc=0.4603 | val_f1m=0.4314
  Época   6 | train_loss=0.2130 | train_acc=0.4707 | val_acc=0.4452 | val_f1m=0.4335
  Época   7 | train_loss=0.2101 | train_acc=0.4801 | val_acc=0.4794 | val_f1m=0.4782
  Época   8 | train_loss=0.2083 | train_acc=0.4849 | val_acc=0.4571 | val_f1m=0.4328
  Época   9 | train_loss=0.2048 | train_acc=0.4925 | val_acc=0.4865 | val_f1m=0.4840
  Época  10 | train_loss=0.2011 | train_acc=0.5050 | val_acc=0.4905 | val_f1m=0.4897
  Época  11 | train_loss=0.1979 | train_acc=0.5210 | val_acc=0.5262 | val_f1m=0.5183
  Época  12 | train_loss=0.1947 | train_acc=0.5277 | val_acc=0.50

Processing TTA: 100%|██████████| 1764/1764 [00:07<00:00, 251.61it/s]


[Fold 2/5] TEST - acc=0.5204, F1-macro=0.5058
              precision    recall  f1-score   support

        left     0.5127    0.7778    0.6180       441
       right     0.5768    0.5873    0.5820       441
  both fists     0.4539    0.4127    0.4323       441
   both feet     0.5469    0.3039    0.3907       441

    accuracy                         0.5204      1764
   macro avg     0.5226    0.5204    0.5058      1764
weighted avg     0.5226    0.5204    0.5058      1764

Confusion matrix:
[[343  26  49  23]
 [ 61 259  76  45]
 [152  64 182  43]
 [113 100  94 134]]

PROCESANDO FOLD 3/5


Cargando train fold3: 100%|██████████| 67/67 [00:07<00:00,  8.38it/s]
Cargando val fold3: 100%|██████████| 15/15 [00:01<00:00,  8.39it/s]
Cargando test fold3: 100%|██████████| 21/21 [00:02<00:00,  7.97it/s]


[Fold 3/5] n_train=5628 | n_val=1260 | n_test=1764
📊 Distribución de clases - Train: [1407 1407 1407 1407]
📊 sfreq=160.0 Hz -> patch_size=16 | stride=8
📊 Pesos Focal Loss: [0.25 0.25 0.25 0.25]




  Época   1 | train_loss=0.2663 | train_acc=0.2496 | val_acc=0.2635 | val_f1m=0.1613
  Época   2 | train_loss=0.2494 | train_acc=0.3351 | val_acc=0.3524 | val_f1m=0.3445
  Época   3 | train_loss=0.2290 | train_acc=0.4275 | val_acc=0.4048 | val_f1m=0.3895
  Época   4 | train_loss=0.2179 | train_acc=0.4598 | val_acc=0.4111 | val_f1m=0.4035
  Época   5 | train_loss=0.2136 | train_acc=0.4783 | val_acc=0.4365 | val_f1m=0.4378
  Época   6 | train_loss=0.2064 | train_acc=0.4998 | val_acc=0.4381 | val_f1m=0.4213
  Época   7 | train_loss=0.2024 | train_acc=0.5046 | val_acc=0.4468 | val_f1m=0.4404
  Época   8 | train_loss=0.1977 | train_acc=0.5194 | val_acc=0.4429 | val_f1m=0.4369
  Época   9 | train_loss=0.1922 | train_acc=0.5275 | val_acc=0.4286 | val_f1m=0.4148
  Época  10 | train_loss=0.1926 | train_acc=0.5299 | val_acc=0.4698 | val_f1m=0.4618
  Época  11 | train_loss=0.1881 | train_acc=0.5462 | val_acc=0.4619 | val_f1m=0.4527
  Época  12 | train_loss=0.1835 | train_acc=0.5574 | val_acc=0.42

Processing TTA: 100%|██████████| 1764/1764 [00:06<00:00, 253.23it/s]


[Fold 3/5] TEST - acc=0.4705, F1-macro=0.4641
              precision    recall  f1-score   support

        left     0.5021    0.5351    0.5181       441
       right     0.5000    0.6213    0.5541       441
  both fists     0.4387    0.3243    0.3729       441
   both feet     0.4214    0.4014    0.4111       441

    accuracy                         0.4705      1764
   macro avg     0.4656    0.4705    0.4641      1764
weighted avg     0.4656    0.4705    0.4641      1764

Confusion matrix:
[[236  62  75  68]
 [ 40 274  49  78]
 [102  99 143  97]
 [ 92 113  59 177]]

PROCESANDO FOLD 4/5


Cargando train fold4: 100%|██████████| 68/68 [00:08<00:00,  8.41it/s]
Cargando val fold4: 100%|██████████| 15/15 [00:01<00:00,  8.47it/s]
Cargando test fold4: 100%|██████████| 20/20 [00:02<00:00,  8.20it/s]


[Fold 4/5] n_train=5712 | n_val=1260 | n_test=1680
📊 Distribución de clases - Train: [1428 1428 1428 1428]
📊 sfreq=160.0 Hz -> patch_size=16 | stride=8
📊 Pesos Focal Loss: [0.25 0.25 0.25 0.25]




  Época   1 | train_loss=0.2660 | train_acc=0.2544 | val_acc=0.2532 | val_f1m=0.1128
  Época   2 | train_loss=0.2522 | train_acc=0.3235 | val_acc=0.3825 | val_f1m=0.3818
  Época   3 | train_loss=0.2296 | train_acc=0.4205 | val_acc=0.4238 | val_f1m=0.4064
  Época   4 | train_loss=0.2139 | train_acc=0.4774 | val_acc=0.4286 | val_f1m=0.4276
  Época   5 | train_loss=0.2114 | train_acc=0.4834 | val_acc=0.4357 | val_f1m=0.4337
  Época   6 | train_loss=0.2087 | train_acc=0.4914 | val_acc=0.4373 | val_f1m=0.4306
  Época   7 | train_loss=0.2086 | train_acc=0.4970 | val_acc=0.4611 | val_f1m=0.4572
  Época   8 | train_loss=0.2026 | train_acc=0.5116 | val_acc=0.4548 | val_f1m=0.4378
  Época   9 | train_loss=0.2014 | train_acc=0.5165 | val_acc=0.4556 | val_f1m=0.4569
  Época  10 | train_loss=0.1970 | train_acc=0.5235 | val_acc=0.4698 | val_f1m=0.4693
  Época  11 | train_loss=0.1950 | train_acc=0.5221 | val_acc=0.4397 | val_f1m=0.4155
  Época  12 | train_loss=0.1925 | train_acc=0.5361 | val_acc=0.46

Processing TTA: 100%|██████████| 1680/1680 [00:06<00:00, 253.52it/s]


[Fold 4/5] TEST - acc=0.5155, F1-macro=0.5169
              precision    recall  f1-score   support

        left     0.5986    0.6214    0.6098       420
       right     0.6160    0.5310    0.5703       420
  both fists     0.4254    0.4619    0.4429       420
   both feet     0.4413    0.4476    0.4444       420

    accuracy                         0.5155      1680
   macro avg     0.5203    0.5155    0.5169      1680
weighted avg     0.5203    0.5155    0.5169      1680

Confusion matrix:
[[261  23  73  63]
 [ 36 223  77  84]
 [ 77  58 194  91]
 [ 62  58 112 188]]

PROCESANDO FOLD 5/5


Cargando train fold5: 100%|██████████| 68/68 [00:08<00:00,  8.48it/s]
Cargando val fold5: 100%|██████████| 15/15 [00:01<00:00,  8.53it/s]
Cargando test fold5: 100%|██████████| 20/20 [00:02<00:00,  8.50it/s]


[Fold 5/5] n_train=5712 | n_val=1260 | n_test=1680
📊 Distribución de clases - Train: [1428 1428 1428 1428]
📊 sfreq=160.0 Hz -> patch_size=16 | stride=8
📊 Pesos Focal Loss: [0.25 0.25 0.25 0.25]




  Época   1 | train_loss=0.2662 | train_acc=0.2512 | val_acc=0.2635 | val_f1m=0.1674
  Época   2 | train_loss=0.2598 | train_acc=0.2847 | val_acc=0.3413 | val_f1m=0.3144
  Época   3 | train_loss=0.2389 | train_acc=0.3972 | val_acc=0.5246 | val_f1m=0.5147
  Época   4 | train_loss=0.2215 | train_acc=0.4550 | val_acc=0.5317 | val_f1m=0.5047
  Época   5 | train_loss=0.2191 | train_acc=0.4582 | val_acc=0.5579 | val_f1m=0.5350
  Época   6 | train_loss=0.2199 | train_acc=0.4520 | val_acc=0.5730 | val_f1m=0.5713
  Época   7 | train_loss=0.2152 | train_acc=0.4662 | val_acc=0.5429 | val_f1m=0.5359
  Época   8 | train_loss=0.2126 | train_acc=0.4727 | val_acc=0.5556 | val_f1m=0.5337
  Época   9 | train_loss=0.2111 | train_acc=0.4781 | val_acc=0.5294 | val_f1m=0.5079
  Época  10 | train_loss=0.2058 | train_acc=0.4965 | val_acc=0.5460 | val_f1m=0.5239
  Época  11 | train_loss=0.2066 | train_acc=0.4907 | val_acc=0.5476 | val_f1m=0.5442
  Época  12 | train_loss=0.2016 | train_acc=0.5021 | val_acc=0.52

Processing TTA: 100%|██████████| 1680/1680 [00:06<00:00, 250.78it/s]

[Fold 5/5] TEST - acc=0.5089, F1-macro=0.5025
              precision    recall  f1-score   support

        left     0.6024    0.6095    0.6059       420
       right     0.5396    0.6000    0.5682       420
  both fists     0.4058    0.5333    0.4609       420
   both feet     0.5212    0.2929    0.3750       420

    accuracy                         0.5089      1680
   macro avg     0.5172    0.5089    0.5025      1680
weighted avg     0.5172    0.5089    0.5025      1680

Confusion matrix:
[[256  30  98  36]
 [ 29 252 101  38]
 [ 78  79 224  39]
 [ 62 106 129 123]]

🎉 RESULTADOS FINALES CON WINDOW SLIDING
Global folds (ACC): ['0.4796', '0.5204', '0.4705', '0.5155', '0.5089']
Global mean ACC: 0.4990
F1 folds (MACRO): ['0.4811', '0.5058', '0.4641', '0.5169', '0.5025']
F1 mean (MACRO): 0.4941
🔧 Configuración usada: SW_MODE=tta, SW_LEN=2.5, SW_STRIDE=0.3



