In [3]:
# -*- coding: utf-8 -*-
# Parallel-EEGNet + S-TCN + MHSA con VAT (global) + fine-tuning por sujeto
# - Misma construcción de dataset que EEGNet (6s, 8 canales, sujetos excluidos)
# - Cross-subject global (5 folds) + FT por sujeto como etapa secundaria
# - Notch adaptativo 50/60 Hz
# - VAT durante entrenamiento global (opcional)

import os, re, math, json, copy, random, itertools
from pathlib import Path
from datetime import datetime

import numpy as np
import pandas as pd
import mne

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset, WeightedRandomSampler

from sklearn.model_selection import GroupKFold, GroupShuffleSplit, StratifiedKFold, StratifiedShuffleSplit
from sklearn.utils import check_random_state
from sklearn.metrics import confusion_matrix, classification_report

import matplotlib.pyplot as plt
from tqdm import tqdm

# =========================
# RUTAS / CONFIG GENERAL
# =========================
PROJ = Path('..').resolve().parent
DATA_RAW = PROJ / 'data' / 'raw'
CACHE_DIR = PROJ / 'data' / 'cache'
FOLDS_DIR = PROJ / 'models' / 'folds' / 'Kfold5.json'
CACHE_DIR.mkdir(parents=True, exist_ok=True)

RANDOM_STATE = 42
torch.manual_seed(RANDOM_STATE); np.random.seed(RANDOM_STATE); random.seed(RANDOM_STATE)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🚀 Usando dispositivo: {DEVICE}")

# Datos
FS = 160.0
WINDOW_MODE = '3s'   # [-1, +5] seg
EXPECTED_8 = ['C3','C4','Cz','CP3','CP4','FC3','FC4','FCz']
EXCLUDE_SUBJECTS = {38, 88, 89, 92, 100, 104}

# Escenario
CLASS_SCENARIO = '4c'
CLASS_NAMES_4C = ['Left', 'Right', 'Both Fists', 'Both Feet']
N_CLASSES = 4
N_FOLDS = 5

# Entrenamiento global
BATCH_SIZE = 64
EPOCHS_GLOBAL = 100
LR_INIT = 1e-2
SGDR_T0 = 6
SGDR_Tmult = 2
GLOBAL_VAL_SPLIT = 0.15
GLOBAL_PATIENCE = 10
LOG_EVERY = 5

# Fine-tuning por sujeto
CALIB_CV_FOLDS = 4
FT_EPOCHS = 30
FT_BASE_LR = 5e-5
FT_HEAD_LR = 1e-3
FT_L2SP = 1e-4
FT_PATIENCE = 5
FT_VAL_RATIO = 0.2

# VAT (Virtual Adversarial Training)
USE_VAT = True
VAT_XI = 1e-6     # tamaño de paso para power-iteration
VAT_EPS = 1.5     # radio de perturbación adversaria
VAT_IP = 1        # iteraciones de potencia (1 suele bastar)

# Augments ligeros
USE_MIXUP = False
MIXUP_ALPHA = 0.2

# Runs (mismo set tuyo)
MI_RUNS_LR = [4, 8, 12]
MI_RUNS_OF = [6, 10, 14]
BASELINE_RUNS_EO = [1]

# =========================
# CANALES / LECTURA EDF
# =========================
def normalize_label(s: str) -> str:
    if s is None: return s
    s = s.strip()
    s = re.sub(r'[^A-Za-z0-9]', '', s)
    s = re.sub(r'([A-Za-z])0([0-9])', r'\1\2', s)
    s = re.sub(r'([A-Za-z])Z$', r'\1z', s)
    s = s.replace('fp', 'Fp').replace('FP', 'Fp')
    s = ''.join(ch.upper() if ch != 'z' else 'z' for ch in s)
    return s

def rename_channels_1010(raw: mne.io.BaseRaw):
    mapping = {}
    for ch in raw.ch_names:
        lab = normalize_label(ch)
        lab = lab[:-1] + 'z' if lab.endswith('Z') else lab
        lab = re.sub(r'([A-Z])Z$', r'\1z', lab)
        mapping[ch] = lab
    mne.rename_channels(raw.info, mapping)

def ensure_channels_order(raw: mne.io.BaseRaw, desired_channels=EXPECTED_8):
    missing = [ch for ch in desired_channels if ch not in raw.ch_names]
    if missing:
        print(f"[WARN] faltan canales {missing} en archivo {getattr(raw,'filenames', [''])[0]}")
        return None
    raw.reorder_channels([ch for ch in raw.ch_names if ch in desired_channels] +
                         [ch for ch in raw.ch_names if ch not in desired_channels])
    raw.pick_channels(desired_channels, ordered=True)
    return raw

_re_file = re.compile(r'[Ss](\d{3}).*?[Rr](\d{2})')
def parse_subject_run(path: Path):
    m = _re_file.search(str(path))
    if not m: return None, None
    return int(m.group(1)), int(m.group(2))

def run_kind(run_id:int):
    if run_id in MI_RUNS_LR: return 'LR'
    if run_id in MI_RUNS_OF: return 'OF'
    if run_id in BASELINE_RUNS_EO: return 'EO'
    return None

# --- Notch por SNR (50/60) ---
_SNR_TABLE = None
def _load_snr_table():
    global _SNR_TABLE
    if _SNR_TABLE is not None:
        return _SNR_TABLE
    csv_path = PROJ / 'reports' / 'psd_mains' / 'psd_mains_summary.csv'
    if csv_path.exists():
        try:
            _SNR_TABLE = pd.read_csv(csv_path)
        except:
            _SNR_TABLE = None
    return _SNR_TABLE

def _decide_notch(subject, run, th_db=10.0):
    df = _load_snr_table()
    if df is None:
        return 60.0
    row = df[(df['subject']==subject) & (df['run']==run)]
    if row.empty:
        return 60.0
    snr50 = float(row['snr50_db'].iloc[0]); snr60 = float(row['snr60_db'].iloc[0])
    if snr60 >= th_db and snr60 >= snr50: return 60.0
    if snr50 >= th_db and snr50 >  snr60: return 50.0
    return None

def read_raw_edf(path: Path):
    raw = mne.io.read_raw_edf(path, preload=True, verbose=False)
    raw.pick(mne.pick_types(raw.info, eeg=True))
    rename_channels_1010(raw)
    try:
        mont = mne.channels.make_standard_montage('standard_1020')
        raw.set_montage(mont, on_missing='ignore')
    except Exception:
        pass
    if abs(raw.info['sfreq'] - FS) > 1e-6:
        raw.resample(FS, npad="auto")
    raw = ensure_channels_order(raw, EXPECTED_8)
    if raw is None:
        return None
    sid, rid = parse_subject_run(path)
    notch = _decide_notch(sid, rid)
    if notch is not None:
        raw.notch_filter(freqs=[float(notch)], picks='eeg', method='spectrum_fit', phase='zero')
    return raw

def collect_events_T1T2(raw: mne.io.BaseRaw):
    if raw.annotations is None or len(raw.annotations) == 0:
        return []
    def _norm(s): return str(s).strip().upper().replace(' ', '')
    res = []
    for onset, desc in zip(raw.annotations.onset, raw.annotations.description):
        tag = _norm(desc)
        if tag in ('T1','T2'):
            res.append((float(onset), tag))
    res.sort()
    # de-dup mínimo por clase
    dedup = []; last_t1 = last_t2 = -1e9
    for t, tag in res:
        if tag == 'T1':
            if (t - last_t1) >= 0.5: dedup.append((t, tag)); last_t1 = t
        else:
            if (t - last_t2) >= 0.5: dedup.append((t, tag)); last_t2 = t
    return dedup

# =========================
# DATASET (MISMA LÓGICA 6s)
# =========================
def subjects_available():
    subs = []
    for sdir in sorted(DATA_RAW.glob('S*')):
        if not sdir.is_dir(): continue
        try: sid = int(sdir.name[1:])
        except: continue
        if sid in EXCLUDE_SUBJECTS: continue
        any_mi = any((sdir / f"S{sid:03d}R{r:02d}.edf").exists() for r in (MI_RUNS_LR + MI_RUNS_OF))
        if any_mi: subs.append(sid)
    return subs

def extract_trials_from_run(edf_path: Path, scenario: str, window_mode: str):
    subj, run = parse_subject_run(edf_path)
    kind = run_kind(run)
    if kind not in ('LR','OF','EO'):
        return ([], [])
    raw = read_raw_edf(edf_path)
    if raw is None:
        return ([], [])
    data = raw.get_data()
    fs = raw.info['sfreq']
    assert abs(fs - FS) < 1e-6
    out = []
    if kind in ('LR','OF'):
        events = collect_events_T1T2(raw)
        if window_mode == '3s':
            rel_start, rel_end = 0.0, 3.0
        else:
            rel_start, rel_end = -1.0, 5.0  # 6 s
        for onset_sec, tag in events:
            if kind == 'LR':
                label = 0 if tag=='T1' else 1 if tag=='T2' else None
            else:
                label = 2 if tag=='T1' else 3 if tag=='T2' else None
            if label is None: continue
            if scenario == '2c' and label not in (0,1): continue
            if scenario == '3c' and label not in (0,1,2): continue
            s = int(round((raw.first_time + onset_sec + rel_start) * fs))
            e = int(round((raw.first_time + onset_sec + rel_end)   * fs))
            if s < 0 or e > data.shape[1]: continue
            seg = data[:, s:e].T.astype(np.float32)  # (T,C)
            # z-score por época canal-a-canal
            seg = (seg - seg.mean(axis=0, keepdims=True)) / (seg.std(axis=0, keepdims=True) + 1e-6)
            out.append((seg, label, subj))
    elif kind == 'EO':
        return ([], raw.ch_names)
    return out, raw.ch_names

def build_dataset_all(subjects, scenario='4c', window_mode='6s'):
    X, y, groups = [], [], []
    ch_template = None
    for s in tqdm(subjects, desc="Construyendo dataset (RAW)"):
        sdir = DATA_RAW / f"S{s:03d}"
        if not sdir.exists(): continue
        trials = []
        for r in (MI_RUNS_LR + MI_RUNS_OF):
            p = sdir / f"S{s:03d}R{r:02d}.edf"
            if not p.exists(): continue
            outs, chs = extract_trials_from_run(p, scenario, window_mode)
            if ch_template is None and chs: ch_template = chs
            trials.extend(outs)
        # balanceo a 21 por clase/sujeto
        per_class = {0:[],1:[],2:[],3:[]}
        for seg, lab, _ in trials: per_class[lab].append(seg)
        need = 21
        rng = check_random_state(RANDOM_STATE + s)
        valid = all(len(per_class[c])>0 for c in range(4))
        if not valid: continue
        def pick(arr):
            if len(arr) >= need:
                rng.shuffle(arr); return arr[:need]
            idx = rng.choice(len(arr), size=need, replace=True)
            return [arr[i] for i in idx]
        for c in range(4):
            segs = pick(per_class[c])
            for seg in segs:
                X.append(seg); y.append(c); groups.append(s)
    X = np.stack(X, axis=0)
    y = np.asarray(y, dtype=np.int64)
    groups = np.asarray(groups, dtype=np.int64)
    n, T, C = X.shape
    print(f"Dataset construido: N={n} | T={T} | C={C} | clases={len(np.unique(y))} | sujetos únicos={len(np.unique(groups))}")
    return X, y, groups, ch_template

# =========================
# AUGMENTS + VAT
# =========================
def do_time_jitter(x, max_ms=50, fs=160.0):
    B,_,T,C = x.shape
    max_shift = int(round(max_ms/1000.0 * fs))
    if max_shift <= 0: return x
    shifts = torch.randint(low=-max_shift, high=max_shift+1, size=(B,), device=x.device)
    out = torch.empty_like(x)
    for i,s in enumerate(shifts):
        if s==0: out[i] = x[i]; continue
        if s>0:
            out[i,:,s:,:] = x[i,:,:T-s,:]; out[i,:,:s,:] = 0
        else:
            s = -s
            out[i,:,:T-s,:] = x[i,:,s:,:]; out[i,:,T-s:,:] = 0
    return out

def do_gaussian_noise(x, sigma=0.01):
    if sigma<=0: return x
    return x + sigma*torch.randn_like(x)

def do_temporal_cutout_masked(x, y, classes_mask={2,3}, min_ms=30, max_ms=60, fs=160.0):
    B,_,T,C = x.shape
    Lmin = int(round(min_ms/1000.0*fs)); Lmax = int(round(max_ms/1000.0*fs))
    if Lmin<=0 or Lmax<=0 or Lmin>Lmax: return x
    out = x.clone(); y_np = y.detach().cpu().numpy()
    for i in range(B):
        if int(y_np[i]) not in classes_mask: continue
        L = random.randint(Lmin, Lmax)
        if L>=T: continue
        s = random.randint(0, T-L)
        out[i,:,s:s+L,:] = 0
    return out

def mixup_batch(x, y, n_classes, alpha=0.2):
    if alpha<=0:
        y_onehot = torch.nn.functional.one_hot(y, num_classes=n_classes).float()
        return x, y_onehot, 1.0
    lam = np.random.beta(alpha, alpha)
    perm = torch.randperm(x.size(0), device=x.device)
    x_mix = lam*x + (1-lam)*x[perm]
    y_a = torch.nn.functional.one_hot(y, num_classes=n_classes).float()
    y_b = y_a[perm]
    y_mix = lam*y_a + (1-lam)*y_b
    return x_mix, y_mix, lam

@torch.no_grad()
def _normalize_like(x, ref):
    # normaliza vector perturbación como ref (L2 sobre todo el tensor de entrada)
    eps = 1e-6
    r = x.view(x.size(0), -1)
    n = torch.norm(r, p=2, dim=1, keepdim=True) + eps
    r = (r / n) * ref
    return r.view_as(x)

def kl_div_with_logits(p_logits, q_logits):
    p = torch.log_softmax(p_logits, dim=1)
    q = torch.softmax(q_logits, dim=1)
    return torch.mean(torch.sum(q * (torch.log(q + 1e-8) - p), dim=1))

def virtual_adversarial_loss(model, x, logits=None, xi=1e-6, eps=1.5, ip=1):
    # x: (B,1,T,C)
    with torch.no_grad():
        if logits is None:
            logits = model(x)
        p = logits.detach()
    # d inicial aleatorio
    d = torch.randn_like(x)
    d = _normalize_like(d, ref=xi)
    for _ in range(ip):
        d.requires_grad_()
        q = model(x + d)
        adv_kl = kl_div_with_logits(p, q)
        adv_kl.backward()
        g = d.grad.detach()
        d = _normalize_like(g, ref=xi)
        model.zero_grad(set_to_none=True)
    # perturb final
    r_adv = _normalize_like(d, ref=eps)
    q = model(x + r_adv)
    vat_loss = kl_div_with_logits(p, q)
    return vat_loss

# =========================
# DATASET TORCH
# =========================
class EEGTrials(Dataset):
    def __init__(self, X, y, groups):
        self.X = X.astype(np.float32)
        self.y = y.astype(np.int64)
        self.g = groups.astype(np.int64)
    def __len__(self): return self.X.shape[0]
    def __getitem__(self, idx):
        x = self.X[idx]  # (T,C)
        x = np.expand_dims(x, 0)  # (1,T,C)
        return torch.from_numpy(x), torch.tensor(self.y[idx]), torch.tensor(self.g[idx])

# =========================
# MODELO: Parallel-EEGNet + S-TCN + MHSA
# =========================
class ChannelDropout(nn.Module):
    def __init__(self, p=0.1):
        super().__init__(); self.p = p
    def forward(self, x):
        if not self.training or self.p<=0: return x
        B,_,T,C = x.shape
        mask = (torch.rand(B,1,1,C, device=x.device) > self.p).float()
        return x * mask

class SeparableTemporal(nn.Module):
    def __init__(self, in_ch, out_ch, k, p_drop=0.2):
        super().__init__()
        self.depth = nn.Conv2d(in_ch, in_ch, kernel_size=(k,1), groups=in_ch, padding=(k//2,0), bias=False)
        self.point = nn.Conv2d(in_ch, out_ch, kernel_size=(1,1), bias=False)
        self.bn = nn.BatchNorm2d(out_ch)
        self.act = nn.ELU()
        self.drop = nn.Dropout(p_drop)
    def forward(self, x):
        z = self.depth(x); z = self.point(z)
        z = self.bn(z); z = self.act(z)
        z = self.drop(z)
        return z

class ParallelEEGStem(nn.Module):
    """
    Dos ramas temporales (kernels distintos) + depthwise espacial y separable temporal.
    Entrada (B,1,T,C) -> Salida (B, F2_total, T', 1)
    """
    def __init__(self, n_ch=8, F1=16, D=2, k_t_short=32, k_t_long=96,
                 pool1_t=4, pool2_t=6, drop1=0.25, drop2=0.5, chdrop=0.1):
        super().__init__()
        self.chdrop = ChannelDropout(p=chdrop)

        # Rama A (temporal corto)
        self.convA = nn.Conv2d(1, F1, kernel_size=(k_t_short,1), padding=(k_t_short//2,0), bias=False)
        self.bnA = nn.BatchNorm2d(F1); self.act = nn.ELU()

        # Rama B (temporal largo)
        self.convB = nn.Conv2d(1, F1, kernel_size=(k_t_long,1), padding=(k_t_long//2,0), bias=False)
        self.bnB = nn.BatchNorm2d(F1)

        # Depthwise espacial por rama
        self.dwA = nn.Conv2d(F1, F1*D, kernel_size=(1,n_ch), groups=F1, bias=False)
        self.dwB = nn.Conv2d(F1, F1*D, kernel_size=(1,n_ch), groups=F1, bias=False)
        self.bn_dwA = nn.BatchNorm2d(F1*D)
        self.bn_dwB = nn.BatchNorm2d(F1*D)
        self.pool1A = nn.AvgPool2d(kernel_size=(pool1_t,1), stride=(pool1_t,1))
        self.pool1B = nn.AvgPool2d(kernel_size=(pool1_t,1), stride=(pool1_t,1))
        self.drop1A = nn.Dropout(drop1); self.drop1B = nn.Dropout(drop1)

        # Separable temporal por rama
        self.sepA = SeparableTemporal(F1*D, F1*D, k=16, p_drop=drop2)
        self.sepB = SeparableTemporal(F1*D, F1*D, k=32, p_drop=drop2)
        self.pool2A = nn.AvgPool2d(kernel_size=(pool2_t,1), stride=(pool2_t,1))
        self.pool2B = nn.AvgPool2d(kernel_size=(pool2_t,1), stride=(pool2_t,1))

        self.Fout = 2 * (F1*D)

    def forward(self, x):
        # x: (B,1,T,C)
        x = self.chdrop(x)

        a = self.convA(x); a = self.bnA(a); a = self.act(a)
        b = self.convB(x); b = self.bnB(b); b = self.act(b)

        a = self.dwA(a); a = self.bn_dwA(a); a = self.act(a)
        b = self.dwB(b); b = self.bn_dwB(b); b = self.act(b)
        a = self.pool1A(a); a = self.drop1A(a)
        b = self.pool1B(b); b = self.drop1B(b)

        a = self.sepA(a); a = self.pool2A(a)   # (B, F, T', 1)
        b = self.sepB(b); b = self.pool2B(b)
        z = torch.cat([a, b], dim=1)           # concat canales
        return z  # (B, F_total, T', 1)

class STCNBlock(nn.Module):
    def __init__(self, d_model, k=7, dilation=1, p_drop=0.2):
        super().__init__()
        self.conv1 = nn.Conv1d(d_model, d_model, kernel_size=k, padding=dilation*(k//2), dilation=dilation, groups=d_model)
        self.bn1 = nn.BatchNorm1d(d_model)
        self.act = nn.GELU()
        self.conv2 = nn.Conv1d(d_model, d_model, kernel_size=1)
        self.bn2 = nn.BatchNorm1d(d_model)
        self.drop = nn.Dropout(p_drop)
    def forward(self, x):  # x: (B, D, T)
        y = self.conv1(x); y = self.bn1(y); y = self.act(y)
        y = self.conv2(y); y = self.bn2(y); y = self.drop(y)
        return x + y

class MHSA1D(nn.Module):
    def __init__(self, d_model, n_heads=4, p_drop=0.2):
        super().__init__()
        self.ln = nn.LayerNorm(d_model)
        self.mha = nn.MultiheadAttention(d_model, n_heads, dropout=p_drop, batch_first=True)
        self.drop = nn.Dropout(p_drop)
    def forward(self, x):  # x: (B,T,D)
        y = self.ln(x)
        y,_ = self.mha(y,y,y, need_weights=False)
        return x + self.drop(y)

class EEG_Parallel_STCN_MHSA(nn.Module):
    """
    Entrada: (B,1,T,C)
    1) Parallel EEG stem -> (B, F, T', 1)
    2) Reordenar a (B, T', F)
    3) S-TCN (dilations 1-2-4)
    4) MHSA sobre T'
    5) MeanPool + Head
    """
    def __init__(self, n_ch=8, n_classes=4,
                 F1=16, D=2, k_t_short=32, k_t_long=96,
                 pool1_t=4, pool2_t=6, drop1=0.25, drop2=0.5, chdrop=0.1,
                 d_model=None, n_heads=4, stcn_k=7, p_drop=0.2):
        super().__init__()
        self.stem = ParallelEEGStem(n_ch=n_ch, F1=F1, D=D, k_t_short=k_t_short, k_t_long=k_t_long,
                                    pool1_t=pool1_t, pool2_t=pool2_t, drop1=drop1, drop2=drop2, chdrop=chdrop)
        D_in = self.stem.Fout
        self.proj = nn.Conv1d(D_in, D_in, kernel_size=1)  # opcional
        self.stcn1 = STCNBlock(D_in, k=stcn_k, dilation=1, p_drop=p_drop)
        self.stcn2 = STCNBlock(D_in, k=stcn_k, dilation=2, p_drop=p_drop)
        self.stcn3 = STCNBlock(D_in, k=stcn_k, dilation=4, p_drop=p_drop)
        self.mhsa = MHSA1D(D_in, n_heads=n_heads, p_drop=p_drop)
        self.ln = nn.LayerNorm(D_in)
        self.head = nn.Sequential(
            nn.Linear(D_in, 128),
            nn.GELU(),
            nn.Dropout(p_drop),
            nn.Linear(128, n_classes)
        )

    def forward(self, x):
        # x: (B,1,T,C)
        z = self.stem(x)                # (B, F, T', 1)
        z = z.squeeze(-1)               # (B, F, T')
        z = self.proj(z)                # (B, F, T')
        z = self.stcn1(z); z = self.stcn2(z); z = self.stcn3(z)
        z = z.transpose(1,2)            # (B, T', F)
        z = self.mhsa(z)
        z = self.ln(z)
        z = z.mean(dim=1)               # (B, F)
        logits = self.head(z)
        return logits

# =========================
# ENTRENAMIENTO / EVAL
# =========================
def build_weighted_sampler(y, groups):
    y = np.asarray(y); groups = np.asarray(groups)
    class_counts = np.bincount(y, minlength=len(np.unique(y))).astype(float)
    class_w = 1.0 / class_counts[y]
    subj_vals, subj_counts = np.unique(groups, return_counts=True)
    subj_map = {s:c for s,c in zip(subj_vals, subj_counts)}
    subj_w = np.array([1.0/subj_map[g] for g in groups], dtype=float)
    w = class_w * subj_w
    w = w / w.mean()
    return WeightedRandomSampler(weights=torch.from_numpy(w).float(), num_samples=len(w), replacement=True)

def make_class_weight_tensor(y_indices, n_classes):
    counts = np.bincount(y_indices, minlength=n_classes).astype(np.float32)
    counts[counts == 0] = 1.0
    w = counts.sum() / counts
    w = w / w.mean()
    return torch.tensor(w, dtype=torch.float32, device=DEVICE)

def train_one_epoch(model, loader, opt, criterion, epoch, n_classes=N_CLASSES,
                    use_vat=True, vat_xi=VAT_XI, vat_eps=VAT_EPS, vat_ip=VAT_IP):
    model.train()
    for xb, yb, _ in loader:
        xb, yb = xb.to(DEVICE), yb.to(DEVICE)

        # Augments ligeros
        xb = do_time_jitter(xb, max_ms=50, fs=FS)
        xb = do_gaussian_noise(xb, sigma=0.01)
        xb = do_temporal_cutout_masked(xb, yb, classes_mask={2,3}, min_ms=30, max_ms=60, fs=FS)

        if USE_MIXUP:
            xb, yt, _ = mixup_batch(xb, yb, n_classes=n_classes, alpha=MIXUP_ALPHA)
            logits = model(xb)
            loss_sup = -(yt * torch.log_softmax(logits, dim=1)).sum(dim=1).mean()
        else:
            logits = model(xb)
            loss_sup = criterion(logits, yb)

        loss = loss_sup
        if use_vat:
            vat_loss = virtual_adversarial_loss(model, xb, logits=logits, xi=vat_xi, eps=vat_eps, ip=vat_ip)
            loss = loss + 0.5 * vat_loss  # peso VAT (ajústalo si quieres)

        opt.zero_grad(set_to_none=True)
        loss.backward()
        opt.step()

@torch.no_grad()
def evaluate(model, loader, use_tta=True, tta_n=5):
    model.eval()
    y_true, y_pred = [], []
    for xb, yb, _ in loader:
        xb = xb.to(DEVICE)
        if use_tta:
            outs = []
            for _ in range(tta_n):
                outs.append(model(do_time_jitter(xb, max_ms=25, fs=FS)))
            logits = torch.stack(outs, dim=0).mean(dim=0)
        else:
            logits = model(xb)
        pred = logits.argmax(dim=1).cpu().numpy().tolist()
        y_pred.extend(pred); y_true.extend(yb.numpy().tolist())
    y_true = np.asarray(y_true, dtype=int); y_pred = np.asarray(y_pred, dtype=int)
    acc = (y_true == y_pred).mean()
    return y_true, y_pred, float(acc)

def plot_confusion(y_true, y_pred, classes, title, fname):
    cm = confusion_matrix(y_true, y_pred, labels=list(range(len(classes))))
    with np.errstate(invalid='ignore'):
        cm_norm = cm.astype('float') / cm.sum(axis=1, keepdims=True)
    cm_norm = np.nan_to_num(cm_norm)
    plt.figure(figsize=(6,5))
    plt.imshow(cm_norm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title(title); plt.colorbar(fraction=0.046, pad=0.04)
    ticks = np.arange(len(classes))
    plt.xticks(ticks, classes, rotation=45, ha='right'); plt.yticks(ticks, classes)
    fmt = '.2f'; thresh = cm_norm.max()/2.
    for i,j in itertools.product(range(cm_norm.shape[0]), range(cm_norm.shape[1])):
        plt.text(j, i, format(cm_norm[i,j], fmt), ha="center",
                 color="white" if cm_norm[i,j] > thresh else "black")
    plt.ylabel('True'); plt.xlabel('Pred')
    plt.tight_layout(); plt.savefig(fname, dpi=150, bbox_inches='tight'); plt.close()

def print_report(y_true, y_pred, classes):
    print(classification_report(y_true, y_pred, target_names=classes, digits=4))

# =========================
# FINE-TUNING POR SUJETO (mismo estilo que tuyo)
# =========================
def _freeze_all(m):
    for p in m.parameters(): p.requires_grad = False

def subject_cv_finetune_predict_progressive(model_global, Xs, ys, n_classes=4):
    # Dos etapas: (1) head; (2) last blocks (stcn+mhsa) + head con L2-SP
    y_true_full = np.empty_like(ys); y_pred_full = np.empty_like(ys)
    # CV adaptativo
    ys_np = np.asarray(ys)
    counts = np.bincount(ys_np, minlength=n_classes)
    min_class = counts[counts>0].min() if counts.sum()>0 else 0
    splits = min(CALIB_CV_FOLDS, max(2, int(min_class))) if min_class>0 else 2
    splitter = StratifiedKFold(n_splits=splits, shuffle=True, random_state=RANDOM_STATE)

    for tr_idx, te_idx in splitter.split(Xs, ys):
        Xcal, ycal = Xs[tr_idx], ys[tr_idx]
        Xho,  yho  = Xs[te_idx], ys[te_idx]

        # split interno
        if len(np.unique(ycal)) >= 2 and len(ycal) >= 5:
            sss = StratifiedShuffleSplit(n_splits=1, test_size=FT_VAL_RATIO, random_state=RANDOM_STATE)
            (tr2, va2), = sss.split(Xcal, ycal)
        else:
            idx = np.arange(len(ycal)); np.random.shuffle(idx)
            cut = max(1, int(0.85*len(idx)))
            tr2, va2 = idx[:cut], idx[cut:]

        ds_tr2 = torch.utils.data.TensorDataset(torch.from_numpy(Xcal[tr2]).float().unsqueeze(1), torch.from_numpy(ycal[tr2]).long())
        ds_va2 = torch.utils.data.TensorDataset(torch.from_numpy(Xcal[va2]).float().unsqueeze(1), torch.from_numpy(ycal[va2]).long())
        dl_tr2 = DataLoader(ds_tr2, batch_size=32, shuffle=True)
        dl_va2 = DataLoader(ds_va2, batch_size=64, shuffle=False)
        ds_te  = torch.utils.data.TensorDataset(torch.from_numpy(Xho).float().unsqueeze(1), torch.from_numpy(yho).long())
        dl_te  = DataLoader(ds_te, batch_size=64, shuffle=False)

        # ========== (1) HEAD ==========
        m = copy.deepcopy(model_global).to(DEVICE)
        _freeze_all(m)
        # habilitar solo la head
        for p in m.head.parameters(): p.requires_grad = True
        opt = optim.Adam(m.head.parameters(), lr=FT_HEAD_LR)
        crit = nn.CrossEntropyLoss()

        best = copy.deepcopy(m.state_dict()); best_va = 1e9; bad=0
        for _ in range(FT_EPOCHS):
            m.train()
            for xb, yb in dl_tr2:
                xb, yb = xb.to(DEVICE), yb.to(DEVICE)
                opt.zero_grad(set_to_none=True)
                loss = crit(m(xb), yb); loss.backward(); opt.step()
            # val
            m.eval(); val_loss=0.0; n=0
            with torch.no_grad():
                for xb, yb in dl_va2:
                    xb, yb = xb.to(DEVICE), yb.to(DEVICE)
                    val_loss += crit(m(xb), yb).item()*xb.size(0); n+=xb.size(0)
            val_loss/=max(1,n)
            if val_loss < best_va-1e-7:
                best_va = val_loss; best=copy.deepcopy(m.state_dict()); bad=0
            else:
                bad+=1
                if bad>=FT_PATIENCE: break
        m.load_state_dict(best)

        # ========== (2) Último bloque (stcn3 + mhsa + ln + head) con L2-SP ==========
        _freeze_all(m)
        train_params = []
        for p in m.stcn3.parameters(): p.requires_grad = True; train_params.append(p)
        for p in m.mhsa.parameters():  p.requires_grad = True; train_params.append(p)
        for p in m.ln.parameters():    p.requires_grad = True; train_params.append(p)
        for p in m.head.parameters():  p.requires_grad = True; train_params.append(p)

        ref = [p.detach().clone() for p in train_params]
        opt = optim.Adam([
            {"params": list(m.stcn3.parameters()) + list(m.mhsa.parameters()) + list(m.ln.parameters()), "lr": FT_BASE_LR},
            {"params": m.head.parameters(), "lr": FT_HEAD_LR},
        ])
        crit = nn.CrossEntropyLoss()

        best = copy.deepcopy(m.state_dict()); best_va = 1e9; bad=0
        for _ in range(FT_EPOCHS):
            m.train()
            for xb, yb in dl_tr2:
                xb, yb = xb.to(DEVICE), yb.to(DEVICE)
                opt.zero_grad(set_to_none=True)
                logits = m(xb)
                loss = crit(logits, yb)
                # L2-SP
                reg = 0.0
                for p_cur, p_ref in zip(train_params, ref):
                    reg = reg + torch.sum((p_cur - p_ref)**2)
                loss = loss + FT_L2SP * reg
                loss.backward(); opt.step()
            # val
            m.eval(); val_loss=0.0; n=0
            with torch.no_grad():
                for xb, yb in dl_va2:
                    xb, yb = xb.to(DEVICE), yb.to(DEVICE)
                    val_loss += crit(m(xb), yb).item()*xb.size(0); n+=xb.size(0)
            val_loss/=max(1,n)
            if val_loss < best_va-1e-7:
                best_va = val_loss; best=copy.deepcopy(m.state_dict()); bad=0
            else:
                bad+=1
                if bad>=FT_PATIENCE: break
        m.load_state_dict(best)

        # test hold-out del CV interno
        with torch.no_grad():
            m.eval()
            preds=[]
            for xb, _ in dl_te:
                xb = xb.to(DEVICE)
                logits = m(xb)
                preds.extend(logits.argmax(1).cpu().numpy().tolist())
        y_pred_full[te_idx] = np.asarray(preds); y_true_full[te_idx] = yho

    return y_true_full, y_pred_full

# =========================
# HELPERS FOLDS JSON
# =========================
def save_group_folds_json_with_indices(subject_ids_str, groups_array, n_splits, out_json_path,
                                       created_by="ParallelEEG_STCN_MHSA", description=None):
    out_json_path = Path(out_json_path)
    unique_subjects_int = sorted(np.unique(groups_array).tolist())
    subject_ids = [f"S{sid:03d}" for sid in unique_subjects_int]
    if len(subject_ids) < n_splits:
        raise ValueError(f"n_splits={n_splits} mayor que número de sujetos={len(subject_ids)}")
    groups = np.arange(len(subject_ids))
    gkf = GroupKFold(n_splits=n_splits)
    folds = []
    fold_i = 0
    for train_idx_grp, test_idx_grp in gkf.split(groups, groups, groups):
        fold_i += 1
        train_sids = [subject_ids[int(i)] for i in train_idx_grp]
        test_sids  = [subject_ids[int(i)] for i in test_idx_grp]
        train_sids_int = [int(s[1:]) for s in train_sids]
        test_sids_int  = [int(s[1:]) for s in test_sids]
        tr_idx = np.where(np.isin(groups_array, train_sids_int))[0].tolist()
        te_idx = np.where(np.isin(groups_array, test_sids_int))[0].tolist()
        folds.append({
            "fold": int(fold_i),
            "train": train_sids,
            "test": test_sids,
            "tr_idx": tr_idx,
            "te_idx": te_idx
        })
    payload = {
        "created_at": datetime.now().isoformat(),
        "created_by": created_by,
        "description": description if description is not None else "",
        "n_splits": int(n_splits),
        "n_subjects": len(subject_ids),
        "subject_ids": subject_ids,
        "folds": folds
    }
    out_json_path.parent.mkdir(parents=True, exist_ok=True)
    with open(out_json_path, "w", encoding="utf-8") as f:
        json.dump(payload, f, indent=2, ensure_ascii=False)
    print(f"Folds JSON con índices guardado → {out_json_path}")
    return out_json_path

def load_group_folds_json(path_json, expected_subject_ids=None, strict_check=True):
    path_json = Path(path_json)
    if not path_json.exists():
        raise FileNotFoundError(f"No existe {path_json}")
    with open(path_json, "r", encoding="utf-8") as f:
        payload = json.load(f)
    subj_json = payload.get("subject_ids", [])
    if expected_subject_ids is not None:
        expected = sorted(list(expected_subject_ids))
        if subj_json != expected:
            msg = ("Los subject_ids del JSON no coinciden con expected_subject_ids.\n"
                   f"JSON has {len(subj_json)} subjects, expected {len(expected)}.\n"
                   f"First 10 JSON: {subj_json[:10]}\nFirst 10 expected: {expected[:10]}")
            if strict_check:
                raise ValueError(msg)
            else:
                print("WARNING: " + msg)
    return payload

# =========================
# EXPERIMENTO
# =========================
def run_experiment(save_folds_json=True, folds_json_path=FOLDS_DIR,
                   folds_json_description="GroupKFold folds for ParallelEEG_STCN_MHSA + VAT"):
    mne.set_log_level('WARNING')

    subs = subjects_available()
    print(f"Sujetos elegibles: {len(subs)} → {subs[:10]}{'...' if len(subs)>10 else ''}")

    X, y, groups, chs = build_dataset_all(subs, scenario=CLASS_SCENARIO, window_mode=WINDOW_MODE)
    N, T, C = X.shape
    n_classes = len(np.unique(y))
    print(f"Listo para entrenar: N={N} | T={T} | C={C} | clases={n_classes} | sujetos={len(np.unique(groups))}")

    ds = EEGTrials(X, y, groups)

    # Folds JSON
    if folds_json_path is None:
        folds_json_path = Path("folds") / f"group_folds_{N_FOLDS}splits.json"
    else:
        folds_json_path = Path(folds_json_path)
    folds_json_path.parent.mkdir(parents=True, exist_ok=True)
    unique_subs = sorted(np.unique(groups).tolist())
    subject_ids_str = [f"S{s:03d}" for s in unique_subs]
    if not folds_json_path.exists():
        if not save_folds_json:
            raise FileNotFoundError(f"Folds JSON no encontrado en {folds_json_path} y save_folds_json=False.")
        save_group_folds_json_with_indices(subject_ids_str, groups, n_splits=N_FOLDS,
                                           out_json_path=folds_json_path,
                                           created_by="ParallelEEG_STCN_MHSA",
                                           description=f"{folds_json_description}")
    payload = load_group_folds_json(folds_json_path, expected_subject_ids=subject_ids_str, strict_check=False)
    folds = payload["folds"]

    global_folds = []
    ft_prog_folds = []
    all_true = []
    all_pred = []

    for f in folds:
        fold = f["fold"]
        tr_idx = np.asarray(f.get("tr_idx", []), dtype=int)
        te_idx = np.asarray(f.get("te_idx", []), dtype=int)
        if tr_idx.size == 0 or te_idx.size == 0:
            print(f"[WARN] fold {fold} sin índices tr/te válidos. Saltando.")
            continue

        # split de validación por sujetos
        gss = GroupShuffleSplit(n_splits=1, test_size=GLOBAL_VAL_SPLIT, random_state=RANDOM_STATE)
        tr_subj_idx, va_subj_idx = next(gss.split(tr_idx, groups[tr_idx], groups[tr_idx]))
        tr_sub_idx = tr_idx[tr_subj_idx]
        va_idx     = tr_idx[va_subj_idx]

        sampler = build_weighted_sampler(y[tr_sub_idx], groups[tr_sub_idx])
        tr_loader = DataLoader(Subset(ds, tr_sub_idx), batch_size=BATCH_SIZE, sampler=sampler, drop_last=False)
        va_loader = DataLoader(Subset(ds, va_idx),     batch_size=BATCH_SIZE, shuffle=False, drop_last=False)
        te_loader = DataLoader(Subset(ds, te_idx),     batch_size=BATCH_SIZE, shuffle=False, drop_last=False)

        # modelo
        model = EEG_Parallel_STCN_MHSA(n_ch=C, n_classes=n_classes,
                                       F1=16, D=2, k_t_short=32, k_t_long=96,
                                       pool1_t=4, pool2_t=6,
                                       drop1=0.25, drop2=0.5, chdrop=0.1,
                                       n_heads=4, stcn_k=7, p_drop=0.2).to(DEVICE)

        opt = optim.Adam(model.parameters(), lr=LR_INIT, weight_decay=1e-4)
        scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(opt, T_0=SGDR_T0, T_mult=SGDR_Tmult)
        class_weights = make_class_weight_tensor(y[tr_sub_idx], n_classes)
        criterion = nn.CrossEntropyLoss(weight=class_weights)

        best_state = copy.deepcopy(model.state_dict())
        best_val = -1.0; bad = 0

        print(f"\n[Fold {fold}/{N_FOLDS}] Entrenando modelo global..."
              f" (n_train={len(tr_sub_idx)} | n_val={len(va_idx)} | n_test={len(te_idx)})")

        for epoch in range(1, EPOCHS_GLOBAL + 1):
            train_one_epoch(model, tr_loader, opt, criterion, epoch,
                            n_classes=n_classes, use_vat=USE_VAT,
                            vat_xi=VAT_XI, vat_eps=VAT_EPS, vat_ip=VAT_IP)
            scheduler.step(epoch-1 + 1e-8)

            _, _, tr_acc = evaluate(model, tr_loader, use_tta=True, tta_n=5)
            _, _, va_acc = evaluate(model, va_loader, use_tta=True, tta_n=5)

            if (epoch % LOG_EVERY == 0) or epoch in (1, 10, 20, 50, 100):
                cur_lr = opt.param_groups[0]['lr']
                print(f"  Época {epoch:3d} | train_acc={tr_acc:.4f} | val_acc={va_acc:.4f} | LR={cur_lr:.5f}")

            if va_acc > best_val + 1e-4:
                best_val = va_acc; best_state = copy.deepcopy(model.state_dict()); bad = 0
            else:
                bad += 1
                if bad >= GLOBAL_PATIENCE:
                    print(f"  Early stopping en época {epoch} (mejor val_acc={best_val:.4f})")
                    break

        # eval global
        model.load_state_dict(best_state)
        y_true, y_pred, acc_global = evaluate(model, te_loader, use_tta=True, tta_n=5)
        global_folds.append(acc_global); all_true.append(y_true); all_pred.append(y_pred)
        print(f"[Fold {fold}/{N_FOLDS}] Global acc={acc_global:.4f}")
        print_report(y_true, y_pred, CLASS_NAMES_4C)

        # ---------- Fine-tuning por sujeto (secundario) ----------
        X_te, y_te, g_te = X[te_idx], y[te_idx], groups[te_idx]
        y_true_ft_all, y_pred_ft_all = [], []
        used_subjects = 0
        for sid in np.unique(g_te):
            idx = np.where(g_te == sid)[0]
            Xs, ys = X_te[idx], y_te[idx]
            if len(np.unique(ys)) < 2:
                continue
            y_true_subj, y_pred_subj = subject_cv_finetune_predict_progressive(model, Xs, ys, n_classes=n_classes)
            y_true_ft_all.append(y_true_subj); y_pred_ft_all.append(y_pred_subj); used_subjects += 1
        if len(y_true_ft_all) > 0:
            y_true_ft_all = np.concatenate(y_true_ft_all)
            y_pred_ft_all = np.concatenate(y_pred_ft_all)
            acc_ft = (y_true_ft_all == y_pred_ft_all).mean()
            print(f"  Fine-tuning (por sujeto) acc={acc_ft:.4f} | Δ(FT-Global)={acc_ft - acc_global:+.4f} | sujetos={used_subjects}")
        else:
            acc_ft = np.nan
            print("  Fine-tuning no ejecutado (sujeto(s) con muestras insuficientes).")
        ft_prog_folds.append(acc_ft)

    # ---------- resultados finales ----------
    if len(all_true) > 0:
        all_true = np.concatenate(all_true)
        all_pred = np.concatenate(all_pred)
    else:
        all_true = np.array([], dtype=int)
        all_pred = np.array([], dtype=int)

    print("\n" + "="*60)
    print("RESULTADOS FINALES")
    print("="*60)
    print("Global folds:", [f"{a:.4f}" for a in global_folds])
    if len(global_folds) > 0:
        print(f"Global mean: {np.mean(global_folds):.4f}")
    print("Fine-tune folds:", [("nan" if (a is None or np.isnan(a)) else f"{a:.4f}") for a in ft_prog_folds])
    if len(ft_prog_folds) > 0:
        print(f"Fine-tune mean: {np.nanmean(ft_prog_folds):.4f}")
        print(f"Δ(FT-Global) mean: {np.nanmean(ft_prog_folds) - np.mean(global_folds):+.4f}")

    if all_true.size > 0:
        plot_confusion(all_true, all_pred, CLASS_NAMES_4C,
                       title="Confusion Matrix - Global (ParallelEEG+STCN+MHSA+VAT)",
                       fname="confusion_parallel_stcn_mhsa_vat_global_allfolds.png")
        print("↳ Matriz de confusión guardada: confusion_parallel_stcn_mhsa_vat_global_allfolds.png")

    return {
        "global_folds": global_folds,
        "ft_prog_folds": ft_prog_folds,
        "all_true": all_true,
        "all_pred": all_pred,
        "folds_json_path": str(folds_json_path)
    }

# ---------- MAIN ----------
if __name__ == "__main__":
    print("🧠 INICIANDO EXPERIMENTO Parallel-EEGNet + S-TCN + MHSA + VAT")
    print(f"🔧 Config: {CLASS_SCENARIO}, 8 canales, ventana {WINDOW_MODE}, VAT={'ON' if USE_VAT else 'OFF'}")
    print(f"⚙️  FT: epochs={FT_EPOCHS}, base_lr={FT_BASE_LR}, head_lr={FT_HEAD_LR}, L2SP={FT_L2SP}, patience={FT_PATIENCE}, CV={CALIB_CV_FOLDS}")
    run_experiment()


🚀 Usando dispositivo: cuda
🧠 INICIANDO EXPERIMENTO Parallel-EEGNet + S-TCN + MHSA + VAT
🔧 Config: 4c, 8 canales, ventana 3s, VAT=ON
⚙️  FT: epochs=30, base_lr=5e-05, head_lr=0.001, L2SP=0.0001, patience=5, CV=4
Sujetos elegibles: 103 → [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]...


Construyendo dataset (RAW):   0%|          | 0/103 [00:00<?, ?it/s]

Construyendo dataset (RAW): 100%|██████████| 103/103 [00:44<00:00,  2.32it/s]


Dataset construido: N=8652 | T=480 | C=8 | clases=4 | sujetos únicos=103
Listo para entrenar: N=8652 | T=480 | C=8 | clases=4 | sujetos=103

[Fold 1/5] Entrenando modelo global... (n_train=5796 | n_val=1092 | n_test=1764)
  Época   1 | train_acc=0.4072 | val_acc=0.3956 | LR=0.01000
  Época   5 | train_acc=0.5002 | val_acc=0.4130 | LR=0.00250
  Época  10 | train_acc=0.3463 | val_acc=0.3214 | LR=0.00854
  Época  15 | train_acc=0.5547 | val_acc=0.4405 | LR=0.00250
  Época  20 | train_acc=0.5328 | val_acc=0.4222 | LR=0.00996
  Época  25 | train_acc=0.5488 | val_acc=0.4350 | LR=0.00854
  Early stopping en época 25 (mejor val_acc=0.4405)
[Fold 1/5] Global acc=0.4552
              precision    recall  f1-score   support

        Left     0.5225    0.5261    0.5243       441
       Right     0.4331    0.5578    0.4876       441
  Both Fists     0.3918    0.3447    0.3667       441
   Both Feet     0.4753    0.3923    0.4298       441

    accuracy                         0.4552      1764
   ma

In [4]:
# -*- coding: utf-8 -*-
# Two-Stage Transformer Architecture for Physionet 2099 - 4-class MI
# Versión corregida - Dimensiones compatibles

import os, re, math, random, json, itertools, copy
from pathlib import Path
from datetime import datetime

import numpy as np
import pandas as pd
import mne

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset, WeightedRandomSampler

from sklearn.model_selection import GroupKFold, StratifiedKFold, StratifiedShuffleSplit, GroupShuffleSplit
from sklearn.utils import check_random_state
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.cluster import KMeans

import matplotlib.pyplot as plt
from tqdm import tqdm

# =========================
# CONFIGURACIÓN GENERAL
# =========================
PROJ = Path('..').resolve().parent
DATA_RAW = PROJ / 'data' / 'raw'
CACHE_DIR = PROJ / 'data' / 'cache'
FOLDS_DIR = PROJ / 'models' / 'folds' / 'Kfold5.json'
CACHE_DIR.mkdir(parents=True, exist_ok=True)

# Dispositivo y semilla
RANDOM_STATE = 42
torch.manual_seed(RANDOM_STATE); np.random.seed(RANDOM_STATE); random.seed(RANDOM_STATE)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🚀 Usando dispositivo: {DEVICE}")

# Escenario y ventana
CLASS_SCENARIO = '4c'
WINDOW_MODE = '6s'
FS = 160.0
N_FOLDS = 5

# Arquitectura Two-Stage
BATCH_SIZE = 16  # Reducido para evitar problemas de memoria
EPOCHS_STAGE1 = 50
EPOCHS_STAGE2 = 30
LR_STAGE1 = 1e-3
LR_STAGE2 = 1e-4

# Parámetros Transformer
N_HEADS = 2
D_MODEL = 64  # Dimensión del modelo para atención
DROPOUT_ATTN = 0.3
TCN_FILTERS = 32
TCN_KERNEL_SIZE = 4
TCN_N_BLOCKS = 2
TIME_SLICES = 3  # Reducido para simplificar

# TabNet
TABNET_N_D = 16
TABNET_N_A = 16
TABNET_N_HIDDEN = 64

# Data Augmentation
N_CLUSTERS = 3
AUG_SAMPLES_PER_CLASS = 20  # Reducido para prueba

# =========================
# UTILIDADES DE CANALES (MANTENIDAS)
# =========================
def normalize_label(s: str) -> str:
    if s is None: return s
    s = s.strip()
    s = re.sub(r'[^A-Za-z0-9]', '', s)
    s = re.sub(r'([A-Za-z])0([0-9])', r'\1\2', s)
    s = re.sub(r'([A-Za-z])Z$', r'\1z', s)
    s = s.replace('fp', 'Fp').replace('FP', 'Fp')
    s = ''.join(ch.upper() if ch != 'z' else 'z' for ch in s)
    return s

def rename_channels_1010(raw: mne.io.BaseRaw):
    mapping = {}
    for ch in raw.ch_names:
        lab = normalize_label(ch)
        lab = lab[:-1] + 'z' if lab.endswith('Z') else lab
        lab = re.sub(r'([A-Z])Z$', r'\1z', lab)
        mapping[ch] = lab
    mne.rename_channels(raw.info, mapping)

# Canales (8 con FCz)
EXPECTED_8 = ['C3','C4','Cz','CP3','CP4','FC3','FC4','FCz']

def ensure_channels_order(raw: mne.io.BaseRaw, desired_channels=EXPECTED_8):
    missing = [ch for ch in desired_channels if ch not in raw.ch_names]
    if missing:
        print(f"Warning: faltan canales {missing} en archivo {getattr(raw,'filenames', [''])[0]}")
        return None
    raw.reorder_channels([ch for ch in raw.ch_names if ch in desired_channels] +
                         [ch for ch in raw.ch_names if ch not in desired_channels])
    raw.pick_channels(desired_channels, ordered=True)
    return raw

# =========================
# LECTURA DE EDF y EVENTOS (MANTENIDAS)
# =========================
_re_file = re.compile(r'[Ss](\d{3}).*?[Rr](\d{2})')

def parse_subject_run(path: Path):
    m = _re_file.search(str(path))
    if not m: return None, None
    return int(m.group(1)), int(m.group(2))

# Runs MI
MI_RUNS_LR = [4, 8, 12]
MI_RUNS_OF = [6, 10, 14]
BASELINE_RUNS_EO = [1]

def run_kind(run_id:int):
    if run_id in MI_RUNS_LR: return 'LR'
    if run_id in MI_RUNS_OF: return 'OF'
    if run_id in BASELINE_RUNS_EO: return 'EO'
    return None

# Sujetos excluidos
EXCLUDE_SUBJECTS = {38, 88, 89, 92, 100, 104}

def read_raw_edf(path: Path):
    raw = mne.io.read_raw_edf(path, preload=True, verbose=False)
    raw.pick(mne.pick_types(raw.info, eeg=True))
    rename_channels_1010(raw)
    try:
        mont = mne.channels.make_standard_montage('standard_1020')
        raw.set_montage(mont, on_missing='ignore')
    except Exception:
        pass
    if abs(raw.info['sfreq'] - FS) > 1e-6:
        raw.resample(FS, npad="auto")
    raw = ensure_channels_order(raw, EXPECTED_8)
    if raw is None:
        return None

    # Filtrado básico
    raw.filter(4.0, 40.0, method='iir', verbose=False)
    
    return raw

def collect_events_T1T2(raw: mne.io.BaseRaw):
    if raw.annotations is None or len(raw.annotations) == 0:
        return []
    def _norm(s): return str(s).strip().upper().replace(' ', '')
    res = []
    for onset, desc in zip(raw.annotations.onset, raw.annotations.description):
        tag = _norm(desc)
        if tag in ('T1','T2'):
            res.append((float(onset), tag))
    res.sort()
    dedup = []
    last_t1 = last_t2 = -1e9
    for t, tag in res:
        if tag == 'T1':
            if (t - last_t1) >= 0.5: dedup.append((t, tag)); last_t1 = t
        else:
            if (t - last_t2) >= 0.5: dedup.append((t, tag)); last_t2 = t
    return dedup

# =========================
# CONSTRUCCIÓN DE DATASETS (MANTENIDAS)
# =========================
def subjects_available():
    subs = []
    for sdir in sorted(DATA_RAW.glob('S*')):
        if not sdir.is_dir(): continue
        try: sid = int(sdir.name[1:])
        except: continue
        if sid in EXCLUDE_SUBJECTS: continue
        any_mi = any((sdir / f"S{sid:03d}R{r:02d}.edf").exists() for r in (MI_RUNS_LR + MI_RUNS_OF))
        if any_mi: subs.append(sid)
    return subs

def extract_trials_from_run(edf_path: Path, scenario: str, window_mode: str):
    subj, run = parse_subject_run(edf_path)
    kind = run_kind(run)
    if kind not in ('LR','OF','EO'):
        return ([], [])

    raw = read_raw_edf(edf_path)
    if raw is None:
        return ([], [])

    data = raw.get_data()
    fs = raw.info['sfreq']
    assert abs(fs - FS) < 1e-6

    out = []

    if kind in ('LR','OF'):
        events = collect_events_T1T2(raw)
        if window_mode == '3s':
            rel_start, rel_end = 0.0, 3.0
        else:  # 6s
            rel_start, rel_end = -1.0, 5.0

        for onset_sec, tag in events:
            if kind == 'LR':
                if tag == 'T1': label = 'L'
                elif tag == 'T2': label = 'R'
                else: continue
            else:
                if tag == 'T1': label = 'BFISTS'
                elif tag == 'T2': label = 'BFEET'
                else: continue

            if scenario == '2c' and label not in ('L','R'): continue
            if scenario == '3c' and label not in ('L','R','BFISTS'): continue
            if scenario == '4c' and label not in ('L','R','BFISTS','BFEET'): continue

            s = int(round((raw.first_time + onset_sec + rel_start) * fs))
            e = int(round((raw.first_time + onset_sec + rel_end) * fs))
            if s < 0 or e > data.shape[1]:
                continue

            seg = data[:, s:e].T.astype(np.float32)
            # Normalización z-score por canal
            seg = (seg - seg.mean(axis=0, keepdims=True)) / (seg.std(axis=0, keepdims=True) + 1e-6)

            if label == 'L':       y = 0
            elif label == 'R':     y = 1
            elif label == 'BFISTS':y = 2
            elif label == 'BFEET': y = 3
            else: continue

            out.append((seg, y, subj))

    elif kind == 'EO':
        return ([], raw.ch_names)

    return out, raw.ch_names

def build_dataset_all(subjects, scenario='4c', window_mode='3s'):
    X, y, groups = [], [], []
    ch_template = None

    for s in tqdm(subjects, desc="Construyendo dataset (RAW)"):
        sdir = DATA_RAW / f"S{s:03d}"
        if not sdir.exists(): continue

        trials_L, trials_R, trials_FISTS, trials_FEET = [], [], [], []

        for r in MI_RUNS_LR:
            p = sdir / f"S{s:03d}R{r:02d}.edf"
            if not p.exists(): continue
            outs, chs = extract_trials_from_run(p, scenario, window_mode)
            if ch_template is None and chs: ch_template = chs
            for seg, lab, _ in outs:
                if lab == 0: trials_L.append(seg)
                elif lab == 1: trials_R.append(seg)

        for r in MI_RUNS_OF:
            p = sdir / f"S{s:03d}R{r:02d}.edf"
            if not p.exists(): continue
            outs, chs = extract_trials_from_run(p, scenario, window_mode)
            if ch_template is None and chs: ch_template = chs
            for seg, lab, _ in outs:
                if lab == 2: trials_FISTS.append(seg)
                elif lab == 3: trials_FEET.append(seg)

        need_per_class = 15  # Reducido para prueba
        def pick(trials, n, rng):
            if len(trials) < n:
                idx = rng.choice(len(trials), size=n, replace=True)
                return [trials[i] for i in idx]
            rng.shuffle(trials)
            return trials[:n]

        rng = check_random_state(RANDOM_STATE + s)
        if len(trials_L)==0 or len(trials_R)==0 or len(trials_FISTS)==0 or len(trials_FEET)==0:
            continue

        Lp  = pick(trials_L,     need_per_class, rng)
        Rp  = pick(trials_R,     need_per_class, rng)
        FIp = pick(trials_FISTS, need_per_class, rng)
        FEp = pick(trials_FEET,  need_per_class, rng)

        pack = [(Lp, 0), (Rp, 1), (FIp, 2), (FEp, 3)]
        for segs, lab in pack:
            for seg in segs:
                X.append(seg); y.append(lab); groups.append(s)

    if len(X) == 0:
        print("No se pudieron cargar datos. Verifica las rutas de archivos.")
        return None, None, None, None
        
    X = np.stack(X, axis=0)
    y = np.asarray(y, dtype=np.int64)
    groups = np.asarray(groups, dtype=np.int64)

    n, T, C = X.shape
    n_classes = len(np.unique(y))
    print(f"Dataset construido: N={n} | T={T} | C={C} | clases={n_classes} | sujetos únicos={len(np.unique(groups))}")
    return X, y, groups, ch_template

# =========================
# ARQUITECTURA TWO-STAGE TRANSFORMER - CORREGIDA
# =========================

class SimpleEEGNet(nn.Module):
    """
    EEGNet simplificado y estable
    """
    def __init__(self, n_ch=8, F1=16, D=2, T=960):
        super().__init__()
        self.F1 = F1
        self.F2 = F1 * D
        
        # Bloque 1: Temporal + Espacial
        self.conv1 = nn.Conv2d(1, F1, (64, 1), padding=(32, 0), bias=False)
        self.bn1 = nn.BatchNorm2d(F1)
        
        # Bloque 2: Depthwise
        self.conv2 = nn.Conv2d(F1, self.F2, (1, n_ch), groups=F1, bias=False)
        self.bn2 = nn.BatchNorm2d(self.F2)
        
        # Pooling
        self.pool1 = nn.AvgPool2d((8, 1))
        self.dropout1 = nn.Dropout(0.3)
        
        # Bloque 3: Separable
        self.conv3 = nn.Conv2d(self.F2, self.F2, (16, 1), groups=self.F2, padding=(8, 0), bias=False)
        self.conv4 = nn.Conv2d(self.F2, self.F2, (1, 1), bias=False)
        self.bn3 = nn.BatchNorm2d(self.F2)
        
        self.pool2 = nn.AvgPool2d((8, 1))
        self.dropout2 = nn.Dropout(0.4)
        
        # Calcular dimensión de salida
        self.output_dim = self._get_output_dim(n_ch, T)
        
    def _get_output_dim(self, n_ch, T):
        # Pasar un tensor dummy para calcular la dimensión de salida
        x = torch.randn(1, 1, T, n_ch)
        with torch.no_grad():
            x = self.forward_features(x)
        return x.view(1, -1).size(1)
    
    def forward_features(self, x):
        # Bloque 1
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.elu(x)
        
        # Bloque 2
        x = self.conv2(x)
        x = self.bn2(x)
        x = F.elu(x)
        x = self.pool1(x)
        x = self.dropout1(x)
        
        # Bloque 3
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.bn3(x)
        x = F.elu(x)
        x = self.pool2(x)
        x = self.dropout2(x)
        
        return x
    
    def forward(self, x):
        x = self.forward_features(x)
        return x.flatten(1)

import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    """
    Mecanismo de atención multi-cabeza simplificado
    """
    def __init__(self, d_model, n_heads=8, dropout=0.1):
        super().__init__()
        assert d_model % n_heads == 0
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        # Proyecciones lineales
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_o = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        B, T, d_model = x.shape
        
        # Proyecciones Q, K, V
        Q = self.w_q(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        K = self.w_k(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        V = self.w_v(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        
        # Atención escalada
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        # Aplicar atención
        context = torch.matmul(attn_weights, V)
        context = context.transpose(1, 2).contiguous().view(B, T, self.d_model)
        
        # Proyección final
        output = self.w_o(context)
        return output, attn_weights

class TemporalConvBlock(nn.Module):
    """
    Bloque convolucional temporal simple
    """
    def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1):
        super().__init__()
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size, 
                              padding=(kernel_size-1)*dilation//2, 
                              dilation=dilation)
        self.bn1 = nn.BatchNorm1d(out_channels)
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size, 
                              padding=(kernel_size-1)*dilation//2, 
                              dilation=dilation)
        self.bn2 = nn.BatchNorm1d(out_channels)
        
        self.residual = nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
        
    def forward(self, x):
        residual = self.residual(x)
        x = F.elu(self.bn1(self.conv1(x)))
        x = self.bn2(self.conv2(x))
        return F.elu(x + residual)

class Stage1Architecture(nn.Module):
    """
    Arquitectura Stage 1 simplificada y estable
    """
    def __init__(self, n_ch=8, n_classes=4, T=960, time_slices=3):
        super().__init__()
        self.n_ch = n_ch
        self.n_classes = n_classes
        self.T = T
        self.time_slices = time_slices
        
        # EEGNet base
        self.eegnet = SimpleEEGNet(n_ch=n_ch, T=T)
        eegnet_output_dim = self.eegnet.output_dim
        
        # Proyección para atención
        self.projection = nn.Linear(eegnet_output_dim, D_MODEL)
        
        # Multi-Head Attention
        self.attention = MultiHeadAttention(d_model=D_MODEL, n_heads=N_HEADS, dropout=DROPOUT_ATTN)
        
        # Bloques temporales
        self.temporal_blocks = nn.ModuleList()
        for i in range(TCN_N_BLOCKS):
            block = TemporalConvBlock(D_MODEL, D_MODEL, kernel_size=TCN_KERNEL_SIZE, dilation=2**i)
            self.temporal_blocks.append(block)
        
        # Capa de salida
        self.output_dim = D_MODEL * time_slices
        self.classifier = nn.Linear(self.output_dim, n_classes)
        
    def forward(self, x):
        B, _, T, C = x.shape
        
        # 1. EEGNet features
        eeg_features = self.eegnet(x)  # (B, eegnet_output_dim)
        
        # 2. Proyección para time slicing
        projected = self.projection(eeg_features).view(B, -1, D_MODEL)  # (B, seq_len, D_MODEL)
        
        # 3. Time slicing simple
        seq_len = projected.shape[1]
        slice_length = max(1, seq_len // self.time_slices)
        
        slice_embeddings = []
        for i in range(self.time_slices):
            start_idx = i * slice_length
            end_idx = min((i + 1) * slice_length, seq_len)
            
            if start_idx < seq_len:
                slice_data = projected[:, start_idx:end_idx, :]
                
                # Aplicar atención
                attended, _ = self.attention(slice_data)
                
                # Aplicar bloques temporales
                temporal_feat = attended.transpose(1, 2)  # (B, D_MODEL, T_slice)
                for block in self.temporal_blocks:
                    temporal_feat = block(temporal_feat)
                
                # Pooling global
                slice_embed = temporal_feat.mean(dim=-1)  # (B, D_MODEL)
                slice_embeddings.append(slice_embed)
        
        # Concatenar embeddings
        if slice_embeddings:
            embeddings = torch.cat(slice_embeddings, dim=1)  # (B, D_MODEL * n_slices)
        else:
            embeddings = eeg_features  # Fallback
        
        # Clasificación
        logits = self.classifier(embeddings)
        
        return logits, embeddings

class SimpleTabNet(nn.Module):
    """
    TabNet simplificado para Stage 2
    """
    def __init__(self, input_dim, n_classes=4, n_d=16, n_steps=3):
        super().__init__()
        self.n_d = n_d
        self.n_steps = n_steps
        
        # Batch normalization inicial
        self.bn = nn.BatchNorm1d(input_dim)
        
        # Transformadores de características
        self.feature_transformers = nn.ModuleList()
        for i in range(n_steps):
            transformer = nn.Sequential(
                nn.Linear(input_dim, n_d * 2),
                nn.BatchNorm1d(n_d * 2),
                nn.GLU(dim=1)
            )
            self.feature_transformers.append(transformer)
        
        # Atención por pasos
        self.attentive_transformers = nn.ModuleList()
        for i in range(n_steps):
            attention = nn.Sequential(
                nn.Linear(n_d, input_dim),
                nn.BatchNorm1d(input_dim)
            )
            self.attentive_transformers.append(attention)
        
        # Clasificador final
        self.classifier = nn.Linear(n_d * n_steps, n_classes)
        
    def forward(self, x):
        x = self.bn(x)
        
        decision_outputs = []
        total_attention = torch.zeros(x.size(0), x.size(1), device=x.device)
        
        for step in range(self.n_steps):
            # Aplicar máscara de atención
            mask = torch.softmax(self.attentive_transformers[step](x), dim=1)
            total_attention = total_attention + mask
            
            # Transformar características
            masked_x = x * mask
            transformed = self.feature_transformers[step](masked_x)
            
            decision_outputs.append(transformed)
        
        # Combinar decisiones
        final_decision = torch.cat(decision_outputs, dim=1)
        logits = self.classifier(final_decision)
        
        return logits

class TwoStageTransformer(nn.Module):
    """
    Arquitectura completa de dos etapas - CORREGIDA
    """
    def __init__(self, n_ch=8, n_classes=4, T=960, time_slices=3):
        super().__init__()
        
        # Stage 1
        self.stage1 = Stage1Architecture(n_ch, n_classes, T, time_slices)
        stage1_output_dim = self.stage1.output_dim
        
        # Stage 2 - TabNet
        # Características PSD: 8 canales * 5 bandas = 40
        psd_feature_dim = 40
        tabnet_input_dim = stage1_output_dim + psd_feature_dim
        
        self.stage2 = SimpleTabNet(tabnet_input_dim, n_classes, n_d=TABNET_N_D)
        
    def forward(self, x, psd_features=None):
        # Stage 1
        logits_stage1, embeddings_stage1 = self.stage1(x)
        
        # Si no hay características PSD, retornar solo Stage 1
        if psd_features is None:
            return logits_stage1, embeddings_stage1
        
        # Stage 2: Combinar embeddings con características PSD
        combined_features = torch.cat([embeddings_stage1, psd_features], dim=1)
        logits_stage2 = self.stage2(combined_features)
        
        return logits_stage1, logits_stage2, embeddings_stage1

# =========================
# EXTRACCIÓN DE CARACTERÍSTICAS PSD
# =========================
def extract_psd_features(X, fs=160.0):
    """
    Extraer características de densidad espectral de potencia (PSD)
    """
    print("Extrayendo características PSD...")
    
    n_trials, T, n_channels = X.shape
    n_bands = 5
    psd_features = np.zeros((n_trials, n_channels * n_bands))
    
    # Bandas de frecuencia
    freq_bands = [(1, 4), (4, 8), (8, 14), (14, 31), (31, 49)]
    
    for i in tqdm(range(n_trials)):
        trial_psd = []
        for ch in range(n_channels):
            signal = X[i, :, ch]
            
            # Calcular PSD usando FFT
            fft_vals = np.fft.rfft(signal)
            freqs = np.fft.rfftfreq(len(signal), 1/fs)
            psd = np.abs(fft_vals) ** 2
            
            for f_low, f_high in freq_bands:
                band_mask = (freqs >= f_low) & (freqs <= f_high)
                if np.any(band_mask):
                    band_power = np.trapz(psd[band_mask], freqs[band_mask])
                else:
                    band_power = 0.0
                trial_psd.append(band_power)
        
        psd_features[i] = np.array(trial_psd)
    
    # Normalizar características PSD
    psd_features = (psd_features - psd_features.mean(axis=0)) / (psd_features.std(axis=0) + 1e-8)
    
    return psd_features.astype(np.float32)

# =========================
# DATA AUGMENTATION SIMPLIFICADA
# =========================
def simple_augmentation(X, y, groups, n_augmented_per_class=20):
    """
    Aumento de datos simplificado para prueba
    """
    print("Aplicando aumento de datos simplificado...")
    
    X_aug, y_aug, groups_aug = [], [], []
    n_classes = len(np.unique(y))
    
    for class_idx in range(n_classes):
        class_mask = (y == class_idx)
        X_class = X[class_mask]
        groups_class = groups[class_mask]
        
        if len(X_class) == 0:
            continue
            
        n_generated = 0
        while n_generated < n_augmented_per_class:
            # Seleccionar trial aleatorio
            idx = np.random.randint(len(X_class))
            original_trial = X_class[idx]
            
            # Aplicar transformaciones simples
            # 1. Ruido gaussiano
            noise = np.random.normal(0, 0.01, original_trial.shape)
            augmented_trial = original_trial + noise
            
            X_aug.append(augmented_trial)
            y_aug.append(class_idx)
            groups_aug.append(groups_class[idx])
            n_generated += 1
    
    if X_aug:
        X_aug = np.stack(X_aug, axis=0)
        y_aug = np.array(y_aug, dtype=np.int64)
        groups_aug = np.array(groups_aug, dtype=np.int64)
        
        # Combinar con datos originales
        X_combined = np.concatenate([X, X_aug], axis=0)
        y_combined = np.concatenate([y, y_aug], axis=0)
        groups_combined = np.concatenate([groups, groups_aug], axis=0)
        
        print(f"Aumento completado: {len(X_aug)} muestras añadidas")
        return X_combined, y_combined, groups_combined
    
    return X, y, groups

# =========================
# DATASET Y ENTRENAMIENTO
# =========================
class EEGDataset(Dataset):
    def __init__(self, X, y, groups, psd_features=None):
        self.X = X.astype(np.float32)
        self.y = y.astype(np.int64)
        self.g = groups.astype(np.int64)
        self.psd_features = psd_features.astype(np.float32) if psd_features is not None else None
        
    def __len__(self): 
        return self.X.shape[0]
    
    def __getitem__(self, idx):
        x = self.X[idx]
        x = np.expand_dims(x, 0)  # (1, T, C)
        
        if self.psd_features is not None:
            return (torch.from_numpy(x), 
                    torch.from_numpy(self.psd_features[idx]), 
                    torch.tensor(self.y[idx]), 
                    torch.tensor(self.g[idx]))
        else:
            return (torch.from_numpy(x), 
                    torch.tensor(self.y[idx]), 
                    torch.tensor(self.g[idx]))

def train_stage1(model, train_loader, val_loader, epochs=50, lr=1e-3):
    """Entrenamiento de la Etapa 1"""
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
    criterion = nn.CrossEntropyLoss()
    
    best_val_acc = 0.0
    history = {'train_loss': [], 'val_acc': []}
    
    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        
        for batch in train_loader:
            if len(batch) == 4:  # Con PSD features
                x, _, y, _ = batch
            else:  # Sin PSD features
                x, y, _ = batch
                
            x, y = x.to(DEVICE), y.to(DEVICE)
            
            optimizer.zero_grad()
            logits, _ = model(x)
            loss = criterion(logits, y)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
        
        # Validación
        val_acc = evaluate_stage1(model, val_loader)
        history['train_loss'].append(train_loss / len(train_loader))
        history['val_acc'].append(val_acc)
        
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), 'best_stage1.pth')
        
        if epoch % 10 == 0:
            print(f'Epoch {epoch}: Train Loss = {train_loss/len(train_loader):.4f}, Val Acc = {val_acc:.4f}')
    
    return history, best_val_acc

def train_stage2(model, train_loader, val_loader, epochs=30, lr=1e-4):
    """Entrenamiento de la Etapa 2 (TabNet)"""
    optimizer = optim.Adam(model.stage2.parameters(), lr=lr, weight_decay=1e-4)
    criterion = nn.CrossEntropyLoss()
    
    best_val_acc = 0.0
    history = {'train_loss': [], 'val_acc': []}
    
    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        
        for batch in train_loader:
            x, psd, y, _ = batch
            x, psd, y = x.to(DEVICE), psd.to(DEVICE), y.to(DEVICE)
            
            optimizer.zero_grad()
            _, logits_stage2, _ = model(x, psd)
            loss = criterion(logits_stage2, y)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
        
        # Validación
        val_acc = evaluate_stage2(model, val_loader)
        history['train_loss'].append(train_loss / len(train_loader))
        history['val_acc'].append(val_acc)
        
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), 'best_stage2.pth')
        
        if epoch % 10 == 0:
            print(f'Epoch {epoch}: Train Loss = {train_loss/len(train_loader):.4f}, Val Acc = {val_acc:.4f}')
    
    return history, best_val_acc

@torch.no_grad()
def evaluate_stage1(model, loader):
    """Evaluación de la Etapa 1"""
    model.eval()
    correct = 0
    total = 0
    
    for batch in loader:
        if len(batch) == 4:  # Con PSD features
            x, _, y, _ = batch
        else:  # Sin PSD features
            x, y, _ = batch
            
        x, y = x.to(DEVICE), y.to(DEVICE)
        logits, _ = model(x)
        pred = logits.argmax(dim=1)
        correct += (pred == y).sum().item()
        total += y.size(0)
    
    return correct / total

@torch.no_grad()
def evaluate_stage2(model, loader):
    """Evaluación de la Etapa 2"""
    model.eval()
    correct = 0
    total = 0
    
    for batch in loader:
        x, psd, y, _ = batch
        x, psd, y = x.to(DEVICE), psd.to(DEVICE), y.to(DEVICE)
        _, logits_stage2, _ = model(x, psd)
        pred = logits_stage2.argmax(dim=1)
        correct += (pred == y).sum().item()
        total += y.size(0)
    
    return correct / total

# =========================
# EXPERIMENTO PRINCIPAL
# =========================
def run_two_stage_experiment():
    print("🧠 INICIANDO EXPERIMENTO TWO-STAGE TRANSFORMER ARCHITECTURE")
    print("📖 Basado en: 'A two-stage transformer based network for motor imagery classification'")
    
    # Cargar datos
    subs = subjects_available()
    print(f"Sujetos disponibles: {len(subs)}")
    
    X, y, groups, chs = build_dataset_all(subs, scenario=CLASS_SCENARIO, window_mode=WINDOW_MODE)
    
    if X is None:
        print("Error: No se pudieron cargar los datos. Verifica las rutas.")
        return None
        
    N, T, C = X.shape
    n_classes = len(np.unique(y))
    
    print(f"Datos cargados: N={N}, T={T}, C={C}, Clases={n_classes}")
    
    # Aumento de datos simplificado
    X_aug, y_aug, groups_aug = simple_augmentation(
        X, y, groups, n_augmented_per_class=AUG_SAMPLES_PER_CLASS
    )
    
    # Extraer características PSD
    psd_features = extract_psd_features(X_aug, fs=FS)
    
    # Crear dataset
    dataset = EEGDataset(X_aug, y_aug, groups_aug, psd_features)
    
    # Evaluación inter-sujeto con GroupKFold
    unique_subjects = np.unique(groups_aug)
    gkf = GroupKFold(n_splits=min(N_FOLDS, len(unique_subjects)))
    
    fold_results = {'stage1': [], 'stage2': []}
    
    for fold, (train_idx, test_idx) in enumerate(gkf.split(X_aug, y_aug, groups_aug)):
        print(f"\n=== FOLD {fold + 1}/{min(N_FOLDS, len(unique_subjects))} ===")
        
        # Split train/val
        train_subjects = np.unique(groups_aug[train_idx])
        val_size = max(1, int(0.15 * len(train_subjects)))
        val_subjects = np.random.choice(train_subjects, val_size, replace=False)
        
        val_mask = np.isin(groups_aug[train_idx], val_subjects)
        val_idx = train_idx[val_mask]
        train_idx_fold = train_idx[~val_mask]
        
        # DataLoaders
        train_loader = DataLoader(Subset(dataset, train_idx_fold), batch_size=BATCH_SIZE, shuffle=True)
        val_loader = DataLoader(Subset(dataset, val_idx), batch_size=BATCH_SIZE, shuffle=False)
        test_loader = DataLoader(Subset(dataset, test_idx), batch_size=BATCH_SIZE, shuffle=False)
        
        # Modelo
        model = TwoStageTransformer(n_ch=C, n_classes=n_classes, T=T, time_slices=TIME_SLICES).to(DEVICE)
        
        # Entrenamiento Stage 1
        print("Entrenando Stage 1...")
        try:
            history_stage1, best_stage1 = train_stage1(
                model, train_loader, val_loader, epochs=EPOCHS_STAGE1, lr=LR_STAGE1
            )
            
            # Evaluación Stage 1 en test
            test_acc_stage1 = evaluate_stage1(model, test_loader)
            fold_results['stage1'].append(test_acc_stage1)
            print(f"Stage 1 Test Accuracy: {test_acc_stage1:.4f}")
            
            # Entrenamiento Stage 2
            print("Entrenando Stage 2 (TabNet)...")
            history_stage2, best_stage2 = train_stage2(
                model, train_loader, val_loader, epochs=EPOCHS_STAGE2, lr=LR_STAGE2
            )
            
            # Evaluación Stage 2 en test
            test_acc_stage2 = evaluate_stage2(model, test_loader)
            fold_results['stage2'].append(test_acc_stage2)
            print(f"Stage 2 Test Accuracy: {test_acc_stage2:.4f}")
            print(f"Improvement: +{(test_acc_stage2 - test_acc_stage1)*100:.2f}%")
            
        except Exception as e:
            print(f"Error en el fold {fold + 1}: {e}")
            continue
    
    # Resultados finales
    if fold_results['stage1']:
        print("\n" + "="*60)
        print("RESULTADOS FINALES - TWO-STAGE TRANSFORMER ARCHITECTURE")
        print("="*60)
        
        stage1_mean = np.mean(fold_results['stage1'])
        stage1_std = np.std(fold_results['stage1'])
        stage2_mean = np.mean(fold_results['stage2'])
        stage2_std = np.std(fold_results['stage2'])
        
        print(f"Stage 1 (EEGNet+Attention+TCN): {stage1_mean:.4f} ± {stage1_std:.4f}")
        print(f"Stage 2 (TabNet): {stage2_mean:.4f} ± {stage2_std:.4f}")
        print(f"Overall Improvement: +{(stage2_mean - stage1_mean)*100:.2f}%")
        
        return {
            'stage1_results': fold_results['stage1'],
            'stage2_results': fold_results['stage2'],
            'stage1_mean': stage1_mean,
            'stage2_mean': stage2_mean,
            'improvement': stage2_mean - stage1_mean
        }
    else:
        print("No se completó ningún fold exitosamente.")
        return None

# =========================
# EJECUCIÓN
# =========================
if __name__ == "__main__":
    results = run_two_stage_experiment()

🚀 Usando dispositivo: cuda
🧠 INICIANDO EXPERIMENTO TWO-STAGE TRANSFORMER ARCHITECTURE
📖 Basado en: 'A two-stage transformer based network for motor imagery classification'
Sujetos disponibles: 103


Construyendo dataset (RAW):   0%|          | 0/103 [00:00<?, ?it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):   1%|          | 1/103 [00:00<00:17,  5.93it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):   2%|▏         | 2/103 [00:00<00:16,  6.14it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):   3%|▎         | 3/103 [00:00<00:16,  6.03it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):   4%|▍         | 4/103 [00:00<00:16,  6.09it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):   5%|▍         | 5/103 [00:00<00:15,  6.20it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):   6%|▌         | 6/103 [00:00<00:15,  6.25it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):   7%|▋         | 7/103 [00:01<00:16,  5.87it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):   8%|▊         | 8/103 [00:01<00:17,  5.58it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):   9%|▊         | 9/103 [00:01<00:17,  5.36it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  10%|▉         | 10/103 [00:01<00:17,  5.28it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  11%|█         | 11/103 [00:01<00:17,  5.19it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  12%|█▏        | 12/103 [00:02<00:17,  5.14it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  13%|█▎        | 13/103 [00:02<00:17,  5.11it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  14%|█▎        | 14/103 [00:02<00:17,  5.05it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  15%|█▍        | 15/103 [00:02<00:17,  4.98it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  16%|█▌        | 16/103 [00:03<00:21,  3.96it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  17%|█▋        | 17/103 [00:03<00:20,  4.24it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  17%|█▋        | 18/103 [00:03<00:19,  4.46it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  18%|█▊        | 19/103 [00:03<00:18,  4.62it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  19%|█▉        | 20/103 [00:03<00:17,  4.75it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  20%|██        | 21/103 [00:04<00:16,  4.84it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  21%|██▏       | 22/103 [00:04<00:16,  4.90it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  22%|██▏       | 23/103 [00:04<00:16,  4.95it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  23%|██▎       | 24/103 [00:04<00:15,  5.00it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  24%|██▍       | 25/103 [00:04<00:15,  5.02it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  25%|██▌       | 26/103 [00:05<00:15,  5.02it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  26%|██▌       | 27/103 [00:05<00:15,  5.04it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  27%|██▋       | 28/103 [00:05<00:14,  5.05it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  28%|██▊       | 29/103 [00:05<00:14,  5.06it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  29%|██▉       | 30/103 [00:05<00:14,  5.02it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  30%|███       | 31/103 [00:06<00:14,  5.04it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  31%|███       | 32/103 [00:06<00:14,  5.04it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  32%|███▏      | 33/103 [00:06<00:13,  5.05it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  33%|███▎      | 34/103 [00:06<00:13,  5.05it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  34%|███▍      | 35/103 [00:06<00:13,  5.04it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  35%|███▍      | 36/103 [00:07<00:13,  5.03it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  36%|███▌      | 37/103 [00:07<00:13,  5.05it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  37%|███▋      | 38/103 [00:07<00:12,  5.08it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  38%|███▊      | 39/103 [00:07<00:12,  5.09it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  39%|███▉      | 40/103 [00:07<00:12,  5.10it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  40%|███▉      | 41/103 [00:08<00:12,  5.09it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  41%|████      | 42/103 [00:08<00:11,  5.10it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  42%|████▏     | 43/103 [00:08<00:11,  5.11it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  43%|████▎     | 44/103 [00:08<00:11,  5.12it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  44%|████▎     | 45/103 [00:08<00:11,  5.10it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  45%|████▍     | 46/103 [00:09<00:11,  5.09it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  46%|████▌     | 47/103 [00:09<00:10,  5.11it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  47%|████▋     | 48/103 [00:09<00:10,  5.10it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  48%|████▊     | 49/103 [00:09<00:10,  5.07it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  49%|████▊     | 50/103 [00:09<00:10,  5.08it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  50%|████▉     | 51/103 [00:10<00:10,  5.05it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  50%|█████     | 52/103 [00:10<00:10,  5.00it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  51%|█████▏    | 53/103 [00:10<00:09,  5.01it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  52%|█████▏    | 54/103 [00:10<00:09,  4.99it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  53%|█████▎    | 55/103 [00:10<00:09,  4.94it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  54%|█████▍    | 56/103 [00:11<00:09,  4.92it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  55%|█████▌    | 57/103 [00:11<00:09,  4.81it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  56%|█████▋    | 58/103 [00:11<00:09,  4.71it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  57%|█████▋    | 59/103 [00:11<00:09,  4.59it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  58%|█████▊    | 60/103 [00:11<00:09,  4.57it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  59%|█████▉    | 61/103 [00:12<00:09,  4.52it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  60%|██████    | 62/103 [00:12<00:09,  4.45it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  61%|██████    | 63/103 [00:12<00:08,  4.46it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  62%|██████▏   | 64/103 [00:12<00:08,  4.40it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  63%|██████▎   | 65/103 [00:13<00:08,  4.41it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  64%|██████▍   | 66/103 [00:13<00:08,  4.37it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  65%|██████▌   | 67/103 [00:13<00:08,  4.37it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  66%|██████▌   | 68/103 [00:13<00:07,  4.40it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  67%|██████▋   | 69/103 [00:14<00:07,  4.36it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  68%|██████▊   | 70/103 [00:14<00:07,  4.37it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  69%|██████▉   | 71/103 [00:14<00:07,  4.37it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  70%|██████▉   | 72/103 [00:14<00:07,  4.41it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  71%|███████   | 73/103 [00:14<00:06,  4.42it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  72%|███████▏  | 74/103 [00:15<00:06,  4.24it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  73%|███████▎  | 75/103 [00:15<00:07,  3.86it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  74%|███████▍  | 76/103 [00:15<00:06,  4.02it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  75%|███████▍  | 77/103 [00:15<00:06,  4.15it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  76%|███████▌  | 78/103 [00:16<00:05,  4.17it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  77%|███████▋  | 79/103 [00:16<00:05,  4.27it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  78%|███████▊  | 80/103 [00:16<00:05,  4.29it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  79%|███████▊  | 81/103 [00:16<00:05,  4.32it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  80%|███████▉  | 82/103 [00:17<00:04,  4.33it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  81%|████████  | 83/103 [00:17<00:04,  4.32it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  82%|████████▏ | 84/103 [00:17<00:04,  4.35it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  83%|████████▎ | 85/103 [00:17<00:04,  4.36it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  83%|████████▎ | 86/103 [00:18<00:03,  4.36it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  84%|████████▍ | 87/103 [00:18<00:03,  4.36it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  85%|████████▌ | 88/103 [00:18<00:03,  4.34it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  86%|████████▋ | 89/103 [00:18<00:03,  4.38it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  87%|████████▋ | 90/103 [00:18<00:02,  4.34it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  88%|████████▊ | 91/103 [00:19<00:03,  3.98it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  89%|████████▉ | 92/103 [00:19<00:02,  3.94it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  90%|█████████ | 93/103 [00:19<00:02,  4.07it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  91%|█████████▏| 94/103 [00:19<00:02,  4.13it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  92%|█████████▏| 95/103 [00:20<00:01,  4.20it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  93%|█████████▎| 96/103 [00:20<00:01,  4.25it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  94%|█████████▍| 97/103 [00:20<00:01,  4.28it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  95%|█████████▌| 98/103 [00:20<00:01,  4.31it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  96%|█████████▌| 99/103 [00:21<00:00,  4.26it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  97%|█████████▋| 100/103 [00:21<00:00,  4.30it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  98%|█████████▊| 101/103 [00:21<00:00,  4.32it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  99%|█████████▉| 102/103 [00:21<00:00,  4.29it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW): 100%|██████████| 103/103 [00:22<00:00,  4.68it/s]

Dataset construido: N=6180 | T=960 | C=8 | clases=4 | sujetos únicos=103
Datos cargados: N=6180, T=960, C=8, Clases=4
Aplicando aumento de datos simplificado...





Aumento completado: 80 muestras añadidas
Extrayendo características PSD...


  band_power = np.trapz(psd[band_mask], freqs[band_mask])
100%|██████████| 6260/6260 [00:03<00:00, 1600.47it/s]



=== FOLD 1/5 ===
Entrenando Stage 1...
Error en el fold 1: Calculated padded input size per channel: (3). Kernel size: (4). Kernel size can't be greater than actual input size

=== FOLD 2/5 ===
Entrenando Stage 1...
Error en el fold 2: Calculated padded input size per channel: (3). Kernel size: (4). Kernel size can't be greater than actual input size

=== FOLD 3/5 ===
Entrenando Stage 1...
Error en el fold 3: Calculated padded input size per channel: (3). Kernel size: (4). Kernel size can't be greater than actual input size

=== FOLD 4/5 ===
Entrenando Stage 1...
Error en el fold 4: Calculated padded input size per channel: (3). Kernel size: (4). Kernel size can't be greater than actual input size

=== FOLD 5/5 ===
Entrenando Stage 1...
Error en el fold 5: Calculated padded input size per channel: (3). Kernel size: (4). Kernel size can't be greater than actual input size
No se completó ningún fold exitosamente.


In [7]:
# -*- coding: utf-8 -*-
# Two-Stage Transformer Architecture for Physionet 2099 - 4-class MI
# Versión FINAL - Dimensiones corregidas

import os, re, math, random, json, itertools, copy
from pathlib import Path
from datetime import datetime

import numpy as np
import pandas as pd
import mne

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset, WeightedRandomSampler

from sklearn.model_selection import GroupKFold, StratifiedKFold, StratifiedShuffleSplit, GroupShuffleSplit
from sklearn.utils import check_random_state
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.cluster import KMeans

import matplotlib.pyplot as plt
from tqdm import tqdm

# =========================
# CONFIGURACIÓN GENERAL
# =========================
PROJ = Path('..').resolve().parent
DATA_RAW = PROJ / 'data' / 'raw'
CACHE_DIR = PROJ / 'data' / 'cache'
FOLDS_DIR = PROJ / 'models' / 'folds' / 'Kfold5.json'
CACHE_DIR.mkdir(parents=True, exist_ok=True)

# Dispositivo y semilla
RANDOM_STATE = 42
torch.manual_seed(RANDOM_STATE); np.random.seed(RANDOM_STATE); random.seed(RANDOM_STATE)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🚀 Usando dispositivo: {DEVICE}")

# Escenario y ventana
CLASS_SCENARIO = '4c'
WINDOW_MODE = '6s'
FS = 160.0
N_FOLDS = 5

# Arquitectura Two-Stage
BATCH_SIZE = 32
EPOCHS_STAGE1 = 80
EPOCHS_STAGE2 = 40
LR_STAGE1 = 1e-3
LR_STAGE2 = 5e-4

# Parámetros Transformer
N_HEADS = 4
D_MODEL = 64
DROPOUT_ATTN = 0.2
TCN_FILTERS = 64
TCN_KERNEL_SIZE = 3  # Reducido
TCN_N_BLOCKS = 2
TIME_SLICES = 4

# TabNet
TABNET_N_D = 32
TABNET_N_A = 32
TABNET_N_HIDDEN = 128

# Data Augmentation
N_CLUSTERS = 3
AUG_SAMPLES_PER_CLASS = 30

# =========================
# ARQUITECTURA CORREGIDA - DIMENSIONES COMPATIBLES
# =========================

import torch.nn.functional as F

class RobustEEGNet(nn.Module):
    """
    EEGNet robusto con dimensiones compatibles
    """
    def __init__(self, n_ch=8, F1=32, D=2, T=960):
        super().__init__()
        self.F1 = F1
        self.F2 = F1 * D
        
        # Bloque 1: Temporal - kernel más pequeño
        self.conv1 = nn.Conv2d(1, F1, (32, 1), padding=(16, 0), bias=False)
        self.bn1 = nn.BatchNorm2d(F1)
        
        # Bloque 2: Depthwise
        self.conv2 = nn.Conv2d(F1, self.F2, (1, n_ch), groups=F1, bias=False)
        self.bn2 = nn.BatchNorm2d(self.F2)
        
        # Pooling más conservador
        self.pool1 = nn.AvgPool2d((4, 1))  # Reducido
        self.dropout1 = nn.Dropout(0.3)
        
        # Bloque 3: Separable con kernel más pequeño
        self.conv3 = nn.Conv2d(self.F2, self.F2, (8, 1), groups=self.F2, padding=(4, 0), bias=False)
        self.conv4 = nn.Conv2d(self.F2, self.F2, (1, 1), bias=False)
        self.bn3 = nn.BatchNorm2d(self.F2)
        
        self.pool2 = nn.AvgPool2d((4, 1))  # Reducido
        self.dropout2 = nn.Dropout(0.4)
        
        # Calcular dimensión de salida
        self.output_dim = self._get_output_dim(n_ch, T)
        
    def _get_output_dim(self, n_ch, T):
        # Calcular dimensión de salida
        with torch.no_grad():
            x = torch.randn(1, 1, T, n_ch)
            x = F.elu(self.bn1(self.conv1(x)))
            x = F.elu(self.bn2(self.conv2(x)))
            x = self.pool1(x)
            x = self.dropout1(x)
            x = F.elu(self.bn3(self.conv4(self.conv3(x))))
            x = self.pool2(x)
            x = self.dropout2(x)
            return x.view(1, -1).size(1)
    
    def forward(self, x):
        # Bloque 1
        x = F.elu(self.bn1(self.conv1(x)))
        
        # Bloque 2
        x = F.elu(self.bn2(self.conv2(x)))
        x = self.pool1(x)
        x = self.dropout1(x)
        
        # Bloque 3
        x = self.conv3(x)
        x = self.conv4(x)
        x = F.elu(self.bn3(x))
        x = self.pool2(x)
        x = self.dropout2(x)
        
        return x.flatten(1)

class MultiHeadAttention(nn.Module):
    """
    Atención multi-cabeza con dimensiones robustas
    """
    def __init__(self, d_model, n_heads=8, dropout=0.1):
        super().__init__()
        assert d_model % n_heads == 0
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_o = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        B, T, d_model = x.shape
        
        # Proyecciones Q, K, V
        Q = self.w_q(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        K = self.w_k(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        V = self.w_v(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        
        # Atención escalada
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        # Aplicar atención
        context = torch.matmul(attn_weights, V)
        context = context.transpose(1, 2).contiguous().view(B, T, self.d_model)
        
        return self.w_o(context), attn_weights

class AdaptiveTemporalBlock(nn.Module):
    """
    Bloque temporal adaptativo que se ajusta al tamaño de entrada
    """
    def __init__(self, in_channels, out_channels, kernel_size=3):
        super().__init__()
        self.kernel_size = kernel_size
        
        # Usar kernel size 1 si la entrada es muy pequeña
        effective_ks = kernel_size if kernel_size > 1 else 1
        padding = (effective_ks - 1) // 2
        
        self.conv1 = nn.Conv1d(in_channels, out_channels, effective_ks, padding=padding)
        self.bn1 = nn.BatchNorm1d(out_channels)
        self.conv2 = nn.Conv1d(out_channels, out_channels, effective_ks, padding=padding)
        self.bn2 = nn.BatchNorm1d(out_channels)
        
        self.residual = nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
        self.activation = nn.ELU()
        
    def forward(self, x):
        residual = self.residual(x)
        
        x = self.activation(self.bn1(self.conv1(x)))
        x = self.bn2(self.conv2(x))
        
        return self.activation(x + residual)

class Stage1Architecture(nn.Module):
    """
    Stage 1 con manejo robusto de dimensiones
    """
    def __init__(self, n_ch=8, n_classes=4, T=960, time_slices=4):
        super().__init__()
        self.n_ch = n_ch
        self.n_classes = n_classes
        self.time_slices = time_slices
        
        # EEGNet base
        self.eegnet = RobustEEGNet(n_ch=n_ch, T=T)
        eegnet_output_dim = self.eegnet.output_dim
        
        print(f"EEGNet output dimension: {eegnet_output_dim}")
        
        # Proyección para atención
        self.projection = nn.Sequential(
            nn.Linear(eegnet_output_dim, 256),
            nn.ELU(),
            nn.Dropout(0.3),
            nn.Linear(256, D_MODEL * time_slices)
        )
        
        # Multi-Head Attention
        self.attention = MultiHeadAttention(d_model=D_MODEL, n_heads=N_HEADS, dropout=DROPOUT_ATTN)
        
        # Bloques temporales adaptativos
        self.temporal_blocks = nn.ModuleList()
        for i in range(TCN_N_BLOCKS):
            block = AdaptiveTemporalBlock(D_MODEL, TCN_FILTERS, kernel_size=TCN_KERNEL_SIZE)
            self.temporal_blocks.append(block)
        
        # Capa de salida
        self.output_dim = TCN_FILTERS * time_slices
        self.classifier = nn.Sequential(
            nn.Linear(self.output_dim, 128),
            nn.ELU(),
            nn.Dropout(0.4),
            nn.Linear(128, n_classes)
        )
        
    def forward(self, x):
        B, _, T, C = x.shape
        
        # 1. EEGNet features
        eeg_features = self.eegnet(x)  # (B, eegnet_output_dim)
        
        # 2. Proyección y reshaping para time slicing
        projected = self.projection(eeg_features)  # (B, D_MODEL * time_slices)
        projected = projected.view(B, self.time_slices, D_MODEL)  # (B, time_slices, D_MODEL)
        
        # 3. Aplicar atención a cada "slice temporal"
        attended, _ = self.attention(projected)  # (B, time_slices, D_MODEL)
        
        # 4. Procesar con bloques temporales
        temporal_feat = attended.transpose(1, 2)  # (B, D_MODEL, time_slices)
        
        for block in self.temporal_blocks:
            temporal_feat = block(temporal_feat)
        
        # 5. Flatten y clasificar
        embeddings = temporal_feat.flatten(1)  # (B, TCN_FILTERS * time_slices)
        logits = self.classifier(embeddings)
        
        return logits, embeddings

class TabNetClassifier(nn.Module):
    """
    TabNet simplificado pero efectivo
    """
    def __init__(self, input_dim, n_classes=4, n_d=32, n_steps=4):
        super().__init__()
        self.n_d = n_d
        self.n_steps = n_steps
        
        # Capa inicial
        self.initial_bn = nn.BatchNorm1d(input_dim)
        self.initial_fc = nn.Linear(input_dim, n_d * 2)
        self.initial_glu = nn.GLU(dim=1)
        
        # Bloques de decisión
        self.decision_blocks = nn.ModuleList()
        for i in range(n_steps):
            block = nn.Sequential(
                nn.Linear(n_d, n_d * 2),
                nn.BatchNorm1d(n_d * 2),
                nn.GLU(dim=1),
                nn.Dropout(0.2)
            )
            self.decision_blocks.append(block)
        
        # Capa de atención (simplificada)
        self.attention_layers = nn.ModuleList()
        for i in range(n_steps):
            attention = nn.Sequential(
                nn.Linear(n_d, input_dim),
                nn.BatchNorm1d(input_dim),
                nn.Sigmoid()
            )
            self.attention_layers.append(attention)
        
        # Clasificador final
        self.classifier = nn.Sequential(
            nn.Linear(n_d * n_steps, TABNET_N_HIDDEN),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(TABNET_N_HIDDEN, n_classes)
        )
        
    def forward(self, x):
        x = self.initial_bn(x)
        x = self.initial_glu(self.initial_fc(x))
        
        decision_outputs = []
        
        for i in range(self.n_steps):
            # Aplicar atención
            attention_weights = self.attention_layers[i](x)
            attended_x = x * attention_weights
            
            # Transformar características
            transformed = self.decision_blocks[i](attended_x)
            decision_outputs.append(transformed)
            
            # Actualizar x para siguiente paso
            x = x + transformed  # Conexión residual
        
        # Combinar todas las decisiones
        final_features = torch.cat(decision_outputs, dim=1)
        return self.classifier(final_features)

class TwoStageTransformer(nn.Module):
    """
    Arquitectura final de dos etapas
    """
    def __init__(self, n_ch=8, n_classes=4, T=960, time_slices=4):
        super().__init__()
        
        # Stage 1
        self.stage1 = Stage1Architecture(n_ch, n_classes, T, time_slices)
        stage1_output_dim = self.stage1.output_dim
        
        # Stage 2 - TabNet
        psd_feature_dim = 40  # 8 canales * 5 bandas
        tabnet_input_dim = stage1_output_dim + psd_feature_dim
        
        self.stage2 = TabNetClassifier(tabnet_input_dim, n_classes, TABNET_N_D)
        
        print(f"Stage 1 output dim: {stage1_output_dim}")
        print(f"TabNet input dim: {tabnet_input_dim}")
        
    def forward(self, x, psd_features=None):
        # Stage 1
        logits_stage1, embeddings_stage1 = self.stage1(x)
        
        if psd_features is None:
            return logits_stage1, embeddings_stage1
        
        # Stage 2
        combined_features = torch.cat([embeddings_stage1, psd_features], dim=1)
        logits_stage2 = self.stage2(combined_features)
        
        return logits_stage1, logits_stage2, embeddings_stage1

# =========================
# MANTENER EL RESTO DEL CÓDIGO IGUAL PERO ACTUALIZAR FUNCIONES DE ENTRENAMIENTO
# =========================

# [MANTENER TODAS LAS FUNCIONES DE UTILIDAD, DATASET, Y PREPROCESAMIENTO DEL CÓDIGO ANTERIOR]

def extract_psd_features(X, fs=160.0):
    """
    Extraer características PSD corregida
    """
    print("Extrayendo características PSD...")
    
    n_trials, T, n_channels = X.shape
    n_bands = 5
    psd_features = np.zeros((n_trials, n_channels * n_bands))
    
    freq_bands = [(1, 4), (4, 8), (8, 14), (14, 31), (31, 49)]
    
    for i in tqdm(range(n_trials)):
        trial_psd = []
        for ch in range(n_channels):
            signal = X[i, :, ch]
            
            # Calcular PSD usando FFT
            fft_vals = np.fft.rfft(signal)
            freqs = np.fft.rfftfreq(len(signal), 1/fs)
            psd = np.abs(fft_vals) ** 2
            
            for f_low, f_high in freq_bands:
                band_mask = (freqs >= f_low) & (freqs <= f_high)
                if np.any(band_mask):
                    band_power = np.trapz(psd[band_mask], freqs[band_mask])
                else:
                    band_power = 0.0
                trial_psd.append(band_power)
        
        psd_features[i] = np.array(trial_psd)
    
    # Normalizar
    psd_features = (psd_features - psd_features.mean(axis=0)) / (psd_features.std(axis=0) + 1e-8)
    
    return psd_features.astype(np.float32)

def channel_cluster_swapping_augmentation(X, y, groups, n_clusters=3, n_augmented_per_class=30):
    """
    Data augmentation del paper original
    """
    print("Aplicando Channel Cluster Swapping Augmentation...")
    
    X_aug, y_aug, groups_aug = [], [], []
    n_classes = len(np.unique(y))
    
    for class_idx in range(n_classes):
        class_mask = (y == class_idx)
        X_class = X[class_mask]
        groups_class = groups[class_mask]
        
        if len(X_class) < 2:
            continue
            
        # Calcular matriz de correlación
        corr_matrix = np.zeros((X_class.shape[2], X_class.shape[2]))
        for i in range(len(X_class)):
            trial_corr = np.corrcoef(X_class[i].T)
            corr_matrix += np.nan_to_num(trial_corr)
        corr_matrix /= len(X_class)
        
        # K-means clustering
        kmeans = KMeans(n_clusters=n_clusters, random_state=RANDOM_STATE, n_init=10)
        channel_labels = kmeans.fit_predict(corr_matrix)
        
        # Generar muestras aumentadas
        n_generated = 0
        while n_generated < n_augmented_per_class:
            idx1, idx2 = np.random.choice(len(X_class), 2, replace=False)
            
            new_trial = X_class[idx1].copy()
            for cluster_id in range(n_clusters):
                cluster_channels = np.where(channel_labels == cluster_id)[0]
                if len(cluster_channels) > 0:
                    swap_channel = np.random.choice(cluster_channels)
                    new_trial[:, swap_channel] = X_class[idx2][:, swap_channel]
            
            X_aug.append(new_trial)
            y_aug.append(class_idx)
            groups_aug.append(groups_class[idx1])
            n_generated += 1
    
    if X_aug:
        X_aug = np.stack(X_aug, axis=0)
        y_aug = np.array(y_aug, dtype=np.int64)
        groups_aug = np.array(groups_aug, dtype=np.int64)
        
        X_combined = np.concatenate([X, X_aug], axis=0)
        y_combined = np.concatenate([y, y_aug], axis=0)
        groups_combined = np.concatenate([groups, groups_aug], axis=0)
        
        print(f"Aumento completado: {len(X_aug)} muestras añadidas")
        return X_combined, y_combined, groups_combined
    
    return X, y, groups

class EEGDataset(Dataset):
    def __init__(self, X, y, groups, psd_features=None):
        self.X = X.astype(np.float32)
        self.y = y.astype(np.int64)
        self.g = groups.astype(np.int64)
        self.psd_features = psd_features.astype(np.float32) if psd_features is not None else None
        
    def __len__(self): 
        return self.X.shape[0]
    
    def __getitem__(self, idx):
        x = self.X[idx]
        x = np.expand_dims(x, 0)  # (1, T, C)
        
        if self.psd_features is not None:
            return (torch.from_numpy(x), 
                    torch.from_numpy(self.psd_features[idx]), 
                    torch.tensor(self.y[idx]), 
                    torch.tensor(self.g[idx]))
        else:
            return (torch.from_numpy(x), 
                    torch.tensor(self.y[idx]), 
                    torch.tensor(self.g[idx]))

def train_stage1(model, train_loader, val_loader, epochs=80, lr=1e-3):
    """Entrenamiento Stage 1 mejorado"""
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
    criterion = nn.CrossEntropyLoss()
    
    best_val_acc = 0.0
    history = {'train_loss': [], 'val_acc': []}
    
    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        
        for batch in train_loader:
            if len(batch) == 4:
                x, _, y, _ = batch
            else:
                x, y, _ = batch
                
            x, y = x.to(DEVICE), y.to(DEVICE)
            
            optimizer.zero_grad()
            logits, _ = model(x)
            loss = criterion(logits, y)
            loss.backward()
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            train_loss += loss.item()
        
        scheduler.step()
        
        # Validación
        val_acc = evaluate_stage1(model, val_loader)
        history['train_loss'].append(train_loss / len(train_loader))
        history['val_acc'].append(val_acc)
        
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), 'best_stage1.pth')
        
        if epoch % 10 == 0:
            current_lr = scheduler.get_last_lr()[0]
            print(f'Epoch {epoch}: Loss = {train_loss/len(train_loader):.4f}, Val Acc = {val_acc:.4f}, LR = {current_lr:.6f}')
    
    return history, best_val_acc

def train_stage2(model, train_loader, val_loader, epochs=40, lr=5e-4):
    """Entrenamiento Stage 2 mejorado"""
    optimizer = optim.AdamW(model.stage2.parameters(), lr=lr, weight_decay=1e-4)
    criterion = nn.CrossEntropyLoss()
    
    best_val_acc = 0.0
    history = {'train_loss': [], 'val_acc': []}
    
    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        
        for batch in train_loader:
            x, psd, y, _ = batch
            x, psd, y = x.to(DEVICE), psd.to(DEVICE), y.to(DEVICE)
            
            optimizer.zero_grad()
            _, logits_stage2, _ = model(x, psd)
            loss = criterion(logits_stage2, y)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
        
        # Validación
        val_acc = evaluate_stage2(model, val_loader)
        history['train_loss'].append(train_loss / len(train_loader))
        history['val_acc'].append(val_acc)
        
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), 'best_stage2.pth')
        
        if epoch % 10 == 0:
            print(f'Epoch {epoch}: Loss = {train_loss/len(train_loader):.4f}, Val Acc = {val_acc:.4f}')
    
    return history, best_val_acc

@torch.no_grad()
def evaluate_stage1(model, loader):
    model.eval()
    correct = 0
    total = 0
    
    for batch in loader:
        if len(batch) == 4:
            x, _, y, _ = batch
        else:
            x, y, _ = batch
            
        x, y = x.to(DEVICE), y.to(DEVICE)
        logits, _ = model(x)
        pred = logits.argmax(dim=1)
        correct += (pred == y).sum().item()
        total += y.size(0)
    
    return correct / total

@torch.no_grad()
def evaluate_stage2(model, loader):
    model.eval()
    correct = 0
    total = 0
    
    for batch in loader:
        x, psd, y, _ = batch
        x, psd, y = x.to(DEVICE), psd.to(DEVICE), y.to(DEVICE)
        _, logits_stage2, _ = model(x, psd)
        pred = logits_stage2.argmax(dim=1)
        correct += (pred == y).sum().item()
        total += y.size(0)
    
    return correct / total

# =========================
# EXPERIMENTO PRINCIPAL MEJORADO
# =========================
def run_two_stage_experiment():
    print("🧠 INICIANDO TWO-STAGE TRANSFORMER - VERSIÓN FINAL")
    
    # Cargar datos (usar tus funciones existentes)
    subs = subjects_available()
    print(f"Sujetos disponibles: {len(subs)}")
    
    X, y, groups, chs = build_dataset_all(subs, scenario=CLASS_SCENARIO, window_mode=WINDOW_MODE)
    
    if X is None:
        print("Error: No se pudieron cargar los datos.")
        return None
        
    N, T, C = X.shape
    n_classes = len(np.unique(y))
    
    print(f"Datos: N={N}, T={T}, C={C}, Clases={n_classes}")
    
    # Data augmentation
    X_aug, y_aug, groups_aug = channel_cluster_swapping_augmentation(
        X, y, groups, n_clusters=N_CLUSTERS, n_augmented_per_class=AUG_SAMPLES_PER_CLASS
    )
    
    # Extraer PSD
    psd_features = extract_psd_features(X_aug, fs=FS)
    
    # Dataset
    dataset = EEGDataset(X_aug, y_aug, groups_aug, psd_features)
    
    # Cross-validation
    unique_subjects = np.unique(groups_aug)
    gkf = GroupKFold(n_splits=min(N_FOLDS, len(unique_subjects)))
    
    fold_results = {'stage1': [], 'stage2': []}
    
    for fold, (train_idx, test_idx) in enumerate(gkf.split(X_aug, y_aug, groups_aug)):
        print(f"\n=== FOLD {fold + 1}/{min(N_FOLDS, len(unique_subjects))} ===")
        
        # Split train/val
        train_subjects = np.unique(groups_aug[train_idx])
        val_size = max(1, int(0.15 * len(train_subjects)))
        val_subjects = np.random.choice(train_subjects, val_size, replace=False)
        
        val_mask = np.isin(groups_aug[train_idx], val_subjects)
        val_idx = train_idx[val_mask]
        train_idx_fold = train_idx[~val_mask]
        
        # DataLoaders
        train_loader = DataLoader(Subset(dataset, train_idx_fold), batch_size=BATCH_SIZE, shuffle=True)
        val_loader = DataLoader(Subset(dataset, val_idx), batch_size=BATCH_SIZE, shuffle=False)
        test_loader = DataLoader(Subset(dataset, test_idx), batch_size=BATCH_SIZE, shuffle=False)
        
        print(f"Train: {len(train_idx_fold)}, Val: {len(val_idx)}, Test: {len(test_idx)}")
        
        # Modelo
        model = TwoStageTransformer(n_ch=C, n_classes=n_classes, T=T, time_slices=TIME_SLICES).to(DEVICE)
        
        try:
            # Stage 1
            print("Entrenando Stage 1...")
            history_stage1, best_stage1 = train_stage1(
                model, train_loader, val_loader, EPOCHS_STAGE1, LR_STAGE1
            )
            
            test_acc_stage1 = evaluate_stage1(model, test_loader)
            fold_results['stage1'].append(test_acc_stage1)
            print(f"✅ Stage 1 Test Acc: {test_acc_stage1:.4f}")
            
            # Stage 2
            print("Entrenando Stage 2...")
            history_stage2, best_stage2 = train_stage2(
                model, train_loader, val_loader, EPOCHS_STAGE2, LR_STAGE2
            )
            
            test_acc_stage2 = evaluate_stage2(model, test_loader)
            fold_results['stage2'].append(test_acc_stage2)
            print(f"✅ Stage 2 Test Acc: {test_acc_stage2:.4f}")
            print(f"📈 Improvement: +{(test_acc_stage2 - test_acc_stage1)*100:.2f}%")
            
        except Exception as e:
            print(f"❌ Error en fold {fold + 1}: {e}")
            continue
    
    # Resultados
    if fold_results['stage1']:
        print("\n" + "="*60)
        print("🎯 RESULTADOS FINALES")
        print("="*60)
        
        stage1_mean = np.mean(fold_results['stage1'])
        stage1_std = np.std(fold_results['stage1'])
        stage2_mean = np.mean(fold_results['stage2'])
        stage2_std = np.std(fold_results['stage2'])
        
        print(f"Stage 1: {stage1_mean:.4f} ± {stage1_std:.4f}")
        print(f"Stage 2: {stage2_mean:.4f} ± {stage2_std:.4f}")
        print(f"Mejora: +{(stage2_mean - stage1_mean)*100:.2f}%")
        
        return {
            'stage1_results': fold_results['stage1'],
            'stage2_results': fold_results['stage2'],
            'stage1_mean': stage1_mean,
            'stage2_mean': stage2_mean,
            'improvement': stage2_mean - stage1_mean
        }
    else:
        print("No se completaron folds exitosamente.")
        return None

if __name__ == "__main__":
    results = run_two_stage_experiment()

🚀 Usando dispositivo: cuda
🧠 INICIANDO TWO-STAGE TRANSFORMER - VERSIÓN FINAL
Sujetos disponibles: 103


Construyendo dataset (RAW):   0%|          | 0/103 [00:00<?, ?it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):   1%|          | 1/103 [00:00<00:17,  5.84it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):   2%|▏         | 2/103 [00:00<00:16,  6.04it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):   3%|▎         | 3/103 [00:00<00:18,  5.54it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):   4%|▍         | 4/103 [00:00<00:17,  5.76it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):   5%|▍         | 5/103 [00:00<00:16,  5.96it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):   6%|▌         | 6/103 [00:01<00:17,  5.63it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):   7%|▋         | 7/103 [00:01<00:17,  5.43it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):   8%|▊         | 8/103 [00:01<00:17,  5.31it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):   9%|▊         | 9/103 [00:01<00:18,  5.21it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  10%|▉         | 10/103 [00:01<00:18,  5.14it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  11%|█         | 11/103 [00:02<00:18,  5.10it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  12%|█▏        | 12/103 [00:02<00:17,  5.08it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  13%|█▎        | 13/103 [00:02<00:17,  5.05it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  14%|█▎        | 14/103 [00:02<00:17,  5.01it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  15%|█▍        | 15/103 [00:02<00:17,  5.02it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  16%|█▌        | 16/103 [00:03<00:17,  5.01it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  17%|█▋        | 17/103 [00:03<00:17,  4.96it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  17%|█▋        | 18/103 [00:03<00:17,  4.96it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  18%|█▊        | 19/103 [00:03<00:16,  4.97it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  19%|█▉        | 20/103 [00:03<00:16,  4.90it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  20%|██        | 21/103 [00:04<00:17,  4.79it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  21%|██▏       | 22/103 [00:04<00:17,  4.68it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  22%|██▏       | 23/103 [00:04<00:17,  4.63it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  23%|██▎       | 24/103 [00:04<00:17,  4.62it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  24%|██▍       | 25/103 [00:04<00:17,  4.54it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  25%|██▌       | 26/103 [00:05<00:16,  4.55it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  26%|██▌       | 27/103 [00:05<00:16,  4.51it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  27%|██▋       | 28/103 [00:05<00:16,  4.53it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  28%|██▊       | 29/103 [00:05<00:16,  4.51it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  29%|██▉       | 30/103 [00:06<00:16,  4.47it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  30%|███       | 31/103 [00:06<00:16,  4.48it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  31%|███       | 32/103 [00:06<00:16,  4.41it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  32%|███▏      | 33/103 [00:06<00:15,  4.44it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  33%|███▎      | 34/103 [00:06<00:15,  4.47it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  34%|███▍      | 35/103 [00:07<00:16,  4.24it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  35%|███▍      | 36/103 [00:07<00:16,  4.04it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  36%|███▌      | 37/103 [00:07<00:17,  3.85it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  37%|███▋      | 38/103 [00:08<00:16,  4.00it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  38%|███▊      | 39/103 [00:08<00:15,  4.14it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  39%|███▉      | 40/103 [00:08<00:14,  4.23it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  40%|███▉      | 41/103 [00:08<00:14,  4.31it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  41%|████      | 42/103 [00:08<00:13,  4.37it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  42%|████▏     | 43/103 [00:09<00:13,  4.44it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  43%|████▎     | 44/103 [00:09<00:13,  4.48it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  44%|████▎     | 45/103 [00:09<00:12,  4.50it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  45%|████▍     | 46/103 [00:09<00:12,  4.55it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  46%|████▌     | 47/103 [00:10<00:12,  4.55it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  47%|████▋     | 48/103 [00:10<00:11,  4.59it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  48%|████▊     | 49/103 [00:10<00:11,  4.60it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  49%|████▊     | 50/103 [00:10<00:11,  4.58it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  50%|████▉     | 51/103 [00:10<00:11,  4.59it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  50%|█████     | 52/103 [00:11<00:11,  4.58it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  51%|█████▏    | 53/103 [00:11<00:10,  4.60it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  52%|█████▏    | 54/103 [00:11<00:10,  4.58it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  53%|█████▎    | 55/103 [00:11<00:10,  4.59it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  54%|█████▍    | 56/103 [00:11<00:10,  4.60it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  55%|█████▌    | 57/103 [00:12<00:10,  4.52it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  56%|█████▋    | 58/103 [00:12<00:09,  4.54it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  57%|█████▋    | 59/103 [00:12<00:09,  4.53it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  58%|█████▊    | 60/103 [00:12<00:09,  4.55it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  59%|█████▉    | 61/103 [00:13<00:09,  4.56it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  60%|██████    | 62/103 [00:13<00:09,  4.53it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  61%|██████    | 63/103 [00:13<00:08,  4.53it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  62%|██████▏   | 64/103 [00:13<00:08,  4.51it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  63%|██████▎   | 65/103 [00:13<00:08,  4.47it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  64%|██████▍   | 66/103 [00:14<00:08,  4.49it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  65%|██████▌   | 67/103 [00:14<00:07,  4.51it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  66%|██████▌   | 68/103 [00:14<00:07,  4.56it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  67%|██████▋   | 69/103 [00:14<00:07,  4.54it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  68%|██████▊   | 70/103 [00:15<00:07,  4.57it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  69%|██████▉   | 71/103 [00:15<00:07,  4.54it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  70%|██████▉   | 72/103 [00:15<00:06,  4.58it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  71%|███████   | 73/103 [00:15<00:06,  4.59it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  72%|███████▏  | 74/103 [00:15<00:06,  4.57it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  73%|███████▎  | 75/103 [00:16<00:06,  4.59it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  74%|███████▍  | 76/103 [00:16<00:05,  4.56it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  75%|███████▍  | 77/103 [00:16<00:05,  4.59it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  76%|███████▌  | 78/103 [00:16<00:05,  4.59it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  77%|███████▋  | 79/103 [00:17<00:05,  4.56it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  78%|███████▊  | 80/103 [00:17<00:05,  4.58it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  79%|███████▊  | 81/103 [00:17<00:04,  4.55it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  80%|███████▉  | 82/103 [00:17<00:05,  4.18it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  81%|████████  | 83/103 [00:17<00:04,  4.25it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  82%|████████▏ | 84/103 [00:18<00:04,  4.37it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  83%|████████▎ | 85/103 [00:18<00:04,  4.43it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  83%|████████▎ | 86/103 [00:18<00:03,  4.45it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  84%|████████▍ | 87/103 [00:18<00:03,  4.50it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  85%|████████▌ | 88/103 [00:19<00:03,  4.35it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  86%|████████▋ | 89/103 [00:19<00:03,  4.17it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  87%|████████▋ | 90/103 [00:19<00:03,  4.26it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  88%|████████▊ | 91/103 [00:19<00:02,  4.37it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  89%|████████▉ | 92/103 [00:20<00:02,  4.42it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  90%|█████████ | 93/103 [00:20<00:02,  4.39it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  91%|█████████▏| 94/103 [00:20<00:02,  4.42it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  92%|█████████▏| 95/103 [00:20<00:01,  4.43it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  93%|█████████▎| 96/103 [00:20<00:01,  4.45it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  94%|█████████▍| 97/103 [00:21<00:01,  4.47it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  95%|█████████▌| 98/103 [00:21<00:01,  4.43it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  96%|█████████▌| 99/103 [00:21<00:00,  4.48it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  97%|█████████▋| 100/103 [00:21<00:00,  4.47it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  98%|█████████▊| 101/103 [00:22<00:00,  4.50it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  99%|█████████▉| 102/103 [00:22<00:00,  4.50it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW): 100%|██████████| 103/103 [00:22<00:00,  4.58it/s]

Dataset construido: N=6180 | T=960 | C=8 | clases=4 | sujetos únicos=103
Datos: N=6180, T=960, C=8, Clases=4
Aplicando Channel Cluster Swapping Augmentation...





Aumento completado: 120 muestras añadidas
Extrayendo características PSD...


  band_power = np.trapz(psd[band_mask], freqs[band_mask])
100%|██████████| 6300/6300 [00:03<00:00, 1601.77it/s]



=== FOLD 1/5 ===
Train: 4286, Val: 730, Test: 1284
EEGNet output dimension: 3840
Stage 1 output dim: 256
TabNet input dim: 296
Entrenando Stage 1...
Epoch 0: Loss = 1.4512, Val Acc = 0.2479, LR = 0.001000
Epoch 10: Loss = 1.0990, Val Acc = 0.3836, LR = 0.000954
Epoch 20: Loss = 0.5335, Val Acc = 0.3781, LR = 0.000839
Epoch 30: Loss = 0.2989, Val Acc = 0.3438, LR = 0.000673
Epoch 40: Loss = 0.1885, Val Acc = 0.3658, LR = 0.000480
Epoch 50: Loss = 0.1234, Val Acc = 0.3644, LR = 0.000291
Epoch 60: Loss = 0.1007, Val Acc = 0.3699, LR = 0.000133
Epoch 70: Loss = 0.0667, Val Acc = 0.3685, LR = 0.000031
✅ Stage 1 Test Acc: 0.3621
Entrenando Stage 2...
❌ Error en fold 1: The size of tensor a (32) must match the size of tensor b (296) at non-singleton dimension 1

=== FOLD 2/5 ===
Train: 4282, Val: 734, Test: 1284
EEGNet output dimension: 3840
Stage 1 output dim: 256
TabNet input dim: 296
Entrenando Stage 1...
Epoch 0: Loss = 1.4432, Val Acc = 0.2548, LR = 0.001000
Epoch 10: Loss = 1.0591, Val

  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,
  arrmean = um.true_divide(arrmean, div, out=arrmean,
  ret = ret.dtype.type(ret / rcount)


In [9]:
# -*- coding: utf-8 -*-
# Two-Stage Transformer Architecture - VERSIÓN COMPLETA AUTO-CONTENIDA

import os, re, math, random, json, itertools, copy
from pathlib import Path
from datetime import datetime

import numpy as np
import pandas as pd
import mne

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset, WeightedRandomSampler

from sklearn.model_selection import GroupKFold, StratifiedKFold, StratifiedShuffleSplit, GroupShuffleSplit
from sklearn.utils import check_random_state
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.cluster import KMeans

import matplotlib.pyplot as plt
from tqdm import tqdm

# =========================
# CONFIGURACIÓN GENERAL
# =========================
PROJ = Path('..').resolve().parent
DATA_RAW = PROJ / 'data' / 'raw'
CACHE_DIR = PROJ / 'data' / 'cache'
FOLDS_DIR = PROJ / 'models' / 'folds' / 'Kfold5.json'
CACHE_DIR.mkdir(parents=True, exist_ok=True)

# Dispositivo y semilla
RANDOM_STATE = 42
torch.manual_seed(RANDOM_STATE); np.random.seed(RANDOM_STATE); random.seed(RANDOM_STATE)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🚀 Usando dispositivo: {DEVICE}")

# Escenario y ventana
CLASS_SCENARIO = '4c'
WINDOW_MODE = '6s'
FS = 160.0
N_FOLDS = 5

# Arquitectura Two-Stage
BATCH_SIZE = 32
EPOCHS_STAGE1 = 80
EPOCHS_STAGE2 = 40
LR_STAGE1 = 1e-3
LR_STAGE2 = 5e-4

# Parámetros Transformer
N_HEADS = 4
D_MODEL = 64
DROPOUT_ATTN = 0.2
TCN_FILTERS = 64
TCN_KERNEL_SIZE = 3
TCN_N_BLOCKS = 2
TIME_SLICES = 4

# TabNet
TABNET_N_D = 64
TABNET_N_A = 64
TABNET_N_HIDDEN = 128
TABNET_N_STEPS = 3

# Data Augmentation
N_CLUSTERS = 3
AUG_SAMPLES_PER_CLASS = 30

# =========================
# UTILIDADES DE CANALES
# =========================
def normalize_label(s: str) -> str:
    if s is None: return s
    s = s.strip()
    s = re.sub(r'[^A-Za-z0-9]', '', s)
    s = re.sub(r'([A-Za-z])0([0-9])', r'\1\2', s)
    s = re.sub(r'([A-Za-z])Z$', r'\1z', s)
    s = s.replace('fp', 'Fp').replace('FP', 'Fp')
    s = ''.join(ch.upper() if ch != 'z' else 'z' for ch in s)
    return s

def rename_channels_1010(raw: mne.io.BaseRaw):
    mapping = {}
    for ch in raw.ch_names:
        lab = normalize_label(ch)
        lab = lab[:-1] + 'z' if lab.endswith('Z') else lab
        lab = re.sub(r'([A-Z])Z$', r'\1z', lab)
        mapping[ch] = lab
    mne.rename_channels(raw.info, mapping)

# Canales (8 con FCz)
EXPECTED_8 = ['C3','C4','Cz','CP3','CP4','FC3','FC4','FCz']

def ensure_channels_order(raw: mne.io.BaseRaw, desired_channels=EXPECTED_8):
    missing = [ch for ch in desired_channels if ch not in raw.ch_names]
    if missing:
        print(f"Warning: faltan canales {missing} en archivo {getattr(raw,'filenames', [''])[0]}")
        return None
    raw.reorder_channels([ch for ch in raw.ch_names if ch in desired_channels] +
                         [ch for ch in raw.ch_names if ch not in desired_channels])
    raw.pick_channels(desired_channels, ordered=True)
    return raw

# =========================
# LECTURA DE EDF y EVENTOS
# =========================
_re_file = re.compile(r'[Ss](\d{3}).*?[Rr](\d{2})')

def parse_subject_run(path: Path):
    m = _re_file.search(str(path))
    if not m: return None, None
    return int(m.group(1)), int(m.group(2))

# Runs MI
MI_RUNS_LR = [4, 8, 12]
MI_RUNS_OF = [6, 10, 14]
BASELINE_RUNS_EO = [1]

def run_kind(run_id:int):
    if run_id in MI_RUNS_LR: return 'LR'
    if run_id in MI_RUNS_OF: return 'OF'
    if run_id in BASELINE_RUNS_EO: return 'EO'
    return None

# Sujetos excluidos
EXCLUDE_SUBJECTS = {38, 88, 89, 92, 100, 104}

def read_raw_edf(path: Path):
    raw = mne.io.read_raw_edf(path, preload=True, verbose=False)
    raw.pick(mne.pick_types(raw.info, eeg=True))
    rename_channels_1010(raw)
    try:
        mont = mne.channels.make_standard_montage('standard_1020')
        raw.set_montage(mont, on_missing='ignore')
    except Exception:
        pass
    if abs(raw.info['sfreq'] - FS) > 1e-6:
        raw.resample(FS, npad="auto")
    raw = ensure_channels_order(raw, EXPECTED_8)
    if raw is None:
        return None

    # Filtrado básico
    raw.filter(4.0, 40.0, method='iir', verbose=False)
    
    return raw

def collect_events_T1T2(raw: mne.io.BaseRaw):
    if raw.annotations is None or len(raw.annotations) == 0:
        return []
    def _norm(s): return str(s).strip().upper().replace(' ', '')
    res = []
    for onset, desc in zip(raw.annotations.onset, raw.annotations.description):
        tag = _norm(desc)
        if tag in ('T1','T2'):
            res.append((float(onset), tag))
    res.sort()
    dedup = []
    last_t1 = last_t2 = -1e9
    for t, tag in res:
        if tag == 'T1':
            if (t - last_t1) >= 0.5: dedup.append((t, tag)); last_t1 = t
        else:
            if (t - last_t2) >= 0.5: dedup.append((t, tag)); last_t2 = t
    return dedup

# =========================
# CONSTRUCCIÓN DE DATASETS
# =========================
def subjects_available():
    subs = []
    for sdir in sorted(DATA_RAW.glob('S*')):
        if not sdir.is_dir(): continue
        try: sid = int(sdir.name[1:])
        except: continue
        if sid in EXCLUDE_SUBJECTS: continue
        any_mi = any((sdir / f"S{sid:03d}R{r:02d}.edf").exists() for r in (MI_RUNS_LR + MI_RUNS_OF))
        if any_mi: subs.append(sid)
    return subs

def extract_trials_from_run(edf_path: Path, scenario: str, window_mode: str):
    subj, run = parse_subject_run(edf_path)
    kind = run_kind(run)
    if kind not in ('LR','OF','EO'):
        return ([], [])

    raw = read_raw_edf(edf_path)
    if raw is None:
        return ([], [])

    data = raw.get_data()
    fs = raw.info['sfreq']
    assert abs(fs - FS) < 1e-6

    out = []

    if kind in ('LR','OF'):
        events = collect_events_T1T2(raw)
        if window_mode == '3s':
            rel_start, rel_end = 0.0, 3.0
        else:  # 6s
            rel_start, rel_end = -1.0, 5.0

        for onset_sec, tag in events:
            if kind == 'LR':
                if tag == 'T1': label = 'L'
                elif tag == 'T2': label = 'R'
                else: continue
            else:
                if tag == 'T1': label = 'BFISTS'
                elif tag == 'T2': label = 'BFEET'
                else: continue

            if scenario == '2c' and label not in ('L','R'): continue
            if scenario == '3c' and label not in ('L','R','BFISTS'): continue
            if scenario == '4c' and label not in ('L','R','BFISTS','BFEET'): continue

            s = int(round((raw.first_time + onset_sec + rel_start) * fs))
            e = int(round((raw.first_time + onset_sec + rel_end) * fs))
            if s < 0 or e > data.shape[1]:
                continue

            seg = data[:, s:e].T.astype(np.float32)
            # Normalización z-score por canal
            seg = (seg - seg.mean(axis=0, keepdims=True)) / (seg.std(axis=0, keepdims=True) + 1e-6)

            if label == 'L':       y = 0
            elif label == 'R':     y = 1
            elif label == 'BFISTS':y = 2
            elif label == 'BFEET': y = 3
            else: continue

            out.append((seg, y, subj))

    elif kind == 'EO':
        return ([], raw.ch_names)

    return out, raw.ch_names

def build_dataset_all(subjects, scenario='4c', window_mode='3s'):
    X, y, groups = [], [], []
    ch_template = None

    for s in tqdm(subjects, desc="Construyendo dataset (RAW)"):
        sdir = DATA_RAW / f"S{s:03d}"
        if not sdir.exists(): continue

        trials_L, trials_R, trials_FISTS, trials_FEET = [], [], [], []

        for r in MI_RUNS_LR:
            p = sdir / f"S{s:03d}R{r:02d}.edf"
            if not p.exists(): continue
            outs, chs = extract_trials_from_run(p, scenario, window_mode)
            if ch_template is None and chs: ch_template = chs
            for seg, lab, _ in outs:
                if lab == 0: trials_L.append(seg)
                elif lab == 1: trials_R.append(seg)

        for r in MI_RUNS_OF:
            p = sdir / f"S{s:03d}R{r:02d}.edf"
            if not p.exists(): continue
            outs, chs = extract_trials_from_run(p, scenario, window_mode)
            if ch_template is None and chs: ch_template = chs
            for seg, lab, _ in outs:
                if lab == 2: trials_FISTS.append(seg)
                elif lab == 3: trials_FEET.append(seg)

        need_per_class = 21
        def pick(trials, n, rng):
            if len(trials) < n:
                idx = rng.choice(len(trials), size=n, replace=True)
                return [trials[i] for i in idx]
            rng.shuffle(trials)
            return trials[:n]

        rng = check_random_state(RANDOM_STATE + s)
        if len(trials_L)==0 or len(trials_R)==0 or len(trials_FISTS)==0 or len(trials_FEET)==0:
            continue

        Lp  = pick(trials_L,     need_per_class, rng)
        Rp  = pick(trials_R,     need_per_class, rng)
        FIp = pick(trials_FISTS, need_per_class, rng)
        FEp = pick(trials_FEET,  need_per_class, rng)

        pack = [(Lp, 0), (Rp, 1), (FIp, 2), (FEp, 3)]
        for segs, lab in pack:
            for seg in segs:
                X.append(seg); y.append(lab); groups.append(s)

    if len(X) == 0:
        print("No se pudieron cargar datos. Verifica las rutas de archivos.")
        return None, None, None, None
        
    X = np.stack(X, axis=0)
    y = np.asarray(y, dtype=np.int64)
    groups = np.asarray(groups, dtype=np.int64)

    n, T, C = X.shape
    n_classes = len(np.unique(y))
    print(f"Dataset construido: N={n} | T={T} | C={C} | clases={n_classes} | sujetos únicos={len(np.unique(groups))}")
    return X, y, groups, ch_template

# =========================
# ARQUITECTURA TWO-STAGE TRANSFORMER
# =========================

import torch.nn.functional as F

class RobustEEGNet(nn.Module):
    def __init__(self, n_ch=8, F1=32, D=2, T=960):
        super().__init__()
        self.F1 = F1
        self.F2 = F1 * D
        
        self.conv1 = nn.Conv2d(1, F1, (32, 1), padding=(16, 0), bias=False)
        self.bn1 = nn.BatchNorm2d(F1)
        
        self.conv2 = nn.Conv2d(F1, self.F2, (1, n_ch), groups=F1, bias=False)
        self.bn2 = nn.BatchNorm2d(self.F2)
        
        self.pool1 = nn.AvgPool2d((4, 1))
        self.dropout1 = nn.Dropout(0.3)
        
        self.conv3 = nn.Conv2d(self.F2, self.F2, (8, 1), groups=self.F2, padding=(4, 0), bias=False)
        self.conv4 = nn.Conv2d(self.F2, self.F2, (1, 1), bias=False)
        self.bn3 = nn.BatchNorm2d(self.F2)
        
        self.pool2 = nn.AvgPool2d((4, 1))
        self.dropout2 = nn.Dropout(0.4)
        
        self.output_dim = self._get_output_dim(n_ch, T)
        
    def _get_output_dim(self, n_ch, T):
        with torch.no_grad():
            x = torch.randn(1, 1, T, n_ch)
            x = F.elu(self.bn1(self.conv1(x)))
            x = F.elu(self.bn2(self.conv2(x)))
            x = self.pool1(x)
            x = self.dropout1(x)
            x = F.elu(self.bn3(self.conv4(self.conv3(x))))
            x = self.pool2(x)
            x = self.dropout2(x)
            return x.view(1, -1).size(1)
    
    def forward(self, x):
        x = F.elu(self.bn1(self.conv1(x)))
        x = F.elu(self.bn2(self.conv2(x)))
        x = self.pool1(x)
        x = self.dropout1(x)
        x = F.elu(self.bn3(self.conv4(self.conv3(x))))
        x = self.pool2(x)
        x = self.dropout2(x)
        return x.flatten(1)

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads=8, dropout=0.1):
        super().__init__()
        assert d_model % n_heads == 0
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_o = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        B, T, d_model = x.shape
        
        Q = self.w_q(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        K = self.w_k(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        V = self.w_v(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        context = torch.matmul(attn_weights, V)
        context = context.transpose(1, 2).contiguous().view(B, T, self.d_model)
        
        return self.w_o(context), attn_weights

class AdaptiveTemporalBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3):
        super().__init__()
        self.kernel_size = kernel_size
        
        effective_ks = kernel_size if kernel_size > 1 else 1
        padding = (effective_ks - 1) // 2
        
        self.conv1 = nn.Conv1d(in_channels, out_channels, effective_ks, padding=padding)
        self.bn1 = nn.BatchNorm1d(out_channels)
        self.conv2 = nn.Conv1d(out_channels, out_channels, effective_ks, padding=padding)
        self.bn2 = nn.BatchNorm1d(out_channels)
        
        self.residual = nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
        self.activation = nn.ELU()
        
    def forward(self, x):
        residual = self.residual(x)
        x = self.activation(self.bn1(self.conv1(x)))
        x = self.bn2(self.conv2(x))
        return self.activation(x + residual)

class Stage1Architecture(nn.Module):
    def __init__(self, n_ch=8, n_classes=4, T=960, time_slices=4):
        super().__init__()
        self.n_ch = n_ch
        self.n_classes = n_classes
        self.time_slices = time_slices
        
        self.eegnet = RobustEEGNet(n_ch=n_ch, T=T)
        eegnet_output_dim = self.eegnet.output_dim
        
        print(f"EEGNet output dimension: {eegnet_output_dim}")
        
        self.projection = nn.Sequential(
            nn.Linear(eegnet_output_dim, 256),
            nn.ELU(),
            nn.Dropout(0.3),
            nn.Linear(256, D_MODEL * time_slices)
        )
        
        self.attention = MultiHeadAttention(d_model=D_MODEL, n_heads=N_HEADS, dropout=DROPOUT_ATTN)
        
        self.temporal_blocks = nn.ModuleList()
        for i in range(TCN_N_BLOCKS):
            block = AdaptiveTemporalBlock(D_MODEL, TCN_FILTERS, kernel_size=TCN_KERNEL_SIZE)
            self.temporal_blocks.append(block)
        
        self.output_dim = TCN_FILTERS * time_slices
        self.classifier = nn.Sequential(
            nn.Linear(self.output_dim, 128),
            nn.ELU(),
            nn.Dropout(0.4),
            nn.Linear(128, n_classes)
        )
        
    def forward(self, x):
        B, _, T, C = x.shape
        
        eeg_features = self.eegnet(x)
        projected = self.projection(eeg_features)
        projected = projected.view(B, self.time_slices, D_MODEL)
        
        attended, _ = self.attention(projected)
        temporal_feat = attended.transpose(1, 2)
        
        for block in self.temporal_blocks:
            temporal_feat = block(temporal_feat)
        
        embeddings = temporal_feat.flatten(1)
        logits = self.classifier(embeddings)
        
        return logits, embeddings

class TabNetClassifier(nn.Module):
    def __init__(self, input_dim, n_classes=4, n_d=64, n_steps=3):
        super().__init__()
        self.n_d = n_d
        self.n_steps = n_steps
        self.input_dim = input_dim
        
        self.initial_bn = nn.BatchNorm1d(input_dim)
        self.initial_fc = nn.Linear(input_dim, n_d * 2)
        
        self.decision_blocks = nn.ModuleList()
        for i in range(n_steps):
            block = nn.Sequential(
                nn.Linear(n_d, n_d * 2),
                nn.BatchNorm1d(n_d * 2),
                nn.GLU(dim=1),
                nn.Dropout(0.2)
            )
            self.decision_blocks.append(block)
        
        self.attention_layers = nn.ModuleList()
        for i in range(n_steps):
            attention = nn.Sequential(
                nn.Linear(n_d, input_dim),
                nn.BatchNorm1d(input_dim),
                nn.Sigmoid()
            )
            self.attention_layers.append(attention)
        
        self.classifier = nn.Sequential(
            nn.Linear(n_d * n_steps, TABNET_N_HIDDEN),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(TABNET_N_HIDDEN, n_classes)
        )
        
    def forward(self, x):
        x = self.initial_bn(x)
        x = self.initial_fc(x)
        x = F.glu(x, dim=1)
        
        decision_outputs = []
        current_x = x
        
        for i in range(self.n_steps):
            attention_weights = self.attention_layers[i](current_x)
            attended_x = x * attention_weights
            
            transformed = self.decision_blocks[i](attended_x)
            decision_outputs.append(transformed)
            
            current_x = current_x + transformed
        
        final_features = torch.cat(decision_outputs, dim=1)
        return self.classifier(final_features)

class TwoStageTransformer(nn.Module):
    def __init__(self, n_ch=8, n_classes=4, T=960, time_slices=4):
        super().__init__()
        
        self.stage1 = Stage1Architecture(n_ch, n_classes, T, time_slices)
        stage1_output_dim = self.stage1.output_dim
        
        psd_feature_dim = 40
        tabnet_input_dim = stage1_output_dim + psd_feature_dim
        
        self.stage2 = TabNetClassifier(tabnet_input_dim, n_classes, TABNET_N_D, TABNET_N_STEPS)
        
        print(f"Stage 1 output dim: {stage1_output_dim}")
        print(f"TabNet input dim: {tabnet_input_dim}")
        print(f"TabNet n_d: {TABNET_N_D}")
        
    def forward(self, x, psd_features=None):
        logits_stage1, embeddings_stage1 = self.stage1(x)
        
        if psd_features is None:
            return logits_stage1, embeddings_stage1
        
        combined_features = torch.cat([embeddings_stage1, psd_features], dim=1)
        logits_stage2 = self.stage2(combined_features)
        
        return logits_stage1, logits_stage2, embeddings_stage1

# =========================
# DATA AUGMENTATION Y PREPROCESAMIENTO
# =========================

def extract_psd_features(X, fs=160.0):
    print("Extrayendo características PSD...")
    
    n_trials, T, n_channels = X.shape
    n_bands = 5
    psd_features = np.zeros((n_trials, n_channels * n_bands))
    
    freq_bands = [(1, 4), (4, 8), (8, 14), (14, 31), (31, 49)]
    
    for i in tqdm(range(n_trials)):
        trial_psd = []
        for ch in range(n_channels):
            signal = X[i, :, ch]
            
            fft_vals = np.fft.rfft(signal)
            freqs = np.fft.rfftfreq(len(signal), 1/fs)
            psd = np.abs(fft_vals) ** 2
            
            for f_low, f_high in freq_bands:
                band_mask = (freqs >= f_low) & (freqs <= f_high)
                if np.any(band_mask):
                    band_power = np.trapz(psd[band_mask], freqs[band_mask])
                else:
                    band_power = 0.0
                trial_psd.append(band_power)
        
        psd_features[i] = np.array(trial_psd)
    
    psd_features = (psd_features - psd_features.mean(axis=0)) / (psd_features.std(axis=0) + 1e-8)
    
    return psd_features.astype(np.float32)

def channel_cluster_swapping_augmentation(X, y, groups, n_clusters=3, n_augmented_per_class=30):
    print("Aplicando Channel Cluster Swapping Augmentation...")
    
    X_aug, y_aug, groups_aug = [], [], []
    n_classes = len(np.unique(y))
    
    for class_idx in range(n_classes):
        class_mask = (y == class_idx)
        X_class = X[class_mask]
        groups_class = groups[class_mask]
        
        if len(X_class) < 2:
            continue
            
        corr_matrix = np.zeros((X_class.shape[2], X_class.shape[2]))
        for i in range(len(X_class)):
            trial_corr = np.corrcoef(X_class[i].T)
            corr_matrix += np.nan_to_num(trial_corr)
        corr_matrix /= len(X_class)
        
        kmeans = KMeans(n_clusters=n_clusters, random_state=RANDOM_STATE, n_init=10)
        channel_labels = kmeans.fit_predict(corr_matrix)
        
        n_generated = 0
        while n_generated < n_augmented_per_class:
            idx1, idx2 = np.random.choice(len(X_class), 2, replace=False)
            
            new_trial = X_class[idx1].copy()
            for cluster_id in range(n_clusters):
                cluster_channels = np.where(channel_labels == cluster_id)[0]
                if len(cluster_channels) > 0:
                    swap_channel = np.random.choice(cluster_channels)
                    new_trial[:, swap_channel] = X_class[idx2][:, swap_channel]
            
            X_aug.append(new_trial)
            y_aug.append(class_idx)
            groups_aug.append(groups_class[idx1])
            n_generated += 1
    
    if X_aug:
        X_aug = np.stack(X_aug, axis=0)
        y_aug = np.array(y_aug, dtype=np.int64)
        groups_aug = np.array(groups_aug, dtype=np.int64)
        
        X_combined = np.concatenate([X, X_aug], axis=0)
        y_combined = np.concatenate([y, y_aug], axis=0)
        groups_combined = np.concatenate([groups, groups_aug], axis=0)
        
        print(f"Aumento completado: {len(X_aug)} muestras añadidas")
        return X_combined, y_combined, groups_combined
    
    return X, y, groups

class EEGDataset(Dataset):
    def __init__(self, X, y, groups, psd_features=None):
        self.X = X.astype(np.float32)
        self.y = y.astype(np.int64)
        self.g = groups.astype(np.int64)
        self.psd_features = psd_features.astype(np.float32) if psd_features is not None else None
        
    def __len__(self): 
        return self.X.shape[0]
    
    def __getitem__(self, idx):
        x = self.X[idx]
        x = np.expand_dims(x, 0)
        
        if self.psd_features is not None:
            return (torch.from_numpy(x), 
                    torch.from_numpy(self.psd_features[idx]), 
                    torch.tensor(self.y[idx]), 
                    torch.tensor(self.g[idx]))
        else:
            return (torch.from_numpy(x), 
                    torch.tensor(self.y[idx]), 
                    torch.tensor(self.g[idx]))

# =========================
# FUNCIONES DE ENTRENAMIENTO CON MÉTRICAS COMPLETAS
# =========================

def calculate_accuracy(output, target):
    pred = output.argmax(dim=1)
    correct = (pred == target).sum().item()
    return correct / target.size(0)

def train_stage1_with_metrics(model, train_loader, val_loader, epochs=80, lr=1e-3):
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
    criterion = nn.CrossEntropyLoss()
    
    best_val_acc = 0.0
    history = {
        'train_loss': [], 'train_acc': [],
        'val_loss': [], 'val_acc': [],
        'lr': []
    }
    
    for epoch in range(epochs):
        # Fase de entrenamiento
        model.train()
        train_loss = 0.0
        train_acc = 0.0
        train_batches = 0
        
        for batch in train_loader:
            if len(batch) == 4:
                x, _, y, _ = batch
            else:
                x, y, _ = batch
                
            x, y = x.to(DEVICE), y.to(DEVICE)
            
            optimizer.zero_grad()
            logits, _ = model(x)
            loss = criterion(logits, y)
            loss.backward()
            
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            train_loss += loss.item()
            train_acc += calculate_accuracy(logits, y)
            train_batches += 1
        
        # Fase de validación
        model.eval()
        val_loss = 0.0
        val_acc = 0.0
        val_batches = 0
        
        with torch.no_grad():
            for batch in val_loader:
                if len(batch) == 4:
                    x, _, y, _ = batch
                else:
                    x, y, _ = batch
                    
                x, y = x.to(DEVICE), y.to(DEVICE)
                logits, _ = model(x)
                loss = criterion(logits, y)
                
                val_loss += loss.item()
                val_acc += calculate_accuracy(logits, y)
                val_batches += 1
        
        # Calcular promedios
        avg_train_loss = train_loss / train_batches
        avg_train_acc = train_acc / train_batches
        avg_val_loss = val_loss / val_batches
        avg_val_acc = val_acc / val_batches
        
        # Actualizar historial
        history['train_loss'].append(avg_train_loss)
        history['train_acc'].append(avg_train_acc)
        history['val_loss'].append(avg_val_loss)
        history['val_acc'].append(avg_val_acc)
        history['lr'].append(scheduler.get_last_lr()[0])
        
        scheduler.step()
        
        # Guardar mejor modelo
        if avg_val_acc > best_val_acc:
            best_val_acc = avg_val_acc
            torch.save(model.state_dict(), 'best_stage1.pth')
        
        # Log cada 10 épocas
        if epoch % 10 == 0 or epoch == epochs - 1:
            current_lr = scheduler.get_last_lr()[0]
            print(f'Epoch {epoch:3d}: '
                  f'Train Loss: {avg_train_loss:.4f}, Train Acc: {avg_train_acc:.4f} | '
                  f'Val Loss: {avg_val_loss:.4f}, Val Acc: {avg_val_acc:.4f} | '
                  f'LR: {current_lr:.6f}')
            
            # Detectar overfitting
            if avg_train_acc - avg_val_acc > 0.15:
                print('⚠️  Posible overfitting detectado!')
    
    return history, best_val_acc

def train_stage2_with_metrics(model, train_loader, val_loader, epochs=40, lr=5e-4):
    optimizer = optim.AdamW(model.stage2.parameters(), lr=lr, weight_decay=1e-4)
    criterion = nn.CrossEntropyLoss()
    
    best_val_acc = 0.0
    history = {
        'train_loss': [], 'train_acc': [],
        'val_loss': [], 'val_acc': []
    }
    
    for epoch in range(epochs):
        # Fase de entrenamiento
        model.train()
        train_loss = 0.0
        train_acc = 0.0
        train_batches = 0
        
        for batch in train_loader:
            x, psd, y, _ = batch
            x, psd, y = x.to(DEVICE), psd.to(DEVICE), y.to(DEVICE)
            
            optimizer.zero_grad()
            _, logits_stage2, _ = model(x, psd)
            loss = criterion(logits_stage2, y)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            train_acc += calculate_accuracy(logits_stage2, y)
            train_batches += 1
        
        # Fase de validación
        model.eval()
        val_loss = 0.0
        val_acc = 0.0
        val_batches = 0
        
        with torch.no_grad():
            for batch in val_loader:
                x, psd, y, _ = batch
                x, psd, y = x.to(DEVICE), psd.to(DEVICE), y.to(DEVICE)
                _, logits_stage2, _ = model(x, psd)
                loss = criterion(logits_stage2, y)
                
                val_loss += loss.item()
                val_acc += calculate_accuracy(logits_stage2, y)
                val_batches += 1
        
        # Calcular promedios
        avg_train_loss = train_loss / train_batches
        avg_train_acc = train_acc / train_batches
        avg_val_loss = val_loss / val_batches
        avg_val_acc = val_acc / val_batches
        
        # Actualizar historial
        history['train_loss'].append(avg_train_loss)
        history['train_acc'].append(avg_train_acc)
        history['val_loss'].append(avg_val_loss)
        history['val_acc'].append(avg_val_acc)
        
        # Guardar mejor modelo
        if avg_val_acc > best_val_acc:
            best_val_acc = avg_val_acc
            torch.save(model.state_dict(), 'best_stage2.pth')
        
        # Log cada 10 épocas
        if epoch % 10 == 0 or epoch == epochs - 1:
            print(f'Epoch {epoch:3d}: '
                  f'Train Loss: {avg_train_loss:.4f}, Train Acc: {avg_train_acc:.4f} | '
                  f'Val Loss: {avg_val_loss:.4f}, Val Acc: {avg_val_acc:.4f}')
            
            if avg_train_acc - avg_val_acc > 0.15:
                print('⚠️  Posible overfitting detectado en Stage 2!')
    
    return history, best_val_acc

@torch.no_grad()
def evaluate_stage1(model, loader):
    model.eval()
    correct = 0
    total = 0
    
    for batch in loader:
        if len(batch) == 4:
            x, _, y, _ = batch
        else:
            x, y, _ = batch
            
        x, y = x.to(DEVICE), y.to(DEVICE)
        logits, _ = model(x)
        pred = logits.argmax(dim=1)
        correct += (pred == y).sum().item()
        total += y.size(0)
    
    return correct / total

@torch.no_grad()
def evaluate_stage2(model, loader):
    model.eval()
    correct = 0
    total = 0
    
    for batch in loader:
        x, psd, y, _ = batch
        x, psd, y = x.to(DEVICE), psd.to(DEVICE), y.to(DEVICE)
        _, logits_stage2, _ = model(x, psd)
        pred = logits_stage2.argmax(dim=1)
        correct += (pred == y).sum().item()
        total += y.size(0)
    
    return correct / total

def plot_training_curves(history, fold, stage):
    plt.figure(figsize=(12, 4))
    
    plt.subplot(1, 2, 1)
    plt.plot(history['train_loss'], label='Train Loss')
    plt.plot(history['val_loss'], label='Val Loss')
    plt.title(f'Fold {fold} - Stage {stage} - Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(history['train_acc'], label='Train Acc')
    plt.plot(history['val_acc'], label='Val Acc')
    plt.title(f'Fold {fold} - Stage {stage} - Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    
    plt.tight_layout()
    plt.savefig(f'training_curves_fold{fold}_stage{stage}.png', dpi=150, bbox_inches='tight')
    plt.close()

# =========================
# EXPERIMENTO PRINCIPAL COMPLETO
# =========================
def run_two_stage_experiment():
    print("🧠 TWO-STAGE TRANSFORMER - VERSIÓN COMPLETA AUTO-CONTENIDA")
    
    # Cargar datos
    subs = subjects_available()
    print(f"Sujetos disponibles: {len(subs)}")
    
    X, y, groups, chs = build_dataset_all(subs, scenario=CLASS_SCENARIO, window_mode=WINDOW_MODE)
    
    if X is None:
        print("Error: No se pudieron cargar los datos.")
        return None
        
    N, T, C = X.shape
    n_classes = len(np.unique(y))
    
    print(f"Datos cargados: N={N}, T={T}, C={C}, Clases={n_classes}")
    
    # Data augmentation
    X_aug, y_aug, groups_aug = channel_cluster_swapping_augmentation(
        X, y, groups, n_clusters=N_CLUSTERS, n_augmented_per_class=AUG_SAMPLES_PER_CLASS
    )
    
    # Extraer características PSD
    psd_features = extract_psd_features(X_aug, fs=FS)
    
    # Crear dataset
    dataset = EEGDataset(X_aug, y_aug, groups_aug, psd_features)
    
    # Cross-validation con métricas completas
    unique_subjects = np.unique(groups_aug)
    gkf = GroupKFold(n_splits=min(N_FOLDS, len(unique_subjects)))
    
    fold_results = {
        'stage1': {'train_acc': [], 'val_acc': [], 'test_acc': []},
        'stage2': {'train_acc': [], 'val_acc': [], 'test_acc': []}
    }
    
    for fold, (train_idx, test_idx) in enumerate(gkf.split(X_aug, y_aug, groups_aug)):
        print(f"\n{'='*60}")
        print(f"=== FOLD {fold + 1}/{min(N_FOLDS, len(unique_subjects))} ===")
        print(f"{'='*60}")
        
        # Split train/val
        train_subjects = np.unique(groups_aug[train_idx])
        val_size = max(1, int(0.15 * len(train_subjects)))
        val_subjects = np.random.choice(train_subjects, val_size, replace=False)
        
        val_mask = np.isin(groups_aug[train_idx], val_subjects)
        val_idx = train_idx[val_mask]
        train_idx_fold = train_idx[~val_mask]
        
        # DataLoaders
        train_loader = DataLoader(Subset(dataset, train_idx_fold), batch_size=BATCH_SIZE, shuffle=True)
        val_loader = DataLoader(Subset(dataset, val_idx), batch_size=BATCH_SIZE, shuffle=False)
        test_loader = DataLoader(Subset(dataset, test_idx), batch_size=BATCH_SIZE, shuffle=False)
        
        print(f"📊 Dataset sizes: Train: {len(train_idx_fold)}, Val: {len(val_idx)}, Test: {len(test_idx)}")
        
        # Modelo
        model = TwoStageTransformer(n_ch=C, n_classes=n_classes, T=T, time_slices=TIME_SLICES).to(DEVICE)
        
        try:
            # Stage 1 con métricas completas
            print("\n🎯 STAGE 1 TRAINING")
            history_stage1, best_val_stage1 = train_stage1_with_metrics(
                model, train_loader, val_loader, EPOCHS_STAGE1, LR_STAGE1
            )
            
            # Evaluar en test
            test_acc_stage1 = evaluate_stage1(model, test_loader)
            train_acc_stage1 = history_stage1['train_acc'][-1]
            val_acc_stage1 = history_stage1['val_acc'][-1]
            
            fold_results['stage1']['train_acc'].append(train_acc_stage1)
            fold_results['stage1']['val_acc'].append(val_acc_stage1)
            fold_results['stage1']['test_acc'].append(test_acc_stage1)
            
            print(f"✅ Stage 1 Final - Train: {train_acc_stage1:.4f}, Val: {val_acc_stage1:.4f}, Test: {test_acc_stage1:.4f}")
            
            # Graficar curvas
            plot_training_curves(history_stage1, fold + 1, 1)
            
            # Stage 2 con métricas completas
            print("\n🎯 STAGE 2 TRAINING")
            history_stage2, best_val_stage2 = train_stage2_with_metrics(
                model, train_loader, val_loader, EPOCHS_STAGE2, LR_STAGE2
            )
            
            # Evaluar en test
            test_acc_stage2 = evaluate_stage2(model, test_loader)
            train_acc_stage2 = history_stage2['train_acc'][-1]
            val_acc_stage2 = history_stage2['val_acc'][-1]
            
            fold_results['stage2']['train_acc'].append(train_acc_stage2)
            fold_results['stage2']['val_acc'].append(val_acc_stage2)
            fold_results['stage2']['test_acc'].append(test_acc_stage2)
            
            print(f"✅ Stage 2 Final - Train: {train_acc_stage2:.4f}, Val: {val_acc_stage2:.4f}, Test: {test_acc_stage2:.4f}")
            print(f"📈 Improvement: +{(test_acc_stage2 - test_acc_stage1)*100:.2f}%")
            
            # Graficar curvas
            plot_training_curves(history_stage2, fold + 1, 2)
            
        except Exception as e:
            print(f"❌ Error en fold {fold + 1}: {e}")
            import traceback
            traceback.print_exc()
            continue
    
    # Resultados finales detallados
    if fold_results['stage1']['test_acc']:
        print("\n" + "="*70)
        print("🎯 RESULTADOS FINALES DETALLADOS")
        print("="*70)
        
        for stage in ['stage1', 'stage2']:
            train_mean = np.mean(fold_results[stage]['train_acc'])
            train_std = np.std(fold_results[stage]['train_acc'])
            val_mean = np.mean(fold_results[stage]['val_acc'])
            val_std = np.std(fold_results[stage]['val_acc'])
            test_mean = np.mean(fold_results[stage]['test_acc'])
            test_std = np.std(fold_results[stage]['test_acc'])
            
            print(f"\n{stage.upper()}:")
            print(f"  Train: {train_mean:.4f} ± {train_std:.4f}")
            print(f"  Val:   {val_mean:.4f} ± {val_std:.4f}")
            print(f"  Test:  {test_mean:.4f} ± {test_std:.4f}")
            print(f"  Overfitting (Train-Val): {(train_mean - val_mean)*100:.2f}%")
        
        improvement = (np.mean(fold_results['stage2']['test_acc']) - 
                     np.mean(fold_results['stage1']['test_acc'])) * 100
        print(f"\n📈 MEJORA TOTAL Stage 2 vs Stage 1: +{improvement:.2f}%")
        
        return fold_results
    else:
        print("❌ No se completaron folds exitosamente.")
        return None

# =========================
# EJECUCIÓN
# =========================
if __name__ == "__main__":
    results = run_two_stage_experiment()

🚀 Usando dispositivo: cuda
🧠 TWO-STAGE TRANSFORMER - VERSIÓN COMPLETA AUTO-CONTENIDA
Sujetos disponibles: 103


Construyendo dataset (RAW):   0%|          | 0/103 [00:00<?, ?it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):   1%|          | 1/103 [00:00<00:20,  4.93it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):   2%|▏         | 2/103 [00:00<00:20,  5.00it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):   3%|▎         | 3/103 [00:00<00:20,  4.97it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):   4%|▍         | 4/103 [00:00<00:19,  4.98it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):   5%|▍         | 5/103 [00:01<00:19,  5.00it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):   6%|▌         | 6/103 [00:01<00:19,  4.99it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):   7%|▋         | 7/103 [00:01<00:19,  4.99it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):   8%|▊         | 8/103 [00:01<00:19,  4.95it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):   9%|▊         | 9/103 [00:01<00:19,  4.91it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  10%|▉         | 10/103 [00:02<00:18,  4.92it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  11%|█         | 11/103 [00:02<00:18,  4.91it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  12%|█▏        | 12/103 [00:02<00:18,  4.91it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  13%|█▎        | 13/103 [00:02<00:18,  4.86it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  14%|█▎        | 14/103 [00:02<00:18,  4.88it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  15%|█▍        | 15/103 [00:03<00:18,  4.88it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  16%|█▌        | 16/103 [00:03<00:17,  4.89it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  17%|█▋        | 17/103 [00:03<00:17,  4.89it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  17%|█▋        | 18/103 [00:03<00:17,  4.88it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  18%|█▊        | 19/103 [00:03<00:17,  4.89it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  19%|█▉        | 20/103 [00:04<00:16,  4.91it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  20%|██        | 21/103 [00:04<00:16,  4.90it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  21%|██▏       | 22/103 [00:04<00:16,  4.86it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  22%|██▏       | 23/103 [00:04<00:16,  4.86it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  23%|██▎       | 24/103 [00:04<00:16,  4.89it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  24%|██▍       | 25/103 [00:05<00:15,  4.89it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  25%|██▌       | 26/103 [00:05<00:15,  4.90it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  26%|██▌       | 27/103 [00:05<00:15,  4.91it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  27%|██▋       | 28/103 [00:05<00:15,  4.91it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  28%|██▊       | 29/103 [00:05<00:15,  4.91it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  29%|██▉       | 30/103 [00:06<00:14,  4.90it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  30%|███       | 31/103 [00:06<00:14,  4.89it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  31%|███       | 32/103 [00:06<00:14,  4.90it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  32%|███▏      | 33/103 [00:06<00:14,  4.90it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  33%|███▎      | 34/103 [00:06<00:14,  4.91it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  34%|███▍      | 35/103 [00:07<00:13,  4.88it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  35%|███▍      | 36/103 [00:07<00:13,  4.92it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  36%|███▌      | 37/103 [00:07<00:13,  4.91it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  37%|███▋      | 38/103 [00:07<00:13,  4.92it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  38%|███▊      | 39/103 [00:07<00:13,  4.91it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  39%|███▉      | 40/103 [00:08<00:12,  4.92it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  40%|███▉      | 41/103 [00:08<00:12,  4.89it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  41%|████      | 42/103 [00:08<00:12,  4.89it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  42%|████▏     | 43/103 [00:08<00:12,  4.89it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  43%|████▎     | 44/103 [00:08<00:12,  4.91it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  44%|████▎     | 45/103 [00:09<00:11,  4.86it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  45%|████▍     | 46/103 [00:09<00:11,  4.88it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  46%|████▌     | 47/103 [00:09<00:11,  4.90it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  47%|████▋     | 48/103 [00:09<00:11,  4.89it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  48%|████▊     | 49/103 [00:09<00:11,  4.91it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  49%|████▊     | 50/103 [00:10<00:10,  4.92it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  50%|████▉     | 51/103 [00:10<00:10,  4.93it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  50%|█████     | 52/103 [00:10<00:10,  4.90it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  51%|█████▏    | 53/103 [00:10<00:10,  4.90it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  52%|█████▏    | 54/103 [00:11<00:10,  4.89it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  53%|█████▎    | 55/103 [00:11<00:09,  4.90it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  54%|█████▍    | 56/103 [00:11<00:09,  4.91it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  55%|█████▌    | 57/103 [00:11<00:09,  4.90it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  56%|█████▋    | 58/103 [00:11<00:09,  4.89it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  57%|█████▋    | 59/103 [00:12<00:08,  4.91it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  58%|█████▊    | 60/103 [00:12<00:08,  4.88it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  59%|█████▉    | 61/103 [00:12<00:08,  4.92it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  60%|██████    | 62/103 [00:12<00:08,  4.91it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  61%|██████    | 63/103 [00:12<00:08,  4.92it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  62%|██████▏   | 64/103 [00:13<00:07,  4.91it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  63%|██████▎   | 65/103 [00:13<00:07,  4.89it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  64%|██████▍   | 66/103 [00:13<00:07,  4.89it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  65%|██████▌   | 67/103 [00:13<00:07,  4.90it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  66%|██████▌   | 68/103 [00:13<00:07,  4.89it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  67%|██████▋   | 69/103 [00:14<00:06,  4.87it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  68%|██████▊   | 70/103 [00:14<00:06,  4.88it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  69%|██████▉   | 71/103 [00:14<00:06,  4.87it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  70%|██████▉   | 72/103 [00:14<00:06,  4.89it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  71%|███████   | 73/103 [00:14<00:06,  4.92it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  72%|███████▏  | 74/103 [00:15<00:05,  4.91it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  73%|███████▎  | 75/103 [00:15<00:05,  4.93it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  74%|███████▍  | 76/103 [00:15<00:05,  4.91it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  75%|███████▍  | 77/103 [00:15<00:05,  4.91it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  76%|███████▌  | 78/103 [00:15<00:05,  4.91it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  77%|███████▋  | 79/103 [00:16<00:04,  4.87it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  78%|███████▊  | 80/103 [00:16<00:04,  4.88it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  79%|███████▊  | 81/103 [00:16<00:04,  4.89it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  80%|███████▉  | 82/103 [00:16<00:04,  4.88it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  81%|████████  | 83/103 [00:16<00:04,  4.88it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  82%|████████▏ | 84/103 [00:17<00:03,  4.89it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  83%|████████▎ | 85/103 [00:17<00:03,  4.88it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  83%|████████▎ | 86/103 [00:17<00:03,  4.86it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  84%|████████▍ | 87/103 [00:17<00:03,  4.88it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  85%|████████▌ | 88/103 [00:17<00:03,  4.88it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  86%|████████▋ | 89/103 [00:18<00:02,  4.88it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  87%|████████▋ | 90/103 [00:18<00:02,  4.87it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  88%|████████▊ | 91/103 [00:18<00:02,  4.88it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  89%|████████▉ | 92/103 [00:18<00:02,  4.85it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  90%|█████████ | 93/103 [00:18<00:02,  4.86it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  91%|█████████▏| 94/103 [00:19<00:01,  4.88it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  92%|█████████▏| 95/103 [00:19<00:01,  4.89it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  93%|█████████▎| 96/103 [00:19<00:01,  4.86it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  94%|█████████▍| 97/103 [00:19<00:01,  4.86it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  95%|█████████▌| 98/103 [00:20<00:01,  4.86it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  96%|█████████▌| 99/103 [00:20<00:00,  4.85it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  97%|█████████▋| 100/103 [00:20<00:00,  4.84it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  98%|█████████▊| 101/103 [00:20<00:00,  4.84it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW):  99%|█████████▉| 102/103 [00:20<00:00,  4.85it/s]

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


Construyendo dataset (RAW): 100%|██████████| 103/103 [00:21<00:00,  4.90it/s]

Dataset construido: N=8652 | T=960 | C=8 | clases=4 | sujetos únicos=103
Datos cargados: N=8652, T=960, C=8, Clases=4
Aplicando Channel Cluster Swapping Augmentation...





Aumento completado: 120 muestras añadidas
Extrayendo características PSD...


  band_power = np.trapz(psd[band_mask], freqs[band_mask])
100%|██████████| 8772/8772 [00:04<00:00, 1860.60it/s]



=== FOLD 1/5 ===
📊 Dataset sizes: Train: 5966, Val: 1018, Test: 1788
EEGNet output dimension: 3840
Stage 1 output dim: 256
TabNet input dim: 296
TabNet n_d: 64

🎯 STAGE 1 TRAINING
Epoch   0: Train Loss: 1.4421, Train Acc: 0.2524 | Val Loss: 1.3998, Val Acc: 0.2579 | LR: 0.001000
Epoch  10: Train Loss: 1.0575, Train Acc: 0.5432 | Val Loss: 1.5017, Val Acc: 0.3338 | LR: 0.000954
⚠️  Posible overfitting detectado!
Epoch  20: Train Loss: 0.5647, Train Acc: 0.7851 | Val Loss: 2.1539, Val Acc: 0.3389 | LR: 0.000839
⚠️  Posible overfitting detectado!
Epoch  30: Train Loss: 0.3190, Train Acc: 0.8857 | Val Loss: 2.4607, Val Acc: 0.3514 | LR: 0.000673
⚠️  Posible overfitting detectado!
Epoch  40: Train Loss: 0.2117, Train Acc: 0.9264 | Val Loss: 2.9458, Val Acc: 0.3377 | LR: 0.000480
⚠️  Posible overfitting detectado!
Epoch  50: Train Loss: 0.1370, Train Acc: 0.9575 | Val Loss: 3.3939, Val Acc: 0.3307 | LR: 0.000291
⚠️  Posible overfitting detectado!
Epoch  60: Train Loss: 0.1066, Train Acc: 0.

Traceback (most recent call last):
  File "<ipython-input-9-e22fa1ddfae2>", line 989, in run_two_stage_experiment
    history_stage2, best_val_stage2 = train_stage2_with_metrics(
                                      ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<ipython-input-9-e22fa1ddfae2>", line 788, in train_stage2_with_metrics
    _, logits_stage2, _ = model(x, psd)
                          ^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<ipython-input-9-e22fa1ddfae2>", line 546, in forward
    logits_stage2 = self.stage2(combined_features)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/to

Epoch   0: Train Loss: 1.4439, Train Acc: 0.2482 | Val Loss: 1.4022, Val Acc: 0.2598 | LR: 0.001000
Epoch  10: Train Loss: 1.0932, Train Acc: 0.5215 | Val Loss: 1.3784, Val Acc: 0.3598 | LR: 0.000954
⚠️  Posible overfitting detectado!
Epoch  20: Train Loss: 0.6109, Train Acc: 0.7632 | Val Loss: 2.0457, Val Acc: 0.3423 | LR: 0.000839
⚠️  Posible overfitting detectado!
Epoch  30: Train Loss: 0.3448, Train Acc: 0.8786 | Val Loss: 2.4451, Val Acc: 0.3433 | LR: 0.000673
⚠️  Posible overfitting detectado!


KeyboardInterrupt: 

In [2]:
# -*- coding: utf-8 -*-
# ===========================================
# MI-EEG (PhysioNet eegmmidb) — CNN+Transformer (Balanced 21/cls/subj) + Sliding Window
# Patch: usar Kfold5.json SOLO para sujetos por fold (ignorar tr_idx/te_idx)
# Split por sujeto/ensayo -> luego windowing (sin fuga). Métrica por ENSAYO.
# FT por sujeto (L2-SP) agrupando por ensayo (sin fuga).
# ===========================================

import os, re, math, json, copy, random, itertools, collections
from pathlib import Path
from datetime import datetime

import numpy as np
import pandas as pd
import mne

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler, Subset

from sklearn.model_selection import GroupKFold, GroupShuffleSplit, StratifiedKFold
from sklearn.metrics import confusion_matrix, classification_report, f1_score, cohen_kappa_score
from sklearn.utils import check_random_state
from tqdm import tqdm
import matplotlib.pyplot as plt

# =========================
# CONFIG
# =========================
PROJ = Path('..').resolve().parent
DATA_RAW = PROJ / 'data' / 'raw'           # .../S###/S###R##.edf
FOLDS_JSON = PROJ / 'models' / 'folds' / 'Kfold5.json'
CACHE_DIR = PROJ / 'data' / 'cache'
CACHE_DIR.mkdir(parents=True, exist_ok=True)

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
RANDOM_STATE = 42
torch.manual_seed(RANDOM_STATE); np.random.seed(RANDOM_STATE); random.seed(RANDOM_STATE)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
print(f"🚀 Dispositivo: {DEVICE}")

FS = 160.0

# Escenario y rango base por ensayo
CLASS_SCENARIO = '4c'            # '2c','3c','4c'
WINDOW_GLOBAL = (0.0, 4.0)       # segundos rel. T1/T2 (ensayo base)

# Sliding window (mismo nº por ensayo)
WIN_LEN_SEC = 2.0
WIN_STRIDE_SEC = 0.5
N_SUBWINS = int(math.floor((WINDOW_GLOBAL[1]-WINDOW_GLOBAL[0]-WIN_LEN_SEC)/WIN_STRIDE_SEC) + 1)

# Balance 21 por clase/sujeto
N_PER_CLASS_PER_SUBJECT = 21

# CV
N_FOLDS = 5
GLOBAL_VAL_SPLIT = 0.15
GLOBAL_PATIENCE  = 10

# Entrenamiento
BATCH_SIZE = 64
EPOCHS_GLOBAL = 80
LR_INIT = 1e-3

# FT por sujeto
CALIB_CV_FOLDS = 4
FT_EPOCHS = 30
FT_BASE_LR = 5e-5
FT_HEAD_LR = 1e-3
FT_L2SP = 1e-4
FT_PATIENCE = 5
FT_VAL_RATIO = 0.2

# Band-pass
USE_BANDPASS = True
BP_LO, BP_HI = 4.0, 38.0
BP_METHOD, BP_PHASE = 'fir', 'zero'

# Dataset
EXCLUDE_SUBJECTS = {38,88,89,92,100,104}
MI_RUNS_LR = [4,8,12]            # Left/Right
MI_RUNS_OF = [6,10,14]           # Both Fists / Both Feet
EXPECTED_8 = ['C3','C4','Cz','CP3','CP4','FC3','FC4','FCz']

CLASS_NAMES = ['Left', 'Right', 'Both Fists', 'Both Feet']

# =========================
# UTIL: canales/edf
# =========================
def normalize_label(s: str) -> str:
    if s is None: return s
    s = s.strip()
    s = re.sub(r'[^A-Za-z0-9]', '', s)
    s = re.sub(r'([A-Za-z])0([0-9])', r'\1\2', s)
    s = re.sub(r'([A-Za-z])Z$', r'\1z', s)
    s = s.replace('fp', 'Fp').replace('FP', 'Fp')
    s = ''.join(ch.upper() if ch != 'z' else 'z' for ch in s)
    return s

def rename_channels_1010(raw: mne.io.BaseRaw):
    mapping = {}
    for ch in raw.ch_names:
        lab = normalize_label(ch)
        lab = lab[:-1] + 'z' if lab.endswith('Z') else lab
        lab = re.sub(r'([A-Z])Z$', r'\1z', lab)
        mapping[ch] = lab
    mne.rename_channels(raw.info, mapping)

def ensure_channels_order(raw: mne.io.BaseRaw, desired_channels=EXPECTED_8):
    missing = [ch for ch in desired_channels if ch not in raw.ch_names]
    if missing:
        print(f"[WARN] faltan canales {missing} en {getattr(raw,'filenames',[''])[0]}")
        return None
    raw.reorder_channels([ch for ch in raw.ch_names if ch in desired_channels] +
                         [ch for ch in raw.ch_names if ch not in desired_channels])
    raw.pick_channels(desired_channels, ordered=True)
    return raw

_re_file = re.compile(r'[Ss](\d{3}).*?[Rr](\d{2})')
def parse_subject_run(path: Path):
    m = _re_file.search(str(path))
    if not m: return None, None
    return int(m.group(1)), int(m.group(2))

def run_kind(run_id:int):
    if run_id in MI_RUNS_LR: return 'LR'
    if run_id in MI_RUNS_OF: return 'OF'
    return None

_SNR_TABLE = None
def _load_snr_table():
    global _SNR_TABLE
    if _SNR_TABLE is not None: return _SNR_TABLE
    csv_path = PROJ / 'reports' / 'psd_mains' / 'psd_mains_summary.csv'
    if csv_path.exists():
        try: _SNR_TABLE = pd.read_csv(csv_path)
        except Exception as e:
            print(f"[SNR] No se pudo leer {csv_path}: {e}")
            _SNR_TABLE = None
    return _SNR_TABLE

def _decide_notch(subject, run, th_db=10.0):
    df = _load_snr_table()
    if df is None:  return 60.0
    row = df[(df['subject']==subject) & (df['run']==run)]
    if row.empty:   return 60.0
    snr50 = float(row['snr50_db'].iloc[0]); snr60 = float(row['snr60_db'].iloc[0])
    if snr60 >= th_db and snr60 >= snr50: return 60.0
    if snr50 >= th_db and snr50 >  snr60: return 50.0
    return None

def read_raw_edf(path: Path):
    raw = mne.io.read_raw_edf(path, preload=True, verbose=False)
    raw.pick(mne.pick_types(raw.info, eeg=True))
    rename_channels_1010(raw)
    try:
        mont = mne.channels.make_standard_montage('standard_1020')
        raw.set_montage(mont, on_missing='ignore')
    except Exception:
        pass
    if abs(raw.info['sfreq'] - FS) > 1e-6:
        raw.resample(FS, npad="auto")
    raw = ensure_channels_order(raw, EXPECTED_8)
    if raw is None: return None

    sid, rid = parse_subject_run(path)
    notch = _decide_notch(sid, rid)
    if notch is not None:
        raw.notch_filter(freqs=[float(notch)], picks='eeg', method='spectrum_fit', phase='zero')

    if USE_BANDPASS and (BP_LO is not None) and (BP_HI is not None):
        raw.filter(l_freq=float(BP_LO), h_freq=float(BP_HI),
                   picks='eeg', method=BP_METHOD, phase=BP_PHASE, verbose=False)
    return raw

def collect_events_T1T2(raw: mne.io.BaseRaw):
    out = []
    if raw.annotations is None or len(raw.annotations) == 0: return out
    def _norm(s): return str(s).strip().upper().replace(' ','')
    for onset, desc in zip(raw.annotations.onset, raw.annotations.description):
        tag = _norm(desc)
        if tag in ('T1','T2'): out.append((float(onset), tag))
    out.sort()
    dedup = []; last_t1 = last_t2 = -1e9
    for t, tag in out:
        if tag == 'T1':
            if (t - last_t1) >= 0.5: dedup.append((t, tag)); last_t1 = t
        else:
            if (t - last_t2) >= 0.5: dedup.append((t, tag)); last_t2 = t
    return dedup

# =========================
# Dataset helpers
# =========================
def subjects_available():
    subs = []
    for sdir in sorted(DATA_RAW.glob('S*')):
        if not sdir.is_dir(): continue
        try: sid = int(sdir.name[1:])
        except: continue
        if sid in EXCLUDE_SUBJECTS: continue
        any_mi = any((sdir / f"S{sid:03d}R{r:02d}.edf").exists() for r in (MI_RUNS_LR+MI_RUNS_OF))
        if any_mi: subs.append(sid)
    return subs

def extract_base_trials_from_run(edf_path: Path, scenario: str, window_global=(0.0,4.0)):
    subj, run = parse_subject_run(edf_path)
    kind = run_kind(run)
    if kind not in ('LR','OF'): return []

    raw = read_raw_edf(edf_path)
    if raw is None: return []
    data = raw.get_data()
    fs = raw.info['sfreq']
    out = []
    events = collect_events_T1T2(raw)
    rel_start, rel_end = window_global
    ev_idx = 0
    for onset_sec, tag in events:
        ev_idx += 1
        if kind == 'LR':
            label = 0 if tag=='T1' else 1
        else:
            label = 2 if tag=='T1' else 3
        if   scenario == '2c' and label not in (0,1): continue
        elif scenario == '3c' and label not in (0,1,2): continue
        elif scenario == '4c' and label not in (0,1,2,3): continue

        s0 = int(round((raw.first_time + onset_sec + rel_start) * fs))
        e0 = int(round((raw.first_time + onset_sec + rel_end)   * fs))
        if s0 < 0 or e0 > data.shape[1] or e0 <= s0: continue

        seg = data[:, s0:e0].T.astype(np.float32)  # (T_base, C)
        seg = (seg - seg.mean(axis=0, keepdims=True)) / (seg.std(axis=0, keepdims=True) + 1e-6)
        trial_key = f"S{subj:03d}_R{run:02d}_E{ev_idx:03d}"
        out.append((seg, label, subj, trial_key))
    return out

def build_balanced_trials(subjects, scenario='4c', window_global=(0.0,4.0), n_per_class=N_PER_CLASS_PER_SUBJECT):
    trials_balanced = []
    for s in tqdm(subjects, desc="Recolectando y balanceando ensayos base por sujeto"):
        sdir = DATA_RAW / f"S{s:03d}"
        if not sdir.exists(): continue
        bins = {0:[],1:[],2:[],3:[]}
        for r in (MI_RUNS_LR + MI_RUNS_OF):
            p = sdir / f"S{s:03d}R{r:02d}.edf"
            if not p.exists(): continue
            for t in extract_base_trials_from_run(p, scenario, window_global):
                bins[t[1]].append(t)
        # descarta sujeto si no tiene todas las clases
        if any(len(bins[c]) == 0 for c in (0,1,2,3)): continue
        rng = check_random_state(RANDOM_STATE + s)
        for c in (0,1,2,3):
            pool = bins[c]
            if len(pool) >= n_per_class:
                rng.shuffle(pool); take = pool[:n_per_class]
            else:
                idx = rng.choice(len(pool), size=n_per_class, replace=True)
                take = [pool[i] for i in idx]
            trials_balanced.extend(take)
    print(f"Ensayos base balanceados: {len(trials_balanced)}  (~ #sujetos_utiles * {n_per_class*4})")
    return trials_balanced

def window_trials(trials, win_len_sec=2.0, win_stride_sec=0.5, fs=160.0):
    """Windowing para una lista de ensayos (de un split). Retorna X,y,subj,trial_ids."""
    X, y, g, tkeys = [], [], [], []
    win_len = int(round(win_len_sec * fs))
    stride  = int(round(win_stride_sec * fs))
    for seg, lab, subj, tkey in trials:
        T = seg.shape[0]
        starts = [i for i in range(0, T - win_len + 1, stride)]
        # asegura N_SUBWINS por ensayo
        if len(starts) > N_SUBWINS: starts = starts[:N_SUBWINS]
        while len(starts) < N_SUBWINS and (T - win_len) >= 0:
            starts.append(T - win_len)
        for s in starts:
            e = s + win_len
            sub = seg[s:e, :].astype(np.float32)
            sub = (sub - sub.mean(axis=0, keepdims=True)) / (sub.std(axis=0, keepdims=True) + 1e-6)
            X.append(sub); y.append(lab); g.append(subj); tkeys.append(tkey)
    uniq = {k:i for i,k in enumerate(sorted(set(tkeys)))}
    trial_ids = np.asarray([uniq[k] for k in tkeys], dtype=np.int64)
    X = np.stack(X, axis=0).astype(np.float32)
    y = np.asarray(y, dtype=np.int64)
    g = np.asarray(g, dtype=np.int64)
    return X, y, g, trial_ids, uniq

# =========================
# Dataset torch
# =========================
class EEGTrials(Dataset):
    def __init__(self, X, y, groups, trial_ids):
        self.X = X.astype(np.float32)
        self.y = y.astype(np.int64)
        self.g = groups.astype(np.int64)
        self.t = trial_ids.astype(np.int64)
    def __len__(self): return self.X.shape[0]
    def __getitem__(self, idx):
        x = self.X[idx]; x = np.expand_dims(x, 0)  # (1,Tw,C)
        return torch.from_numpy(x), torch.tensor(self.y[idx]), torch.tensor(self.g[idx]), torch.tensor(self.t[idx])

def build_weighted_sampler(y, groups):
    y = np.asarray(y); groups = np.asarray(groups)
    class_counts = np.bincount(y, minlength=len(np.unique(y))).astype(float)
    class_w = 1.0 / class_counts[y]
    subj_vals, subj_counts = np.unique(groups, return_counts=True)
    subj_map = {s:c for s,c in zip(subj_vals, subj_counts)}
    subj_w = np.array([1.0/subj_map[g] for g in groups], dtype=float)
    w = class_w * subj_w; w = w / w.mean()
    return WeightedRandomSampler(weights=torch.from_numpy(w).float(), num_samples=len(w), replacement=True)

# =========================
# Modelo: CNN + Transformer
# =========================
class ChannelDropout(nn.Module):
    def __init__(self, p=0.1): super().__init__(); self.p = p
    def forward(self, x):
        if not self.training or self.p<=0: return x
        B,_,T,C = x.shape
        mask = (torch.rand(B,1,1,C, device=x.device) > self.p).float()
        return x * mask

@torch.no_grad()
def apply_max_norm(model, max_value=2.0, p=2.0):
    for m in model.modules():
        if hasattr(m,'weight') and isinstance(m,(nn.Conv2d, nn.Linear)):
            w = m.weight.data
            norms = w.view(w.size(0), -1).norm(p=p, dim=1, keepdim=True)
            desired = torch.clamp(norms, max=max_value)
            w.view(w.size(0), -1).mul_(desired / (1e-8 + norms))

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 4000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0)/d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer('pe', pe.unsqueeze(0), persistent=False)
    def forward(self, x):
        T = x.size(1); return x + self.pe[:, :T, :]

class CNNTokenizer(nn.Module):
    def __init__(self, n_ch: int, d_feat: int = 128, k_temporal: int = 128, stride_t: int = 4, chdrop_p: float = 0.1, drop_p: float = 0.2):
        super().__init__()
        self.chdrop = ChannelDropout(chdrop_p)
        self.conv_t = nn.Conv2d(1, d_feat, kernel_size=(k_temporal,1), stride=(stride_t,1), padding=(k_temporal//2,0), bias=False)
        self.bn_t   = nn.BatchNorm2d(d_feat, momentum=0.99, eps=1e-3)
        self.act    = nn.ELU()
        self.conv_sp_dw = nn.Conv2d(d_feat, d_feat, kernel_size=(1,n_ch), groups=d_feat, bias=False)
        self.bn_sp      = nn.BatchNorm2d(d_feat, momentum=0.99, eps=1e-3)
        self.conv_pw = nn.Conv2d(d_feat, d_feat, kernel_size=(1,1), bias=False)
        self.bn_pw   = nn.BatchNorm2d(d_feat, momentum=0.99, eps=1e-3)
        self.drop    = nn.Dropout(drop_p)
    def forward(self, x):
        z = self.chdrop(x)
        z = self.conv_t(z); z = self.bn_t(z); z = self.act(z)         # (B,D,T',C)
        z = self.conv_sp_dw(z); z = self.bn_sp(z); z = self.act(z)    # (B,D,T',1)
        z = self.conv_pw(z); z = self.bn_pw(z); z = self.act(z)
        z = self.drop(z)
        z = z.squeeze(-1).transpose(1,2)  # (B,T',D)
        return z

def _enc_layer(d_model=128, nhead=4, dim_ff=256, drop=0.2):
    return nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_ff,
                                      dropout=drop, activation='gelu', batch_first=True, norm_first=True)

class TemporalTransformer(nn.Module):
    def __init__(self, d_model=128, nhead=4, dim_ff=256, depth=3, drop=0.2):
        super().__init__()
        self.pe = PositionalEncoding(d_model)
        layer = _enc_layer(d_model, nhead, dim_ff, drop)
        self.enc = nn.TransformerEncoder(layer, num_layers=depth)
        self.norm = nn.LayerNorm(d_model)
    def forward(self, x):
        x = self.pe(x)
        x = self.enc(x)
        return self.norm(x)

class HybridCNNTransformer(nn.Module):
    def __init__(self, n_ch: int, n_classes: int, seq_len: int,
                 d_model: int = 128, nhead: int = 4, depth: int = 3,
                 dim_ff: int = 256, k_temporal: int = 128, stride_t: int = 4,
                 drop: float = 0.2, chdrop_p: float = 0.1, use_cls_token: bool = False):
        super().__init__()
        self.use_cls = use_cls_token
        self.tokenizer = CNNTokenizer(n_ch, d_model, k_temporal, stride_t, chdrop_p, drop)
        self.cls = nn.Parameter(torch.zeros(1,1,d_model)) if use_cls_token else None
        self.transformer = TemporalTransformer(d_model, nhead, dim_ff, depth, drop)
        self.head = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, d_model),
            nn.GELU(),
            nn.Dropout(drop),
            nn.Linear(d_model, n_classes)
        )
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
            elif isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None: nn.init.zeros_(m.bias)

    def forward(self, x):
        tok = self.tokenizer(x)                   # (B,T',D)
        if self.use_cls:
            B = tok.size(0); cls_tok = self.cls.expand(B,-1,-1)
            tok = torch.cat([cls_tok, tok], dim=1)
        z = self.transformer(tok)
        feat = z[:,0,:] if self.use_cls else z.mean(dim=1)
        return self.head(feat)

# =========================
# Entrenamiento / Eval
# =========================
class SoftCE(nn.Module):
    def __init__(self, n_classes: int, label_smoothing: float = 0.05, class_weights: torch.Tensor | None = None):
        super().__init__()
        self.n_classes = n_classes
        self.ls = float(label_smoothing)
        self.register_buffer('w', class_weights if class_weights is not None else None)
    def forward(self, logits, targets_idx):
        B, C = logits.size()
        yt = F.one_hot(targets_idx, num_classes=C).float()
        if self.ls > 0: yt = (1 - self.ls) * yt + self.ls * (1.0 / C)
        logp = F.log_softmax(logits, dim=1)
        loss = -(yt * logp)
        if self.w is not None: loss = loss * self.w.unsqueeze(0)
        return loss.sum(dim=1).mean()

@torch.no_grad()
def _predict_tta(model, xb, n=5, fs=160.0):
    def time_jitter(x, max_ms=25):
        max_shift = int(round(max_ms/1000.0 * fs))
        if max_shift <= 0: return x
        B,_,T,C = x.shape
        shifts = torch.randint(low=-max_shift, high=max_shift+1, size=(B,), device=x.device)
        out = torch.zeros_like(x)
        for i,s in enumerate(shifts):
            if s >= 0: out[i,:,s:,:] = x[i,:,:T-s,:]
            else: s=-s; out[i,:,:T-s,:] = x[i,:,s:,:]
        return out
    outs = []
    for _ in range(n):
        outs.append(model(time_jitter(xb,25)))
    return torch.stack(outs, dim=0).mean(dim=0)

@torch.no_grad()
def evaluate_per_trial(model, loader, use_tta=True, tta_n=5, n_classes=4):
    model.eval()
    sums, counts, labels = {}, {}, {}
    for xb, yb, _, tb in loader:
        xb = xb.to(DEVICE)
        logits = _predict_tta(model, xb, n=tta_n) if use_tta else model(xb)
        logits = logits.detach().cpu().numpy()
        yb = yb.numpy(); tb = tb.numpy()
        for i in range(xb.size(0)):
            t = int(tb[i])
            if t not in sums:
                sums[t] = np.zeros((n_classes,), dtype=np.float64)
                counts[t] = 0
                labels[t] = int(yb[i])
            sums[t] += logits[i]; counts[t] += 1
    ids = sorted(sums.keys())
    y_true_trial = np.asarray([labels[t] for t in ids], dtype=int)
    y_pred_trial = np.asarray([int((sums[t]/max(1,counts[t])).argmax()) for t in ids], dtype=int)
    acc = (y_true_trial == y_pred_trial).mean()
    return y_true_trial, y_pred_trial, float(acc)

def plot_confusion(y_true, y_pred, classes, title, fname):
    cm = confusion_matrix(y_true, y_pred, labels=list(range(len(classes))))
    with np.errstate(invalid='ignore'):
        cm_norm = cm.astype('float') / cm.sum(axis=1, keepdims=True)
    cm_norm = np.nan_to_num(cm_norm)
    plt.figure(figsize=(6,5))
    plt.imshow(cm_norm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title(title); plt.colorbar(fraction=0.046, pad=0.04)
    ticks = np.arange(len(classes))
    plt.xticks(ticks, classes, rotation=45, ha='right'); plt.yticks(ticks, classes)
    fmt = '.2f'; thresh = cm_norm.max()/2.
    for i,j in itertools.product(range(cm_norm.shape[0]), range(cm_norm.shape[1])):
        plt.text(j,i,format(cm_norm[i,j],fmt),ha="center", color="white" if cm_norm[i,j]>thresh else "black")
    plt.ylabel('True'); plt.xlabel('Pred'); plt.tight_layout(); plt.savefig(fname, dpi=150); plt.close()

def print_report(y_true, y_pred, classes, tag=""):
    print(f"\n=== Reporte {tag} ===")
    print(classification_report(y_true, y_pred, target_names=classes, digits=4))
    print(f"F1-macro: {f1_score(y_true, y_pred, average='macro'):.4f} | Kappa: {cohen_kappa_score(y_true,y_pred):.4f}")

# =========================
# FT (L2-SP), SIN fuga (agrupa por ensayo)
# =========================
def _freeze_for_mode(model: HybridCNNTransformer, mode: str):
    for p in model.parameters(): p.requires_grad = False
    if   mode=='out':
        for p in model.head[-1].parameters(): p.requires_grad = True
    elif mode=='head':
        for p in model.head.parameters(): p.requires_grad = True
    elif mode=='cnn+trans+head':
        for p in model.tokenizer.parameters():   p.requires_grad = True
        for p in model.transformer.parameters(): p.requires_grad = True
        for p in model.head.parameters():        p.requires_grad = True
    else:
        raise ValueError(mode)

def _param_groups(model: HybridCNNTransformer, mode: str):
    if   mode=='out': return list(model.head[-1].parameters())
    elif mode=='head': return list(model.head.parameters())
    elif mode=='cnn+trans+head': 
        return list(model.tokenizer.parameters()) + list(model.transformer.parameters()) + list(model.head.parameters())

def finetune_l2sp(model_global: HybridCNNTransformer, Xcal, ycal, trial_ids_cal,
                  n_classes: int, mode='head', epochs=FT_EPOCHS, batch_size=32,
                  base_lr=FT_BASE_LR, head_lr=FT_HEAD_LR, l2sp_lambda=FT_L2SP,
                  val_ratio=FT_VAL_RATIO, seed=RANDOM_STATE, device=DEVICE):
    # split interno sin fuga: por ensayo (GroupShuffleSplit)
    gss = GroupShuffleSplit(n_splits=1, test_size=val_ratio, random_state=seed)
    (tr_idx, va_idx), = gss.split(Xcal, ycal, groups=trial_ids_cal)
    Xtr, ytr = Xcal[tr_idx], ycal[tr_idx]
    Xva, yva = Xcal[va_idx], ycal[va_idx]

    ds_tr = torch.utils.data.TensorDataset(torch.from_numpy(Xtr).float().unsqueeze(1),
                                           torch.from_numpy(ytr).long())
    ds_va = torch.utils.data.TensorDataset(torch.from_numpy(Xva).float().unsqueeze(1),
                                           torch.from_numpy(yva).long())
    dl_tr = DataLoader(ds_tr, batch_size=batch_size, shuffle=True, drop_last=False)
    dl_va = DataLoader(ds_va, batch_size=batch_size, shuffle=False, drop_last=False)

    model = copy.deepcopy(model_global).to(device)
    _freeze_for_mode(model, mode)
    params = _param_groups(model, mode)

    if mode=='cnn+trans+head':
        tok_params = list(model.tokenizer.parameters()) + list(model.transformer.parameters())
        head_params = list(model.head.parameters())
        opt = torch.optim.Adam([
            {'params': tok_params,  'lr': base_lr},
            {'params': head_params, 'lr': head_lr},
        ])
        train_params = tok_params + head_params
    else:
        opt = torch.optim.Adam(params, lr=head_lr)
        train_params = params

    ref = [p.detach().clone().to(device) for p in train_params]   # L2-SP ref
    criterion = nn.CrossEntropyLoss()

    best_state, best_val, bad = copy.deepcopy(model.state_dict()), float('inf'), 0
    for _ in range(epochs):
        model.train()
        for xb, yb in dl_tr:
            xb, yb = xb.to(device), yb.to(device)
            opt.zero_grad(set_to_none=True)
            logits = model(xb)
            loss = criterion(logits, yb)
            reg = 0.0
            for p_cur, p_ref in zip(train_params, ref):
                reg = reg + torch.sum((p_cur - p_ref)**2)
            (loss + l2sp_lambda*reg).backward(); opt.step()
            apply_max_norm(model, 2.0, p=2.0)

        model.eval()
        val_loss, nval = 0.0, 0
        with torch.no_grad():
            for xb, yb in dl_va:
                xb, yb = xb.to(device), yb.to(device)
                val_loss += F.cross_entropy(model(xb), yb).item() * xb.size(0)
                nval += xb.size(0)
        val_loss /= max(1,nval)
        if val_loss + 1e-7 < best_val: best_val, bad, best_state = val_loss, 0, copy.deepcopy(model.state_dict())
        else:
            bad += 1
            if bad >= FT_PATIENCE: break

    model.load_state_dict(best_state)
    return model

# =========================
# Folds JSON: SOLO sujetos (ignora índices)
# =========================
def read_folds_subjects_only(json_path, current_subjects):
    """
    Lee Kfold5.json y devuelve lista de folds con train_subjects/test_subjects (enteros).
    Ignora por completo tr_idx/te_idx si existen.
    Valida que los sujetos existan en el dataset actual.
    """
    p = Path(json_path)
    if not p.exists():
        raise FileNotFoundError(f"No existe {p}. Debes proporcionar un JSON con sujetos por fold.")
    with open(p, 'r', encoding='utf-8') as f:
        payload = json.load(f)
    folds_raw = payload.get('folds', [])
    if not folds_raw:
        raise ValueError("JSON no contiene 'folds'.")

    cur_set = set(current_subjects)
    folds = []
    for fr in folds_raw:
        train_s = fr.get('train_subjects', fr.get('train', []))
        test_s  = fr.get('test_subjects',  fr.get('test',  []))
        if not train_s or not test_s:
            raise ValueError("Cada fold debe tener 'train_subjects' y 'test_subjects' (o 'train'/'test').")

        def to_int_list(lst):
            out = []
            for s in lst:
                if isinstance(s, int): out.append(s)
                elif isinstance(s, str) and s.upper().startswith('S') and s[1:].isdigit():
                    out.append(int(s[1:]))
                elif isinstance(s, str) and s.isdigit():
                    out.append(int(s))
                else:
                    raise ValueError(f"Sujeto inválido en JSON: {s}")
            return out

        tr_int = to_int_list(train_s)
        te_int = to_int_list(test_s)

        # validación: sujetos del JSON deben existir
        miss_tr = [s for s in tr_int if s not in cur_set]
        miss_te = [s for s in te_int if s not in cur_set]
        if miss_tr or miss_te:
            raise ValueError(f"Fold con sujetos no presentes en dataset actual. "
                             f"Faltan en TRAIN: {miss_tr}, Faltan en TEST: {miss_te}")

        folds.append({'fold': fr.get('fold', len(folds)+1),
                      'train_subjects': tr_int, 'test_subjects': te_int})
    return folds

# =========================
# Experimento
# =========================
def plot_training_curves(history, fname):
    plt.figure(figsize=(6,4))
    plt.plot(history['train_acc'], label='train_acc')
    plt.plot(history['val_acc'], label='val_acc')
    plt.xlabel('Epoch'); plt.ylabel('Accuracy'); plt.title('Training curve'); plt.legend()
    plt.tight_layout(); plt.savefig(fname, dpi=150); plt.close()

def run_experiment():
    mne.set_log_level('WARNING')
    subs = subjects_available()
    print(f"Sujetos elegibles: {len(subs)} → {subs[:10]}{'...' if len(subs)>10 else ''}")

    # 1) Ensayos balanceados (21/clase/sujeto) sobre TODOS los sujetos elegibles
    trials_bal = build_balanced_trials(subs, scenario=CLASS_SCENARIO, window_global=WINDOW_GLOBAL,
                                       n_per_class=N_PER_CLASS_PER_SUBJECT)

    # Sujetos presentes tras balanceo
    subj_in_trials = sorted({t[2] for t in trials_bal})
    print(f"Sujetos con ensayos balanceados: {len(subj_in_trials)}")

    # 2) Leer folds SOLO por sujetos
    folds = read_folds_subjects_only(FOLDS_JSON, current_subjects=subj_in_trials)

    # 3) Loop de folds
    global_folds_trial, all_true_trial, all_pred_trial, ft_prog_folds = [], [], [], []

    for f in folds:
        fold = f['fold']
        train_subs = set(f['train_subjects'])
        test_subs  = set(f['test_subjects'])

        # Split por sujeto (antes de windowing): seleccionamos ensayos (trial-level)
        tr_trials = [t for t in trials_bal if t[2] in train_subs]
        te_trials = [t for t in trials_bal if t[2] in test_subs]

        # Sanity: cada clase presente en TEST por ensayo (antes de windowing)
        lbl_te = [t[1] for t in te_trials]
        present = {0,1,2,3}.intersection(set(lbl_te))
        if present != {0,1,2,3}:
            print(f"[WARN] Fold {fold}: clases en TEST (trial-level) = {sorted(present)}; se salta este fold.")
            continue

        # 3a) Split interno de VALID por sujeto dentro de TRAIN (sin fuga de sujetos)
        train_subs_list = sorted(list(train_subs))
        rng = np.random.RandomState(RANDOM_STATE + fold)
        rng.shuffle(train_subs_list)
        n_va = max(1, int(round(GLOBAL_VAL_SPLIT * len(train_subs_list))))
        val_subs = set(train_subs_list[:n_va])
        real_train_subs = set(train_subs_list[n_va:])

        tr_trials_main = [t for t in tr_trials if t[2] in real_train_subs]
        va_trials      = [t for t in tr_trials if t[2] in val_subs]

        # 4) Windowing por split (sin fuga entre ensayos)
        Xtr, ytr, gtr, ttr, _ = window_trials(trials=tr_trials_main, win_len_sec=WIN_LEN_SEC, win_stride_sec=WIN_STRIDE_SEC, fs=FS)
        Xva, yva, gva, tva, _ = window_trials(trials=va_trials,      win_len_sec=WIN_LEN_SEC, win_stride_sec=WIN_STRIDE_SEC, fs=FS)
        Xte, yte, gte, tte, _ = window_trials(trials=te_trials,      win_len_sec=WIN_LEN_SEC, win_stride_sec=WIN_STRIDE_SEC, fs=FS)

        # DataLoaders
        ds_tr = EEGTrials(Xtr, ytr, gtr, ttr)
        ds_va = EEGTrials(Xva, yva, gva, tva)
        ds_te = EEGTrials(Xte, yte, gte, tte)

        tr_loader = DataLoader(ds_tr, batch_size=BATCH_SIZE,
                               sampler=build_weighted_sampler(ytr, gtr), drop_last=False)
        va_loader = DataLoader(ds_va, batch_size=BATCH_SIZE, shuffle=False, drop_last=False)
        te_loader = DataLoader(ds_te, batch_size=BATCH_SIZE, shuffle=False, drop_last=False)

        # 5) Modelo
        Tw, C = Xtr.shape[1], Xtr.shape[2]
        n_classes = len(set(ytr.tolist()) | set(yva.tolist()) | set(yte.tolist()))
        model = HybridCNNTransformer(
            n_ch=C, n_classes=n_classes, seq_len=Tw,
            d_model=128, nhead=4, depth=3, dim_ff=256,
            k_temporal=128, stride_t=4, drop=0.2, chdrop_p=0.10, use_cls_token=False
        ).to(DEVICE)

        opt = torch.optim.Adam(model.parameters(), lr=LR_INIT)
        criterion = SoftCE(n_classes=n_classes, label_smoothing=0.05, class_weights=None)

        def _acc_trial(loader): return evaluate_per_trial(model, loader, use_tta=False, n_classes=n_classes)[2]

        # 6) Entrenamiento (early stopping por ENSAYO)
        history = {'train_acc': [], 'val_acc': []}
        print(f"\n[Fold {fold}/{len(folds)}] Entrenando global (VAL por ensayo)...")
        best_state = copy.deepcopy(model.state_dict()); best_val = -1.0; bad = 0

        for epoch in range(1, EPOCHS_GLOBAL+1):
            model.train()
            for xb, yb, _, _ in tr_loader:
                xb, yb = xb.to(DEVICE), yb.to(DEVICE)
                loss = criterion(model(xb), yb)
                opt.zero_grad(set_to_none=True); loss.backward(); opt.step()
                apply_max_norm(model, 2.0, p=2.0)

            tr_acc = _acc_trial(tr_loader)
            va_acc = _acc_trial(va_loader)
            history['train_acc'].append(tr_acc); history['val_acc'].append(va_acc)

            if epoch % 5 == 0 or epoch in (1,10,20,50,80):
                print(f"  Época {epoch:3d} | train(ensayo)={tr_acc:.4f} | val(ensayo)={va_acc:.4f}")

            if va_acc > best_val + 1e-4:
                best_val, bad = va_acc, 0
                best_state = copy.deepcopy(model.state_dict())
            else:
                bad += 1
                if bad >= GLOBAL_PATIENCE:
                    print(f"  Early stopping @ epoch {epoch} (best val ensayo={best_val:.4f})")
                    break

        plot_training_curves(history, f"training_curve_cnnTrans_subjFolds_fold{fold}.png")
        print(f"↳ Curva: training_curve_cnnTrans_subjFolds_fold{fold}.png")
        model.load_state_dict(best_state)

        # 7) Test por ENSAYO (TTA)
        y_true_tr, y_pred_tr, acc_tr = evaluate_per_trial(model, te_loader, use_tta=True, tta_n=5, n_classes=n_classes)
        global_folds_trial.append(acc_tr)
        all_true_trial.append(y_true_tr); all_pred_trial.append(y_pred_tr)
        print(f"[Fold {fold}] Ensayo acc={acc_tr:.4f}")
        print_report(y_true_tr, y_pred_tr, CLASS_NAMES, tag="por ENSAYO (agregado)")

        # 8) FT por sujeto (sin fuga, agrupando por ensayo)
        #    Trabajamos dentro del set de TEST (por cada sujeto en test_subs)
        y_true_ft_all, y_pred_ft_all, used_subjects = [], [], 0

        # Para comodidad, indices por sujeto dentro de Xte
        trial_to_label = {}
        for idx in range(len(Xte)):
            trial_to_label[int(tte[idx])] = int(yte[idx])

        for sid in sorted(test_subs):
            mask_sid = (gte == sid)
            if mask_sid.sum() == 0: continue
            Xs, ys, ts = Xte[mask_sid], yte[mask_sid], tte[mask_sid]

            # GroupKFold por ensayo (sin fuga)
            if len(np.unique(ts)) < CALIB_CV_FOLDS:
                continue
            gkf = GroupKFold(n_splits=CALIB_CV_FOLDS)
            for tr_i, ho_i in gkf.split(Xs, ys, groups=ts):
                Xcal, ycal, tcal = Xs[tr_i], ys[tr_i], ts[tr_i]
                Xho,  yho,  tho  = Xs[ho_i], ys[ho_i], ts[ho_i]

                m_head = finetune_l2sp(model, Xcal, ycal, tcal,
                                       n_classes=n_classes, mode='head',
                                       epochs=FT_EPOCHS, base_lr=FT_BASE_LR, head_lr=FT_HEAD_LR,
                                       l2sp_lambda=FT_L2SP, val_ratio=FT_VAL_RATIO, device=DEVICE)

                with torch.no_grad():
                    xb = torch.from_numpy(Xho).float().unsqueeze(1).to(DEVICE)
                    logits = m_head(xb).cpu().numpy()

                sums = collections.defaultdict(lambda: np.zeros((n_classes,), dtype=np.float64))
                counts = collections.defaultdict(int)
                labels = {}
                for i in range(len(Xho)):
                    t = int(tho[i]); sums[t] += logits[i]; counts[t] += 1; labels[t] = int(yho[i])
                for t in sums.keys():
                    y_pred_ft_all.append(int((sums[t]/max(1,counts[t])).argmax()))
                    y_true_ft_all.append(labels[t])
                used_subjects += 1

        if len(y_true_ft_all) > 0:
            y_true_ft_all = np.asarray(y_true_ft_all, dtype=int)
            y_pred_ft_all = np.asarray(y_pred_ft_all, dtype=int)
            acc_ft = (y_true_ft_all == y_pred_ft_all).mean()
            print(f"  FT progresivo por ENSAYO acc={acc_ft:.4f} | sujetos={used_subjects}")
            print(f"  Δ(FT-Ensayo - Global-Ensayo) = {acc_ft - acc_tr:+.4f}")
        else:
            acc_ft = np.nan
            print("  FT no ejecutado (fold sin sujetos aptos).")
        ft_prog_folds.append(acc_ft)

    # ----- resultados globales -----
    if len(all_true_trial)>0: all_true_trial = np.concatenate(all_true_trial); all_pred_trial = np.concatenate(all_pred_trial)
    else: all_true_trial = np.array([],dtype=int); all_pred_trial = np.array([],dtype=int)

    print("\n" + "="*60)
    print("RESULTADOS FINALES — CNN+Transformer (21/cls/subj) + Sliding Window + Kfold SUBJECTS-only")
    print("="*60)
    print("Ensayo folds:", [f"{a:.4f}" for a in global_folds_trial])
    if len(global_folds_trial) > 0:
        print(f"Ensayo mean: {np.mean(global_folds_trial):.4f}")
    print("FT prog (ensayo) folds:", [("nan" if (a is None or np.isnan(a)) else f"{a:.4f}") for a in ft_prog_folds])
    if len(ft_prog_folds) > 0:
        print(f"FT (ensayo) mean: {np.nanmean(ft_prog_folds):.4f}")
        print(f"Δ(FT-Ensayo - Global-Ensayo) mean: {np.nanmean(ft_prog_folds) - np.mean(global_folds_trial):+.4f}")

    if all_true_trial.size > 0:
        plot_confusion(all_true_trial, all_pred_trial, CLASS_NAMES,
                       title="Confusion Matrix - Global Model (Per Trial, All Folds)",
                       fname="confusion_cnnTrans_subjectFolds_trial_allfolds.png")
        print("↳ Matriz de confusión (ensayo): confusion_cnnTrans_subjectFolds_trial_allfolds.png")

# ---------- MAIN ----------
if __name__ == "__main__":
    bp_status = f"ON [{BP_LO}-{BP_HI} Hz]" if USE_BANDPASS and (BP_LO is not None) and (BP_HI is not None) else "OFF"
    print("🧠 MI-EEG — CNN+Transformer (21/cls/subj) + Sliding Window + FT progresivo")
    print(f"🔧 Escenario: {CLASS_SCENARIO} | rango ensayo={WINDOW_GLOBAL} s | subventana={WIN_LEN_SEC}s | stride={WIN_STRIDE_SEC}s | subwins/ensayo={N_SUBWINS} | band-pass={bp_status}")
    print(f"⚙️  Balance: 21 ensayos por clase y por sujeto (con reemplazo si falta)")
    print(f"⚙️  FT (sin fuga): GroupKFold por ensayo + GroupShuffleSplit en validación interna")
    run_experiment()


🚀 Dispositivo: cuda
🧠 MI-EEG — CNN+Transformer (21/cls/subj) + Sliding Window + FT progresivo
🔧 Escenario: 4c | rango ensayo=(0.0, 4.0) s | subventana=2.0s | stride=0.5s | subwins/ensayo=5 | band-pass=ON [4.0-38.0 Hz]
⚙️  Balance: 21 ensayos por clase y por sujeto (con reemplazo si falta)
⚙️  FT (sin fuga): GroupKFold por ensayo + GroupShuffleSplit en validación interna
Sujetos elegibles: 103 → [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]...


Recolectando y balanceando ensayos base por sujeto: 100%|██████████| 103/103 [00:39<00:00,  2.63it/s]


Ensayos base balanceados: 8652  (~ #sujetos_utiles * 84)
Sujetos con ensayos balanceados: 103





[Fold 1/5] Entrenando global (VAL por ensayo)...
  Época   1 | train(ensayo)=0.2879 | val(ensayo)=0.2688
  Época   5 | train(ensayo)=0.3847 | val(ensayo)=0.3313
  Época  10 | train(ensayo)=0.5178 | val(ensayo)=0.3403
  Época  15 | train(ensayo)=0.6275 | val(ensayo)=0.3700
  Época  20 | train(ensayo)=0.6860 | val(ensayo)=0.3750
  Época  25 | train(ensayo)=0.7478 | val(ensayo)=0.3472
  Early stopping @ epoch 28 (best val ensayo=0.3869)
↳ Curva: training_curve_cnnTrans_subjFolds_fold1.png
[Fold 1] Ensayo acc=0.2914

=== Reporte por ENSAYO (agregado) ===
              precision    recall  f1-score   support

        Left     0.3232    0.1202    0.1752       441
       Right     0.2875    0.3129    0.2997       441
  Both Fists     0.2755    0.5941    0.3764       441
   Both Feet     0.3609    0.1383    0.2000       441

    accuracy                         0.2914      1764
   macro avg     0.3118    0.2914    0.2628      1764
weighted avg     0.3118    0.2914    0.2628      1764

F1-macr




[Fold 2/5] Entrenando global (VAL por ensayo)...
  Época   1 | train(ensayo)=0.3022 | val(ensayo)=0.2917
  Época   5 | train(ensayo)=0.3604 | val(ensayo)=0.3165
  Época  10 | train(ensayo)=0.5625 | val(ensayo)=0.3244
  Época  15 | train(ensayo)=0.6230 | val(ensayo)=0.3185
  Época  20 | train(ensayo)=0.7325 | val(ensayo)=0.3313
  Early stopping @ epoch 23 (best val ensayo=0.3522)
↳ Curva: training_curve_cnnTrans_subjFolds_fold2.png
[Fold 2] Ensayo acc=0.3316

=== Reporte por ENSAYO (agregado) ===
              precision    recall  f1-score   support

        Left     0.3186    0.5215    0.3955       441
       Right     0.3562    0.1882    0.2463       441
  Both Fists     0.3160    0.3855    0.3473       441
   Both Feet     0.3764    0.2313    0.2865       441

    accuracy                         0.3316      1764
   macro avg     0.3418    0.3316    0.3189      1764
weighted avg     0.3418    0.3316    0.3189      1764

F1-macro: 0.3189 | Kappa: 0.1088
  FT progresivo por ENSAYO acc




[Fold 3/5] Entrenando global (VAL por ensayo)...
  Época   1 | train(ensayo)=0.2892 | val(ensayo)=0.2579
  Época   5 | train(ensayo)=0.4379 | val(ensayo)=0.3065
  Época  10 | train(ensayo)=0.5562 | val(ensayo)=0.2897
  Época  15 | train(ensayo)=0.6463 | val(ensayo)=0.2827
  Early stopping @ epoch 16 (best val ensayo=0.3165)
↳ Curva: training_curve_cnnTrans_subjFolds_fold3.png
[Fold 3] Ensayo acc=0.3413

=== Reporte por ENSAYO (agregado) ===
              precision    recall  f1-score   support

        Left     0.3236    0.2268    0.2667       441
       Right     0.3617    0.3469    0.3542       441
  Both Fists     0.3053    0.4036    0.3477       441
   Both Feet     0.3808    0.3878    0.3843       441

    accuracy                         0.3413      1764
   macro avg     0.3429    0.3413    0.3382      1764
weighted avg     0.3429    0.3413    0.3382      1764

F1-macro: 0.3382 | Kappa: 0.1217


KeyboardInterrupt: 