# Setup

In [None]:
# ================================================
# üì¶ INSTALACI√ìN Y CONFIGURACI√ìN COMPLETA
# ================================================

import subprocess
import sys

print("üöÄ CONFIGURANDO ENTORNO COMPLETO")
print("="*80)

# ================================================
# PASO 1: INSTALAR DEPENDENCIAS
# ================================================
print("\n[1/4] Instalando dependencias principales...")
subprocess.check_call([
    sys.executable, '-m', 'pip', 'install', '-q', 
    'mne', 'pyedflib', 'numpy', 'pandas', 'scipy', 'tqdm', 'ipywidgets'
])
print("   ‚úÖ Dependencias instaladas")

# ================================================
# PASO 2: CONFIGURAR IPYWIDGETS PARA BARRAS
# ================================================
print("\n[2/4] Configurando ipywidgets para barras de progreso...")
try:
    subprocess.check_call([
        sys.executable, '-m', 'pip', 'install', '-q', '--upgrade', 'ipywidgets'
    ])
    print("   ‚úÖ ipywidgets actualizado")
    
    # Intentar habilitar extensi√≥n
    try:
        subprocess.check_call([
            sys.executable, '-m', 'jupyter', 'nbextension', 
            'enable', '--py', 'widgetsnbextension', '--sys-prefix'
        ])
        print("   ‚úÖ Extensi√≥n de Jupyter habilitada")
    except:
        print("   ‚ö†Ô∏è  Extensi√≥n no habilitada (puede no ser necesario)")
        
except Exception as e:
    print(f"   ‚ö†Ô∏è  Advertencia: {e}")

# ================================================
# PASO 3: IMPORTAR LIBRER√çAS
# ================================================
print("\n[3/4] Importando librer√≠as...")

# Imports principales
import mne
import pandas as pd
import numpy as np
import pickle
import multiprocessing as mp
import time
from pathlib import Path
from datetime import datetime, timedelta
from collections import defaultdict
import warnings

# Configurar salida para Jupyter
sys.stdout.flush()
from IPython.display import clear_output, display

# Silenciar warnings molestos
warnings.filterwarnings("ignore", category=RuntimeWarning)

print("   ‚úÖ Librer√≠as importadas")

# ================================================
# PASO 4: CONFIGURAR TQDM
# ================================================
print("\n[4/4] Configurando barras de progreso...")

# Intentar usar tqdm.notebook (barras visuales)
try:
    from tqdm.notebook import tqdm
    TQDM_DISPONIBLE = "notebook"
    print("   ‚úÖ tqdm.notebook disponible (barras visuales)")
    
    # Hacer prueba r√°pida
    print("\n   üß™ Probando barra de progreso:")
    for _ in tqdm(range(3), desc="   Test", leave=False):
        time.sleep(0.3)
    print("   ‚úÖ Barras funcionando correctamente!")
    
except:
    try:
        from tqdm import tqdm
        TQDM_DISPONIBLE = "texto"
        print("   ‚úÖ tqdm disponible (barras de texto)")
    except:
        TQDM_DISPONIBLE = "manual"
        print("   ‚ö†Ô∏è  tqdm no disponible, se usar√°n barras manuales")
        
        # Crear clase tqdm manual como fallback
        class tqdm:
            def __init__(self, iterable=None, total=None, desc="", leave=True, **kwargs):
                self.iterable = iterable if iterable is not None else range(total)
                self.total = total or (len(iterable) if iterable else 0)
                self.desc = desc
                self.n = 0
                self.start = time.time()
                self.leave = leave
            
            def __iter__(self):
                for item in self.iterable:
                    yield item
                    self.update(1)
                if self.leave:
                    self.close()
            
            def update(self, n=1):
                self.n += n
                elapsed = time.time() - self.start
                pct = (self.n / self.total * 100) if self.total > 0 else 0
                rate = self.n / elapsed if elapsed > 0 else 0
                eta = (self.total - self.n) / rate if rate > 0 else 0
                
                bar_len = 40
                filled = int(bar_len * self.n / self.total) if self.total > 0 else 0
                bar = '‚ñà' * filled + '‚ñë' * (bar_len - filled)
                
                sys.stdout.write(
                    f'\r{self.desc}: |{bar}| {self.n}/{self.total} '
                    f'[{pct:.1f}%] [{elapsed:.0f}s<{eta:.0f}s, {rate:.2f}it/s]'
                )
                sys.stdout.flush()
            
            def close(self):
                sys.stdout.write('\n')
                sys.stdout.flush()

# ================================================
# RESUMEN DE CONFIGURACI√ìN
# ================================================
print("\n" + "="*80)
print("‚úÖ CONFIGURACI√ìN COMPLETADA")
print("="*80)
print(f"\nüìã Informaci√≥n del sistema:")
print(f"   ‚Ä¢ Python:      {sys.version.split()[0]}")
print(f"   ‚Ä¢ MNE:         {mne.__version__}")
print(f"   ‚Ä¢ pandas:      {pd.__version__}")
print(f"   ‚Ä¢ numpy:       {np.__version__}")
print(f"   ‚Ä¢ Cores:       {mp.cpu_count()}")
print(f"   ‚Ä¢ Tqdm:        {TQDM_DISPONIBLE}")
print()
print("="*80)

## Configuraci√≥n

In [None]:
from pathlib import Path
import warnings
warnings.filterwarnings("ignore", category=RuntimeWarning)

# === CONFIGURACI√ìN ===
RAW_DIR = Path(r"C:\Users\Mart√≠n\Desktop\TransporteProyectoIC\TransporteProyectoIC\sleep-edf-database-expanded-1.0.0\sleep-edf-database-expanded-1.0.0\sleep-cassette")
OUTPUT_DIR = RAW_DIR / "ventanas_out"
WINDOWS_DIR = OUTPUT_DIR / "ventanas_extraidas"
ANALYSIS_DIR = OUTPUT_DIR / "analisis_canales"

OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
WINDOWS_DIR.mkdir(exist_ok=True)
ANALYSIS_DIR.mkdir(exist_ok=True)

# Par√°metros
WINDOW_SIZE = 30.0
OVERLAP = 15.0
STRIDE = WINDOW_SIZE - OVERLAP
N_WORKERS = max(1, int(mp.cpu_count() * 0.75))  # Usa 75% de cores

print(f"üìÅ RAW_DIR:      {RAW_DIR}")
print(f"üìÅ OUTPUT_DIR:   {OUTPUT_DIR}")
print(f"‚öôÔ∏è  Ventana={WINDOW_SIZE}s | Overlap={OVERLAP}s")
print(f"‚ö° Workers={N_WORKERS}/{mp.cpu_count()}")

## Funciones auxiliares

In [None]:
from datetime import datetime, timedelta
from collections import defaultdict

def key_from_psg(name: str) -> str:
    return Path(name).name.split('-')[0][:7]

def key_from_hyp(name: str) -> str:
    return Path(name).name.split('-')[0][:7]

def encontrar_pares(raw_dir: Path):
    psgs, hyps = {}, {}
    for fp in raw_dir.glob("*.edf"):
        nm = fp.name
        if not nm.startswith("SC"):
            continue
        if nm.endswith("-PSG.edf"):
            psgs[key_from_psg(nm)] = fp
        elif nm.endswith("-Hypnogram.edf"):
            hyps[key_from_hyp(nm)] = fp
    keys = sorted(set(psgs) & set(hyps))
    return [(k, psgs[k], hyps[k]) for k in keys]

def leer_hypnograma_mne(hyp_path: Path) -> pd.DataFrame:
    ann = mne.read_annotations(str(hyp_path))
    etapas = []
    for desc, onset, dur in zip(ann.description, ann.onset, ann.duration):
        if "Sleep stage" in desc:
            st = desc.replace("Sleep stage", "").strip()
            if st in {"W", "1", "2", "3", "4", "R"}:
                etapas.append({
                    "inicio": float(onset),
                    "duracion": float(dur),
                    "etapa": st
                })
    df = pd.DataFrame(etapas)
    if df.empty:
        raise RuntimeError(f"Hipnograma vac√≠o: {hyp_path}")
    return df

def _to_datetime(x):
    if x is None:
        return None
    if isinstance(x, tuple) and len(x) == 2:
        return datetime.fromtimestamp(x[0]) + timedelta(microseconds=x[1])
    try:
        return pd.to_datetime(x).to_pydatetime()
    except Exception:
        return x

def calcular_offset_segundos(psg_path: Path, hyp_path: Path) -> float:
    raw_psg = mne.io.read_raw_edf(str(psg_path), preload=False, verbose=False)
    raw_hyp = mne.io.read_raw_edf(str(hyp_path), preload=False, verbose=False)
    t_psg = _to_datetime(raw_psg.info.get('meas_date'))
    t_hyp = _to_datetime(raw_hyp.info.get('meas_date'))
    if t_psg is None or t_hyp is None:
        return 0.0
    return (t_hyp - t_psg).total_seconds()

def analizar_canales(psg_path: Path):
    raw = mne.io.read_raw_edf(str(psg_path), preload=False, verbose=False)
    info = {}
    for ch in raw.ch_names:
        ch_up = ch.upper()
        if 'EEG' in ch_up:
            tipo = 'EEG'
        elif 'EOG' in ch_up:
            tipo = 'EOG'
        elif 'EMG' in ch_up:
            tipo = 'EMG'
        elif 'ECG' in ch_up or 'EKG' in ch_up:
            tipo = 'ECG'
        elif 'EVENT' in ch_up or 'MARKER' in ch_up:
            tipo = 'EVENTO'
        else:
            tipo = 'OTRO'
        
        info[ch] = {
            'tipo': tipo,
            'freq': float(raw.info['sfreq']),
            'n_samples': int(raw.n_times),
            'duracion': float(raw.times[-1]) if raw.n_times > 0 else 0.0
        }
    return info

print("‚úÖ Funciones auxiliares cargadas")



# Funci√≥n para extraer ventanas de un canal espec√≠fico

def extraer_ventanas_por_canal(psg_path, hyp_df, canal_nombre, window_size, stride, hyp_offset=0.0):
    raw = mne.io.read_raw_edf(str(psg_path), preload=True, verbose=False)
    data, _ = raw[canal_nombre, :]
    x = data.flatten()
    fs = float(raw.info['sfreq'])
    
    win_samps = int(round(window_size * fs))
    stride_samp = int(round(stride * fs))
    
    if win_samps <= 0 or stride_samp <= 0 or len(x) < win_samps:
        return {
            'ventanas': np.empty((0, 0)),
            'etiquetas': [],
            'tiempos_inicio': [],
            'freq_muestreo': fs,
            'nombre_canal': canal_nombre
        }
    
    starts = hyp_df['inicio'].to_numpy(dtype=float) + hyp_offset
    ends = (hyp_df['inicio'] + hyp_df['duracion']).to_numpy(dtype=float) + hyp_offset
    intervals = pd.IntervalIndex.from_arrays(starts, ends, closed='left')
    
    n_vent = 1 + (len(x) - win_samps) // stride_samp
    ventanas = np.empty((n_vent, win_samps), dtype=x.dtype)
    etiquetas, tiempos_inicio = [], []
    keep = np.ones(n_vent, dtype=bool)
    
    for i in range(n_vent):
        s = i * stride_samp
        e = s + win_samps
        if e > len(x):
            keep[i] = False
            continue
        
        ventanas[i, :] = x[s:e]
        t_ini = i * stride
        t_mid = t_ini + window_size / 2.0
        idx = intervals.get_indexer([t_mid])[0]
        
        if idx == -1:
            keep[i] = False
        else:
            etiquetas.append(hyp_df.iloc[idx]['etapa'])
            tiempos_inicio.append(t_ini)
    
    ventanas = ventanas[keep]
    
    return {
        'ventanas': ventanas,
        'etiquetas': etiquetas,
        'tiempos_inicio': tiempos_inicio,
        'freq_muestreo': fs,
        'nombre_canal': canal_nombre
    }

print("‚úÖ Funci√≥n de extracci√≥n lista")

# Cargar datos

In [None]:
# ================================================
# CELDA: Cargar y visualizar ventanas (.npz o .pkl)
# ================================================
import numpy as np
import pickle
from collections import Counter
import matplotlib.pyplot as plt

ID2LABEL = {0:"W", 1:"N1", 2:"N2", 3:"N3", 4:"REM"}

def _load_npz(path):
    d = np.load(path, allow_pickle=False)
    # Estructura esperada de la Celda 6 TURBO (np.savez_compressed):
    # X (n_vent, n_samps) float16, y (n_vent) uint8, t (n_vent) float32, fs float32, canal str
    out = {
        "ventanas": d["X"].astype(np.float32, copy=False),   # para graficar m√°s c√≥modo
        "etiquetas": d["y"].astype(np.uint8, copy=False),
        "tiempos_inicio": d["t"].astype(np.float32, copy=False),
        "freq_muestreo": float(d["fs"]),
        "nombre_canal": str(d["canal"])
    }
    return out

def _load_pkl(path):
    with open(path, "rb") as f:
        data = pickle.load(f)
    # {'ventanas', 'etiquetas', 'tiempos_inicio', 'freq_muestreo', 'nombre_canal'}
    # Si las etiquetas vienen como strings, convertimos a IDs para homogeneizar
    if data and isinstance(data.get("etiquetas", []), list) and data["etiquetas"] and isinstance(data["etiquetas"][0], str):
        label_map = {"W":0, "1":1, "N1":1, "2":2, "N2":2, "3":3, "4":3, "N3":3, "R":4, "REM":4}
        y = np.array([label_map.get(s, 255) for s in data["etiquetas"]], dtype=np.uint8)
    else:
        y = np.array(data.get("etiquetas", []), dtype=np.uint8)
    out = {
        "ventanas": np.asarray(data.get("ventanas", []), dtype=np.float32),
        "etiquetas": y,
        "tiempos_inicio": np.asarray(data.get("tiempos_inicio", []), dtype=np.float32),
        "freq_muestreo": float(data.get("freq_muestreo", 100.0)),
        "nombre_canal": data.get("nombre_canal", "CANAL")
    }
    return out

def cargar_ventanas(paciente: str, canal: str, return_ids: bool=False, mmap_npz: bool=True):
    base = WINDOWS_DIR / f"{paciente}_{canal.replace(' ', '_')}"
    npz_path, pkl_path = base.with_suffix(".npz"), base.with_suffix(".pkl")

    if npz_path.exists():
        d = np.load(npz_path, allow_pickle=False, mmap_mode='r' if mmap_npz else None)
        X = d["X"].astype(np.float32, copy=False)
        y = d["y"].astype(np.uint8,  copy=False)
        t = d["t"].astype(np.float32, copy=False)
        fs = float(d["fs"]); canal_name = str(d["canal"])
        fmt = ".npz"
    elif pkl_path.exists():
        with open(pkl_path, "rb") as f:
            raw = pickle.load(f)
        map_ = {"W":0,"1":1,"N1":1,"2":2,"N2":2,"3":3,"4":3,"N3":3,"R":4,"REM":4}
        y = np.array([map_.get(s,255) for s in raw["etiquetas"]], dtype=np.uint8) \
            if raw.get("etiquetas") and isinstance(raw["etiquetas"][0], str) else np.asarray(raw["etiquetas"], np.uint8)
        X = np.asarray(raw["ventanas"], dtype=np.float32)
        t = np.asarray(raw["tiempos_inicio"], dtype=np.float32)
        fs = float(raw.get("freq_muestreo", 100.0)); canal_name = raw.get("nombre_canal","CANAL")
        fmt = ".pkl"
    else:
        print(f"‚ùå No se encontr√≥ ni {npz_path.name} ni {pkl_path.name}")
        return None

    ID2LABEL = {0:"W",1:"N1",2:"N2",3:"N3",4:"REM"}
    n = X.shape[0]; total = max(1, n)
    if return_ids:
        y_out = y
        dist_keys = [ID2LABEL.get(int(v), f"id{int(v)}") for v in y]
    else:
        y_out = np.array([ID2LABEL.get(int(v), f"id{int(v)}") for v in y], dtype=object)
        dist_keys = y_out

    from collections import Counter
    dist = Counter(dist_keys)
    print(f"‚úÖ Cargado ({fmt}): {paciente} - {canal_name}")
    print(f"   ‚Ä¢ Ventanas: {X.shape} | Fs: {fs} Hz")
    print(f"   ‚Ä¢ Distribuci√≥n:")
    for k,c in sorted(dist.items()):
        print(f"     {k}: {c} ({100.0*c/total:.1f}%)")

    return {"ventanas": X, "etiquetas": y_out, "tiempos_inicio": t,
            "freq_muestreo": fs, "nombre_canal": canal_name}


def visualizar_ventanas(data: dict, indices=[0, 1, 2], figsize=(15, 8)):
    """
    Visualiza ventanas espec√≠ficas de un dataset cargado con cargar_ventanas().
    """
    if data is None:
        print("‚ö†Ô∏è 'data' es None")
        return

    X = data['ventanas']
    etiquetas = data['etiquetas']
    fs = float(data['freq_muestreo'])

    if X is None or len(X) == 0:
        print("‚ö†Ô∏è No hay ventanas para mostrar.")
        return

    idx_validos = [i for i in indices if 0 <= i < len(X)]
    if not idx_validos:
        print("‚ö†Ô∏è √çndices fuera de rango.")
        return

    n_plots = len(idx_validos)
    fig, axes = plt.subplots(n_plots, 1, figsize=figsize)
    if n_plots == 1:
        axes = [axes]

    for ax, idx in zip(axes, idx_validos):
        ventana = X[idx]
        t = np.arange(len(ventana)) / fs
        ax.plot(t, ventana, linewidth=0.8)
        ax.set_title(f"Ventana {idx} ‚Äî Etapa: {etiquetas[idx]}", fontweight='bold')
        ax.set_xlabel("Tiempo (s)")
        ax.set_ylabel("Amplitud")
        ax.grid(True, alpha=0.3)
        ax.set_xlim(0, t[-1] if len(t) else 1)

    plt.tight_layout()
    plt.show()

print("‚úÖ Funciones de carga/visualizaci√≥n listas (compatibles con .npz y .pkl)")

In [None]:
import pandas as pd
import numpy as np
from collections import Counter
from contextlib import redirect_stdout
import io

# Cargar resumen global
df = pd.read_csv(ANALYSIS_DIR / "resumen_global.csv")

pares = (
    df[["Paciente", "Canal"]]
    .dropna()
    .drop_duplicates()
    .reset_index(drop=True)
)

conteo_global = Counter()

def extraer_etiquetas(data):
    """Devuelve una lista 1D de etiquetas como str, sin or booleano ni prints."""
    # (X, y)
    if isinstance(data, (list, tuple)) and len(data) == 2:
        y = data[1]
    # dict con claves comunes
    elif isinstance(data, dict):
        y = None
        for k in ("etiquetas", "labels", "y", "stage", "stages", "etapa", "etapas"):
            if k in data:
                y = data[k]
                break
        if y is None:
            return None
    # DataFrame
    elif isinstance(data, pd.DataFrame):
        y = None
        for k in ("etiquetas", "labels", "y", "stage", "stages", "etapa", "etapas"):
            if k in data.columns:
                y = data[k].values
                break
        if y is None:
            # fallback: √∫ltima columna
            y = data.iloc[:, -1].values
    # Serie
    elif isinstance(data, pd.Series):
        y = data.values
    else:
        return None

    y = np.asarray(y).ravel()   # asegurar 1D
    y = [str(e) for e in y]     # a texto
    return y

# Recorrer todos los pacientes/canales 
for _, row in pares.iterrows():
    paciente = row["Paciente"]
    canal = row["Canal"]

    # Silenciar prints internos de cargar_ventanas (si los hubiera)
    sink = io.StringIO()
    try:
        with redirect_stdout(sink):
            data = cargar_ventanas(paciente, canal)
    except Exception:
        continue

    if data is None:
        continue

    etiquetas = extraer_etiquetas(data)
    if not etiquetas:
        continue

    conteo_global.update(etiquetas)

# C√°lculo y √∫nico print final
total = sum(conteo_global.values())
if total == 0:
    print("No se encontraron etiquetas para calcular proporciones.")
else:
    # Orden est√°ndar de sue√±o primero, luego cualquier extra que aparezca
    orden_std = ["W", "N1", "N2", "N3", "REM"]
    extras = sorted([e for e in conteo_global.keys() if e not in orden_std])
    orden_final = orden_std + extras

    # √önico output:
    print(" | ".join([f"{etapa}: {conteo_global.get(etapa, 0) / total * 100:.2f}%" for etapa in orden_final]))

## Exportar a otros formatos

In [None]:
# ================================================
# CELDA: Exportar a NumPy (versi√≥n en memoria)
# ================================================
import numpy as np
import pandas as pd
from pathlib import Path

LABEL2ID = {"W":0, "N1":1, "N2":2, "N3":3, "REM":4}
ID2LABEL = {v:k for k,v in LABEL2ID.items()}

def exportar_a_numpy_mem(paciente: str, canal: str, dtype="float32"):
    """
    Exporta un (paciente, canal) a arrays NumPy en memoria.
    Retorna (X, y, meta) sin guardar a disco.
    """
    data = cargar_ventanas(paciente, canal)
    if data is None:
        print(f"‚ö†Ô∏è No se pudo cargar {paciente}-{canal}")
        return None, None, None

    # Ventanas y etiquetas
    X = np.asarray(data["ventanas"], dtype=np.float32)
    if dtype == "float16":
        X = X.astype(np.float16, copy=False)

    y_in = data["etiquetas"]
    if np.issubdtype(np.array(y_in).dtype, np.integer):
        y = np.array(y_in, dtype=np.uint8)
    else:
        y = np.array([LABEL2ID.get(str(lbl), 255) for lbl in y_in], dtype=np.uint8)

    meta = {
        "fs": float(data["freq_muestreo"]),
        "canal": data["nombre_canal"],
        "paciente": paciente,
        "shape": tuple(X.shape),
        "dtype": str(X.dtype),
        "label_map": LABEL2ID
    }

    print(f"‚úÖ Exportado en memoria: {paciente}-{canal}")
    print(f"   ‚Ä¢ X shape: {X.shape} ({X.dtype})")
    print(f"   ‚Ä¢ y √∫nicos: {sorted(np.unique(y))}")
    return X, y, meta

# ===================== Ejemplo de uso =====================
resumen_csv = ANALYSIS_DIR / "resumen_global.csv"
df = pd.read_csv(resumen_csv)

eeg_row = df[df["Canal"].str.contains("EEG", case=False)].iloc[0]
X_mem, y_mem, meta_mem = exportar_a_numpy_mem(eeg_row["Paciente"], eeg_row["Canal"], dtype="float32")


## Verificaci√≥n de integridad de datos

In [None]:
# ================================================
# CELDA: Verificaci√≥n de integridad de archivos
# ================================================
import numpy as np
import pickle
from pathlib import Path

def verificar_integridad(window_dir: Path = WINDOWS_DIR):
    """
    Verifica que todos los archivos de ventanas (.npz o .pkl)
    sean v√°lidos y consistentes.
    """
    print("üîç VERIFICACI√ìN DE INTEGRIDAD DE ARCHIVOS\n")

    archivos_npz = list(window_dir.glob("*.npz"))
    archivos_pkl = list(window_dir.glob("*.pkl"))
    total_archivos = len(archivos_npz) + len(archivos_pkl)
    print(f"üìÅ Directorio: {window_dir}")
    print(f"   ‚Ä¢ .npz encontrados: {len(archivos_npz)}")
    print(f"   ‚Ä¢ .pkl encontrados: {len(archivos_pkl)}")
    print(f"   ‚Ä¢ Total archivos:  {total_archivos}\n")

    errores = []
    validos = 0
    total_ventanas = 0

    for archivo in archivos_npz + archivos_pkl:
        try:
            if archivo.suffix == ".npz":
                data = np.load(archivo, allow_pickle=False)
                # verificar claves esperadas
                for k in ["X", "y", "t", "fs", "canal"]:
                    assert k in data.keys(), f"Falta clave '{k}'"
                n_ventanas = data["X"].shape[0]
                assert n_ventanas == len(data["y"]) == len(data["t"]), "Longitudes inconsistentes"
                assert np.isfinite(data["X"]).all(), "Hay NaNs en X"
            else:  # .pkl
                with open(archivo, "rb") as f:
                    data = pickle.load(f)
                for k in ["ventanas", "etiquetas", "tiempos_inicio", "freq_muestreo"]:
                    assert k in data, f"Falta clave '{k}'"
                n_ventanas = len(data["ventanas"])
                assert n_ventanas == len(data["etiquetas"]) == len(data["tiempos_inicio"]), "Longitudes inconsistentes"
                assert np.isfinite(np.asarray(data["ventanas"])).all(), "Hay NaNs en ventanas"

            total_ventanas += n_ventanas
            validos += 1

        except Exception as e:
            errores.append((archivo.name, str(e)))

    # ====== Reporte final ======
    print("=" * 80)
    print(f"‚úÖ Archivos v√°lidos: {validos}/{total_archivos}")
    print(f"üìä Total de ventanas revisadas: {total_ventanas:,}")
    print("=" * 80)

    if errores:
        print(f"\n‚ö†Ô∏è  Se detectaron {len(errores)} errores. Primeros 5:")
        for nombre, err in errores[:5]:
            print(f"   ‚Ä¢ {nombre}: {err}")
        if len(errores) > 5:
            print(f"   ... y {len(errores) - 5} m√°s.\n")
    else:
        print("\n‚ú® ¬°Todos los archivos son v√°lidos y consistentes!\n")

    return validos, errores


# Ejecutar verificaci√≥n
validos, errores = verificar_integridad()

## Resumen y limpieza

In [None]:
# ================================================
# CELDA FINAL: RESUMEN COMPLETO + (opc) LIMPIEZA
# ================================================
import pandas as pd
import numpy as np
from pathlib import Path

def human(nbytes):
    for unit in ['B','KB','MB','GB','TB']:
        if nbytes < 1024 or unit == 'TB':
            return f"{nbytes:.2f} {unit}"
        nbytes /= 1024

print("="*80)
print("üìã RESUMEN FINAL DEL PROCESAMIENTO")
print("="*80)

# Rutas base conocidas
resumen_path = ANALYSIS_DIR / "resumen_global.csv"
windows_dir  = WINDOWS_DIR
datasets_dir = Path.cwd() / "datasets_cnn"
numpy_dir    = OUTPUT_DIR / "numpy_exports"

# ---------- Cargar resumen_global ----------
if not resumen_path.exists():
    raise FileNotFoundError(f"No se encontr√≥ {resumen_path}")
df = pd.read_csv(resumen_path)

# ---------- Contar y pesar ventanas (.npz/.pkl) ----------
archivos_pkl = list(windows_dir.glob("*.pkl"))
archivos_npz = list(windows_dir.glob("*.npz"))
total_archivos = len(archivos_pkl) + len(archivos_npz)
size_windows = sum(f.stat().st_size for f in archivos_pkl + archivos_npz)

# ---------- (Opcional) pesar datasets derivados ----------
size_datasets = sum(f.stat().st_size for f in datasets_dir.rglob("*") if f.is_file()) if datasets_dir.exists() else 0
size_numpy    = sum(f.stat().st_size for f in numpy_dir.rglob("*")    if f.is_file()) if numpy_dir.exists()    else 0

print(f"\n‚úÖ Procesamiento completado")
print(f"\nüìÅ Estructura de salida:")
print(f"   {OUTPUT_DIR}/")
print(f"   ‚îú‚îÄ‚îÄ ventanas_extraidas/  ({total_archivos} archivos: {len(archivos_pkl)} .pkl, {len(archivos_npz)} .npz)")
print(f"   ‚îú‚îÄ‚îÄ analisis_canales/ -> resumen_global.csv")
if datasets_dir.exists():
    print(f"   ‚îú‚îÄ‚îÄ datasets_cnn/ (existe)")
if numpy_dir.exists():
    print(f"   ‚îî‚îÄ‚îÄ numpy_exports/ (existe)")

# ---------- Estad√≠sticas globales ----------
print(f"\nüìä Estad√≠sticas globales (resumen_global):")
print(f"   ‚Ä¢ Pacientes procesados: {df['Paciente'].nunique()}")
print(f"   ‚Ä¢ Canales totales:      {len(df)}")
print(f"   ‚Ä¢ Ventanas totales:     {df['N_Ventanas'].sum():,}")

if 'Tipo' in df.columns:
    tipos = ', '.join(sorted(df['Tipo'].dropna().unique()))
else:
    def inferir_tipo(c):
        u = str(c).upper()
        if 'EEG' in u: return 'EEG'
        if 'EOG' in u: return 'EOG'
        if 'EMG' in u: return 'EMG'
        if 'ECG' in u or 'EKG' in u: return 'ECG'
        if 'RESP' in u or 'AIRFLOW' in u: return 'RESP'
        return 'OTRO'
    tipos = ', '.join(sorted(df['Canal'].apply(inferir_tipo).unique()))
print(f"   ‚Ä¢ Tipos de canales:     {tipos}")

# ---------- Espacio en disco ----------
print(f"\nüíæ Espacio en disco (aprox.):")
print(f"   ‚Ä¢ ventanas_extraidas: {human(size_windows)}")
if datasets_dir.exists():
    print(f"   ‚Ä¢ datasets_cnn:       {human(size_datasets)}")
if numpy_dir.exists():
    print(f"   ‚Ä¢ numpy_exports:      {human(size_numpy)}")
total_all = size_windows + size_datasets + size_numpy
print(f"   ‚Ä¢ TOTAL:              {human(total_all)}")

# ---------- Top archivos m√°s pesados en ventanas_extraidas ----------
if total_archivos > 0:
    top = sorted(archivos_pkl + archivos_npz, key=lambda p: p.stat().st_size, reverse=True)[:10]
    print(f"\nüì¶ Top 10 archivos m√°s pesados en ventanas_extraidas:")
    for f in top:
        print(f"   - {f.name:60s} {human(f.stat().st_size)}")

# ---------- (Opcional) limpieza segura ----------
DELETE_TEMP = False   # ‚¨ÖÔ∏è cambia a True si quieres borrar cach√©s temporales
TEMP_FOLDERS = [
    Path.cwd() / "stft_cache",
    Path.cwd() / "stft_cache_stream",
]
if DELETE_TEMP:
    print("\nüßπ Eliminando cach√©s temporales...")
    for d in TEMP_FOLDERS:
        if d.exists():
            for f in d.rglob("*"):
                try:
                    if f.is_file(): f.unlink()
                except Exception: pass
            try:
                for sub in sorted(d.glob("**/*"), reverse=True):
                    if sub.is_dir(): sub.rmdir()
                d.rmdir()
            except Exception:
                pass
    print("   ‚úÖ Limpieza completada")

print("\n‚ú® Listo. Puedes decidir qu√© conservar o limpiar en base a este resumen.")
print("="*80)


# Creaci√≥n de datasets

In [None]:
# ============================================================
# PIPELINE STREAMING ROBUSTO 
# - Cachea STFT por paciente/canal en .npy (float16)
# - √çndice liviano (index.json)
# - Ensamble con opci√≥n de GUARDA EN SHARDS para archivos grandes
# - Expone x1..x5, y1..y5, meta1..meta5 en memoria si as√≠ se desea
# ============================================================
import numpy as np
import pandas as pd
import json, pickle, shutil, math, gc
from pathlib import Path
from scipy.signal import stft, get_window, resample
from tqdm.notebook import tqdm
import warnings
warnings.filterwarnings('ignore')

# ========= PAR√ÅMETROS DE CONTROL =========
BUILD_WHICH = [1,2,3,4,5]     # qu√© datasets construir (1..5)
SAVE_DATASETS = False          # True: guarda a disco; False: solo variables en RAM
SAVE_FORMAT = "npz"           # "npz" (comprimido) o "npy"
SHARD_MAX_BYTES = 1_200_000_000  # ~1.2GB por shard para evitar OSError en Windows
OUT_DTYPE = "float32"         # dtype final del dataset ensamblado
LIMIT_PATIENTS = None         # ej. 20 para pruebas; None = todos

# ========= CONFIG BASE =========
LABEL2ID = {"W":0, "N1":1, "N2":2, "N3":3, "REM":4}
ID2LABEL = {v:k for k,v in LABEL2ID.items()}

CHANNEL_PATTERNS = {
    "EEG1": ["EEG Fpz-Cz", "Fpz-Cz"],
    "EEG2": ["EEG Pz-Oz", "Pz-Oz"],
    "EOG" : ["EOG", "EOG horizontal", "EOG horizontal derivation"],
    "EMG" : ["EMG", "EMG submental", "Submental EMG"]
}
CHANNEL_BANDS = {
    "EEG1": (0.3, 35.0),
    "EEG2": (0.3, 35.0),
    "EOG" : (0.1, 15.0),
    "EMG" : (10.0, 100.0),
}

# Rejilla com√∫n
H_COMMON = 128
W_TARGET = 15
WIN_SEC  = 30.0
SEG_SEC  = 2.0
HOP_SEC  = 2.0
NPERSEG_FIXED = 256
WINDOW_TYPE   = "hamming"

# Rutas de tu pipeline 
resumen_csv = ANALYSIS_DIR / "resumen_global.csv"
cache_dir   = Path("stft_cache_stream")
ds_dir      = Path("datasets_cnn")
ds_dir.mkdir(parents=True, exist_ok=True)

# ---------- Loaders de ventanas (npz/pkl) ----------
def _load_npz(path: Path):
    d = np.load(path, allow_pickle=False)
    return {"X": d["X"], "y": d["y"].astype(np.uint8),
            "t": d["t"].astype(np.float32), "fs": float(d["fs"]),
            "canal": str(d["canal"])}

def _load_pkl(path: Path):
    with open(path, "rb") as f:
        data = pickle.load(f)
    if isinstance(data.get("etiquetas", []), list) and data["etiquetas"]:
        y = []
        for s in data["etiquetas"]:
            y.append(int(s) if isinstance(s, (int, np.integer)) else LABEL2ID.get(str(s), 255))
        y = np.array(y, dtype=np.uint8)
    else:
        y = np.array(data.get("etiquetas", []), dtype=np.uint8)
    return {"X": np.asarray(data["ventanas"], dtype=np.float32),
            "y": y,
            "t": np.asarray(data["tiempos_inicio"], dtype=np.float32),
            "fs": float(data.get("freq_muestreo", 100.0)),
            "canal": str(data.get("nombre_canal", "CANAL"))}

def load_channel_file(paciente: str, canal_nombre: str, windows_dir: Path):
    base = windows_dir / f"{paciente}_{canal_nombre.replace(' ', '_')}"
    npz_path, pkl_path = base.with_suffix(".npz"), base.with_suffix(".pkl")
    if npz_path.exists(): return _load_npz(npz_path), ".npz"
    if pkl_path.exists(): return _load_pkl(pkl_path), ".pkl"
    return None, None

def pick_channel_name(df_patient: pd.DataFrame, aliases: list[str]) -> str | None:
    names = list(df_patient["Canal"].unique())
    u_names = [n.upper() for n in names]
    for alias in aliases:
        alias_u = alias.upper()
        for n, u in zip(names, u_names):
            if u == alias_u: return n
        for n, u in zip(names, u_names):
            if alias_u in u: return n
    return None

# ---------- STFT ‚Üí rejilla com√∫n ----------
def stft_to_grid(x, fs, fmin, fmax, H_out=H_COMMON, W_out=W_TARGET):
    x = np.asarray(x, dtype=np.float32)
    expected_len = int(round(WIN_SEC * fs))
    if len(x) != expected_len:
        x = x[:expected_len] if len(x) > expected_len else np.pad(x, (0, expected_len - len(x)), mode="constant")
    nperseg = int(NPERSEG_FIXED)
    hop_samps = int(round(HOP_SEC * fs))
    noverlap = max(0, nperseg - hop_samps)
    nfft = 1
    while nfft < nperseg: nfft <<= 1
    f, t, Z = stft(x, fs=fs,
                   window=get_window(WINDOW_TYPE, nperseg, fftbins=True),
                   nperseg=nperseg, noverlap=noverlap, nfft=nfft,
                   boundary=None, padded=False, detrend=False, return_onesided=True)
    P = np.log10(np.maximum(np.abs(Z)**2, 1e-12)).astype(np.float32)
    mask = (f >= fmin) & (f <= fmax)
    P_band = P[mask, :]
    if P_band.shape[0] != H_out: P_band = resample(P_band, H_out, axis=0)
    if P_band.shape[1] != W_out: P_band = resample(P_band, W_out, axis=1)
    return P_band

# ============================================================
# 1) CACHE STREAMING (guarda por paciente/canal)
# ============================================================
def compute_stft_cache_streaming(
    analysis_csv: Path,
    windows_dir: Path,
    cache_dir: Path,
    max_patients: int | None = LIMIT_PATIENTS,
    force_recompute: bool = False,
    save_dtype: str = "float16"
):
    cache_dir.mkdir(parents=True, exist_ok=True)
    index_path = cache_dir / "index.json"

    if index_path.exists() and not force_recompute:
        with open(index_path, "r") as f: return json.load(f)

    if force_recompute and cache_dir.exists():
        for p in cache_dir.glob("*"):
            if p.is_file() and p.name != "index.json": p.unlink()
            elif p.is_dir(): shutil.rmtree(p)
        cache_dir.mkdir(parents=True, exist_ok=True)

    df = pd.read_csv(analysis_csv)
    patients = list(df["Paciente"].unique())
    if max_patients: patients = patients[:max_patients]

    index = {k: {} for k in CHANNEL_PATTERNS.keys()}

    print("\n" + "="*70)
    print("‚ö° CACHE STREAMING DE STFT POR PACIENTE/CANAL")
    print("="*70)
    print(f"üìä Procesando {len(patients)} pacientes √ó 4 canales...")

    for p in tqdm(patients, desc="üîÑ Pacientes"):
        dpf = df[df["Paciente"] == p]
        for ch_key in CHANNEL_PATTERNS.keys():
            ch_name = pick_channel_name(dpf, CHANNEL_PATTERNS[ch_key])
            if ch_name is None: continue

            out_dir = cache_dir / ch_key
            out_dir.mkdir(parents=True, exist_ok=True)
            x_path, y_path, t_path = out_dir / f"{p}_X.npy", out_dir / f"{p}_y.npy", out_dir / f"{p}_t.npy"

            if x_path.exists() and y_path.exists() and t_path.exists():
                index[ch_key][p] = {"X": str(x_path), "y": str(y_path), "t": str(t_path)}
                continue

            dfile, _ = load_channel_file(p, ch_name, WINDOWS_DIR)
            if dfile is None: continue

            fs = dfile["fs"]
            fmin, fmax = CHANNEL_BANDS[ch_key]
            Xraw, y, t = dfile["X"], dfile["y"], dfile["t"]

            n_win = Xraw.shape[0]
            X_stft = np.empty((n_win, H_COMMON, W_TARGET), dtype=np.float32)
            for i in range(n_win):
                X_stft[i] = stft_to_grid(Xraw[i], fs, fmin, fmax)

            if save_dtype == "float16": X_stft = X_stft.astype(np.float16)

            np.save(x_path, X_stft)
            np.save(y_path, y.astype(np.uint8))
            np.save(t_path, t.astype(np.float32))
            index[ch_key][p] = {"X": str(x_path), "y": str(y_path), "t": str(t_path)}

            # liberar RAM por paciente
            del X_stft, Xraw, y, t
            gc.collect()

    with open(index_path, "w") as f: json.dump(index, f, indent=2)
    print(f"\n‚úÖ Cach√© listo en {cache_dir} (index.json)")
    return index

# ============================================================
# Guardado seguro: shards para arrays grandes
# ============================================================
def _save_array_safely(base_path: Path, X: np.ndarray, y: np.ndarray, meta: dict,
                       fmt="npz", shard_max_bytes=SHARD_MAX_BYTES):
    base_path.parent.mkdir(parents=True, exist_ok=True)
    H, W, C = X.shape[1], X.shape[2], X.shape[3]
    bytes_per_sample = X.dtype.itemsize * H * W * C + y.dtype.itemsize
    n = X.shape[0]
    if n == 0:
        # guardar vac√≠o
        if fmt == "npz":
            np.savez_compressed(base_path.parent / f"{base_path.stem}_X.npz", X=X)
            np.savez_compressed(base_path.parent / f"{base_path.stem}_y.npz", y=y)
        else:
            np.save(base_path.parent / f"{base_path.stem}_X.npy", X)
            np.save(base_path.parent / f"{base_path.stem}_y.npy", y)
        with open(base_path.parent / f"{base_path.stem}_meta.pkl", "wb") as f:
            pickle.dump(meta, f)
        return {"shards": []}

    samples_per_shard = max(1, shard_max_bytes // bytes_per_sample)
    n_shards = math.ceil(n / samples_per_shard)

    manifest = {"shards": [], "format": fmt, "n": int(n), "HWC": [H,W,C]}
    for s in range(n_shards):
        a, b = s * samples_per_shard, min(n, (s+1) * samples_per_shard)
        Xs, ys = X[a:b], y[a:b]
        shard_tag = f"{base_path.stem}_shard{s:02d}"
        if fmt == "npz":
            np.savez_compressed(base_path.parent / f"{shard_tag}.npz", X=Xs, y=ys)
        else:
            np.save(base_path.parent / f"{shard_tag}_X.npy", Xs)
            np.save(base_path.parent / f"{shard_tag}_y.npy", ys)
        manifest["shards"].append({"start": int(a), "end": int(b), "tag": shard_tag})

    with open(base_path.parent / f"{base_path.stem}_meta.pkl", "wb") as f:
        pickle.dump(meta, f)
    with open(base_path.parent / f"{base_path.stem}_manifest.json", "w") as f:
        json.dump(manifest, f, indent=2)
    return manifest

# ============================================================
# 2) ENSAMBLAR DATASET DESDE CACH√â (con opci√≥n shards)
# ============================================================
def assemble_dataset_from_cache_streaming(
    index: dict,
    required_keys: list[str],
    save_path: Path | None = None,
    out_dtype: str = OUT_DTYPE,
    save_format: str = SAVE_FORMAT,
    shard_max_bytes: int = SHARD_MAX_BYTES
):
    print(f"\nüî® Ensamblando dataset: {required_keys}")

    valid_patients = set(index[required_keys[0]].keys())
    for k in required_keys[1:]:
        valid_patients &= set(index[k].keys())
    valid_patients = sorted(list(valid_patients))
    print(f"   Pacientes v√°lidos: {len(valid_patients)}")
    if not valid_patients:
        raise ValueError("‚ùå No hay pacientes con todos los canales requeridos")

    X_list, y_list, counts = [], [], []

    for p in tqdm(valid_patients, desc="Ensamblando", leave=False):
        times_rounded = {}
        for k in required_keys:
            t = np.load(index[k][p]["t"]).astype(np.float32)
            times_rounded[k] = np.round(t, 4)

        common = set(times_rounded[required_keys[0]])
        for k in required_keys[1:]:
            common &= set(times_rounded[k])
        if not common: continue
        common_sorted = np.array(sorted(list(common)), dtype=np.float32)

        idx_maps = {}
        for k in required_keys:
            t = times_rounded[k]
            t2idx = {float(tt): i for i, tt in enumerate(t)}
            idx_maps[k] = [t2idx[float(tt)] for tt in common_sorted]

        patient_specs_list, patient_labels = [], None
        for k in required_keys:
            X_ch = np.load(index[k][p]["X"])
            X_aligned = X_ch[idx_maps[k]]  # (n, H, W)
            patient_specs_list.append(X_aligned[..., None])  # (n,H,W,1)
            if patient_labels is None:
                y_ch = np.load(index[k][p]["y"])
                patient_labels = y_ch[idx_maps[k]]

        patient_specs = np.concatenate(patient_specs_list, axis=-1)  # (n,H,W,C)
        if out_dtype == "float32" and patient_specs.dtype != np.float32:
            patient_specs = patient_specs.astype(np.float32)

        X_list.append(patient_specs)
        y_list.append(patient_labels)
        counts.append((p, int(patient_specs.shape[0])))

        # liberar por paciente
        del patient_specs_list, patient_labels, X_ch, X_aligned, y_ch
        gc.collect()

    X = np.concatenate(X_list, axis=0) if X_list else np.empty((0,H_COMMON,W_TARGET,len(required_keys)), dtype=np.float32)
    y = np.concatenate(y_list, axis=0) if y_list else np.empty((0,), dtype=np.uint8)
    del X_list, y_list; gc.collect()

    meta = {
        "shape": tuple(X.shape),
        "labels_unique": sorted(list(map(int, np.unique(y)))) if y.size else [],
        "label_map": ID2LABEL,
        "counts_per_patient": counts,
        "channels_used": {k: CHANNEL_PATTERNS[k] for k in required_keys},
        "channel_bands": {k: CHANNEL_BANDS[k] for k in required_keys},
        "grid": {"H": H_COMMON, "W": W_TARGET}
    }

    if save_path and SAVE_DATASETS:
        print("   üíæ Guardando con shards seguros...")
        _ = _save_array_safely(save_path, X, y, meta, fmt=save_format, shard_max_bytes=shard_max_bytes)
        print(f"   ‚úÖ Guardado en {save_path.parent}")

    return X, y, meta

# ============================================================
# 3) EJECUCI√ìN: CREA x1..x5, y1..y5, meta1..meta5
# ============================================================
print("\n" + "="*70)
print("üöÄ CREACI√ìN STREAMING DE 5 DATASETS (ROBUSTA)")
print("="*70)

print("\nüìç PASO 1/2: Cachear STFT por paciente/canal")
index = compute_stft_cache_streaming(
    analysis_csv=resumen_csv,
    windows_dir=WINDOWS_DIR,
    cache_dir=cache_dir,
    max_patients=LIMIT_PATIENTS,
    force_recompute=False,
    save_dtype="float16"
)

print("\nüìç PASO 2/2: Ensamblar datasets desde el cach√©")
# Definici√≥n de los 5
recipes = {
    1: (["EEG1","EEG2","EOG","EMG"], "ds1_4ch"),
    2: (["EEG1","EEG2","EOG"],       "ds2_eeg_eog"),
    3: (["EEG1"],                    "ds3_eeg1"),
    4: (["EEG1","EEG2","EMG"],       "ds4_eeg_emg"),
    5: (["EOG","EMG"],               "ds5_eog_emg"),
}

# Helpers para exponer variables x1..x5
globals_map = {}
for i in [1,2,3,4,5]:
    if i not in BUILD_WHICH: 
        globals()[f"x{i}"] = None; globals()[f"y{i}"] = None; globals()[f"meta{i}"] = None
        continue
    keys, fname = recipes[i]
    save_path = (ds_dir / fname) if SAVE_DATASETS else None
    Xi, Yi, Metai = assemble_dataset_from_cache_streaming(
        index, keys, save_path=save_path, out_dtype=OUT_DTYPE,
        save_format=SAVE_FORMAT, shard_max_bytes=SHARD_MAX_BYTES
    )
    globals()[f"x{i}"] = Xi
    globals()[f"y{i}"] = Yi
    globals()[f"meta{i}"] = Metai
    print(f"   ‚¨ÜÔ∏è Listo dataset {i}: shape={Metai['shape']}")

print("\n" + "="*70)
print("üìä RESUMEN DE DATASETS CREADOS")
print("="*70)
for i, desc in zip([1,2,3,4,5], ["EEG1+EEG2+EOG+EMG","EEG1+EEG2+EOG","EEG1 only","EEG1+EEG2+EMG","EOG+EMG"]):
    Mi = globals()[f"meta{i}"]
    if Mi is None:
        print(f"üóÇÔ∏è  Dataset {i}: {desc} ‚Äî (NO construido)")
        continue
    print(f"\nüóÇÔ∏è  Dataset {i}: {desc}")
    print(f"   ‚Ä¢ Shape: {Mi['shape']}")
    print(f"   ‚Ä¢ Etiquetas: {[ID2LABEL[j] for j in Mi['labels_unique']]}")
    print(f"   ‚Ä¢ Pacientes: {len(Mi['counts_per_patient'])}")

print("\n‚úÖ Variables disponibles: x1..x5, y1..y5, meta1..meta5")
print(f"üíæ Guardado en shards: {'S√≠' if SAVE_DATASETS else 'No (solo RAM)'}")


## Separaci√≥n de datos

In [None]:
# ============================================
# SPLIT POR PACIENTE 60/20/20 (robusto a memoria/disco)
# - Si x*, y*, meta* ya existen, los usa.
# - Si no, carga datasets desde datasets_cnn/.
# ============================================

import numpy as np
import pickle
from pathlib import Path
from sklearn.model_selection import GroupShuffleSplit

DS_DIR = Path("datasets_cnn")

def _load_ds_from_disk(tag: str):
    """Carga X, y, meta de disco: tag='ds1_4ch', 'ds2_eeg_eog', ..."""
    x_path = DS_DIR / f"{tag}_X.npy"
    y_path = DS_DIR / f"{tag}_y.npy"
    m_path = DS_DIR / f"{tag}_meta.pkl"
    if not x_path.exists():
        raise FileNotFoundError(f"No existe {x_path}")
    if not y_path.exists():
        raise FileNotFoundError(f"No existe {y_path}")
    if not m_path.exists():
        raise FileNotFoundError(f"No existe {m_path}")
    X = np.load(x_path, mmap_mode=None)     
    y = np.load(y_path)
    with open(m_path, "rb") as f:
        meta = pickle.load(f)
    return X, y, meta

def _ensure_loaded(var_triplet, fallback_tag):
    """Si (X,y,meta) no est√°n en memoria, los carga desde disco."""
    X, y, meta = var_triplet
    if "X" not in locals() and "y" not in locals():  # no sirve dentro de funci√≥n
        pass
    if X is None or y is None or meta is None:
        return _load_ds_from_disk(fallback_tag)
    return X, y, meta

# Intenta usar variables en memoria; si no existen, carga de disco.
try:
    x1, y1, meta1
except NameError:
    x1 = y1 = meta1 = None
try:
    x2, y2, meta2
except NameError:
    x2 = y2 = meta2 = None
try:
    x3, y3, meta3
except NameError:
    x3 = y3 = meta3 = None
try:
    x4, y4, meta4
except NameError:
    x4 = y4 = meta4 = None
try:
    x5, y5, meta5
except NameError:
    x5 = y5 = meta5 = None

x1, y1, meta1 = _ensure_loaded((x1,y1,meta1), "ds1_4ch")
x2, y2, meta2 = _ensure_loaded((x2,y2,meta2), "ds2_eeg_eog")
x3, y3, meta3 = _ensure_loaded((x3,y3,meta3), "ds3_eeg1")
x4, y4, meta4 = _ensure_loaded((x4,y4,meta4), "ds4_eeg_emg")
x5, y5, meta5 = _ensure_loaded((x5,y5,meta5), "ds5_eog_emg")

def make_patient_ids(meta):
    """Construye vector (N,) de IDs de paciente a partir de meta['counts_per_patient']."""
    ids = []
    for patient, n in meta["counts_per_patient"]:
        ids.extend([patient] * int(n))
    return np.array(ids, dtype=object)  # object para mantener strings completos

def split_by_patient(y, patient_ids, test_size=0.20, val_size=0.20, random_state=42):
    """Devuelve dict con √≠ndices 'train', 'val', 'test' (sin fuga entre pacientes)."""
    N = len(y)
    assert len(patient_ids) == N, "Desalineaci√≥n patient_ids vs y"
    gss1 = GroupShuffleSplit(n_splits=1, test_size=test_size, random_state=random_state)
    train_val_idx, test_idx = next(gss1.split(np.zeros(N), y, groups=patient_ids))

    rel_val = val_size / (1.0 - test_size)  # ej. 0.20/0.80=0.25
    gss2 = GroupShuffleSplit(n_splits=1, test_size=rel_val, random_state=random_state + 1)
    pv = patient_ids[train_val_idx]
    yv = y[train_val_idx]
    sub_train_idx, val_idx_sub = next(gss2.split(np.zeros(len(train_val_idx)), yv, groups=pv))

    train_idx = train_val_idx[sub_train_idx]
    val_idx   = train_val_idx[val_idx_sub]

    # sanity checks
    assert set(train_idx).isdisjoint(test_idx) and set(val_idx).isdisjoint(test_idx)
    assert set(train_idx).isdisjoint(val_idx)

    return {"train": train_idx, "val": val_idx, "test": test_idx}

def print_split_summary(y, patient_ids, splits, name, label_names={0:"W",1:"N1",2:"N2",3:"N3",4:"REM"}):
    p_train = set(np.unique(patient_ids[splits["train"]]))
    p_val   = set(np.unique(patient_ids[splits["val"]]))
    p_test  = set(np.unique(patient_ids[splits["test"]]))

    print(f"\n====== {name}: PACIENTES POR SPLIT ======")
    print(f"Train: {len(p_train)} | Val: {len(p_val)} | Test: {len(p_test)}")
    print("Intersecciones (deben ser 0):",
          len(p_train & p_val), len(p_train & p_test), len(p_val & p_test))

    print("------ Distribuci√≥n de clases (por ventanas) ------")
    for split_name, idx in splits.items():
        yy = y[idx]
        uniq, cnt = np.unique(yy, return_counts=True)
        total = len(yy)
        nice = ", ".join([f"{label_names.get(int(k),k)}: {int(v)} ({(int(v)/total*100):.1f}%)"
                          for k, v in sorted(zip(uniq, cnt), key=lambda z:int(z[0]))])
        print(f"{split_name:>5} -> N={total} | {nice}")

# Construir IDs de paciente
patient_ids1 = make_patient_ids(meta1)
patient_ids2 = make_patient_ids(meta2)
patient_ids3 = make_patient_ids(meta3)
patient_ids4 = make_patient_ids(meta4)
patient_ids5 = make_patient_ids(meta5)

# Asegurar alineaci√≥n
assert len(patient_ids1) == len(y1) == x1.shape[0]
assert len(patient_ids2) == len(y2) == x2.shape[0]
assert len(patient_ids3) == len(y3) == x3.shape[0]
assert len(patient_ids4) == len(y4) == x4.shape[0]
assert len(patient_ids5) == len(y5) == x5.shape[0]

# Ejecutar splits
splits1 = split_by_patient(y1, patient_ids1, 0.20, 0.20, 42)
splits2 = split_by_patient(y2, patient_ids2, 0.20, 0.20, 42)
splits3 = split_by_patient(y3, patient_ids3, 0.20, 0.20, 42)
splits4 = split_by_patient(y4, patient_ids4, 0.20, 0.20, 42)
splits5 = split_by_patient(y5, patient_ids5, 0.20, 0.20, 42)

# Resumenes
print_split_summary(y1, patient_ids1, splits1, "DS1 EEG1+EEG2+EOG+EMG")
print_split_summary(y2, patient_ids2, splits2, "DS2 EEG1+EEG2+EOG")
print_split_summary(y3, patient_ids3, splits3, "DS3 EEG1 only")
print_split_summary(y4, patient_ids4, splits4, "DS4 EEG1+EEG2+EMG")
print_split_summary(y5, patient_ids5, splits5, "DS5 EOG+EMG")


# (Opcional) Guardar √≠ndices por dataset para reproducibilidad
#SAVE_SPLITS = True
#if SAVE_SPLITS:
#    sp_dir = DS_DIR / "splits"
#    sp_dir.mkdir(parents=True, exist_ok=True)
#    for k, sp in enumerate([splits1, splits2, splits3, splits4, splits5], start=1):
#        np.save(sp_dir / f"ds{k}_train_idx.npy", sp["train"].astype(np.uint32))
#        np.save(sp_dir / f"ds{k}_val_idx.npy",   sp["val"].astype(np.uint32))
#        np.save(sp_dir / f"ds{k}_test_idx.npy",  sp["test"].astype(np.uint32))
#    print(f"\nüíæ √çndices guardados en {sp_dir}")


# Modelo

In [None]:
import torch
import torch.nn as nn

class MApooling2D(nn.Module):
    """Multi-scale Pooling: concat(MaxPool2d, AvgPool2d) ‚Üí duplica canales."""
    def __init__(self, kernel_size, stride, padding=1):
        super().__init__()
        self.max_pool = nn.MaxPool2d(kernel_size=kernel_size, stride=stride, padding=padding)
        self.avg_pool = nn.AvgPool2d(kernel_size=kernel_size, stride=stride, padding=padding)

    def forward(self, x):
        return torch.cat([self.max_pool(x), self.avg_pool(x)], dim=1)  # (B, 2*C, H', W')

class MCBlock(nn.Module):
    """Multi-scale Convolutional Block (4 ramas en paralelo). Salida: 240 canales."""
    def __init__(self, in_channels):
        super().__init__()
        self.branch1 = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=1, stride=1, padding=0),
            nn.ReLU(inplace=True)
        )
        self.branch2 = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=1, stride=1, padding=0),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 96, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True)
        )
        self.branch3 = nn.Sequential(
            nn.Conv2d(in_channels, 16, kernel_size=1, stride=1, padding=0),
            nn.ReLU(inplace=True),
            nn.Conv2d(16, 48, kernel_size=7, stride=1, padding=3),
            nn.ReLU(inplace=True)
        )
        self.branch4 = nn.Sequential(
            MApooling2D(kernel_size=3, stride=1, padding=1),           # duplica canales: in‚Üí2*in
            nn.Conv2d(in_channels * 2, 32, kernel_size=1, stride=1, padding=0),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        b1 = self.branch1(x)
        b2 = self.branch2(x)
        b3 = self.branch3(x)
        b4 = self.branch4(x)
        return torch.cat([b1, b2, b3, b4], dim=1)  # (B, 64+96+48+32=240, H, W)

class SleepStageModel(nn.Module):
    """
    Entrada esperada por defecto: (B, in_ch, H=121, W=15)
    Para tus 4 canales: in_ch=4. Si usas subsets (e.g., EEG1 solo), cambia in_ch.
    Arquitectura: Conv ‚Üí MApool ‚Üí MCBlock ‚Üí MApool ‚Üí GAP ‚Üí FC(480‚Üínum_classes)
    """
    def __init__(self, num_classes=5, in_ch=4):
        super().__init__()

        # Conv inicial
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_ch, 256, kernel_size=3, stride=2, padding=1),  # (B,256, ~61, ~8) con H=121,W=15
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(256)
        )

        # MApooling duplica canales: 256 ‚Üí 512
        self.m_pool1 = MApooling2D(kernel_size=3, stride=2, padding=1)  # (B,512, ~31, ~4)

        # MCBlock: 512 ‚Üí 240 canales
        self.mc_block = nn.Sequential(
            MCBlock(in_channels=512),
            nn.BatchNorm2d(240)
        )  # (B,240, ~31, ~4)

        # Segundo MApooling: 240 ‚Üí 480 canales
        self.m_pool2 = MApooling2D(kernel_size=3, stride=2, padding=1)  # (B,480, ~16, ~2)

        self.dropout1 = nn.Dropout(p=0.1)

        # Global Average Pooling ‚Üí (B,480,1,1) independiza de H/W exactos
        self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1))

        # Clasificador final. in_features = 480 fijo tras m_pool2 + GAP
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Dropout(p=0.5),
            nn.Linear(480, num_classes)
        )

    def forward(self, x):
        # x: (B, in_ch, H, W)  t√≠pico: (B,4,121,15)
        x = self.conv1(x)          # (B,256,‚âà61,‚âà8)
        x = self.m_pool1(x)        # (B,512,‚âà31,‚âà4)
        x = self.mc_block(x)       # (B,240,‚âà31,‚âà4)
        x = self.m_pool2(x)        # (B,480,‚âà16,‚âà2)
        x = self.dropout1(x)
        x = self.global_avg_pool(x)  # (B,480,1,1)
        x = self.classifier(x)       # (B,num_classes)
        return x


## Entrenamiento

In [None]:
# ===== ENTRENAMIENTO =====
import math
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, WeightedRandomSampler

# ---------- Dataset a NIVEL M√ìDULO (evita errores de pickle con num_workers>0) ----------
if 'SpectroDataset' not in globals():
    class SpectroDataset(torch.utils.data.Dataset):
        """
        Espera X con shape (N, H, W, C) y y con shape (N,).
        Devuelve tensores listos para Conv2d: (C, H, W) y etiqueta long.
        """
        def __init__(self, X, y, indices=None, dtype=np.float32):
            self.X = X
            self.y = y
            self.indices = np.arange(len(y)) if indices is None else np.asarray(indices)
            self.dtype = dtype

        def __len__(self):
            return len(self.indices)

        def __getitem__(self, i):
            idx = int(self.indices[i])
            x = np.asarray(self.X[idx], dtype=self.dtype)   # (H, W, C)
            x = np.transpose(x, (2, 0, 1))                  # -> (C, H, W)
            y_i = int(self.y[idx])
            return torch.from_numpy(x), torch.tensor(y_i, dtype=torch.long)

# ---------- EarlyStopping ----------
class EarlyStopping:
    def __init__(self, patience=15, mode='max', min_delta=0.0):
        self.patience = int(patience)
        self.mode = mode
        self.min_delta = float(min_delta)
        self.best = -np.inf if mode == 'max' else np.inf
        self.num_bad_epochs = 0

    def step(self, metric):
        if metric is None or (isinstance(metric, float) and (math.isnan(metric) or math.isinf(metric))):
            self.num_bad_epochs += 1
            return self.num_bad_epochs >= self.patience

        if self.mode == 'max':
            if metric - self.best > self.min_delta:
                self.best, self.num_bad_epochs = metric, 0
            else:
                self.num_bad_epochs += 1
        else:
            if self.best - metric > self.min_delta:
                self.best, self.num_bad_epochs = metric, 0
            else:
                self.num_bad_epochs += 1
        return self.num_bad_epochs >= self.patience

# ---------- evaluate ----------
@torch.no_grad()
def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss, total_correct, total = 0.0, 0, 0
    for xb, yb in loader:
        xb = xb.to(device, non_blocking=True)
        yb = yb.to(device, non_blocking=True)
        logits = model(xb)
        loss = criterion(logits, yb)
        if torch.isnan(loss) or torch.isinf(loss):
            print("‚ö†Ô∏è  Loss NaN/Inf detectado en evaluaci√≥n; batch omitido")
            continue
        pred = logits.argmax(dim=1)
        total_correct += (pred == yb).sum().item()
        total_loss += float(loss.item())
        total += yb.size(0)
    val_loss = total_loss / max(1, len(loader))
    val_acc = total_correct / max(1, total)
    return val_acc, val_loss

# ---------- train_one_epoch ----------
def train_one_epoch(model, loader, optimizer, criterion, device, grad_clip=None, scaler=None):
    model.train()
    total_loss, total_correct, total, nan_count = 0.0, 0, 0, 0
    autocast_ctx = (
        torch.autocast(device_type='cuda', dtype=torch.float16) if (scaler is not None and device.type == 'cuda')
        else torch.cuda.amp.autocast(enabled=False)
    )

    for xb, yb in loader:
        xb = xb.to(device, non_blocking=True)
        yb = yb.to(device, non_blocking=True)
        optimizer.zero_grad(set_to_none=True)

        if scaler is not None:
            with autocast_ctx:
                logits = model(xb)
                loss = criterion(logits, yb)
            if torch.isnan(loss) or torch.isinf(loss):
                nan_count += 1
                continue
            scaler.scale(loss).backward()
            if grad_clip is not None:
                scaler.unscale_(optimizer)
                nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            scaler.step(optimizer)
            scaler.update()
        else:
            logits = model(xb)
            loss = criterion(logits, yb)
            if torch.isnan(loss) or torch.isinf(loss):
                nan_count += 1
                continue
            loss.backward()
            if grad_clip is not None:
                nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            optimizer.step()

        pred = logits.argmax(dim=1)
        total_correct += (pred == yb).sum().item()
        total_loss += float(loss.item())
        total += yb.size(0)

    if nan_count > 0:
        print(f"‚ö†Ô∏è  {nan_count} batch(es) con NaN/Inf fueron omitidos")
    train_loss = total_loss / max(1, len(loader))
    train_acc = total_correct / max(1, total)
    return train_acc, train_loss

# ---------- Pesos de clase ----------
def compute_class_weights(y_train, method='sqrt_inverse', clip_range=(0.5, 2.5)):
    uniq, cnt = np.unique(y_train, return_counts=True)
    n_cls = int(np.max(y_train)) + 1
    freq = cnt / cnt.sum()
    if method == 'inverse':
        w = 1.0 / np.maximum(freq, 1e-8)
    elif method == 'sqrt_inverse':
        w = 1.0 / np.sqrt(np.maximum(freq, 1e-8))
    elif method == 'log_inverse':
        w = 1.0 / np.log1p(freq * 100)
    elif method == 'manual':
        return dict(zip(uniq.astype(int), cnt.tolist()))
    else:
        raise ValueError("method debe ser 'inverse', 'sqrt_inverse', 'log_inverse' o 'manual'")
    w = w / w.sum() * len(uniq)
    if clip_range is not None:
        w = np.clip(w, clip_range[0], clip_range[1])
    cw_np = np.ones(n_cls, dtype=np.float32)
    for k, weight in zip(uniq.astype(int), w):
        cw_np[k] = float(weight)
    return cw_np

# ---------- Entrenador ----------
def train_sleep_model(
    model,
    X, y, splits,
    batch_size=128,
    lr=1e-3,
    epochs=100,
    optimizer_type='adam',
    criterion_name="ce",
    label_smoothing=0.0,
    focal_gamma=2.0,
    use_gpu=True,
    num_workers=0,                     
    class_weights=None,
    weight_clip_range=(0.5, 2.5),
    grad_clip=1.0,
    amp=False,
    save_path="best_model.pt",
    early_stopping_tolerance=15,
    early_stopping_metric="val_acc",
    early_stopping_min_delta=0.0
):
    # --- Device
    if use_gpu and torch.cuda.is_available():
        device = torch.device("cuda")
        print(f"üöÄ Usando GPU: {torch.cuda.get_device_name(0)}")
        print(f"   Memoria total: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
        torch.backends.cudnn.benchmark = True
    else:
        device = torch.device("cpu")
        print("üíª Usando CPU")

    model = model.to(device)

    train_idx, val_idx, test_idx = splits["train"], splits["val"], splits["test"]

    # --- Pesos de clase
    cw_np = None
    if class_weights is not None:
        if isinstance(class_weights, str):
            cw_np = compute_class_weights(y[train_idx], method=class_weights, clip_range=weight_clip_range)
            uniq, cnt = np.unique(y[train_idx], return_counts=True)
            print("\nüìä Distribuci√≥n TRAIN:")
            for cls, count in zip(uniq.astype(int), cnt):
                print(f"   Clase {cls}: {count:7d} ({count/cnt.sum()*100:5.2f}%)")
            print("‚öñÔ∏è  Pesos de clase:", [f"{w:.3f}" for w in cw_np])
        elif isinstance(class_weights, dict):
            n_cls = int(np.max(y)) + 1
            cw_np = np.ones(n_cls, dtype=np.float32)
            for k, w in class_weights.items(): cw_np[int(k)] = float(w)
        elif isinstance(class_weights, (list, np.ndarray)):
            cw_np = np.array(class_weights, dtype=np.float32)

        if cw_np is not None and np.any(cw_np > 5.0):
            print(f"‚ö†Ô∏è  WARNING: Pesos muy altos (max={cw_np.max():.2f}). Considera clip m√°s estricto.")

    weight_tensor = torch.tensor(cw_np, dtype=torch.float32, device=device) if cw_np is not None else None

    # --- Criterio
    def make_ce(weight_tensor=None, label_smoothing=0.0):
        if label_smoothing and label_smoothing > 0.0:
            return nn.CrossEntropyLoss(weight=weight_tensor, label_smoothing=float(label_smoothing))
        return nn.CrossEntropyLoss(weight=weight_tensor)

    class FocalLoss(nn.Module):
        def __init__(self, gamma=2.0, alpha=None):
            super().__init__()
            self.gamma = float(gamma)
            self.alpha = alpha
            self.ce = nn.CrossEntropyLoss(reduction='none')
        def forward(self, logits, target):
            logits = torch.clamp(logits, min=-50, max=50)
            ce = torch.clamp(self.ce(logits, target), min=1e-7, max=50)
            pt = torch.clamp(torch.exp(-ce), min=1e-7, max=0.9999)
            loss = (1 - pt) ** self.gamma * ce
            if self.alpha is not None:
                loss = self.alpha[target] * loss
            mask = ~(torch.isnan(loss) | torch.isinf(loss))
            return loss[mask].mean() if mask.any() else torch.tensor(0.0, device=logits.device, requires_grad=True)

    if criterion_name == "ce":
        criterion = make_ce(weight_tensor=weight_tensor)
    elif criterion_name == "ce_smooth":
        criterion = make_ce(weight_tensor=weight_tensor, label_smoothing=label_smoothing)
    elif criterion_name == "focal":
        criterion = FocalLoss(gamma=focal_gamma, alpha=weight_tensor)
    else:
        raise ValueError("criterion_name inv√°lido")
    criterion = criterion.to(device)

    # --- Datasets / Loaders
    train_ds = SpectroDataset(X, y, indices=train_idx)
    val_ds   = SpectroDataset(X, y, indices=val_idx)
    test_ds  = SpectroDataset(X, y, indices=test_idx)

    # Sampler balanceado para TRAIN
    y_train_subset = y[train_idx]
    class_counts = np.bincount(y_train_subset, minlength=int(np.max(y))+1)
    inv_freq = 1.0 / np.maximum(class_counts, 1)
    sample_weights = inv_freq[y_train_subset]
    sampler = WeightedRandomSampler(
        weights=torch.as_tensor(sample_weights, dtype=torch.double),
        num_samples=len(sample_weights),
        replacement=True
    )

    pin_mem = (device.type == 'cuda')
    pw = num_workers > 0
    common_loader_kwargs = dict(num_workers=num_workers, pin_memory=pin_mem, persistent_workers=pw, prefetch_factor=(2 if pw else None))

    train_loader = DataLoader(train_ds, batch_size=batch_size, sampler=sampler, drop_last=False, **common_loader_kwargs)
    val_loader   = DataLoader(val_ds, batch_size=batch_size, shuffle=False, **common_loader_kwargs)
    test_loader  = DataLoader(test_ds, batch_size=batch_size, shuffle=False, **common_loader_kwargs)

    # --- Optimizador
    if optimizer_type == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    elif optimizer_type == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    elif optimizer_type == 'adamw':
        optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    else:
        raise ValueError("optimizer_type debe ser 'adam', 'sgd' o 'adamw'")

    print(f"üéØ Optimizador: {optimizer_type.upper()}  |  üìà LR: {lr}  |  üì¶ Batch: {batch_size}  |  üî¢ √âpocas: {epochs}")
    print(f"‚úÇÔ∏è  Grad clip: {grad_clip if grad_clip else 'OFF'}  |  ‚ö° AMP: {'ON' if (amp and device.type=='cuda') else 'OFF'}")

    scaler = torch.cuda.amp.GradScaler(enabled=(amp and device.type == 'cuda'))

    monitor_mode = 'max' if early_stopping_metric == 'val_acc' else 'min'
    es = EarlyStopping(patience=early_stopping_tolerance, mode=monitor_mode, min_delta=early_stopping_min_delta)

    history = {"train_acc": [], "val_acc": [], "train_loss": [], "val_loss": [], "lr": []}
    best_val_metric = -np.inf if monitor_mode == 'max' else np.inf
    best_state = None

    # --- Loop
    for epoch in range(1, epochs + 1):
        train_acc, train_loss = train_one_epoch(model, train_loader, optimizer, criterion, device, grad_clip=grad_clip, scaler=scaler)
        val_acc, val_loss = evaluate(model, val_loader, criterion, device)

        monitor_value = val_acc if early_stopping_metric == "val_acc" else val_loss
        improved = (monitor_value > best_val_metric) if monitor_mode == 'max' else (monitor_value < best_val_metric)
        if improved:
            best_val_metric = monitor_value
            best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
            torch.save(best_state, save_path)

        history["train_acc"].append(train_acc); history["train_loss"].append(train_loss)
        history["val_acc"].append(val_acc);     history["val_loss"].append(val_loss)
        history["lr"].append(lr)

        star = " ‚≠ê" if improved else ""
        print(f"Epoch {epoch:02d}/{epochs} | Train[L {train_loss:.4f} A {train_acc:.4f}] | "
              f"Val[L {val_loss:.4f} A {val_acc:.4f}] | Best {early_stopping_metric} {best_val_metric:.4f}{star}")

        if es.step(monitor_value):
            print(f"\nüõë Early stopping: {early_stopping_tolerance} √©pocas sin mejora en {early_stopping_metric}")
            break

    # --- Evaluaci√≥n final
    if best_state is not None:
        model.load_state_dict(best_state)
    test_acc, test_loss = evaluate(model, test_loader, criterion, device)

    print("\n" + "="*60)
    print("‚úÖ RESULTADOS FINALES")
    print("="*60)
    print(f"Test Loss:     {test_loss:.4f}")
    print(f"Test Accuracy: {test_acc:.4f} ({test_acc*100:.2f}%)")
    print(f"Best Val {early_stopping_metric}: {best_val_metric:.4f}")
    print(f"üíæ Modelo guardado en: {save_path}")
    print("="*60 + "\n")

    return model, history, {"test_acc": test_acc, "test_loss": test_loss}


In [None]:
import torch
print(torch.__version__, torch.version.cuda, torch.cuda.is_available())


# Ejecuci√≥n (1 sola vez)

## Primer dataset: EEG1,EEG2,EOG,EMG

In [None]:
# ===============================================================
# ENTRENAMIENTO + CURVAS + EVALUACI√ìN COMPLETA 
# ===============================================================

import torch, numpy as np, matplotlib.pyplot as plt, seaborn as sns, pandas as pd
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix, cohen_kappa_score
from torch.utils.data import Dataset, DataLoader

# -------- CONFIGURACI√ìN --------
use_gpu = True
device = torch.device("cuda" if (use_gpu and torch.cuda.is_available()) else "cpu")

# Reemplaza X, y, splits seg√∫n dataset (x1,y1,splits1), (x2,y2,splits2), etc.
X, y, splits = x1, y1, splits1
DATASET_NAME = "EEG1+EEG2+EOG+EMG"

# -------- ENTRENAMIENTO --------
try:
    model = SleepStageModel(num_classes=5, in_ch=X.shape[-1])
except TypeError:
    model = SleepStageModel(num_classes=5)

model, hist, results = train_sleep_model(
    model, X, y, splits,
    lr=1e-4,
    batch_size=256,
    epochs=35,
    criterion_name='ce',
    class_weights=None,  
    weight_clip_range=(0.1, 2.5),
    grad_clip=1.0,
    use_gpu=True,
    amp=False,
    num_workers=0,
    early_stopping_tolerance=8,
    early_stopping_metric="val_acc",
    save_path="NUL"  # para no guardar en disco
)

print("\n" + "="*60)
print(f"üìä RESULTADOS GENERALES ({DATASET_NAME}):")
print(f"   Test Accuracy: {results['test_acc']:.4f} ({results['test_acc']*100:.2f}%)")
print(f"   Test Loss: {results['test_loss']:.4f}")
print("="*60 + "\n")

# -------- CURVAS DE ENTRENAMIENTO --------
epochs = range(1, len(hist["train_loss"]) + 1)
plt.figure(figsize=(12,5))

plt.subplot(1,2,1)
plt.plot(epochs, hist["train_loss"], 'r-', label='Training')
plt.plot(epochs, hist["val_loss"], 'b-', label='Validation')
plt.title('Loss evolution'); plt.xlabel('Epoch'); plt.ylabel('Loss')
plt.legend(); plt.grid(alpha=0.3)

plt.subplot(1,2,2)
plt.plot(epochs, hist["train_acc"], 'r-', label='Training')
plt.plot(epochs, hist["val_acc"], 'b-', label='Validation')
plt.title('Accuracy evolution'); plt.xlabel('Epoch'); plt.ylabel('Accuracy')
plt.legend(); plt.grid(alpha=0.3)

plt.tight_layout()
plt.show()

# -------- EVALUACI√ìN DETALLADA --------
class SpectroDataset(Dataset):
    def __init__(self, X, y, indices):
        self.X, self.y = X, y
        self.idx = np.asarray(indices)
    def __len__(self): return len(self.idx)
    def __getitem__(self, i):
        j = self.idx[i]
        x = np.asarray(self.X[j], dtype=np.float32)
        x = np.transpose(x, (2,0,1))   # (C,F,T)
        yj = int(self.y[j])
        return torch.from_numpy(x), torch.tensor(yj, dtype=torch.long)

test_loader = DataLoader(SpectroDataset(X, y, splits["test"]),
                         batch_size=256, shuffle=False, num_workers=0,
                         pin_memory=(device.type=='cuda'))

@torch.no_grad()
def predict(model, loader):
    model.eval(); yp, yt = [], []
    for xb, yb in loader:
        xb = xb.to(device, non_blocking=True)
        pred = model(xb).argmax(1).cpu().numpy()
        yp.append(pred); yt.append(yb.numpy())
    return np.concatenate(yp), np.concatenate(yt)

y_pred, y_true = predict(model, test_loader)

# --- m√©tricas por clase ---
labels = ["W","N1","N2","N3","REM"]
prec, rec, f1, support = precision_recall_fscore_support(y_true, y_pred, labels=range(5), zero_division=0)
cm = confusion_matrix(y_true, y_pred, labels=range(5))
cm_norm = cm / np.maximum(cm.sum(1, keepdims=True), 1)

# --- tabla de m√©tricas ---
df_metrics = pd.DataFrame({
    "etapa": labels,
    "precision": np.round(prec,3),
    "recall": np.round(rec,3),
    "f1_score": np.round(f1,3),
    "soporte": support
})
display(df_metrics.style.set_caption(f"M√©tricas por etapa - {DATASET_NAME}").format(precision=3))

# --- m√©tricas globales ---
acc_global = accuracy_score(y_true, y_pred)
kappa_global = cohen_kappa_score(y_true, y_pred)
print(f"‚úÖ Accuracy global: {acc_global:.3f}")
print(f"‚úÖ Cohen‚Äôs Œ∫: {kappa_global:.3f}")

# --- matriz de confusi√≥n ---
plt.figure(figsize=(6,5))
sns.heatmap(cm_norm, annot=True, fmt=".2f", cmap="Blues",
            xticklabels=labels, yticklabels=labels, vmin=0, vmax=1)
plt.xlabel("Predicho"); plt.ylabel("Real")
plt.title(f"Matriz de confusi√≥n (normalizada) ‚Äî {DATASET_NAME}")
plt.tight_layout()
plt.show()


## Segundo dataset: EEG1,EEG2 y EOG

In [None]:
# ===============================================================
# ENTRENAMIENTO + CURVAS + EVALUACI√ìN COMPLETA 
# ===============================================================

import torch, numpy as np, matplotlib.pyplot as plt, seaborn as sns, pandas as pd
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix, cohen_kappa_score
from torch.utils.data import Dataset, DataLoader

# -------- CONFIGURACI√ìN --------
use_gpu = True
device = torch.device("cuda" if (use_gpu and torch.cuda.is_available()) else "cpu")

# Reemplaza X, y, splits seg√∫n dataset (x1,y1,splits1), (x2,y2,splits2), etc.
X, y, splits = x2, y2, splits2
DATASET_NAME = "EEG1+EEG2+EOG"

# -------- ENTRENAMIENTO --------
try:
    model = SleepStageModel(num_classes=5, in_ch=X.shape[-1])
except TypeError:
    model = SleepStageModel(num_classes=5)

model, hist, results = train_sleep_model(
    model, X, y, splits,
    lr=1e-4,
    batch_size=256,
    epochs=35,
    criterion_name='ce',
    class_weights=None,  
    weight_clip_range=(0.1, 2.5),
    grad_clip=1.0,
    use_gpu=True,
    amp=False,
    num_workers=0,
    early_stopping_tolerance=8,
    early_stopping_metric="val_acc",
    save_path="NUL"  # ‚ö†Ô∏è para no guardar en disco
)

print("\n" + "="*60)
print(f"üìä RESULTADOS GENERALES ({DATASET_NAME}):")
print(f"   Test Accuracy: {results['test_acc']:.4f} ({results['test_acc']*100:.2f}%)")
print(f"   Test Loss: {results['test_loss']:.4f}")
print("="*60 + "\n")

# -------- CURVAS DE ENTRENAMIENTO --------
epochs = range(1, len(hist["train_loss"]) + 1)
plt.figure(figsize=(12,5))

plt.subplot(1,2,1)
plt.plot(epochs, hist["train_loss"], 'r-', label='Training')
plt.plot(epochs, hist["val_loss"], 'b-', label='Validation')
plt.title('Loss evolution'); plt.xlabel('Epoch'); plt.ylabel('Loss')
plt.legend(); plt.grid(alpha=0.3)

plt.subplot(1,2,2)
plt.plot(epochs, hist["train_acc"], 'r-', label='Training')
plt.plot(epochs, hist["val_acc"], 'b-', label='Validation')
plt.title('Accuracy evolution'); plt.xlabel('Epoch'); plt.ylabel('Accuracy')
plt.legend(); plt.grid(alpha=0.3)

plt.tight_layout()
plt.show()

# -------- EVALUACI√ìN DETALLADA --------
class SpectroDataset(Dataset):
    def __init__(self, X, y, indices):
        self.X, self.y = X, y
        self.idx = np.asarray(indices)
    def __len__(self): return len(self.idx)
    def __getitem__(self, i):
        j = self.idx[i]
        x = np.asarray(self.X[j], dtype=np.float32)
        x = np.transpose(x, (2,0,1))   # (C,F,T)
        yj = int(self.y[j])
        return torch.from_numpy(x), torch.tensor(yj, dtype=torch.long)

test_loader = DataLoader(SpectroDataset(X, y, splits["test"]),
                         batch_size=256, shuffle=False, num_workers=0,
                         pin_memory=(device.type=='cuda'))

@torch.no_grad()
def predict(model, loader):
    model.eval(); yp, yt = [], []
    for xb, yb in loader:
        xb = xb.to(device, non_blocking=True)
        pred = model(xb).argmax(1).cpu().numpy()
        yp.append(pred); yt.append(yb.numpy())
    return np.concatenate(yp), np.concatenate(yt)

y_pred, y_true = predict(model, test_loader)

# --- m√©tricas por clase ---
labels = ["W","N1","N2","N3","REM"]
prec, rec, f1, support = precision_recall_fscore_support(y_true, y_pred, labels=range(5), zero_division=0)
cm = confusion_matrix(y_true, y_pred, labels=range(5))
cm_norm = cm / np.maximum(cm.sum(1, keepdims=True), 1)

# --- tabla de m√©tricas ---
df_metrics = pd.DataFrame({
    "etapa": labels,
    "precision": np.round(prec,3),
    "recall": np.round(rec,3),
    "f1_score": np.round(f1,3),
    "soporte": support
})
display(df_metrics.style.set_caption(f"M√©tricas por etapa - {DATASET_NAME}").format(precision=3))

# --- m√©tricas globales ---
acc_global = accuracy_score(y_true, y_pred)
kappa_global = cohen_kappa_score(y_true, y_pred)
print(f"‚úÖ Accuracy global: {acc_global:.3f}")
print(f"‚úÖ Cohen‚Äôs Œ∫: {kappa_global:.3f}")

# --- matriz de confusi√≥n ---
plt.figure(figsize=(6,5))
sns.heatmap(cm_norm, annot=True, fmt=".2f", cmap="Blues",
            xticklabels=labels, yticklabels=labels, vmin=0, vmax=1)
plt.xlabel("Predicho"); plt.ylabel("Real")
plt.title(f"Matriz de confusi√≥n (normalizada) ‚Äî {DATASET_NAME}")
plt.tight_layout()
plt.show()


## Tercer dataset: S√≥lo el EEG 1

In [None]:
# ===============================================================
# ENTRENAMIENTO + CURVAS + EVALUACI√ìN COMPLETA 
# ===============================================================

import torch, numpy as np, matplotlib.pyplot as plt, seaborn as sns, pandas as pd
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix, cohen_kappa_score
from torch.utils.data import Dataset, DataLoader

# -------- CONFIGURACI√ìN --------
use_gpu = True
device = torch.device("cuda" if (use_gpu and torch.cuda.is_available()) else "cpu")

# Reemplaza X, y, splits seg√∫n dataset (x1,y1,splits1), (x2,y2,splits2), etc.
X, y, splits = x3, y3, splits3
DATASET_NAME = "EEG1"

# -------- ENTRENAMIENTO --------
try:
    model = SleepStageModel(num_classes=5, in_ch=X.shape[-1])
except TypeError:
    model = SleepStageModel(num_classes=5)

model, hist, results = train_sleep_model(
    model, X, y, splits,
    lr=5e-6,
    batch_size=256,
    epochs=50,
    criterion_name='ce',
    class_weights=None,  
    weight_clip_range=(0.1, 2.5),
    grad_clip=1.0,
    use_gpu=True,
    amp=False,
    num_workers=0,
    early_stopping_tolerance=4,
    early_stopping_metric="val_acc",
    save_path="NUL"  # ‚ö†Ô∏è para no guardar en disco
)

print("\n" + "="*60)
print(f"üìä RESULTADOS GENERALES ({DATASET_NAME}):")
print(f"   Test Accuracy: {results['test_acc']:.4f} ({results['test_acc']*100:.2f}%)")
print(f"   Test Loss: {results['test_loss']:.4f}")
print("="*60 + "\n")

# -------- CURVAS DE ENTRENAMIENTO --------
epochs = range(1, len(hist["train_loss"]) + 1)
plt.figure(figsize=(12,5))

plt.subplot(1,2,1)
plt.plot(epochs, hist["train_loss"], 'r-', label='Training')
plt.plot(epochs, hist["val_loss"], 'b-', label='Validation')
plt.title('Loss evolution'); plt.xlabel('Epoch'); plt.ylabel('Loss')
plt.legend(); plt.grid(alpha=0.3)

plt.subplot(1,2,2)
plt.plot(epochs, hist["train_acc"], 'r-', label='Training')
plt.plot(epochs, hist["val_acc"], 'b-', label='Validation')
plt.title('Accuracy evolution'); plt.xlabel('Epoch'); plt.ylabel('Accuracy')
plt.legend(); plt.grid(alpha=0.3)

plt.tight_layout()
plt.show()

# -------- EVALUACI√ìN DETALLADA --------
class SpectroDataset(Dataset):
    def __init__(self, X, y, indices):
        self.X, self.y = X, y
        self.idx = np.asarray(indices)
    def __len__(self): return len(self.idx)
    def __getitem__(self, i):
        j = self.idx[i]
        x = np.asarray(self.X[j], dtype=np.float32)
        x = np.transpose(x, (2,0,1))   # (C,F,T)
        yj = int(self.y[j])
        return torch.from_numpy(x), torch.tensor(yj, dtype=torch.long)

test_loader = DataLoader(SpectroDataset(X, y, splits["test"]),
                         batch_size=256, shuffle=False, num_workers=0,
                         pin_memory=(device.type=='cuda'))

@torch.no_grad()
def predict(model, loader):
    model.eval(); yp, yt = [], []
    for xb, yb in loader:
        xb = xb.to(device, non_blocking=True)
        pred = model(xb).argmax(1).cpu().numpy()
        yp.append(pred); yt.append(yb.numpy())
    return np.concatenate(yp), np.concatenate(yt)

y_pred, y_true = predict(model, test_loader)

# --- m√©tricas por clase ---
labels = ["W","N1","N2","N3","REM"]
prec, rec, f1, support = precision_recall_fscore_support(y_true, y_pred, labels=range(5), zero_division=0)
cm = confusion_matrix(y_true, y_pred, labels=range(5))
cm_norm = cm / np.maximum(cm.sum(1, keepdims=True), 1)

# --- tabla de m√©tricas ---
df_metrics = pd.DataFrame({
    "etapa": labels,
    "precision": np.round(prec,3),
    "recall": np.round(rec,3),
    "f1_score": np.round(f1,3),
    "soporte": support
})
display(df_metrics.style.set_caption(f"M√©tricas por etapa - {DATASET_NAME}").format(precision=3))

# --- m√©tricas globales ---
acc_global = accuracy_score(y_true, y_pred)
kappa_global = cohen_kappa_score(y_true, y_pred)
print(f"‚úÖ Accuracy global: {acc_global:.3f}")
print(f"‚úÖ Cohen‚Äôs Œ∫: {kappa_global:.3f}")

# --- matriz de confusi√≥n ---
plt.figure(figsize=(6,5))
sns.heatmap(cm_norm, annot=True, fmt=".2f", cmap="Blues",
            xticklabels=labels, yticklabels=labels, vmin=0, vmax=1)
plt.xlabel("Predicho"); plt.ylabel("Real")
plt.title(f"Matriz de confusi√≥n (normalizada) ‚Äî {DATASET_NAME}")
plt.tight_layout()
plt.show()


## Cuarto dataset: EEG1, EEG2 y EMG

In [None]:
# ===============================================================
# ENTRENAMIENTO + CURVAS + EVALUACI√ìN COMPLETA 
# ===============================================================

import torch, numpy as np, matplotlib.pyplot as plt, seaborn as sns, pandas as pd
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix, cohen_kappa_score
from torch.utils.data import Dataset, DataLoader

# -------- CONFIGURACI√ìN --------
use_gpu = True
device = torch.device("cuda" if (use_gpu and torch.cuda.is_available()) else "cpu")

# Reemplaza X, y, splits seg√∫n dataset (x1,y1,splits1), (x2,y2,splits2), etc.
X, y, splits = x4, y4, splits4
DATASET_NAME = "EEG1+EEG2+EMG"

# -------- 1ENTRENAMIENTO --------
try:
    model = SleepStageModel(num_classes=5, in_ch=X.shape[-1])
except TypeError:
    model = SleepStageModel(num_classes=5)

model, hist, results = train_sleep_model(
    model, X, y, splits,
    lr=5e-6,
    batch_size=256, # Cambiar a 256 xd
    epochs=50,
    criterion_name='ce',
    class_weights=None,  
    weight_clip_range=(0.1, 2.5),
    grad_clip=1.0,
    use_gpu=True,
    amp=False,
    num_workers=0,
    early_stopping_tolerance=4,
    early_stopping_metric="val_acc",
    save_path="NUL"  # ‚ö†Ô∏è para no guardar en disco
)

print("\n" + "="*60)
print(f"üìä RESULTADOS GENERALES ({DATASET_NAME}):")
print(f"   Test Accuracy: {results['test_acc']:.4f} ({results['test_acc']*100:.2f}%)")
print(f"   Test Loss: {results['test_loss']:.4f}")
print("="*60 + "\n")

# -------- CURVAS DE ENTRENAMIENTO --------
epochs = range(1, len(hist["train_loss"]) + 1)
plt.figure(figsize=(12,5))

plt.subplot(1,2,1)
plt.plot(epochs, hist["train_loss"], 'r-', label='Training')
plt.plot(epochs, hist["val_loss"], 'b-', label='Validation')
plt.title('Loss evolution'); plt.xlabel('Epoch'); plt.ylabel('Loss')
plt.legend(); plt.grid(alpha=0.3)

plt.subplot(1,2,2)
plt.plot(epochs, hist["train_acc"], 'r-', label='Training')
plt.plot(epochs, hist["val_acc"], 'b-', label='Validation')
plt.title('Accuracy evolution'); plt.xlabel('Epoch'); plt.ylabel('Accuracy')
plt.legend(); plt.grid(alpha=0.3)

plt.tight_layout()
plt.show()

# -------- EVALUACI√ìN DETALLADA --------
class SpectroDataset(Dataset):
    def __init__(self, X, y, indices):
        self.X, self.y = X, y
        self.idx = np.asarray(indices)
    def __len__(self): return len(self.idx)
    def __getitem__(self, i):
        j = self.idx[i]
        x = np.asarray(self.X[j], dtype=np.float32)
        x = np.transpose(x, (2,0,1))   # (C,F,T)
        yj = int(self.y[j])
        return torch.from_numpy(x), torch.tensor(yj, dtype=torch.long)

test_loader = DataLoader(SpectroDataset(X, y, splits["test"]),
                         batch_size=256, shuffle=False, num_workers=0,
                         pin_memory=(device.type=='cuda'))

@torch.no_grad()
def predict(model, loader):
    model.eval(); yp, yt = [], []
    for xb, yb in loader:
        xb = xb.to(device, non_blocking=True)
        pred = model(xb).argmax(1).cpu().numpy()
        yp.append(pred); yt.append(yb.numpy())
    return np.concatenate(yp), np.concatenate(yt)

y_pred, y_true = predict(model, test_loader)

# --- m√©tricas por clase ---
labels = ["W","N1","N2","N3","REM"]
prec, rec, f1, support = precision_recall_fscore_support(y_true, y_pred, labels=range(5), zero_division=0)
cm = confusion_matrix(y_true, y_pred, labels=range(5))
cm_norm = cm / np.maximum(cm.sum(1, keepdims=True), 1)

# --- tabla de m√©tricas ---
df_metrics = pd.DataFrame({
    "etapa": labels,
    "precision": np.round(prec,3),
    "recall": np.round(rec,3),
    "f1_score": np.round(f1,3),
    "soporte": support
})
display(df_metrics.style.set_caption(f"M√©tricas por etapa - {DATASET_NAME}").format(precision=3))

# --- m√©tricas globales ---
acc_global = accuracy_score(y_true, y_pred)
kappa_global = cohen_kappa_score(y_true, y_pred)
print(f"‚úÖ Accuracy global: {acc_global:.3f}")
print(f"‚úÖ Cohen‚Äôs Œ∫: {kappa_global:.3f}")

# --- matriz de confusi√≥n ---
plt.figure(figsize=(6,5))
sns.heatmap(cm_norm, annot=True, fmt=".2f", cmap="Blues",
            xticklabels=labels, yticklabels=labels, vmin=0, vmax=1)
plt.xlabel("Predicho"); plt.ylabel("Real")
plt.title(f"Matriz de confusi√≥n (normalizada) ‚Äî {DATASET_NAME}")
plt.tight_layout()
plt.show()


## Quinto dataset: EOG y EMG

In [None]:
# ===============================================================
# ENTRENAMIENTO + CURVAS + EVALUACI√ìN COMPLETA 
# ===============================================================

import torch, numpy as np, matplotlib.pyplot as plt, seaborn as sns, pandas as pd
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix, cohen_kappa_score
from torch.utils.data import Dataset, DataLoader

# -------- CONFIGURACI√ìN --------
use_gpu = True
device = torch.device("cuda" if (use_gpu and torch.cuda.is_available()) else "cpu")

# Reemplaza X, y, splits seg√∫n dataset (x1,y1,splits1), (x2,y2,splits2), etc.
X, y, splits = x5, y5, splits5
DATASET_NAME = "EOG+EMG"

# -------- 1ENTRENAMIENTO --------
try:
    model = SleepStageModel(num_classes=5, in_ch=X.shape[-1])
except TypeError:
    model = SleepStageModel(num_classes=5)

model, hist, results = train_sleep_model(
    model, X, y, splits,
    lr=1e-4,
    batch_size=256,
    epochs=35,
    criterion_name='ce',
    class_weights=None,  
    weight_clip_range=(0.1, 2.5),
    grad_clip=1.0,
    use_gpu=True,
    amp=False,
    num_workers=0,
    early_stopping_tolerance=8,
    early_stopping_metric="val_acc",
    save_path="NUL"  # ‚ö†Ô∏è para no guardar en disco
)

print("\n" + "="*60)
print(f"üìä RESULTADOS GENERALES ({DATASET_NAME}):")
print(f"   Test Accuracy: {results['test_acc']:.4f} ({results['test_acc']*100:.2f}%)")
print(f"   Test Loss: {results['test_loss']:.4f}")
print("="*60 + "\n")

# -------- CURVAS DE ENTRENAMIENTO --------
epochs = range(1, len(hist["train_loss"]) + 1)
plt.figure(figsize=(12,5))

plt.subplot(1,2,1)
plt.plot(epochs, hist["train_loss"], 'r-', label='Training')
plt.plot(epochs, hist["val_loss"], 'b-', label='Validation')
plt.title('Loss evolution'); plt.xlabel('Epoch'); plt.ylabel('Loss')
plt.legend(); plt.grid(alpha=0.3)

plt.subplot(1,2,2)
plt.plot(epochs, hist["train_acc"], 'r-', label='Training')
plt.plot(epochs, hist["val_acc"], 'b-', label='Validation')
plt.title('Accuracy evolution'); plt.xlabel('Epoch'); plt.ylabel('Accuracy')
plt.legend(); plt.grid(alpha=0.3)

plt.tight_layout()
plt.show()

# -------- EVALUACI√ìN DETALLADA --------
class SpectroDataset(Dataset):
    def __init__(self, X, y, indices):
        self.X, self.y = X, y
        self.idx = np.asarray(indices)
    def __len__(self): return len(self.idx)
    def __getitem__(self, i):
        j = self.idx[i]
        x = np.asarray(self.X[j], dtype=np.float32)
        x = np.transpose(x, (2,0,1))   # (C,F,T)
        yj = int(self.y[j])
        return torch.from_numpy(x), torch.tensor(yj, dtype=torch.long)

test_loader = DataLoader(SpectroDataset(X, y, splits["test"]),
                         batch_size=256, shuffle=False, num_workers=0,
                         pin_memory=(device.type=='cuda'))

@torch.no_grad()
def predict(model, loader):
    model.eval(); yp, yt = [], []
    for xb, yb in loader:
        xb = xb.to(device, non_blocking=True)
        pred = model(xb).argmax(1).cpu().numpy()
        yp.append(pred); yt.append(yb.numpy())
    return np.concatenate(yp), np.concatenate(yt)

y_pred, y_true = predict(model, test_loader)

# --- m√©tricas por clase ---
labels = ["W","N1","N2","N3","REM"]
prec, rec, f1, support = precision_recall_fscore_support(y_true, y_pred, labels=range(5), zero_division=0)
cm = confusion_matrix(y_true, y_pred, labels=range(5))
cm_norm = cm / np.maximum(cm.sum(1, keepdims=True), 1)

# --- tabla de m√©tricas ---
df_metrics = pd.DataFrame({
    "etapa": labels,
    "precision": np.round(prec,3),
    "recall": np.round(rec,3),
    "f1_score": np.round(f1,3),
    "soporte": support
})
display(df_metrics.style.set_caption(f"M√©tricas por etapa - {DATASET_NAME}").format(precision=3))

# --- m√©tricas globales ---
acc_global = accuracy_score(y_true, y_pred)
kappa_global = cohen_kappa_score(y_true, y_pred)
print(f"‚úÖ Accuracy global: {acc_global:.3f}")
print(f"‚úÖ Cohen‚Äôs Œ∫: {kappa_global:.3f}")

# --- matriz de confusi√≥n ---
plt.figure(figsize=(6,5))
sns.heatmap(cm_norm, annot=True, fmt=".2f", cmap="Blues",
            xticklabels=labels, yticklabels=labels, vmin=0, vmax=1)
plt.xlabel("Predicho"); plt.ylabel("Real")
plt.title(f"Matriz de confusi√≥n (normalizada) ‚Äî {DATASET_NAME}")
plt.tight_layout()
plt.show()


# Ejecuci√≥n (3 veces)

## Primer dataset: EEG1,EEG2,EOG,EMG

In [None]:
# ================================================
# CELDA √öNICA: MULTI-RUN + M√âTRICAS 
# ================================================
import os, json, copy, math, random, pickle, warnings
from datetime import datetime
from pathlib import Path

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import (
    accuracy_score, precision_recall_fscore_support,
    confusion_matrix, cohen_kappa_score
)
warnings.filterwarnings("ignore")

# ========= 0) SELECCI√ìN DEL DATASET ACTUAL =========
X, y, splits = x1, y1, splits1
DATASET_NAME = "EEG1+EEG2+EOG+EMG"

# ========= 1) FLAGS (por defecto NO guarda) =========
SAVE_CHECKPOINTS   = False   # Guarda best_model.pt por run
SAVE_HISTORIES     = False   # Guarda history.npz por run
SAVE_PER_RUN_FILES = False   # Guarda CSV / PNG / NPY por run (m√©tricas y CM)
SAVE_AGGREGATES    = False   # Guarda tablas y plots agregados

# ========= 2) CONFIG GLOBAL =========
N_RUNS = 3
BASE_SEED = 42
CONFIG = {
    "lr": 5e-6,
    "batch_size": 256,
    "epochs": 50,
    "criterion_name": "ce",
    "class_weights": None,
    "weight_clip_range": (0.1, 2.5),
    "grad_clip": 1.0,
    "use_gpu": True,
    "amp": False,
    "num_workers": 0,
    "early_stopping_tolerance": 5,
    "early_stopping_metric": "val_acc"
}

# ======= Paths =======
OUTPUT_DIR = Path(OUTPUT_DIR) if 'OUTPUT_DIR' in globals() else (Path.cwd() / "outputs")
RUNS_DIR   = OUTPUT_DIR / "multiple_runs" / DATASET_NAME.replace(" ", "_")
RUNS_DIR.mkdir(parents=True, exist_ok=True)

# ========= 3) Utils =========
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

class SpectroDataset(Dataset):
    def __init__(self, X, y, indices):
        self.X, self.y = X, y
        self.idx = np.asarray(indices)
    def __len__(self): return len(self.idx)
    def __getitem__(self, i):
        j = self.idx[i]
        x = np.asarray(self.X[j], dtype=np.float32)  # (H,W,C)
        x = np.transpose(x, (2,0,1))                 # -> (C,H,W)
        yj = int(self.y[j])
        return torch.from_numpy(x), torch.tensor(yj, dtype=torch.long)

def _build_loaders(X, y, splits, batch_size=256, num_workers=0, pin=True):
    train_ds = SpectroDataset(X, y, splits['train'])
    val_ds   = SpectroDataset(X, y, splits['val'])
    test_ds  = SpectroDataset(X, y, splits['test'])

    # Weighted sampler (balanceo por clase en TRAIN)
    y_train_subset = y[splits['train']]
    class_counts = np.bincount(y_train_subset, minlength=int(np.max(y))+1)
    class_weights = 1.0 / np.maximum(class_counts, 1)
    sample_weights = class_weights[y_train_subset]
    sampler = torch.utils.data.WeightedRandomSampler(
        weights=torch.DoubleTensor(sample_weights),
        num_samples=len(sample_weights),
        replacement=True
    )

    train_loader = DataLoader(train_ds, batch_size=batch_size, sampler=sampler,
                              num_workers=num_workers, pin_memory=pin, drop_last=False)
    val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False,
                              num_workers=num_workers, pin_memory=pin)
    test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False,
                              num_workers=num_workers, pin_memory=pin)
    return train_loader, val_loader, test_loader

# ======= Modelo 
def _new_model():
    try:
        return SleepStageModel(num_classes=5, in_ch=X.shape[-1])
    except TypeError:
        return SleepStageModel(num_classes=5)

# ======= Entrenamiento 

# ======= M√©tricas detalladas por run =======
@torch.no_grad()
def evaluate_detailed(model, test_loader, device):
    labels = ["W","N1","N2","N3","REM"]
    n_classes = len(labels)

    model.eval()
    all_p, all_t = [], []
    for xb, yb in test_loader:
        xb = xb.to(device, non_blocking=True)
        logits = model(xb)
        p = torch.argmax(logits, dim=1).cpu().numpy()
        all_p.append(p); all_t.append(yb.numpy())
    y_pred = np.concatenate(all_p)
    y_true = np.concatenate(all_t)

    prec, rec, f1, support = precision_recall_fscore_support(
        y_true, y_pred, labels=range(n_classes), average=None, zero_division=0
    )
    cm = confusion_matrix(y_true, y_pred, labels=range(n_classes))
    row_sums = cm.sum(axis=1, keepdims=True)
    cm_norm = np.divide(cm, np.maximum(row_sums, 1), where=(row_sums!=0))

    # accuracy/kappa one-vs-rest
    N = y_true.size
    acc_per_class = np.zeros(n_classes, dtype=np.float64)
    kappa_per_class = np.zeros(n_classes, dtype=np.float64)
    for k in range(n_classes):
        TP = cm[k, k]
        FN = cm[k, :].sum() - TP
        FP = cm[:, k].sum() - TP
        TN = cm.sum() - (TP + FN + FP)
        acc_per_class[k] = (TP + TN) / max(1, cm.sum())

        obs = acc_per_class[k]
        p_yes_true = (TP + FN) / N
        p_yes_pred = (TP + FP) / N
        p_no_true  = (FP + TN) / N
        p_no_pred  = (FN + TN) / N
        exp = p_yes_true * p_yes_pred + p_no_true * p_no_pred
        kappa_per_class[k] = (obs - exp) / (1 - exp + 1e-12)

    df_per_class = pd.DataFrame({
        "etapa": labels,
        "precision": np.round(prec, 3),
        "recall":    np.round(rec, 3),
        "f1_score":  np.round(f1, 3),
        "accuracy":  np.round(acc_per_class, 3),
        "kappa":     np.round(kappa_per_class, 3),
        "soporte":   support.astype(int)
    })

    overall_acc = accuracy_score(y_true, y_pred)
    kappa_global = cohen_kappa_score(y_true, y_pred)

    return {
        "df": df_per_class,
        "cm": cm,
        "cm_norm": cm_norm,
        "acc": overall_acc,
        "kappa": kappa_global,
        "y_true": y_true,
        "y_pred": y_pred
    }

# ========= 4) LOOP de runs =========
assert 'train_sleep_model' in globals(), "Falta la funci√≥n train_sleep_model en el entorno."
device = torch.device("cuda" if (CONFIG["use_gpu"] and torch.cuda.is_available()) else "cpu")
pin_mem = (device.type == "cuda")

print("="*90)
print(f"üöÄ MULTI-RUN sobre dataset: {DATASET_NAME}")
print("="*90)
print(f"Seeds: {[BASE_SEED+i for i in range(N_RUNS)]}")
print(f"Guardar checkpoints: {SAVE_CHECKPOINTS} | Guardar histories: {SAVE_HISTORIES} | Guardar per-run: {SAVE_PER_RUN_FILES}")
print()

# Loaders (fijos por dataset)
train_loader, val_loader, test_loader = _build_loaders(
    X, y, splits, batch_size=CONFIG["batch_size"], num_workers=CONFIG["num_workers"], pin=pin_mem
)

all_runs_data = []
for run_id in range(1, N_RUNS+1):
    seed = BASE_SEED + (run_id-1)
    set_seed(seed)

    model = _new_model()
    run_dir = RUNS_DIR / f"run_{run_id:02d}"
    if (SAVE_CHECKPOINTS or SAVE_HISTORIES or SAVE_PER_RUN_FILES or SAVE_AGGREGATES):
        run_dir.mkdir(parents=True, exist_ok=True)

    save_path = (str(run_dir / "best_model.pt")) if SAVE_CHECKPOINTS else None

    # Entrenar
    model, hist, results = train_sleep_model(
        model=model,
        X=X, y=y, splits=splits,
        save_path=(save_path if save_path else "best_model_tmp.pt"),
        **CONFIG
    )

    # Curvas del √∫ltimo run (on-screen)
    if run_id == N_RUNS:
        epochs_arr = range(1, len(hist["train_loss"])+1)
        fig, ax = plt.subplots(1,2, figsize=(12,5))
        ax[0].plot(epochs_arr, hist["train_loss"], 'r-', label='training')
        ax[0].plot(epochs_arr, hist["val_loss"], 'b-', label='validation')
        ax[0].set_title('Loss evolution'); ax[0].set_xlabel('Epoch'); ax[0].set_ylabel('Loss'); ax[0].grid(True, alpha=.3); ax[0].legend()

        ax[1].plot(epochs_arr, hist["train_acc"], 'r-', label='training')
        ax[1].plot(epochs_arr, hist["val_acc"], 'b-', label='validation')
        ax[1].set_title('Accuracy evolution'); ax[1].set_xlabel('Epoch'); ax[1].set_ylabel('Accuracy'); ax[1].grid(True, alpha=.3); ax[1].legend()
        plt.suptitle(f"Learning Curves ‚Äî {DATASET_NAME} (Run {run_id})")
        plt.tight_layout()
        plt.show()

    # Guardar opcional: history / config
    if SAVE_HISTORIES:
        np.savez(run_dir / "history.npz",
                 train_loss=hist["train_loss"], val_loss=hist["val_loss"],
                 train_acc=hist["train_acc"], val_acc=hist["val_acc"],
                 lr=hist["lr"])
        with open(run_dir / "config.json","w") as f:
            cfg = copy.deepcopy(CONFIG); cfg.update(seed=seed, run_id=run_id, dataset=DATASET_NAME, ts=datetime.now().isoformat())
            json.dump(cfg, f, indent=2)

    # Evaluaci√≥n detallada por run 
    eval_res = evaluate_detailed(model, test_loader, device)

    # Mostrar tabla por etapa en pantalla (sin guardar por defecto)
    print(f"\nüìä RUN {run_id} ‚Äî M√©tricas por etapa")
    display(eval_res["df"].style.set_caption(f"Run {run_id} ‚Äî {DATASET_NAME}"))

    print(f"   ‚û§ Acc={eval_res['acc']:.4f} | Kappa={eval_res['kappa']:.4f}")

    # Plots de CM (on-screen)
    labels = ["W","N1","N2","N3","REM"]
    fig, axes = plt.subplots(1,2, figsize=(12,5))
    sns.heatmap(eval_res["cm"], annot=True, fmt='d', cmap='Blues',
                xticklabels=labels, yticklabels=labels, ax=axes[0])
    axes[0].set_title(f"CM Cruda ‚Äî Run {run_id}")
    axes[0].set_xlabel("Predicho"); axes[0].set_ylabel("Real")

    sns.heatmap(eval_res["cm_norm"], annot=True, fmt='.2f', cmap='Blues',
                xticklabels=labels, yticklabels=labels, vmin=0, vmax=1, ax=axes[1])
    axes[1].set_title(f"CM Normalizada ‚Äî Run {run_id}")
    axes[1].set_xlabel("Predicho"); axes[1].set_ylabel("Real")
    plt.tight_layout(); plt.show()

    # Guardado por run (opcional)
    if SAVE_PER_RUN_FILES:
        eval_res["df"].to_csv(run_dir / "eval_test_per_class.csv", index=False)
        pd.DataFrame({"y_true": eval_res["y_true"].astype(int),
                      "y_pred": eval_res["y_pred"].astype(int)}).to_csv(run_dir / "eval_test_pred_vs_true.csv", index=False)
        np.save(run_dir / "eval_test_cm.npy", eval_res["cm"])
        np.save(run_dir / "eval_test_cm_norm.npy", eval_res["cm_norm"])
        with open(run_dir / "eval_test_summary.txt","w") as f:
            f.write(f"accuracy_global={eval_res['acc']:.6f}\n")
            f.write(f"kappa_global={eval_res['kappa']:.6f}\n")

    all_runs_data.append({
        "run_id": run_id,
        "history": hist,
        "results": results,
        "eval": eval_res,
        "run_dir": (run_dir if (SAVE_CHECKPOINTS or SAVE_HISTORIES or SAVE_PER_RUN_FILES or SAVE_AGGREGATES) else None)
    })

# ========= 5) Resumen simple =========
test_accs = [rd["results"]["test_acc"] for rd in all_runs_data]
test_losses = [rd["results"]["test_loss"] for rd in all_runs_data]
print("\n" + "="*90)
print(f"‚úÖ {N_RUNS} corridas completadas ‚Äî {DATASET_NAME}")
print("="*90)
print(f"Test Acc:  mean={np.mean(test_accs):.4f}  std={np.std(test_accs):.4f}  "
      f"min={np.min(test_accs):.4f}  max={np.max(test_accs):.4f}")
print(f"Test Loss: mean={np.mean(test_losses):.4f}  std={np.std(test_losses):.4f}  "
      f"min={np.min(test_losses):.4f}  max={np.max(test_losses):.4f}")

# ========= 6) Agregaci√≥n (media ¬± std) y plots agregados =========
labels = ["W","N1","N2","N3","REM"]
metrics_cols = ["precision", "recall", "f1_score", "accuracy", "kappa"]

# stack m√©tricas por clase
per_class_list = [rd["eval"]["df"][metrics_cols].to_numpy() for rd in all_runs_data]  # list of (5x5)
per_class_arr  = np.stack(per_class_list, axis=0)  # (n_runs, 5, 5)

means = per_class_arr.mean(axis=0)  # (5,5)
stds  = per_class_arr.std(axis=0)   # (5,5)

# Mostrar tabla agregada (en pantalla)
df_agg = pd.DataFrame({"etapa": labels})
for j, col in enumerate(metrics_cols):
    df_agg[col] = [f"{means[i,j]:.3f} ¬± {stds[i,j]:.3f}" for i in range(len(labels))]

print("\nüìä M√âTRICAS AGREGADAS POR ETAPA (media ¬± std):")
print(df_agg.to_string(index=False))

# F1 barplot agregado (on-screen)
f1_means = means[:, metrics_cols.index("f1_score")]
f1_stds  = stds[:,  metrics_cols.index("f1_score")]
plt.figure(figsize=(9,5))
x = np.arange(len(labels))
plt.bar(x, f1_means, yerr=f1_stds, capsize=4, alpha=.85)
plt.xticks(x, labels)
plt.ylim(0, 1.05)
plt.xlabel("Etapa"); plt.ylabel("F1-Score"); plt.title(f"F1 por etapa (media¬±std) ‚Äî {DATASET_NAME}")
plt.grid(axis='y', alpha=.3)
plt.tight_layout(); plt.show()

# CM agregada (normalizada)
cm_norm_mean = np.mean([rd["eval"]["cm_norm"] for rd in all_runs_data], axis=0)
cm_norm_std  = np.std ([rd["eval"]["cm_norm"] for rd in all_runs_data], axis=0)

fig, ax = plt.subplots(1,2, figsize=(12,5))
sns.heatmap(cm_norm_mean, annot=True, fmt=".3f", cmap='Blues',
            xticklabels=labels, yticklabels=labels, vmin=0, vmax=1, ax=ax[0])
ax[0].set_title(f"CM Normalizada (media) ‚Äî {DATASET_NAME}")
sns.heatmap(cm_norm_std, annot=True, fmt=".3f", cmap='Reds',
            xticklabels=labels, yticklabels=labels, vmin=0, vmax=0.2, ax=ax[1])
ax[1].set_title(f"CM Normalizada (std) ‚Äî {DATASET_NAME}")
for a in ax: a.set_xlabel("Predicho"); a.set_ylabel("Real")
plt.tight_layout(); plt.show()

# Guardados agregados (opcional)
if SAVE_AGGREGATES:
    RUNS_DIR.mkdir(parents=True, exist_ok=True)
    df_agg.to_csv(RUNS_DIR / "metrics_aggregated_per_class.csv", index=False)
    np.save(RUNS_DIR / "cm_norm_mean.npy", cm_norm_mean)
    np.save(RUNS_DIR / "cm_norm_std.npy", cm_norm_std)

print("\n Listo. ")


## Segundo dataset: EEG1,EEG2 y EOG

In [None]:
# ================================================
# CELDA √öNICA: MULTI-RUN + M√âTRICAS 
# ================================================
import os, json, copy, math, random, pickle, warnings
from datetime import datetime
from pathlib import Path

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import (
    accuracy_score, precision_recall_fscore_support,
    confusion_matrix, cohen_kappa_score
)
warnings.filterwarnings("ignore")

# ========= 0) SELECCI√ìN DEL DATASET ACTUAL =========
X, y, splits = x2, y2, splits2
DATASET_NAME = "EEG1+EEG2+EOG"

# ========= 1) FLAGS (por defecto NO guarda) =========
SAVE_CHECKPOINTS   = False   # Guarda best_model.pt por run
SAVE_HISTORIES     = False   # Guarda history.npz por run
SAVE_PER_RUN_FILES = False   # Guarda CSV / PNG / NPY por run (m√©tricas y CM)
SAVE_AGGREGATES    = False   # Guarda tablas y plots agregados

# ========= 2) CONFIG GLOBAL =========
N_RUNS = 3
BASE_SEED = 42
CONFIG = {
    "lr": 5e-6,
    "batch_size": 256,
    "epochs": 50,
    "criterion_name": "ce",
    "class_weights": None,
    "weight_clip_range": (0.1, 2.5),
    "grad_clip": 1.0,
    "use_gpu": True,
    "amp": False,
    "num_workers": 0,
    "early_stopping_tolerance": 5,
    "early_stopping_metric": "val_acc"
}

# ======= Paths =======
OUTPUT_DIR = Path(OUTPUT_DIR) if 'OUTPUT_DIR' in globals() else (Path.cwd() / "outputs")
RUNS_DIR   = OUTPUT_DIR / "multiple_runs" / DATASET_NAME.replace(" ", "_")
RUNS_DIR.mkdir(parents=True, exist_ok=True)

# ========= 3) Utils =========
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

class SpectroDataset(Dataset):
    def __init__(self, X, y, indices):
        self.X, self.y = X, y
        self.idx = np.asarray(indices)
    def __len__(self): return len(self.idx)
    def __getitem__(self, i):
        j = self.idx[i]
        x = np.asarray(self.X[j], dtype=np.float32)  # (H,W,C)
        x = np.transpose(x, (2,0,1))                 # -> (C,H,W)
        yj = int(self.y[j])
        return torch.from_numpy(x), torch.tensor(yj, dtype=torch.long)

def _build_loaders(X, y, splits, batch_size=256, num_workers=0, pin=True):
    train_ds = SpectroDataset(X, y, splits['train'])
    val_ds   = SpectroDataset(X, y, splits['val'])
    test_ds  = SpectroDataset(X, y, splits['test'])

    # Weighted sampler (balanceo por clase en TRAIN)
    y_train_subset = y[splits['train']]
    class_counts = np.bincount(y_train_subset, minlength=int(np.max(y))+1)
    class_weights = 1.0 / np.maximum(class_counts, 1)
    sample_weights = class_weights[y_train_subset]
    sampler = torch.utils.data.WeightedRandomSampler(
        weights=torch.DoubleTensor(sample_weights),
        num_samples=len(sample_weights),
        replacement=True
    )

    train_loader = DataLoader(train_ds, batch_size=batch_size, sampler=sampler,
                              num_workers=num_workers, pin_memory=pin, drop_last=False)
    val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False,
                              num_workers=num_workers, pin_memory=pin)
    test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False,
                              num_workers=num_workers, pin_memory=pin)
    return train_loader, val_loader, test_loader

# ======= Modelo 
def _new_model():
    try:
        return SleepStageModel(num_classes=5, in_ch=X.shape[-1])
    except TypeError:
        return SleepStageModel(num_classes=5)

# ======= Entrenamiento 
# ======= M√©tricas detalladas por run =======
@torch.no_grad()
def evaluate_detailed(model, test_loader, device):
    labels = ["W","N1","N2","N3","REM"]
    n_classes = len(labels)

    model.eval()
    all_p, all_t = [], []
    for xb, yb in test_loader:
        xb = xb.to(device, non_blocking=True)
        logits = model(xb)
        p = torch.argmax(logits, dim=1).cpu().numpy()
        all_p.append(p); all_t.append(yb.numpy())
    y_pred = np.concatenate(all_p)
    y_true = np.concatenate(all_t)

    prec, rec, f1, support = precision_recall_fscore_support(
        y_true, y_pred, labels=range(n_classes), average=None, zero_division=0
    )
    cm = confusion_matrix(y_true, y_pred, labels=range(n_classes))
    row_sums = cm.sum(axis=1, keepdims=True)
    cm_norm = np.divide(cm, np.maximum(row_sums, 1), where=(row_sums!=0))

    # accuracy/kappa one-vs-rest
    N = y_true.size
    acc_per_class = np.zeros(n_classes, dtype=np.float64)
    kappa_per_class = np.zeros(n_classes, dtype=np.float64)
    for k in range(n_classes):
        TP = cm[k, k]
        FN = cm[k, :].sum() - TP
        FP = cm[:, k].sum() - TP
        TN = cm.sum() - (TP + FN + FP)
        acc_per_class[k] = (TP + TN) / max(1, cm.sum())

        obs = acc_per_class[k]
        p_yes_true = (TP + FN) / N
        p_yes_pred = (TP + FP) / N
        p_no_true  = (FP + TN) / N
        p_no_pred  = (FN + TN) / N
        exp = p_yes_true * p_yes_pred + p_no_true * p_no_pred
        kappa_per_class[k] = (obs - exp) / (1 - exp + 1e-12)

    df_per_class = pd.DataFrame({
        "etapa": labels,
        "precision": np.round(prec, 3),
        "recall":    np.round(rec, 3),
        "f1_score":  np.round(f1, 3),
        "accuracy":  np.round(acc_per_class, 3),
        "kappa":     np.round(kappa_per_class, 3),
        "soporte":   support.astype(int)
    })

    overall_acc = accuracy_score(y_true, y_pred)
    kappa_global = cohen_kappa_score(y_true, y_pred)

    return {
        "df": df_per_class,
        "cm": cm,
        "cm_norm": cm_norm,
        "acc": overall_acc,
        "kappa": kappa_global,
        "y_true": y_true,
        "y_pred": y_pred
    }

# ========= 4) LOOP de runs =========
assert 'train_sleep_model' in globals(), "Falta la funci√≥n train_sleep_model en el entorno."
device = torch.device("cuda" if (CONFIG["use_gpu"] and torch.cuda.is_available()) else "cpu")
pin_mem = (device.type == "cuda")

print("="*90)
print(f"üöÄ MULTI-RUN sobre dataset: {DATASET_NAME}")
print("="*90)
print(f"Seeds: {[BASE_SEED+i for i in range(N_RUNS)]}")
print(f"Guardar checkpoints: {SAVE_CHECKPOINTS} | Guardar histories: {SAVE_HISTORIES} | Guardar per-run: {SAVE_PER_RUN_FILES}")
print()

# Loaders (fijos por dataset)
train_loader, val_loader, test_loader = _build_loaders(
    X, y, splits, batch_size=CONFIG["batch_size"], num_workers=CONFIG["num_workers"], pin=pin_mem
)

all_runs_data = []
for run_id in range(1, N_RUNS+1):
    seed = BASE_SEED + (run_id-1)
    set_seed(seed)

    model = _new_model()
    run_dir = RUNS_DIR / f"run_{run_id:02d}"
    if (SAVE_CHECKPOINTS or SAVE_HISTORIES or SAVE_PER_RUN_FILES or SAVE_AGGREGATES):
        run_dir.mkdir(parents=True, exist_ok=True)

    save_path = (str(run_dir / "best_model.pt")) if SAVE_CHECKPOINTS else None

    # Entrenar
    model, hist, results = train_sleep_model(
        model=model,
        X=X, y=y, splits=splits,
        save_path=(save_path if save_path else "best_model_tmp.pt"),
        **CONFIG
    )

    # Curvas del √∫ltimo run (on-screen)
    if run_id == N_RUNS:
        epochs_arr = range(1, len(hist["train_loss"])+1)
        fig, ax = plt.subplots(1,2, figsize=(12,5))
        ax[0].plot(epochs_arr, hist["train_loss"], 'r-', label='training')
        ax[0].plot(epochs_arr, hist["val_loss"], 'b-', label='validation')
        ax[0].set_title('Loss evolution'); ax[0].set_xlabel('Epoch'); ax[0].set_ylabel('Loss'); ax[0].grid(True, alpha=.3); ax[0].legend()

        ax[1].plot(epochs_arr, hist["train_acc"], 'r-', label='training')
        ax[1].plot(epochs_arr, hist["val_acc"], 'b-', label='validation')
        ax[1].set_title('Accuracy evolution'); ax[1].set_xlabel('Epoch'); ax[1].set_ylabel('Accuracy'); ax[1].grid(True, alpha=.3); ax[1].legend()
        plt.suptitle(f"Learning Curves ‚Äî {DATASET_NAME} (Run {run_id})")
        plt.tight_layout()
        plt.show()

    # Guardar opcional: history / config
    if SAVE_HISTORIES:
        np.savez(run_dir / "history.npz",
                 train_loss=hist["train_loss"], val_loss=hist["val_loss"],
                 train_acc=hist["train_acc"], val_acc=hist["val_acc"],
                 lr=hist["lr"])
        with open(run_dir / "config.json","w") as f:
            cfg = copy.deepcopy(CONFIG); cfg.update(seed=seed, run_id=run_id, dataset=DATASET_NAME, ts=datetime.now().isoformat())
            json.dump(cfg, f, indent=2)

    # Evaluaci√≥n detallada por run 
    eval_res = evaluate_detailed(model, test_loader, device)

    # Mostrar tabla por etapa en pantalla (sin guardar por defecto)
    print(f"\nüìä RUN {run_id} ‚Äî M√©tricas por etapa")
    display(eval_res["df"].style.set_caption(f"Run {run_id} ‚Äî {DATASET_NAME}"))

    print(f"   ‚û§ Acc={eval_res['acc']:.4f} | Kappa={eval_res['kappa']:.4f}")

    # Plots de CM (on-screen)
    labels = ["W","N1","N2","N3","REM"]
    fig, axes = plt.subplots(1,2, figsize=(12,5))
    sns.heatmap(eval_res["cm"], annot=True, fmt='d', cmap='Blues',
                xticklabels=labels, yticklabels=labels, ax=axes[0])
    axes[0].set_title(f"CM Cruda ‚Äî Run {run_id}")
    axes[0].set_xlabel("Predicho"); axes[0].set_ylabel("Real")

    sns.heatmap(eval_res["cm_norm"], annot=True, fmt='.2f', cmap='Blues',
                xticklabels=labels, yticklabels=labels, vmin=0, vmax=1, ax=axes[1])
    axes[1].set_title(f"CM Normalizada ‚Äî Run {run_id}")
    axes[1].set_xlabel("Predicho"); axes[1].set_ylabel("Real")
    plt.tight_layout(); plt.show()

    # Guardado por run (opcional)
    if SAVE_PER_RUN_FILES:
        eval_res["df"].to_csv(run_dir / "eval_test_per_class.csv", index=False)
        pd.DataFrame({"y_true": eval_res["y_true"].astype(int),
                      "y_pred": eval_res["y_pred"].astype(int)}).to_csv(run_dir / "eval_test_pred_vs_true.csv", index=False)
        np.save(run_dir / "eval_test_cm.npy", eval_res["cm"])
        np.save(run_dir / "eval_test_cm_norm.npy", eval_res["cm_norm"])
        with open(run_dir / "eval_test_summary.txt","w") as f:
            f.write(f"accuracy_global={eval_res['acc']:.6f}\n")
            f.write(f"kappa_global={eval_res['kappa']:.6f}\n")

    all_runs_data.append({
        "run_id": run_id,
        "history": hist,
        "results": results,
        "eval": eval_res,
        "run_dir": (run_dir if (SAVE_CHECKPOINTS or SAVE_HISTORIES or SAVE_PER_RUN_FILES or SAVE_AGGREGATES) else None)
    })

# ========= 5) Resumen simple =========
test_accs = [rd["results"]["test_acc"] for rd in all_runs_data]
test_losses = [rd["results"]["test_loss"] for rd in all_runs_data]
print("\n" + "="*90)
print(f"‚úÖ {N_RUNS} corridas completadas ‚Äî {DATASET_NAME}")
print("="*90)
print(f"Test Acc:  mean={np.mean(test_accs):.4f}  std={np.std(test_accs):.4f}  "
      f"min={np.min(test_accs):.4f}  max={np.max(test_accs):.4f}")
print(f"Test Loss: mean={np.mean(test_losses):.4f}  std={np.std(test_losses):.4f}  "
      f"min={np.min(test_losses):.4f}  max={np.max(test_losses):.4f}")

# ========= 6) Agregaci√≥n (media ¬± std) y plots agregados =========
labels = ["W","N1","N2","N3","REM"]
metrics_cols = ["precision", "recall", "f1_score", "accuracy", "kappa"]

# stack m√©tricas por clase
per_class_list = [rd["eval"]["df"][metrics_cols].to_numpy() for rd in all_runs_data]  # list of (5x5)
per_class_arr  = np.stack(per_class_list, axis=0)  # (n_runs, 5, 5)

means = per_class_arr.mean(axis=0)  # (5,5)
stds  = per_class_arr.std(axis=0)   # (5,5)

# Mostrar tabla agregada (en pantalla)
df_agg = pd.DataFrame({"etapa": labels})
for j, col in enumerate(metrics_cols):
    df_agg[col] = [f"{means[i,j]:.3f} ¬± {stds[i,j]:.3f}" for i in range(len(labels))]

print("\nüìä M√âTRICAS AGREGADAS POR ETAPA (media ¬± std):")
print(df_agg.to_string(index=False))

# F1 barplot agregado (on-screen)
f1_means = means[:, metrics_cols.index("f1_score")]
f1_stds  = stds[:,  metrics_cols.index("f1_score")]
plt.figure(figsize=(9,5))
x = np.arange(len(labels))
plt.bar(x, f1_means, yerr=f1_stds, capsize=4, alpha=.85)
plt.xticks(x, labels)
plt.ylim(0, 1.05)
plt.xlabel("Etapa"); plt.ylabel("F1-Score"); plt.title(f"F1 por etapa (media¬±std) ‚Äî {DATASET_NAME}")
plt.grid(axis='y', alpha=.3)
plt.tight_layout(); plt.show()

# CM agregada (normalizada)
cm_norm_mean = np.mean([rd["eval"]["cm_norm"] for rd in all_runs_data], axis=0)
cm_norm_std  = np.std ([rd["eval"]["cm_norm"] for rd in all_runs_data], axis=0)

fig, ax = plt.subplots(1,2, figsize=(12,5))
sns.heatmap(cm_norm_mean, annot=True, fmt=".3f", cmap='Blues',
            xticklabels=labels, yticklabels=labels, vmin=0, vmax=1, ax=ax[0])
ax[0].set_title(f"CM Normalizada (media) ‚Äî {DATASET_NAME}")
sns.heatmap(cm_norm_std, annot=True, fmt=".3f", cmap='Reds',
            xticklabels=labels, yticklabels=labels, vmin=0, vmax=0.2, ax=ax[1])
ax[1].set_title(f"CM Normalizada (std) ‚Äî {DATASET_NAME}")
for a in ax: a.set_xlabel("Predicho"); a.set_ylabel("Real")
plt.tight_layout(); plt.show()

# Guardados agregados (opcional)
if SAVE_AGGREGATES:
    RUNS_DIR.mkdir(parents=True, exist_ok=True)
    df_agg.to_csv(RUNS_DIR / "metrics_aggregated_per_class.csv", index=False)
    np.save(RUNS_DIR / "cm_norm_mean.npy", cm_norm_mean)
    np.save(RUNS_DIR / "cm_norm_std.npy", cm_norm_std)

print("\n Listo.")


## Tercer dataset: S√≥lo el EEG 1

In [None]:
# ================================================
# CELDA √öNICA: MULTI-RUN + M√âTRICAS 
# ================================================
import os, json, copy, math, random, pickle, warnings
from datetime import datetime
from pathlib import Path

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import (
    accuracy_score, precision_recall_fscore_support,
    confusion_matrix, cohen_kappa_score
)
warnings.filterwarnings("ignore")

# ========= 0) SELECCI√ìN DEL DATASET ACTUAL =========
X, y, splits = x3, y3, splits3
DATASET_NAME = "EEG1"
# ========= 1) FLAGS (por defecto NO guarda) =========
SAVE_CHECKPOINTS   = False   # Guarda best_model.pt por run
SAVE_HISTORIES     = False   # Guarda history.npz por run
SAVE_PER_RUN_FILES = False   # Guarda CSV / PNG / NPY por run (m√©tricas y CM)
SAVE_AGGREGATES    = False   # Guarda tablas y plots agregados

# ========= 2) CONFIG GLOBAL =========
N_RUNS = 3
BASE_SEED = 42
CONFIG = {
    "lr": 5e-6,
    "batch_size": 256,
    "epochs": 50,
    "criterion_name": "ce",
    "class_weights": None,
    "weight_clip_range": (0.1, 2.5),
    "grad_clip": 1.0,
    "use_gpu": True,
    "amp": False,
    "num_workers": 0,
    "early_stopping_tolerance": 5,
    "early_stopping_metric": "val_acc"
}

# ======= Paths (solo se usan si guardas algo) =======
OUTPUT_DIR = Path(OUTPUT_DIR) if 'OUTPUT_DIR' in globals() else (Path.cwd() / "outputs")
RUNS_DIR   = OUTPUT_DIR / "multiple_runs" / DATASET_NAME.replace(" ", "_")
RUNS_DIR.mkdir(parents=True, exist_ok=True)

# ========= 3) Utils =========
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

class SpectroDataset(Dataset):
    def __init__(self, X, y, indices):
        self.X, self.y = X, y
        self.idx = np.asarray(indices)
    def __len__(self): return len(self.idx)
    def __getitem__(self, i):
        j = self.idx[i]
        x = np.asarray(self.X[j], dtype=np.float32)  # (H,W,C)
        x = np.transpose(x, (2,0,1))                 # -> (C,H,W)
        yj = int(self.y[j])
        return torch.from_numpy(x), torch.tensor(yj, dtype=torch.long)

def _build_loaders(X, y, splits, batch_size=256, num_workers=0, pin=True):
    train_ds = SpectroDataset(X, y, splits['train'])
    val_ds   = SpectroDataset(X, y, splits['val'])
    test_ds  = SpectroDataset(X, y, splits['test'])

    # Weighted sampler (balanceo por clase en TRAIN)
    y_train_subset = y[splits['train']]
    class_counts = np.bincount(y_train_subset, minlength=int(np.max(y))+1)
    class_weights = 1.0 / np.maximum(class_counts, 1)
    sample_weights = class_weights[y_train_subset]
    sampler = torch.utils.data.WeightedRandomSampler(
        weights=torch.DoubleTensor(sample_weights),
        num_samples=len(sample_weights),
        replacement=True
    )

    train_loader = DataLoader(train_ds, batch_size=batch_size, sampler=sampler,
                              num_workers=num_workers, pin_memory=pin, drop_last=False)
    val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False,
                              num_workers=num_workers, pin_memory=pin)
    test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False,
                              num_workers=num_workers, pin_memory=pin)
    return train_loader, val_loader, test_loader

# ======= Modelo 
def _new_model():
    try:
        return SleepStageModel(num_classes=5, in_ch=X.shape[-1])
    except TypeError:
        return SleepStageModel(num_classes=5)

# ======= Entrenamiento 
# ======= M√©tricas detalladas por run =======
@torch.no_grad()
def evaluate_detailed(model, test_loader, device):
    labels = ["W","N1","N2","N3","REM"]
    n_classes = len(labels)

    model.eval()
    all_p, all_t = [], []
    for xb, yb in test_loader:
        xb = xb.to(device, non_blocking=True)
        logits = model(xb)
        p = torch.argmax(logits, dim=1).cpu().numpy()
        all_p.append(p); all_t.append(yb.numpy())
    y_pred = np.concatenate(all_p)
    y_true = np.concatenate(all_t)

    prec, rec, f1, support = precision_recall_fscore_support(
        y_true, y_pred, labels=range(n_classes), average=None, zero_division=0
    )
    cm = confusion_matrix(y_true, y_pred, labels=range(n_classes))
    row_sums = cm.sum(axis=1, keepdims=True)
    cm_norm = np.divide(cm, np.maximum(row_sums, 1), where=(row_sums!=0))

    # accuracy/kappa one-vs-rest
    N = y_true.size
    acc_per_class = np.zeros(n_classes, dtype=np.float64)
    kappa_per_class = np.zeros(n_classes, dtype=np.float64)
    for k in range(n_classes):
        TP = cm[k, k]
        FN = cm[k, :].sum() - TP
        FP = cm[:, k].sum() - TP
        TN = cm.sum() - (TP + FN + FP)
        acc_per_class[k] = (TP + TN) / max(1, cm.sum())

        obs = acc_per_class[k]
        p_yes_true = (TP + FN) / N
        p_yes_pred = (TP + FP) / N
        p_no_true  = (FP + TN) / N
        p_no_pred  = (FN + TN) / N
        exp = p_yes_true * p_yes_pred + p_no_true * p_no_pred
        kappa_per_class[k] = (obs - exp) / (1 - exp + 1e-12)

    df_per_class = pd.DataFrame({
        "etapa": labels,
        "precision": np.round(prec, 3),
        "recall":    np.round(rec, 3),
        "f1_score":  np.round(f1, 3),
        "accuracy":  np.round(acc_per_class, 3),
        "kappa":     np.round(kappa_per_class, 3),
        "soporte":   support.astype(int)
    })

    overall_acc = accuracy_score(y_true, y_pred)
    kappa_global = cohen_kappa_score(y_true, y_pred)

    return {
        "df": df_per_class,
        "cm": cm,
        "cm_norm": cm_norm,
        "acc": overall_acc,
        "kappa": kappa_global,
        "y_true": y_true,
        "y_pred": y_pred
    }

# ========= 4) LOOP de runs =========
assert 'train_sleep_model' in globals(), "Falta la funci√≥n train_sleep_model en el entorno."
device = torch.device("cuda" if (CONFIG["use_gpu"] and torch.cuda.is_available()) else "cpu")
pin_mem = (device.type == "cuda")

print("="*90)
print(f"üöÄ MULTI-RUN sobre dataset: {DATASET_NAME}")
print("="*90)
print(f"Seeds: {[BASE_SEED+i for i in range(N_RUNS)]}")
print(f"Guardar checkpoints: {SAVE_CHECKPOINTS} | Guardar histories: {SAVE_HISTORIES} | Guardar per-run: {SAVE_PER_RUN_FILES}")
print()

# Loaders (fijos por dataset)
train_loader, val_loader, test_loader = _build_loaders(
    X, y, splits, batch_size=CONFIG["batch_size"], num_workers=CONFIG["num_workers"], pin=pin_mem
)

all_runs_data = []
for run_id in range(1, N_RUNS+1):
    seed = BASE_SEED + (run_id-1)
    set_seed(seed)

    model = _new_model()
    run_dir = RUNS_DIR / f"run_{run_id:02d}"
    if (SAVE_CHECKPOINTS or SAVE_HISTORIES or SAVE_PER_RUN_FILES or SAVE_AGGREGATES):
        run_dir.mkdir(parents=True, exist_ok=True)

    save_path = (str(run_dir / "best_model.pt")) if SAVE_CHECKPOINTS else None

    # Entrenar
    model, hist, results = train_sleep_model(
        model=model,
        X=X, y=y, splits=splits,
        save_path=(save_path if save_path else "best_model_tmp.pt"),
        **CONFIG
    )

    # Curvas del √∫ltimo run (on-screen)
    if run_id == N_RUNS:
        epochs_arr = range(1, len(hist["train_loss"])+1)
        fig, ax = plt.subplots(1,2, figsize=(12,5))
        ax[0].plot(epochs_arr, hist["train_loss"], 'r-', label='training')
        ax[0].plot(epochs_arr, hist["val_loss"], 'b-', label='validation')
        ax[0].set_title('Loss evolution'); ax[0].set_xlabel('Epoch'); ax[0].set_ylabel('Loss'); ax[0].grid(True, alpha=.3); ax[0].legend()

        ax[1].plot(epochs_arr, hist["train_acc"], 'r-', label='training')
        ax[1].plot(epochs_arr, hist["val_acc"], 'b-', label='validation')
        ax[1].set_title('Accuracy evolution'); ax[1].set_xlabel('Epoch'); ax[1].set_ylabel('Accuracy'); ax[1].grid(True, alpha=.3); ax[1].legend()
        plt.suptitle(f"Learning Curves ‚Äî {DATASET_NAME} (Run {run_id})")
        plt.tight_layout()
        plt.show()

    # Guardar opcional: history / config
    if SAVE_HISTORIES:
        np.savez(run_dir / "history.npz",
                 train_loss=hist["train_loss"], val_loss=hist["val_loss"],
                 train_acc=hist["train_acc"], val_acc=hist["val_acc"],
                 lr=hist["lr"])
        with open(run_dir / "config.json","w") as f:
            cfg = copy.deepcopy(CONFIG); cfg.update(seed=seed, run_id=run_id, dataset=DATASET_NAME, ts=datetime.now().isoformat())
            json.dump(cfg, f, indent=2)

    # Evaluaci√≥n detallada por run 
    eval_res = evaluate_detailed(model, test_loader, device)

    # Mostrar tabla por etapa en pantalla (sin guardar por defecto)
    print(f"\nüìä RUN {run_id} ‚Äî M√©tricas por etapa")
    display(eval_res["df"].style.set_caption(f"Run {run_id} ‚Äî {DATASET_NAME}"))

    print(f"   ‚û§ Acc={eval_res['acc']:.4f} | Kappa={eval_res['kappa']:.4f}")

    # Plots de CM (on-screen)
    labels = ["W","N1","N2","N3","REM"]
    fig, axes = plt.subplots(1,2, figsize=(12,5))
    sns.heatmap(eval_res["cm"], annot=True, fmt='d', cmap='Blues',
                xticklabels=labels, yticklabels=labels, ax=axes[0])
    axes[0].set_title(f"CM Cruda ‚Äî Run {run_id}")
    axes[0].set_xlabel("Predicho"); axes[0].set_ylabel("Real")

    sns.heatmap(eval_res["cm_norm"], annot=True, fmt='.2f', cmap='Blues',
                xticklabels=labels, yticklabels=labels, vmin=0, vmax=1, ax=axes[1])
    axes[1].set_title(f"CM Normalizada ‚Äî Run {run_id}")
    axes[1].set_xlabel("Predicho"); axes[1].set_ylabel("Real")
    plt.tight_layout(); plt.show()

    # Guardado por run (opcional)
    if SAVE_PER_RUN_FILES:
        eval_res["df"].to_csv(run_dir / "eval_test_per_class.csv", index=False)
        pd.DataFrame({"y_true": eval_res["y_true"].astype(int),
                      "y_pred": eval_res["y_pred"].astype(int)}).to_csv(run_dir / "eval_test_pred_vs_true.csv", index=False)
        np.save(run_dir / "eval_test_cm.npy", eval_res["cm"])
        np.save(run_dir / "eval_test_cm_norm.npy", eval_res["cm_norm"])
        with open(run_dir / "eval_test_summary.txt","w") as f:
            f.write(f"accuracy_global={eval_res['acc']:.6f}\n")
            f.write(f"kappa_global={eval_res['kappa']:.6f}\n")

    all_runs_data.append({
        "run_id": run_id,
        "history": hist,
        "results": results,
        "eval": eval_res,
        "run_dir": (run_dir if (SAVE_CHECKPOINTS or SAVE_HISTORIES or SAVE_PER_RUN_FILES or SAVE_AGGREGATES) else None)
    })

# ========= 5) Resumen simple =========
test_accs = [rd["results"]["test_acc"] for rd in all_runs_data]
test_losses = [rd["results"]["test_loss"] for rd in all_runs_data]
print("\n" + "="*90)
print(f"‚úÖ {N_RUNS} corridas completadas ‚Äî {DATASET_NAME}")
print("="*90)
print(f"Test Acc:  mean={np.mean(test_accs):.4f}  std={np.std(test_accs):.4f}  "
      f"min={np.min(test_accs):.4f}  max={np.max(test_accs):.4f}")
print(f"Test Loss: mean={np.mean(test_losses):.4f}  std={np.std(test_losses):.4f}  "
      f"min={np.min(test_losses):.4f}  max={np.max(test_losses):.4f}")

# ========= 6) Agregaci√≥n (media ¬± std) y plots agregados =========
labels = ["W","N1","N2","N3","REM"]
metrics_cols = ["precision", "recall", "f1_score", "accuracy", "kappa"]

# stack m√©tricas por clase
per_class_list = [rd["eval"]["df"][metrics_cols].to_numpy() for rd in all_runs_data]  # list of (5x5)
per_class_arr  = np.stack(per_class_list, axis=0)  # (n_runs, 5, 5)

means = per_class_arr.mean(axis=0)  # (5,5)
stds  = per_class_arr.std(axis=0)   # (5,5)

# Mostrar tabla agregada (en pantalla)
df_agg = pd.DataFrame({"etapa": labels})
for j, col in enumerate(metrics_cols):
    df_agg[col] = [f"{means[i,j]:.3f} ¬± {stds[i,j]:.3f}" for i in range(len(labels))]

print("\nüìä M√âTRICAS AGREGADAS POR ETAPA (media ¬± std):")
print(df_agg.to_string(index=False))

# F1 barplot agregado (on-screen)
f1_means = means[:, metrics_cols.index("f1_score")]
f1_stds  = stds[:,  metrics_cols.index("f1_score")]
plt.figure(figsize=(9,5))
x = np.arange(len(labels))
plt.bar(x, f1_means, yerr=f1_stds, capsize=4, alpha=.85)
plt.xticks(x, labels)
plt.ylim(0, 1.05)
plt.xlabel("Etapa"); plt.ylabel("F1-Score"); plt.title(f"F1 por etapa (media¬±std) ‚Äî {DATASET_NAME}")
plt.grid(axis='y', alpha=.3)
plt.tight_layout(); plt.show()

# CM agregada (normalizada)
cm_norm_mean = np.mean([rd["eval"]["cm_norm"] for rd in all_runs_data], axis=0)
cm_norm_std  = np.std ([rd["eval"]["cm_norm"] for rd in all_runs_data], axis=0)

fig, ax = plt.subplots(1,2, figsize=(12,5))
sns.heatmap(cm_norm_mean, annot=True, fmt=".3f", cmap='Blues',
            xticklabels=labels, yticklabels=labels, vmin=0, vmax=1, ax=ax[0])
ax[0].set_title(f"CM Normalizada (media) ‚Äî {DATASET_NAME}")
sns.heatmap(cm_norm_std, annot=True, fmt=".3f", cmap='Reds',
            xticklabels=labels, yticklabels=labels, vmin=0, vmax=0.2, ax=ax[1])
ax[1].set_title(f"CM Normalizada (std) ‚Äî {DATASET_NAME}")
for a in ax: a.set_xlabel("Predicho"); a.set_ylabel("Real")
plt.tight_layout(); plt.show()

# Guardados agregados (opcional)
if SAVE_AGGREGATES:
    RUNS_DIR.mkdir(parents=True, exist_ok=True)
    df_agg.to_csv(RUNS_DIR / "metrics_aggregated_per_class.csv", index=False)
    np.save(RUNS_DIR / "cm_norm_mean.npy", cm_norm_mean)
    np.save(RUNS_DIR / "cm_norm_std.npy", cm_norm_std)

print("\n Listo. ")


## Cuarto dataset: EEG1, EEG2 y EMG

In [None]:
# ================================================
# CELDA √öNICA: MULTI-RUN + M√âTRICAS 
# ================================================
import os, json, copy, math, random, pickle, warnings
from datetime import datetime
from pathlib import Path

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import (
    accuracy_score, precision_recall_fscore_support,
    confusion_matrix, cohen_kappa_score
)
warnings.filterwarnings("ignore")

# ========= 0) SELECCI√ìN DEL DATASET ACTUAL =========
X, y, splits = x4, y4, splits4
DATASET_NAME = "EEG1+EEG2+EMG"
# ========= 1) FLAGS (por defecto NO guarda) =========
SAVE_CHECKPOINTS   = False   # Guarda best_model.pt por run
SAVE_HISTORIES     = False   # Guarda history.npz por run
SAVE_PER_RUN_FILES = False   # Guarda CSV / PNG / NPY por run (m√©tricas y CM)
SAVE_AGGREGATES    = False   # Guarda tablas y plots agregados

# ========= 2) CONFIG GLOBAL =========
N_RUNS = 3
BASE_SEED = 42
CONFIG = {
    "lr": 5e-6,
    "batch_size": 256,
    "epochs": 50,
    "criterion_name": "ce",
    "class_weights": None,
    "weight_clip_range": (0.1, 2.5),
    "grad_clip": 1.0,
    "use_gpu": True,
    "amp": False,
    "num_workers": 0,
    "early_stopping_tolerance": 5,
    "early_stopping_metric": "val_acc"
}

# ======= Paths (solo se usan si guardas algo) =======
OUTPUT_DIR = Path(OUTPUT_DIR) if 'OUTPUT_DIR' in globals() else (Path.cwd() / "outputs")
RUNS_DIR   = OUTPUT_DIR / "multiple_runs" / DATASET_NAME.replace(" ", "_")
RUNS_DIR.mkdir(parents=True, exist_ok=True)

# ========= 3) Utils =========
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

class SpectroDataset(Dataset):
    def __init__(self, X, y, indices):
        self.X, self.y = X, y
        self.idx = np.asarray(indices)
    def __len__(self): return len(self.idx)
    def __getitem__(self, i):
        j = self.idx[i]
        x = np.asarray(self.X[j], dtype=np.float32)  # (H,W,C)
        x = np.transpose(x, (2,0,1))                 # -> (C,H,W)
        yj = int(self.y[j])
        return torch.from_numpy(x), torch.tensor(yj, dtype=torch.long)

def _build_loaders(X, y, splits, batch_size=256, num_workers=0, pin=True):
    train_ds = SpectroDataset(X, y, splits['train'])
    val_ds   = SpectroDataset(X, y, splits['val'])
    test_ds  = SpectroDataset(X, y, splits['test'])

    # Weighted sampler (balanceo por clase en TRAIN)
    y_train_subset = y[splits['train']]
    class_counts = np.bincount(y_train_subset, minlength=int(np.max(y))+1)
    class_weights = 1.0 / np.maximum(class_counts, 1)
    sample_weights = class_weights[y_train_subset]
    sampler = torch.utils.data.WeightedRandomSampler(
        weights=torch.DoubleTensor(sample_weights),
        num_samples=len(sample_weights),
        replacement=True
    )

    train_loader = DataLoader(train_ds, batch_size=batch_size, sampler=sampler,
                              num_workers=num_workers, pin_memory=pin, drop_last=False)
    val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False,
                              num_workers=num_workers, pin_memory=pin)
    test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False,
                              num_workers=num_workers, pin_memory=pin)
    return train_loader, val_loader, test_loader

# ======= Modelo =======
def _new_model():
    try:
        return SleepStageModel(num_classes=5, in_ch=X.shape[-1])
    except TypeError:
        return SleepStageModel(num_classes=5)

# ======= Entrenamiento 

# ======= M√©tricas detalladas por run =======
@torch.no_grad()
def evaluate_detailed(model, test_loader, device):
    labels = ["W","N1","N2","N3","REM"]
    n_classes = len(labels)

    model.eval()
    all_p, all_t = [], []
    for xb, yb in test_loader:
        xb = xb.to(device, non_blocking=True)
        logits = model(xb)
        p = torch.argmax(logits, dim=1).cpu().numpy()
        all_p.append(p); all_t.append(yb.numpy())
    y_pred = np.concatenate(all_p)
    y_true = np.concatenate(all_t)

    prec, rec, f1, support = precision_recall_fscore_support(
        y_true, y_pred, labels=range(n_classes), average=None, zero_division=0
    )
    cm = confusion_matrix(y_true, y_pred, labels=range(n_classes))
    row_sums = cm.sum(axis=1, keepdims=True)
    cm_norm = np.divide(cm, np.maximum(row_sums, 1), where=(row_sums!=0))

    # accuracy/kappa one-vs-rest
    N = y_true.size
    acc_per_class = np.zeros(n_classes, dtype=np.float64)
    kappa_per_class = np.zeros(n_classes, dtype=np.float64)
    for k in range(n_classes):
        TP = cm[k, k]
        FN = cm[k, :].sum() - TP
        FP = cm[:, k].sum() - TP
        TN = cm.sum() - (TP + FN + FP)
        acc_per_class[k] = (TP + TN) / max(1, cm.sum())

        obs = acc_per_class[k]
        p_yes_true = (TP + FN) / N
        p_yes_pred = (TP + FP) / N
        p_no_true  = (FP + TN) / N
        p_no_pred  = (FN + TN) / N
        exp = p_yes_true * p_yes_pred + p_no_true * p_no_pred
        kappa_per_class[k] = (obs - exp) / (1 - exp + 1e-12)

    df_per_class = pd.DataFrame({
        "etapa": labels,
        "precision": np.round(prec, 3),
        "recall":    np.round(rec, 3),
        "f1_score":  np.round(f1, 3),
        "accuracy":  np.round(acc_per_class, 3),
        "kappa":     np.round(kappa_per_class, 3),
        "soporte":   support.astype(int)
    })

    overall_acc = accuracy_score(y_true, y_pred)
    kappa_global = cohen_kappa_score(y_true, y_pred)

    return {
        "df": df_per_class,
        "cm": cm,
        "cm_norm": cm_norm,
        "acc": overall_acc,
        "kappa": kappa_global,
        "y_true": y_true,
        "y_pred": y_pred
    }

# ========= 4) LOOP de runs =========
assert 'train_sleep_model' in globals(), "Falta la funci√≥n train_sleep_model en el entorno."
device = torch.device("cuda" if (CONFIG["use_gpu"] and torch.cuda.is_available()) else "cpu")
pin_mem = (device.type == "cuda")

print("="*90)
print(f"üöÄ MULTI-RUN sobre dataset: {DATASET_NAME}")
print("="*90)
print(f"Seeds: {[BASE_SEED+i for i in range(N_RUNS)]}")
print(f"Guardar checkpoints: {SAVE_CHECKPOINTS} | Guardar histories: {SAVE_HISTORIES} | Guardar per-run: {SAVE_PER_RUN_FILES}")
print()

# Loaders (fijos por dataset)
train_loader, val_loader, test_loader = _build_loaders(
    X, y, splits, batch_size=CONFIG["batch_size"], num_workers=CONFIG["num_workers"], pin=pin_mem
)

all_runs_data = []
for run_id in range(1, N_RUNS+1):
    seed = BASE_SEED + (run_id-1)
    set_seed(seed)

    model = _new_model()
    run_dir = RUNS_DIR / f"run_{run_id:02d}"
    if (SAVE_CHECKPOINTS or SAVE_HISTORIES or SAVE_PER_RUN_FILES or SAVE_AGGREGATES):
        run_dir.mkdir(parents=True, exist_ok=True)

    save_path = (str(run_dir / "best_model.pt")) if SAVE_CHECKPOINTS else None

    # Entrenar
    model, hist, results = train_sleep_model(
        model=model,
        X=X, y=y, splits=splits,
        save_path=(save_path if save_path else "best_model_tmp.pt"),
        **CONFIG
    )

    # Curvas del √∫ltimo run (on-screen)
    if run_id == N_RUNS:
        epochs_arr = range(1, len(hist["train_loss"])+1)
        fig, ax = plt.subplots(1,2, figsize=(12,5))
        ax[0].plot(epochs_arr, hist["train_loss"], 'r-', label='training')
        ax[0].plot(epochs_arr, hist["val_loss"], 'b-', label='validation')
        ax[0].set_title('Loss evolution'); ax[0].set_xlabel('Epoch'); ax[0].set_ylabel('Loss'); ax[0].grid(True, alpha=.3); ax[0].legend()

        ax[1].plot(epochs_arr, hist["train_acc"], 'r-', label='training')
        ax[1].plot(epochs_arr, hist["val_acc"], 'b-', label='validation')
        ax[1].set_title('Accuracy evolution'); ax[1].set_xlabel('Epoch'); ax[1].set_ylabel('Accuracy'); ax[1].grid(True, alpha=.3); ax[1].legend()
        plt.suptitle(f"Learning Curves ‚Äî {DATASET_NAME} (Run {run_id})")
        plt.tight_layout()
        plt.show()

    # Guardar opcional: history / config
    if SAVE_HISTORIES:
        np.savez(run_dir / "history.npz",
                 train_loss=hist["train_loss"], val_loss=hist["val_loss"],
                 train_acc=hist["train_acc"], val_acc=hist["val_acc"],
                 lr=hist["lr"])
        with open(run_dir / "config.json","w") as f:
            cfg = copy.deepcopy(CONFIG); cfg.update(seed=seed, run_id=run_id, dataset=DATASET_NAME, ts=datetime.now().isoformat())
            json.dump(cfg, f, indent=2)

    # Evaluaci√≥n detallada por run 
    eval_res = evaluate_detailed(model, test_loader, device)

    # Mostrar tabla por etapa en pantalla (sin guardar por defecto)
    print(f"\nüìä RUN {run_id} ‚Äî M√©tricas por etapa")
    display(eval_res["df"].style.set_caption(f"Run {run_id} ‚Äî {DATASET_NAME}"))

    print(f"   ‚û§ Acc={eval_res['acc']:.4f} | Kappa={eval_res['kappa']:.4f}")

    # Plots de CM (on-screen)
    labels = ["W","N1","N2","N3","REM"]
    fig, axes = plt.subplots(1,2, figsize=(12,5))
    sns.heatmap(eval_res["cm"], annot=True, fmt='d', cmap='Blues',
                xticklabels=labels, yticklabels=labels, ax=axes[0])
    axes[0].set_title(f"CM Cruda ‚Äî Run {run_id}")
    axes[0].set_xlabel("Predicho"); axes[0].set_ylabel("Real")

    sns.heatmap(eval_res["cm_norm"], annot=True, fmt='.2f', cmap='Blues',
                xticklabels=labels, yticklabels=labels, vmin=0, vmax=1, ax=axes[1])
    axes[1].set_title(f"CM Normalizada ‚Äî Run {run_id}")
    axes[1].set_xlabel("Predicho"); axes[1].set_ylabel("Real")
    plt.tight_layout(); plt.show()

    # Guardado por run (opcional)
    if SAVE_PER_RUN_FILES:
        eval_res["df"].to_csv(run_dir / "eval_test_per_class.csv", index=False)
        pd.DataFrame({"y_true": eval_res["y_true"].astype(int),
                      "y_pred": eval_res["y_pred"].astype(int)}).to_csv(run_dir / "eval_test_pred_vs_true.csv", index=False)
        np.save(run_dir / "eval_test_cm.npy", eval_res["cm"])
        np.save(run_dir / "eval_test_cm_norm.npy", eval_res["cm_norm"])
        with open(run_dir / "eval_test_summary.txt","w") as f:
            f.write(f"accuracy_global={eval_res['acc']:.6f}\n")
            f.write(f"kappa_global={eval_res['kappa']:.6f}\n")

    all_runs_data.append({
        "run_id": run_id,
        "history": hist,
        "results": results,
        "eval": eval_res,
        "run_dir": (run_dir if (SAVE_CHECKPOINTS or SAVE_HISTORIES or SAVE_PER_RUN_FILES or SAVE_AGGREGATES) else None)
    })

# ========= 5) Resumen simple =========
test_accs = [rd["results"]["test_acc"] for rd in all_runs_data]
test_losses = [rd["results"]["test_loss"] for rd in all_runs_data]
print("\n" + "="*90)
print(f"‚úÖ {N_RUNS} corridas completadas ‚Äî {DATASET_NAME}")
print("="*90)
print(f"Test Acc:  mean={np.mean(test_accs):.4f}  std={np.std(test_accs):.4f}  "
      f"min={np.min(test_accs):.4f}  max={np.max(test_accs):.4f}")
print(f"Test Loss: mean={np.mean(test_losses):.4f}  std={np.std(test_losses):.4f}  "
      f"min={np.min(test_losses):.4f}  max={np.max(test_losses):.4f}")

# ========= 6) Agregaci√≥n (media ¬± std) y plots agregados =========
labels = ["W","N1","N2","N3","REM"]
metrics_cols = ["precision", "recall", "f1_score", "accuracy", "kappa"]

# stack m√©tricas por clase
per_class_list = [rd["eval"]["df"][metrics_cols].to_numpy() for rd in all_runs_data]  # list of (5x5)
per_class_arr  = np.stack(per_class_list, axis=0)  # (n_runs, 5, 5)

means = per_class_arr.mean(axis=0)  # (5,5)
stds  = per_class_arr.std(axis=0)   # (5,5)

# Mostrar tabla agregada (en pantalla)
df_agg = pd.DataFrame({"etapa": labels})
for j, col in enumerate(metrics_cols):
    df_agg[col] = [f"{means[i,j]:.3f} ¬± {stds[i,j]:.3f}" for i in range(len(labels))]

print("\nüìä M√âTRICAS AGREGADAS POR ETAPA (media ¬± std):")
print(df_agg.to_string(index=False))

# F1 barplot agregado (on-screen)
f1_means = means[:, metrics_cols.index("f1_score")]
f1_stds  = stds[:,  metrics_cols.index("f1_score")]
plt.figure(figsize=(9,5))
x = np.arange(len(labels))
plt.bar(x, f1_means, yerr=f1_stds, capsize=4, alpha=.85)
plt.xticks(x, labels)
plt.ylim(0, 1.05)
plt.xlabel("Etapa"); plt.ylabel("F1-Score"); plt.title(f"F1 por etapa (media¬±std) ‚Äî {DATASET_NAME}")
plt.grid(axis='y', alpha=.3)
plt.tight_layout(); plt.show()

# CM agregada (normalizada)
cm_norm_mean = np.mean([rd["eval"]["cm_norm"] for rd in all_runs_data], axis=0)
cm_norm_std  = np.std ([rd["eval"]["cm_norm"] for rd in all_runs_data], axis=0)

fig, ax = plt.subplots(1,2, figsize=(12,5))
sns.heatmap(cm_norm_mean, annot=True, fmt=".3f", cmap='Blues',
            xticklabels=labels, yticklabels=labels, vmin=0, vmax=1, ax=ax[0])
ax[0].set_title(f"CM Normalizada (media) ‚Äî {DATASET_NAME}")
sns.heatmap(cm_norm_std, annot=True, fmt=".3f", cmap='Reds',
            xticklabels=labels, yticklabels=labels, vmin=0, vmax=0.2, ax=ax[1])
ax[1].set_title(f"CM Normalizada (std) ‚Äî {DATASET_NAME}")
for a in ax: a.set_xlabel("Predicho"); a.set_ylabel("Real")
plt.tight_layout(); plt.show()

# Guardados agregados (opcional)
if SAVE_AGGREGATES:
    RUNS_DIR.mkdir(parents=True, exist_ok=True)
    df_agg.to_csv(RUNS_DIR / "metrics_aggregated_per_class.csv", index=False)
    np.save(RUNS_DIR / "cm_norm_mean.npy", cm_norm_mean)
    np.save(RUNS_DIR / "cm_norm_std.npy", cm_norm_std)

print("\n Listo.")


## Quinto dataset: EOG y EMG

In [None]:
# ================================================
# CELDA √öNICA: MULTI-RUN + M√âTRICAS 
# ================================================
import os, json, copy, math, random, pickle, warnings
from datetime import datetime
from pathlib import Path

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import (
    accuracy_score, precision_recall_fscore_support,
    confusion_matrix, cohen_kappa_score
)
warnings.filterwarnings("ignore")

# ========= 0) SELECCI√ìN DEL DATASET ACTUAL =========
X, y, splits = x5, y5, splits5
DATASET_NAME = "EOG+EMG"
# ========= 1) FLAGS (por defecto NO guarda) =========
SAVE_CHECKPOINTS   = False   # Guarda best_model.pt por run
SAVE_HISTORIES     = False   # Guarda history.npz por run
SAVE_PER_RUN_FILES = False   # Guarda CSV / PNG / NPY por run (m√©tricas y CM)
SAVE_AGGREGATES    = False   # Guarda tablas y plots agregados

# ========= 2) CONFIG GLOBAL =========
N_RUNS = 3
BASE_SEED = 42
CONFIG = {
    "lr": 5e-6,
    "batch_size": 256,
    "epochs": 50,
    "criterion_name": "ce",
    "class_weights": None,
    "weight_clip_range": (0.1, 2.5),
    "grad_clip": 1.0,
    "use_gpu": True,
    "amp": False,
    "num_workers": 0,
    "early_stopping_tolerance": 5,
    "early_stopping_metric": "val_acc"
}

# ======= Paths (solo se usan si guardas algo) =======
OUTPUT_DIR = Path(OUTPUT_DIR) if 'OUTPUT_DIR' in globals() else (Path.cwd() / "outputs")
RUNS_DIR   = OUTPUT_DIR / "multiple_runs" / DATASET_NAME.replace(" ", "_")
RUNS_DIR.mkdir(parents=True, exist_ok=True)

# ========= 3) Utils =========
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

class SpectroDataset(Dataset):
    def __init__(self, X, y, indices):
        self.X, self.y = X, y
        self.idx = np.asarray(indices)
    def __len__(self): return len(self.idx)
    def __getitem__(self, i):
        j = self.idx[i]
        x = np.asarray(self.X[j], dtype=np.float32)  # (H,W,C)
        x = np.transpose(x, (2,0,1))                 # -> (C,H,W)
        yj = int(self.y[j])
        return torch.from_numpy(x), torch.tensor(yj, dtype=torch.long)

def _build_loaders(X, y, splits, batch_size=256, num_workers=0, pin=True):
    train_ds = SpectroDataset(X, y, splits['train'])
    val_ds   = SpectroDataset(X, y, splits['val'])
    test_ds  = SpectroDataset(X, y, splits['test'])

    # Weighted sampler (balanceo por clase en TRAIN)
    y_train_subset = y[splits['train']]
    class_counts = np.bincount(y_train_subset, minlength=int(np.max(y))+1)
    class_weights = 1.0 / np.maximum(class_counts, 1)
    sample_weights = class_weights[y_train_subset]
    sampler = torch.utils.data.WeightedRandomSampler(
        weights=torch.DoubleTensor(sample_weights),
        num_samples=len(sample_weights),
        replacement=True
    )

    train_loader = DataLoader(train_ds, batch_size=batch_size, sampler=sampler,
                              num_workers=num_workers, pin_memory=pin, drop_last=False)
    val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False,
                              num_workers=num_workers, pin_memory=pin)
    test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False,
                              num_workers=num_workers, pin_memory=pin)
    return train_loader, val_loader, test_loader

# ======= Modelo 
def _new_model():
    try:
        return SleepStageModel(num_classes=5, in_ch=X.shape[-1])
    except TypeError:
        return SleepStageModel(num_classes=5)

# ======= Entrenamiento 

# ======= M√©tricas detalladas por run =======
@torch.no_grad()
def evaluate_detailed(model, test_loader, device):
    labels = ["W","N1","N2","N3","REM"]
    n_classes = len(labels)

    model.eval()
    all_p, all_t = [], []
    for xb, yb in test_loader:
        xb = xb.to(device, non_blocking=True)
        logits = model(xb)
        p = torch.argmax(logits, dim=1).cpu().numpy()
        all_p.append(p); all_t.append(yb.numpy())
    y_pred = np.concatenate(all_p)
    y_true = np.concatenate(all_t)

    prec, rec, f1, support = precision_recall_fscore_support(
        y_true, y_pred, labels=range(n_classes), average=None, zero_division=0
    )
    cm = confusion_matrix(y_true, y_pred, labels=range(n_classes))
    row_sums = cm.sum(axis=1, keepdims=True)
    cm_norm = np.divide(cm, np.maximum(row_sums, 1), where=(row_sums!=0))

    # accuracy/kappa one-vs-rest
    N = y_true.size
    acc_per_class = np.zeros(n_classes, dtype=np.float64)
    kappa_per_class = np.zeros(n_classes, dtype=np.float64)
    for k in range(n_classes):
        TP = cm[k, k]
        FN = cm[k, :].sum() - TP
        FP = cm[:, k].sum() - TP
        TN = cm.sum() - (TP + FN + FP)
        acc_per_class[k] = (TP + TN) / max(1, cm.sum())

        obs = acc_per_class[k]
        p_yes_true = (TP + FN) / N
        p_yes_pred = (TP + FP) / N
        p_no_true  = (FP + TN) / N
        p_no_pred  = (FN + TN) / N
        exp = p_yes_true * p_yes_pred + p_no_true * p_no_pred
        kappa_per_class[k] = (obs - exp) / (1 - exp + 1e-12)

    df_per_class = pd.DataFrame({
        "etapa": labels,
        "precision": np.round(prec, 3),
        "recall":    np.round(rec, 3),
        "f1_score":  np.round(f1, 3),
        "accuracy":  np.round(acc_per_class, 3),
        "kappa":     np.round(kappa_per_class, 3),
        "soporte":   support.astype(int)
    })

    overall_acc = accuracy_score(y_true, y_pred)
    kappa_global = cohen_kappa_score(y_true, y_pred)

    return {
        "df": df_per_class,
        "cm": cm,
        "cm_norm": cm_norm,
        "acc": overall_acc,
        "kappa": kappa_global,
        "y_true": y_true,
        "y_pred": y_pred
    }

# ========= 4) LOOP de runs =========
assert 'train_sleep_model' in globals(), "Falta la funci√≥n train_sleep_model en el entorno."
device = torch.device("cuda" if (CONFIG["use_gpu"] and torch.cuda.is_available()) else "cpu")
pin_mem = (device.type == "cuda")

print("="*90)
print(f"üöÄ MULTI-RUN sobre dataset: {DATASET_NAME}")
print("="*90)
print(f"Seeds: {[BASE_SEED+i for i in range(N_RUNS)]}")
print(f"Guardar checkpoints: {SAVE_CHECKPOINTS} | Guardar histories: {SAVE_HISTORIES} | Guardar per-run: {SAVE_PER_RUN_FILES}")
print()

# Loaders (fijos por dataset)
train_loader, val_loader, test_loader = _build_loaders(
    X, y, splits, batch_size=CONFIG["batch_size"], num_workers=CONFIG["num_workers"], pin=pin_mem
)

all_runs_data = []
for run_id in range(1, N_RUNS+1):
    seed = BASE_SEED + (run_id-1)
    set_seed(seed)

    model = _new_model()
    run_dir = RUNS_DIR / f"run_{run_id:02d}"
    if (SAVE_CHECKPOINTS or SAVE_HISTORIES or SAVE_PER_RUN_FILES or SAVE_AGGREGATES):
        run_dir.mkdir(parents=True, exist_ok=True)

    save_path = (str(run_dir / "best_model.pt")) if SAVE_CHECKPOINTS else None

    # Entrenar
    model, hist, results = train_sleep_model(
        model=model,
        X=X, y=y, splits=splits,
        save_path=(save_path if save_path else "best_model_tmp.pt"),
        **CONFIG
    )

    # Curvas del √∫ltimo run (on-screen)
    if run_id == N_RUNS:
        epochs_arr = range(1, len(hist["train_loss"])+1)
        fig, ax = plt.subplots(1,2, figsize=(12,5))
        ax[0].plot(epochs_arr, hist["train_loss"], 'r-', label='training')
        ax[0].plot(epochs_arr, hist["val_loss"], 'b-', label='validation')
        ax[0].set_title('Loss evolution'); ax[0].set_xlabel('Epoch'); ax[0].set_ylabel('Loss'); ax[0].grid(True, alpha=.3); ax[0].legend()

        ax[1].plot(epochs_arr, hist["train_acc"], 'r-', label='training')
        ax[1].plot(epochs_arr, hist["val_acc"], 'b-', label='validation')
        ax[1].set_title('Accuracy evolution'); ax[1].set_xlabel('Epoch'); ax[1].set_ylabel('Accuracy'); ax[1].grid(True, alpha=.3); ax[1].legend()
        plt.suptitle(f"Learning Curves ‚Äî {DATASET_NAME} (Run {run_id})")
        plt.tight_layout()
        plt.show()

    # Guardar opcional: history / config
    if SAVE_HISTORIES:
        np.savez(run_dir / "history.npz",
                 train_loss=hist["train_loss"], val_loss=hist["val_loss"],
                 train_acc=hist["train_acc"], val_acc=hist["val_acc"],
                 lr=hist["lr"])
        with open(run_dir / "config.json","w") as f:
            cfg = copy.deepcopy(CONFIG); cfg.update(seed=seed, run_id=run_id, dataset=DATASET_NAME, ts=datetime.now().isoformat())
            json.dump(cfg, f, indent=2)

    # Evaluaci√≥n detallada por run 
    eval_res = evaluate_detailed(model, test_loader, device)

    # Mostrar tabla por etapa en pantalla (sin guardar por defecto)
    print(f"\nüìä RUN {run_id} ‚Äî M√©tricas por etapa")
    display(eval_res["df"].style.set_caption(f"Run {run_id} ‚Äî {DATASET_NAME}"))

    print(f"   ‚û§ Acc={eval_res['acc']:.4f} | Kappa={eval_res['kappa']:.4f}")

    # Plots de CM (on-screen)
    labels = ["W","N1","N2","N3","REM"]
    fig, axes = plt.subplots(1,2, figsize=(12,5))
    sns.heatmap(eval_res["cm"], annot=True, fmt='d', cmap='Blues',
                xticklabels=labels, yticklabels=labels, ax=axes[0])
    axes[0].set_title(f"CM Cruda ‚Äî Run {run_id}")
    axes[0].set_xlabel("Predicho"); axes[0].set_ylabel("Real")

    sns.heatmap(eval_res["cm_norm"], annot=True, fmt='.2f', cmap='Blues',
                xticklabels=labels, yticklabels=labels, vmin=0, vmax=1, ax=axes[1])
    axes[1].set_title(f"CM Normalizada ‚Äî Run {run_id}")
    axes[1].set_xlabel("Predicho"); axes[1].set_ylabel("Real")
    plt.tight_layout(); plt.show()

    # Guardado por run (opcional)
    if SAVE_PER_RUN_FILES:
        eval_res["df"].to_csv(run_dir / "eval_test_per_class.csv", index=False)
        pd.DataFrame({"y_true": eval_res["y_true"].astype(int),
                      "y_pred": eval_res["y_pred"].astype(int)}).to_csv(run_dir / "eval_test_pred_vs_true.csv", index=False)
        np.save(run_dir / "eval_test_cm.npy", eval_res["cm"])
        np.save(run_dir / "eval_test_cm_norm.npy", eval_res["cm_norm"])
        with open(run_dir / "eval_test_summary.txt","w") as f:
            f.write(f"accuracy_global={eval_res['acc']:.6f}\n")
            f.write(f"kappa_global={eval_res['kappa']:.6f}\n")

    all_runs_data.append({
        "run_id": run_id,
        "history": hist,
        "results": results,
        "eval": eval_res,
        "run_dir": (run_dir if (SAVE_CHECKPOINTS or SAVE_HISTORIES or SAVE_PER_RUN_FILES or SAVE_AGGREGATES) else None)
    })

# ========= 5) Resumen simple =========
test_accs = [rd["results"]["test_acc"] for rd in all_runs_data]
test_losses = [rd["results"]["test_loss"] for rd in all_runs_data]
print("\n" + "="*90)
print(f"‚úÖ {N_RUNS} corridas completadas ‚Äî {DATASET_NAME}")
print("="*90)
print(f"Test Acc:  mean={np.mean(test_accs):.4f}  std={np.std(test_accs):.4f}  "
      f"min={np.min(test_accs):.4f}  max={np.max(test_accs):.4f}")
print(f"Test Loss: mean={np.mean(test_losses):.4f}  std={np.std(test_losses):.4f}  "
      f"min={np.min(test_losses):.4f}  max={np.max(test_losses):.4f}")

# ========= 6) Agregaci√≥n (media ¬± std) y plots agregados =========
labels = ["W","N1","N2","N3","REM"]
metrics_cols = ["precision", "recall", "f1_score", "accuracy", "kappa"]

# stack m√©tricas por clase
per_class_list = [rd["eval"]["df"][metrics_cols].to_numpy() for rd in all_runs_data]  # list of (5x5)
per_class_arr  = np.stack(per_class_list, axis=0)  # (n_runs, 5, 5)

means = per_class_arr.mean(axis=0)  # (5,5)
stds  = per_class_arr.std(axis=0)   # (5,5)

# Mostrar tabla agregada (en pantalla)
df_agg = pd.DataFrame({"etapa": labels})
for j, col in enumerate(metrics_cols):
    df_agg[col] = [f"{means[i,j]:.3f} ¬± {stds[i,j]:.3f}" for i in range(len(labels))]

print("\nüìä M√âTRICAS AGREGADAS POR ETAPA (media ¬± std):")
print(df_agg.to_string(index=False))

# F1 barplot agregado (on-screen)
f1_means = means[:, metrics_cols.index("f1_score")]
f1_stds  = stds[:,  metrics_cols.index("f1_score")]
plt.figure(figsize=(9,5))
x = np.arange(len(labels))
plt.bar(x, f1_means, yerr=f1_stds, capsize=4, alpha=.85)
plt.xticks(x, labels)
plt.ylim(0, 1.05)
plt.xlabel("Etapa"); plt.ylabel("F1-Score"); plt.title(f"F1 por etapa (media¬±std) ‚Äî {DATASET_NAME}")
plt.grid(axis='y', alpha=.3)
plt.tight_layout(); plt.show()

# CM agregada (normalizada)
cm_norm_mean = np.mean([rd["eval"]["cm_norm"] for rd in all_runs_data], axis=0)
cm_norm_std  = np.std ([rd["eval"]["cm_norm"] for rd in all_runs_data], axis=0)

fig, ax = plt.subplots(1,2, figsize=(12,5))
sns.heatmap(cm_norm_mean, annot=True, fmt=".3f", cmap='Blues',
            xticklabels=labels, yticklabels=labels, vmin=0, vmax=1, ax=ax[0])
ax[0].set_title(f"CM Normalizada (media) ‚Äî {DATASET_NAME}")
sns.heatmap(cm_norm_std, annot=True, fmt=".3f", cmap='Reds',
            xticklabels=labels, yticklabels=labels, vmin=0, vmax=0.2, ax=ax[1])
ax[1].set_title(f"CM Normalizada (std) ‚Äî {DATASET_NAME}")
for a in ax: a.set_xlabel("Predicho"); a.set_ylabel("Real")
plt.tight_layout(); plt.show()

# Guardados agregados (opcional)
if SAVE_AGGREGATES:
    RUNS_DIR.mkdir(parents=True, exist_ok=True)
    df_agg.to_csv(RUNS_DIR / "metrics_aggregated_per_class.csv", index=False)
    np.save(RUNS_DIR / "cm_norm_mean.npy", cm_norm_mean)
    np.save(RUNS_DIR / "cm_norm_std.npy", cm_norm_std)

print("\n Listo. ")


# Set 1A (Todo filtrado igual)

In [None]:
# ============================================================
# DATASET "Set A": STFT 2s/hop 2s, Hamming nperseg=256, 0.5‚Äì40 Hz
# Lee ventanas ya guardadas (WINDOWS_DIR) y arma (N, 61, 15, 4)
# ============================================================
import numpy as np
import pandas as pd
import pickle
from pathlib import Path
from scipy.signal import stft, get_window, resample



CHANNEL_PATTERNS = {
    "EEG1": ["EEG Fpz-Cz", "Fpz-Cz"],
    "EEG2": ["EEG Pz-Oz", "Pz-Oz"],
    "EOG" : ["EOG", "EOG horizontal", "EOG horizontal derivation"],
    "EMG" : ["EMG", "EMG submental", "Submental EMG"]
}

# Banda com√∫n 0.5‚Äì30 Hz para TODOS los canales (requisito del Set A)
FMIN, FMAX = 0.5, 30.0

# Salida 61√ó15 (STFT 2s / hop 2s en una ventana de 30s)
N_FREQ_OUT, N_TIME_OUT = 61, 15
WIN_SEC, SEG_SEC, HOP_SEC = 30.0, 2.0, 2.0

# STFT: ventana Hamming fija de 256 puntos (NumPy 2.0-safe)
NPERSEG_FIXED = 256
WINDOW_TYPE = "hamming"

LABEL2ID = {"W":0, "N1":1, "N2":2, "N3":3, "REM":4}
ID2LABEL = {v:k for k,v in LABEL2ID.items()}

def _load_npz(path: Path):
    d = np.load(path, allow_pickle=False)
    return {
        "X": d["X"],
        "y": d["y"].astype(np.uint8),
        "t": d["t"].astype(np.float32),
        "fs": float(d["fs"]),
        "canal": str(d["canal"])
    }

def _load_pkl(path: Path):
    with open(path, "rb") as f:
        data = pickle.load(f)
    if isinstance(data.get("etiquetas", []), list) and data["etiquetas"]:
        y = []
        for s in data["etiquetas"]:
            y.append(int(s) if isinstance(s, (int, np.integer)) else LABEL2ID.get(str(s), 255))
        y = np.array(y, dtype=np.uint8)
    else:
        y = np.array(data.get("etiquetas", []), dtype=np.uint8)
    return {
        "X": np.asarray(data["ventanas"], dtype=np.float32),
        "y": y,
        "t": np.asarray(data["tiempos_inicio"], dtype=np.float32),
        "fs": float(data.get("freq_muestreo", 100.0)),
        "canal": str(data.get("nombre_canal", "CANAL"))
    }

def load_channel_file(paciente: str, canal_nombre: str, windows_dir: Path):
    base = windows_dir / f"{paciente}_{canal_nombre.replace(' ', '_')}"
    npz_path, pkl_path = base.with_suffix(".npz"), base.with_suffix(".pkl")
    if npz_path.exists(): return _load_npz(npz_path), ".npz"
    if pkl_path.exists(): return _load_pkl(pkl_path), ".pkl"
    return None, None

def stft_2s_2s(x, fs,
               fmin=FMIN, fmax=FMAX,
               n_freq_out=N_FREQ_OUT, n_time_out=N_TIME_OUT,
               win_sec=WIN_SEC, seg_sec=SEG_SEC, hop_sec=HOP_SEC):
    # NumPy 2.0-safe
    x = np.asarray(x, dtype=np.float32)

    # fuerza 30 s
    expected_len = int(round(win_sec * fs))
    if len(x) != expected_len:
        x = x[:expected_len] if len(x) > expected_len else np.pad(x, (0, expected_len - len(x)))

    nperseg = NPERSEG_FIXED
    hop_samps = int(round(hop_sec * fs))
    noverlap = max(0, nperseg - hop_samps)

    nfft = 1
    while nfft < nperseg:
        nfft <<= 1

    f, t, Z = stft(
        x, fs=fs,
        window=get_window(WINDOW_TYPE, nperseg, fftbins=True),
        nperseg=nperseg, noverlap=noverlap, nfft=nfft,
        boundary=None, padded=False, detrend=False, return_onesided=True
    )
    band = (f >= fmin) & (f <= fmax)
    Zb = Z[band, :]
    P = (np.abs(Zb) ** 2).astype(np.float32)
    S = np.log10(P + 1e-12)

    # remuestreo a (61,15)
    if S.shape[0] != n_freq_out:
        S = resample(S, n_freq_out, axis=0)
    if S.shape[1] != n_time_out:
        S = resample(S, n_time_out, axis=1)
    return S

def pick_channel_name(df_patient: pd.DataFrame, aliases: list[str]) -> str | None:
    names = list(df_patient["Canal"].unique())
    u_names = [n.upper() for n in names]
    for alias in aliases:
        alias_u = alias.upper()
        for n, u in zip(names, u_names):
            if u == alias_u: return n
        for n, u in zip(names, u_names):
            if alias_u in u: return n
    return None

def build_cnn_dataset_setA(
    analysis_csv: Path,
    windows_dir: Path,
    dtype="float32",
    memmap_path: Path | None = None,
    max_patients: int | None = None
):
    df = pd.read_csv(analysis_csv)
    for col in ["Paciente", "Canal"]:
        assert col in df.columns, f"Falta columna {col} en {analysis_csv}"

    patients = list(df["Paciente"].unique())
    if max_patients: patients = patients[:max_patients]

    all_specs, all_labels, counts = [], [], []

    # memmap opcional (por defecto NO guarda en disco)
    if memmap_path is not None:
        total_N = 0
        for p in patients:
            dpf = df[df["Paciente"] == p]
            chosen = {k: pick_channel_name(dpf, v) for k, v in CHANNEL_PATTERNS.items()}
            if any(v is None for v in chosen.values()): continue
            loaded = {}
            ok = True
            for k, nm in chosen.items():
                dfile, _ = load_channel_file(p, nm, windows_dir)
                if dfile is None: ok = False; break
                loaded[k] = dfile
            if not ok: continue
            times_sets = [set(np.round(loaded[k]["t"], 4)) for k in loaded]
            common = set.intersection(*times_sets)
            total_N += len(common)

        memmap_path.parent.mkdir(parents=True, exist_ok=True)
        X = np.memmap(memmap_path, dtype=(np.float16 if dtype=="float16" else np.float32),
                      mode='w+', shape=(total_N, N_FREQ_OUT, N_TIME_OUT, 4))
        y = np.memmap(memmap_path.with_suffix(".labels.npy"), dtype=np.uint8,
                      mode='w+', shape=(total_N,))
        widx = 0
    else:
        X = y = None

    for p in patients:
        dpf = df[df["Paciente"] == p]
        chosen = {k: pick_channel_name(dpf, v) for k, v in CHANNEL_PATTERNS.items()}
        if any(v is None for v in chosen.values()): continue

        loaded = {}
        ok = True
        for k, nm in chosen.items():
            dfile, _ = load_channel_file(p, nm, windows_dir)
            if dfile is None: ok = False; break
            loaded[k] = dfile
        if not ok: continue

        # alinear por tiempos comunes
        tr = {k: np.round(loaded[k]["t"], 4) for k in loaded}
        common = set(tr["EEG1"])
        for k in ["EEG2", "EOG", "EMG"]: common &= set(tr[k])
        if not common: continue
        common_sorted = np.array(sorted(list(common)), dtype=np.float32)

        idx_maps = {}
        for k in loaded:
            t2idx = {float(t): i for i, t in enumerate(tr[k])}
            idx_maps[k] = [t2idx[float(t)] for t in common_sorted]

        # STFT por canal (0.5‚Äì30 Hz)
        ch_specs = []
        labels_p = None
        for k in ["EEG1","EEG2","EOG","EMG"]:
            d_k = loaded[k]
            fs = d_k["fs"]
            Xraw = d_k["X"][idx_maps[k]]
            yk   = d_k["y"][idx_maps[k]]
            if labels_p is None: labels_p = yk.copy()

            n = Xraw.shape[0]
            S_k = np.empty((n, N_FREQ_OUT, N_TIME_OUT), dtype=np.float32)
            for i in range(n):
                S_k[i] = stft_2s_2s(Xraw[i], fs)
            ch_specs.append(S_k)

        specs_p = np.stack(ch_specs, axis=-1)  # (n, 61, 15, 4)
        if dtype == "float16": specs_p = specs_p.astype(np.float16)

        if memmap_path is not None:
            nn = specs_p.shape[0]
            X[widx:widx+nn] = specs_p
            y[widx:widx+nn] = labels_p
            widx += nn
        else:
            all_specs.append(specs_p)
            all_labels.append(labels_p)

        counts.append((p, int(specs_p.shape[0])))

    if memmap_path is None:
        if all_specs:
            X = np.concatenate(all_specs, axis=0)
            y = np.concatenate(all_labels, axis=0)
        else:
            X = np.empty((0, N_FREQ_OUT, N_TIME_OUT, 4), dtype=(np.float16 if dtype=="float16" else np.float32))
            y = np.empty((0,), dtype=np.uint8)

    meta = {
        "shape": tuple(X.shape),
        "labels_unique": sorted(list(map(int, np.unique(y)))) if y.size else [],
        "label_map": ID2LABEL,
        "counts_per_patient": counts,
        "channels_used": CHANNEL_PATTERNS,
        "stft_config": {
            "window_type": WINDOW_TYPE,
            "nperseg": NPERSEG_FIXED,
            "freq_range": (FMIN, FMAX),
            "output_shape": (N_FREQ_OUT, N_TIME_OUT),
            "hop_s": HOP_SEC, "seg_s": SEG_SEC, "win_s": WIN_SEC
        }
    }
    return X, y, meta

# ---------- Ejecutar (no guarda a disco por defecto) ----------
resumen_csv = ANALYSIS_DIR / "resumen_global.csv"
SAVE_TO_DISK = False  # <- c√°mbialo a True si quieres memmap en OUTPUT_DIR
memmap_path = (OUTPUT_DIR / "SetA_X.dat") if SAVE_TO_DISK else None

X_A, y_A, meta_A = build_cnn_dataset_setA(
    analysis_csv=resumen_csv,
    windows_dir=WINDOWS_DIR,
    dtype="float32",
    memmap_path=memmap_path,
    max_patients=None
)

print("‚úÖ Set A listo (0.5‚Äì40 Hz, 4 canales)")
print("   X_A:", X_A.shape, "| y_A:", y_A.shape, "| clases:", sorted(np.unique(y_A)))
print("   Pacientes:", len(meta_A["counts_per_patient"]))
print("   Config STFT:", meta_A["stft_config"])

In [None]:
# ============================================================
# Patient IDs + Split 60/20/20 por paciente (sin leakage)
# ============================================================
import numpy as np
from sklearn.model_selection import GroupShuffleSplit

def make_patient_ids(meta):
    ids = []
    for patient, n in meta["counts_per_patient"]:
        ids.extend([patient] * int(n))
    return np.array(ids)

patient_ids_A = make_patient_ids(meta_A)
assert len(patient_ids_A) == len(y_A) == X_A.shape[0], "Desalineaci√≥n en Set A"

def split_by_patient(X, y, patient_ids, test_size=0.20, val_size=0.20, random_state=42):
    N = len(y)
    gss1 = GroupShuffleSplit(n_splits=1, test_size=test_size, random_state=random_state)
    train_val_idx, test_idx = next(gss1.split(np.zeros(N), y, groups=patient_ids))

    rel_val = val_size / (1.0 - test_size)
    gss2 = GroupShuffleSplit(n_splits=1, test_size=rel_val, random_state=random_state+1)
    pv = patient_ids[train_val_idx]
    yv = y[train_val_idx]
    sub_train_idx, val_idx_sub = next(gss2.split(np.zeros(len(train_val_idx)), yv, groups=pv))

    train_idx = train_val_idx[sub_train_idx]
    val_idx = train_val_idx[val_idx_sub]
    return {"train": train_idx, "val": val_idx, "test": test_idx}

splits_A = split_by_patient(X_A, y_A, patient_ids_A, test_size=0.20, val_size=0.20, random_state=42)

def print_split_summary(y, patient_ids, splits, label_names={0:"W",1:"N1",2:"N2",3:"N3",4:"REM"}):
    p_train = set(np.unique(patient_ids[splits["train"]]))
    p_val   = set(np.unique(patient_ids[splits["val"]]))
    p_test  = set(np.unique(patient_ids[splits["test"]]))
    print("====== PACIENTES ======")
    print(f"Train: {len(p_train)} | Val: {len(p_val)} | Test: {len(p_test)}")
    print("Intersecciones (deben ser 0):",
          len(p_train & p_val), len(p_train & p_test), len(p_val & p_test))
    print("\n====== DISTRIBUCI√ìN (ventanas) ======")
    for name, idx in splits.items():
        yy = y[idx]
        uniq, cnt = np.unique(yy, return_counts=True)
        total = len(yy)
        nice = ", ".join([f"{label_names[int(k)]}: {int(v)} ({v/total*100:.1f}%)"
                          for k, v in zip(uniq, cnt)])
        print(f"{name:>5} -> N={total} | {nice}")

print_split_summary(y_A, patient_ids_A, splits_A)


In [None]:
# ================================================
# CELDA √öNICA: MULTI-RUN + M√âTRICAS 
# ================================================
import os, json, copy, math, random, pickle, warnings
from datetime import datetime
from pathlib import Path

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import (
    accuracy_score, precision_recall_fscore_support,
    confusion_matrix, cohen_kappa_score
)
warnings.filterwarnings("ignore")

# ========= 0) SELECCI√ìN DEL DATASET ACTUAL =========
X, y, splits = X_A, y_A, splits_A
DATASET_NAME = "EEG1+EEG2+EOG+EMG Set A (0.5-30 Hz)"

# ========= 1) FLAGS (por defecto NO guarda) =========
SAVE_CHECKPOINTS   = False   # Guarda best_model.pt por run
SAVE_HISTORIES     = False   # Guarda history.npz por run
SAVE_PER_RUN_FILES = False   # Guarda CSV / PNG / NPY por run (m√©tricas y CM)
SAVE_AGGREGATES    = False   # Guarda tablas y plots agregados

# ========= 2) CONFIG GLOBAL =========
N_RUNS = 3
BASE_SEED = 42
CONFIG = {
    "lr":5e-6,
    "batch_size": 256,
    "epochs": 50,
    "criterion_name": "ce",
    "class_weights": None,
    "weight_clip_range": (0.1, 2.5),
    "grad_clip": 1.0,
    "use_gpu": True,
    "amp": False,
    "num_workers": 0,
    "early_stopping_tolerance": 5,
    "early_stopping_metric": "val_acc"
}

# ======= Paths  =======
OUTPUT_DIR = Path(OUTPUT_DIR) if 'OUTPUT_DIR' in globals() else (Path.cwd() / "outputs")
RUNS_DIR   = OUTPUT_DIR / "multiple_runs" / DATASET_NAME.replace(" ", "_")
RUNS_DIR.mkdir(parents=True, exist_ok=True)

# ========= 3) Utils =========
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

class SpectroDataset(Dataset):
    def __init__(self, X, y, indices):
        self.X, self.y = X, y
        self.idx = np.asarray(indices)
    def __len__(self): return len(self.idx)
    def __getitem__(self, i):
        j = self.idx[i]
        x = np.asarray(self.X[j], dtype=np.float32)  # (H,W,C)
        x = np.transpose(x, (2,0,1))                 # -> (C,H,W)
        yj = int(self.y[j])
        return torch.from_numpy(x), torch.tensor(yj, dtype=torch.long)

def _build_loaders(X, y, splits, batch_size=256, num_workers=0, pin=True):
    train_ds = SpectroDataset(X, y, splits['train'])
    val_ds   = SpectroDataset(X, y, splits['val'])
    test_ds  = SpectroDataset(X, y, splits['test'])

    # Weighted sampler (balanceo por clase en TRAIN)
    y_train_subset = y[splits['train']]
    class_counts = np.bincount(y_train_subset, minlength=int(np.max(y))+1)
    class_weights = 1.0 / np.maximum(class_counts, 1)
    sample_weights = class_weights[y_train_subset]
    sampler = torch.utils.data.WeightedRandomSampler(
        weights=torch.DoubleTensor(sample_weights),
        num_samples=len(sample_weights),
        replacement=True
    )

    train_loader = DataLoader(train_ds, batch_size=batch_size, sampler=sampler,
                              num_workers=num_workers, pin_memory=pin, drop_last=False)
    val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False,
                              num_workers=num_workers, pin_memory=pin)
    test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False,
                              num_workers=num_workers, pin_memory=pin)
    return train_loader, val_loader, test_loader

# ======= Modelo 
def _new_model():
    try:
        return SleepStageModel(num_classes=5, in_ch=X.shape[-1])
    except TypeError:
        return SleepStageModel(num_classes=5)

# ======= Entrenamiento 

# ======= M√©tricas detalladas por run =======
@torch.no_grad()
def evaluate_detailed(model, test_loader, device):
    labels = ["W","N1","N2","N3","REM"]
    n_classes = len(labels)

    model.eval()
    all_p, all_t = [], []
    for xb, yb in test_loader:
        xb = xb.to(device, non_blocking=True)
        logits = model(xb)
        p = torch.argmax(logits, dim=1).cpu().numpy()
        all_p.append(p); all_t.append(yb.numpy())
    y_pred = np.concatenate(all_p)
    y_true = np.concatenate(all_t)

    prec, rec, f1, support = precision_recall_fscore_support(
        y_true, y_pred, labels=range(n_classes), average=None, zero_division=0
    )
    cm = confusion_matrix(y_true, y_pred, labels=range(n_classes))
    row_sums = cm.sum(axis=1, keepdims=True)
    cm_norm = np.divide(cm, np.maximum(row_sums, 1), where=(row_sums!=0))

    # accuracy/kappa one-vs-rest
    N = y_true.size
    acc_per_class = np.zeros(n_classes, dtype=np.float64)
    kappa_per_class = np.zeros(n_classes, dtype=np.float64)
    for k in range(n_classes):
        TP = cm[k, k]
        FN = cm[k, :].sum() - TP
        FP = cm[:, k].sum() - TP
        TN = cm.sum() - (TP + FN + FP)
        acc_per_class[k] = (TP + TN) / max(1, cm.sum())

        obs = acc_per_class[k]
        p_yes_true = (TP + FN) / N
        p_yes_pred = (TP + FP) / N
        p_no_true  = (FP + TN) / N
        p_no_pred  = (FN + TN) / N
        exp = p_yes_true * p_yes_pred + p_no_true * p_no_pred
        kappa_per_class[k] = (obs - exp) / (1 - exp + 1e-12)

    df_per_class = pd.DataFrame({
        "etapa": labels,
        "precision": np.round(prec, 3),
        "recall":    np.round(rec, 3),
        "f1_score":  np.round(f1, 3),
        "accuracy":  np.round(acc_per_class, 3),
        "kappa":     np.round(kappa_per_class, 3),
        "soporte":   support.astype(int)
    })

    overall_acc = accuracy_score(y_true, y_pred)
    kappa_global = cohen_kappa_score(y_true, y_pred)

    return {
        "df": df_per_class,
        "cm": cm,
        "cm_norm": cm_norm,
        "acc": overall_acc,
        "kappa": kappa_global,
        "y_true": y_true,
        "y_pred": y_pred
    }

# ========= 4) LOOP de runs =========
assert 'train_sleep_model' in globals(), "Falta la funci√≥n train_sleep_model en el entorno."
device = torch.device("cuda" if (CONFIG["use_gpu"] and torch.cuda.is_available()) else "cpu")
pin_mem = (device.type == "cuda")

print("="*90)
print(f"üöÄ MULTI-RUN sobre dataset: {DATASET_NAME}")
print("="*90)
print(f"Seeds: {[BASE_SEED+i for i in range(N_RUNS)]}")
print(f"Guardar checkpoints: {SAVE_CHECKPOINTS} | Guardar histories: {SAVE_HISTORIES} | Guardar per-run: {SAVE_PER_RUN_FILES}")
print()

# Loaders (fijos por dataset)
train_loader, val_loader, test_loader = _build_loaders(
    X, y, splits, batch_size=CONFIG["batch_size"], num_workers=CONFIG["num_workers"], pin=pin_mem
)

all_runs_data = []
for run_id in range(1, N_RUNS+1):
    seed = BASE_SEED + (run_id-1)
    set_seed(seed)

    model = _new_model()
    run_dir = RUNS_DIR / f"run_{run_id:02d}"
    if (SAVE_CHECKPOINTS or SAVE_HISTORIES or SAVE_PER_RUN_FILES or SAVE_AGGREGATES):
        run_dir.mkdir(parents=True, exist_ok=True)

    save_path = (str(run_dir / "best_model.pt")) if SAVE_CHECKPOINTS else None

    # Entrenar
    model, hist, results = train_sleep_model(
        model=model,
        X=X, y=y, splits=splits,
        save_path=(save_path if save_path else "best_model_tmp.pt"),
        **CONFIG
    )

    # Curvas del √∫ltimo run (on-screen)
    if run_id == N_RUNS:
        epochs_arr = range(1, len(hist["train_loss"])+1)
        fig, ax = plt.subplots(1,2, figsize=(12,5))
        ax[0].plot(epochs_arr, hist["train_loss"], 'r-', label='training')
        ax[0].plot(epochs_arr, hist["val_loss"], 'b-', label='validation')
        ax[0].set_title('Loss evolution'); ax[0].set_xlabel('Epoch'); ax[0].set_ylabel('Loss'); ax[0].grid(True, alpha=.3); ax[0].legend()

        ax[1].plot(epochs_arr, hist["train_acc"], 'r-', label='training')
        ax[1].plot(epochs_arr, hist["val_acc"], 'b-', label='validation')
        ax[1].set_title('Accuracy evolution'); ax[1].set_xlabel('Epoch'); ax[1].set_ylabel('Accuracy'); ax[1].grid(True, alpha=.3); ax[1].legend()
        plt.suptitle(f"Learning Curves ‚Äî {DATASET_NAME} (Run {run_id})")
        plt.tight_layout()
        plt.show()

    # Guardar opcional: history / config
    if SAVE_HISTORIES:
        np.savez(run_dir / "history.npz",
                 train_loss=hist["train_loss"], val_loss=hist["val_loss"],
                 train_acc=hist["train_acc"], val_acc=hist["val_acc"],
                 lr=hist["lr"])
        with open(run_dir / "config.json","w") as f:
            cfg = copy.deepcopy(CONFIG); cfg.update(seed=seed, run_id=run_id, dataset=DATASET_NAME, ts=datetime.now().isoformat())
            json.dump(cfg, f, indent=2)

    # Evaluaci√≥n detallada por run
    eval_res = evaluate_detailed(model, test_loader, device)

    # Mostrar tabla por etapa en pantalla (sin guardar por defecto)
    print(f"\nüìä RUN {run_id} ‚Äî M√©tricas por etapa")
    display(eval_res["df"].style.set_caption(f"Run {run_id} ‚Äî {DATASET_NAME}"))

    print(f"   ‚û§ Acc={eval_res['acc']:.4f} | Kappa={eval_res['kappa']:.4f}")

    # Plots de CM (on-screen)
    labels = ["W","N1","N2","N3","REM"]
    fig, axes = plt.subplots(1,2, figsize=(12,5))
    sns.heatmap(eval_res["cm"], annot=True, fmt='d', cmap='Blues',
                xticklabels=labels, yticklabels=labels, ax=axes[0])
    axes[0].set_title(f"CM Cruda ‚Äî Run {run_id}")
    axes[0].set_xlabel("Predicho"); axes[0].set_ylabel("Real")

    sns.heatmap(eval_res["cm_norm"], annot=True, fmt='.2f', cmap='Blues',
                xticklabels=labels, yticklabels=labels, vmin=0, vmax=1, ax=axes[1])
    axes[1].set_title(f"CM Normalizada ‚Äî Run {run_id}")
    axes[1].set_xlabel("Predicho"); axes[1].set_ylabel("Real")
    plt.tight_layout(); plt.show()

    # Guardado por run (opcional)
    if SAVE_PER_RUN_FILES:
        eval_res["df"].to_csv(run_dir / "eval_test_per_class.csv", index=False)
        pd.DataFrame({"y_true": eval_res["y_true"].astype(int),
                      "y_pred": eval_res["y_pred"].astype(int)}).to_csv(run_dir / "eval_test_pred_vs_true.csv", index=False)
        np.save(run_dir / "eval_test_cm.npy", eval_res["cm"])
        np.save(run_dir / "eval_test_cm_norm.npy", eval_res["cm_norm"])
        with open(run_dir / "eval_test_summary.txt","w") as f:
            f.write(f"accuracy_global={eval_res['acc']:.6f}\n")
            f.write(f"kappa_global={eval_res['kappa']:.6f}\n")

    all_runs_data.append({
        "run_id": run_id,
        "history": hist,
        "results": results,
        "eval": eval_res,
        "run_dir": (run_dir if (SAVE_CHECKPOINTS or SAVE_HISTORIES or SAVE_PER_RUN_FILES or SAVE_AGGREGATES) else None)
    })

# ========= 5) Resumen simple =========
test_accs = [rd["results"]["test_acc"] for rd in all_runs_data]
test_losses = [rd["results"]["test_loss"] for rd in all_runs_data]
print("\n" + "="*90)
print(f"‚úÖ {N_RUNS} corridas completadas ‚Äî {DATASET_NAME}")
print("="*90)
print(f"Test Acc:  mean={np.mean(test_accs):.4f}  std={np.std(test_accs):.4f}  "
      f"min={np.min(test_accs):.4f}  max={np.max(test_accs):.4f}")
print(f"Test Loss: mean={np.mean(test_losses):.4f}  std={np.std(test_losses):.4f}  "
      f"min={np.min(test_losses):.4f}  max={np.max(test_losses):.4f}")

# ========= 6) Agregaci√≥n (media ¬± std) y plots agregados =========
labels = ["W","N1","N2","N3","REM"]
metrics_cols = ["precision", "recall", "f1_score", "accuracy", "kappa"]

# stack m√©tricas por clase
per_class_list = [rd["eval"]["df"][metrics_cols].to_numpy() for rd in all_runs_data]  # list of (5x5)
per_class_arr  = np.stack(per_class_list, axis=0)  # (n_runs, 5, 5)

means = per_class_arr.mean(axis=0)  # (5,5)
stds  = per_class_arr.std(axis=0)   # (5,5)

# Mostrar tabla agregada (en pantalla)
df_agg = pd.DataFrame({"etapa": labels})
for j, col in enumerate(metrics_cols):
    df_agg[col] = [f"{means[i,j]:.3f} ¬± {stds[i,j]:.3f}" for i in range(len(labels))]

print("\nüìä M√âTRICAS AGREGADAS POR ETAPA (media ¬± std):")
print(df_agg.to_string(index=False))

# F1 barplot agregado (on-screen)
f1_means = means[:, metrics_cols.index("f1_score")]
f1_stds  = stds[:,  metrics_cols.index("f1_score")]
plt.figure(figsize=(9,5))
x = np.arange(len(labels))
plt.bar(x, f1_means, yerr=f1_stds, capsize=4, alpha=.85)
plt.xticks(x, labels)
plt.ylim(0, 1.05)
plt.xlabel("Etapa"); plt.ylabel("F1-Score"); plt.title(f"F1 por etapa (media¬±std) ‚Äî {DATASET_NAME}")
plt.grid(axis='y', alpha=.3)
plt.tight_layout(); plt.show()

# CM agregada (normalizada)
cm_norm_mean = np.mean([rd["eval"]["cm_norm"] for rd in all_runs_data], axis=0)
cm_norm_std  = np.std ([rd["eval"]["cm_norm"] for rd in all_runs_data], axis=0)

fig, ax = plt.subplots(1,2, figsize=(12,5))
sns.heatmap(cm_norm_mean, annot=True, fmt=".3f", cmap='Blues',
            xticklabels=labels, yticklabels=labels, vmin=0, vmax=1, ax=ax[0])
ax[0].set_title(f"CM Normalizada (media) ‚Äî {DATASET_NAME}")
sns.heatmap(cm_norm_std, annot=True, fmt=".3f", cmap='Reds',
            xticklabels=labels, yticklabels=labels, vmin=0, vmax=0.2, ax=ax[1])
ax[1].set_title(f"CM Normalizada (std) ‚Äî {DATASET_NAME}")
for a in ax: a.set_xlabel("Predicho"); a.set_ylabel("Real")
plt.tight_layout(); plt.show()

# Guardados agregados (opcional)
if SAVE_AGGREGATES:
    RUNS_DIR.mkdir(parents=True, exist_ok=True)
    df_agg.to_csv(RUNS_DIR / "metrics_aggregated_per_class.csv", index=False)
    np.save(RUNS_DIR / "cm_norm_mean.npy", cm_norm_mean)
    np.save(RUNS_DIR / "cm_norm_std.npy", cm_norm_std)

print("\n Listo. ")
