#  Setup

In [None]:
# ================================================
#  INSTALACI√ìN Y CONFIGURACI√ìN DE DEPENDENCIAS
# ================================================

import sys
import subprocess

print(" CONFIGURANDO ENTORNO DE DEPENDENCIAS")
print("="*80)

# ================================================
# PASO 1: Instalar librer√≠as principales
# ================================================
print("\n[1/4] Instalando librer√≠as de procesamiento y datos...")
subprocess.check_call([
    sys.executable, '-m', 'pip', 'install', '-q',
    'mne', 'pyedflib', 'numpy', 'pandas', 'scipy', 'tqdm', 'pydrive2'
])
print("    Librer√≠as principales instaladas")

# ================================================
# PASO 2: Instalar soporte visual (ipywidgets)
# ================================================
print("\n[2/4] Instalando soporte visual para tqdm y widgets interactivos...")
try:
    subprocess.check_call([
        sys.executable, '-m', 'pip', 'install', '-q', '--upgrade', 'ipywidgets'
    ])
    print("    ipywidgets instalado/actualizado")

    try:
        subprocess.check_call([
            sys.executable, '-m', 'jupyter', 'nbextension',
            'enable', '--py', 'widgetsnbextension', '--sys-prefix'
        ])
        print("    Extensi√≥n widgetsnbextension habilitada")
    except:
        print("    No se pudo habilitar widgetsnbextension (Colab normalmente no lo requiere)")
except Exception as e:
    print(f"    Advertencia durante instalaci√≥n: {e}")

# ================================================
# PASO 3: Instalar utilidades opcionales
# ================================================
print("\n[3/4] Instalando utilidades adicionales (matplotlib, seaborn, openpyxl)...")
subprocess.check_call([
    sys.executable, '-m', 'pip', 'install', '-q',
    'matplotlib', 'seaborn', 'openpyxl'
])
print("   Utilidades adicionales instaladas")

# ================================================
# PASO 4: Importar y verificar versiones
# ================================================
print("\n[4/4] Importando m√≥dulos y verificando entorno...")

import mne
import pandas as pd
import numpy as np
import pickle
import tqdm
import ipywidgets

print("    Dependencias cargadas correctamente")
print(f"\n Versi√≥n de componentes:")
print(f"   ‚Ä¢ Python:   {sys.version.split()[0]}")
print(f"   ‚Ä¢ MNE:      {mne.__version__}")
print(f"   ‚Ä¢ pandas:   {pd.__version__}")
print(f"   ‚Ä¢ numpy:    {np.__version__}")

print("\n Entorno completamente configurado. Puedes continuar con el procesamiento.")
print("="*80)


# Configuraci√≥n

In [None]:
# ============================================================
#  CONFIGURACI√ìN DE RUTAS, PAR√ÅMETROS Y DIRECTORIOS
# ============================================================

from pathlib import Path
import os

print("  CONFIGURANDO RUTAS Y PAR√ÅMETROS DE PROCESAMIENTO")
print("="*80)

# ============================================================
# PASO 1: DEFINIR DIRECTORIOS BASE
# ============================================================
print("\n[1/4] Estableciendo rutas principales...")

# Carpeta local con los archivos EDF (PSG + Hypnograma)
RAW_DIR = Path(
    r"C:\Users\shipa\OneDrive\Escritorio\Inteligencia Computacional\sleep-edf-database-expanded-1.0.0\sleep-edf-database-expanded-1.0.0\sleep-cassette"
)



# Carpeta ra√≠z donde se guardar√°n resultados
OUTPUT_DIR = RAW_DIR / "ventanas_out_mas_corto"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

print(f"    RAW_DIR:    {RAW_DIR}")
print(f"    OUTPUT_DIR: {OUTPUT_DIR}")

# ============================================================
# PASO 2: CREAR SUBCARPETAS DE SALIDA
# ============================================================
print("\n[2/4] Creando subdirectorios para ventanas y an√°lisis...")

WINDOWS_DIR = OUTPUT_DIR / "ventanas_extraidas_mas_corto"
ANALYSIS_DIR = OUTPUT_DIR / "analisis_canales_mas_corto"

WINDOWS_DIR.mkdir(exist_ok=True)
ANALYSIS_DIR.mkdir(exist_ok=True)

print(f"    WINDOWS_DIR:  {WINDOWS_DIR}")
print(f"    ANALYSIS_DIR: {ANALYSIS_DIR}")

# ============================================================
# PASO 3: DEFINIR PAR√ÅMETROS DE PROCESAMIENTO
# ============================================================
print("\n[3/4] Estableciendo par√°metros de ventaneo...")

WINDOW_SIZE = 30.0   # en segundos
OVERLAP = 15.0       # en segundos
STRIDE = WINDOW_SIZE - OVERLAP

print(f"     Ventana:  {WINDOW_SIZE}s")
print(f"     Overlap:  {OVERLAP}s")
print(f"     Stride:   {STRIDE}s")

# ============================================================
# PASO 4: CONFIGURAR ARCHIVO ZIP DE EXPORTACI√ìN
# ============================================================
print("\n[4/4] Configurando ruta para archivo ZIP de exportaci√≥n...")

ZIP_PATH = OUTPUT_DIR / "sleep_edf_ventanas.zip"
print(f"    ZIP_PATH: {ZIP_PATH}")

# ============================================================
# RESUMEN FINAL DE CONFIGURACI√ìN
# ============================================================
print("\n" + "="*80)
print(" CONFIGURACI√ìN COMPLETADA")
print("="*80)


# Funciones auxiliares + recorte de Wakes

In [None]:
# ================================================================
#  FUNCIONES AUXILIARES Y PREPROCESAMIENTO DE SUE√ëO 
# ================================================================

import numpy as np
import pandas as pd
import mne
from pathlib import Path
from datetime import datetime, timedelta

print(" Cargando funciones auxiliares, extracci√≥n de ventanas y recorte de sue√±o")
print("="*80)

# ================================================================
# PASO 1: FUNCIONES PARA DETECCI√ìN DE PARES PSG/HYPNOGRAM
# ================================================================

def key_from_psg(name: str) -> str:
    """SC4001E0-PSG.edf  ‚Üí  SC4001E"""
    return Path(name).name.split('-')[0][:7]

def key_from_hyp(name: str) -> str:
    """SC4001EC-Hypnogram.edf ‚Üí SC4001E"""
    return Path(name).name.split('-')[0][:7]

def encontrar_pares(raw_dir: Path):
    """
    Busca pares PSG/Hypnogram en estilo Sleep-EDF Cassette (SCxxxxE).
    Devuelve lista de tuplas (key, psg_path, hyp_path).
    """
    psgs, hyps = {}, {}

    for fp in raw_dir.glob("*.edf"):
        nm = fp.name
        if not nm.startswith("SC"):
            continue

        if nm.endswith("-PSG.edf"):
            psgs[key_from_psg(nm)] = fp
        elif nm.endswith("-Hypnogram.edf"):
            hyps[key_from_hyp(nm)] = fp

    keys = sorted(set(psgs) & set(hyps))
    return [(k, psgs[k], hyps[k]) for k in keys]

print("   Funciones para emparejar PSG/Hypnogram listas")


# ================================================================
# PASO 2: LECTURA DEL HYPNOGRAMA
# ================================================================

def leer_hypnograma_mne(hyp_path: Path) -> pd.DataFrame:
    """
    Lee anotaciones y devuelve DataFrame con columnas:
    [inicio, duracion, etapa]
    Etapas v√°lidas (ya normalizadas): W, 1, 2, 3, R.
    N3 = {3,4} colapsadas.
    """
    ann = mne.read_annotations(str(hyp_path))
    etapas = []

    for desc, onset, dur in zip(ann.description, ann.onset, ann.duration):
        if "Sleep stage" in desc:
            st = desc.replace("Sleep stage", "").strip()

            # üîπ Normalizaci√≥n de etiquetas
            if st in {"W", "1", "2", "3", "4", "R"}:
                # Colapsar N3/N4 ‚Üí "3"
                if st in {"3", "4"}:
                    st_norm = "3"
                else:
                    st_norm = st

                etapas.append({
                    "inicio": float(onset),
                    "duracion": float(dur),
                    "etapa": st_norm
                })

    df = pd.DataFrame(etapas)
    if df.empty:
        raise RuntimeError(f"Hypnograma vac√≠o o sin etiquetas v√°lidas: {hyp_path}")

    return df


print("   Lector de hypnograma disponible")


# ================================================================
# PASO 3: C√ìMPUTO DE OFFSET (ALINEACI√ìN PSG‚ÄìHYPNO)
# ================================================================

def _to_datetime(x):
    if x is None:
        return None
    if isinstance(x, tuple) and len(x) == 2:
        return datetime.fromtimestamp(x[0]) + timedelta(microseconds=x[1])
    try:
        return pd.to_datetime(x).to_pydatetime()
    except Exception:
        return x

def calcular_offset_segundos(psg_path: Path, hyp_path: Path) -> float:
    """
    offset = (inicio_hyp - inicio_psg) en segundos.
    Positivo si el hypnograma empieza despu√©s del PSG.
    """
    raw_psg = mne.io.read_raw_edf(str(psg_path), preload=False, verbose=False)
    raw_hyp = mne.io.read_raw_edf(str(hyp_path), preload=False, verbose=False)

    t_psg = _to_datetime(raw_psg.info.get('meas_date'))
    t_hyp = _to_datetime(raw_hyp.info.get('meas_date'))

    if t_psg is None or t_hyp is None:
        return 0.0

    return (t_hyp - t_psg).total_seconds()

print("   Funci√≥n de alineaci√≥n (offset) lista")


# ================================================================
# PASO 4: AN√ÅLISIS DE CANALES DEL PSG
# ================================================================

def analizar_canales(psg_path: Path):
    """
    Devuelve dict {nombre_canal: {'tipo','freq','n_samples','duracion'}}.
    """
    raw = mne.io.read_raw_edf(str(psg_path), preload=False, verbose=False)
    info = {}

    for ch in raw.ch_names:
        ch_up = ch.upper()

        if 'EEG' in ch_up:
            tipo = 'EEG'
        elif 'EOG' in ch_up:
            tipo = 'EOG'
        elif 'EMG' in ch_up:
            tipo = 'EMG'
        elif 'ECG' in ch_up or 'EKG' in ch_up:
            tipo = 'ECG'
        elif 'EVENT' in ch_up or 'MARKER' in ch_up:
            tipo = 'EVENTO'
        else:
            tipo = 'OTRO'

        info[ch] = {
            'tipo': tipo,
            'freq': float(raw.info['sfreq']),
            'n_samples': int(raw.n_times),
            'duracion': float(raw.times[-1]) if raw.n_times > 0 else 0.0
        }

    return info

print("   Clasificador de canales cargado")


# ================================================================
# PASO 5: EXTRACCI√ìN DE VENTANAS POR CANAL (CORREGIDA)
# ================================================================

def extraer_ventanas_por_canal(
    psg_path: Path,
    hyp_df: pd.DataFrame,
    canal_nombre: str,
    window_size: float,
    stride: float,
    t_ini_psg: float,
    t_fin_psg: float,
    hyp_offset: float = 0.0
):
    """
    Extrae ventanas solapadas del canal elegido usando:
        - l√≠mites de PSG (t_ini_psg, t_fin_psg)
        - hipnograma recortado
        - asignaci√≥n de etiqueta desde hyp_df
        - punto medio de la ventana
    """

    raw = mne.io.read_raw_edf(str(psg_path), preload=True, verbose=False)
    data, _ = raw[canal_nombre, :]
    x = data.flatten()
    fs = float(raw.info['sfreq'])

    win_samps   = int(round(window_size * fs))
    stride_samp = int(round(stride * fs))

    # Si la se√±al es demasiado corta, retornar vac√≠o
    if win_samps <= 0 or stride_samp <= 0 or len(x) < win_samps:
        return {
            "ventanas": np.empty((0, win_samps)),
            "etiquetas": [],
            "tiempos_inicio": [],
            "freq_muestreo": fs,
            "nombre_canal": canal_nombre
        }

    # Intervalos del hypnograma EN TIEMPO PSG
    starts = hyp_df["inicio"].to_numpy(float) + hyp_offset
    ends   = (hyp_df["inicio"] + hyp_df["duracion"]).to_numpy(float) + hyp_offset

    intervals = pd.IntervalIndex.from_arrays(starts, ends, closed="left")

    # N√∫mero m√°ximo posible de ventanas
    n_max = 1 + (len(x) - win_samps) // stride_samp

    ventanas = []
    etiquetas = []
    tiempos_inicio = []

    for i in range(n_max):
        t_ini = i * stride           # tiempo en eje PSG
        t_mid = t_ini + window_size/2

        # 1) Verificar si cae dentro del intervalo permitido
        if not (t_ini_psg <= t_mid <= t_fin_psg):
            continue

        # 2) Calcular muestras
        s = int(round(t_ini * fs))
        e = s + win_samps

        if e > len(x):
            break

        # 3) Guardar ventana
        ventanas.append(x[s:e])
        tiempos_inicio.append(t_ini)

        # 4) Etiqueta desde hypnograma
        idx = intervals.get_indexer([t_mid])[0]
        if idx == -1:
            etiquetas.append("W")      # fallback seguro
        else:
            etiquetas.append(hyp_df.iloc[idx]["etapa"])

    # Convertir a arrays
    ventanas = np.array(ventanas, dtype=np.float32)

    return {
        "ventanas": ventanas,
        "etiquetas": etiquetas,
        "tiempos_inicio": tiempos_inicio,
        "freq_muestreo": fs,
        "nombre_canal": canal_nombre
    }


print("extraer_ventanas_por_canal redefinida correctamente")



# ================================================================
# PASO 6: RECORTE A INTERVALO DE SUE√ëO (onset ‚Üí offset)
# ================================================================

def recortar_hyp_a_sueno(hyp_df: pd.DataFrame,
                          etiqueta_wake: str = "W",
                          margen_pre: float = 0.0,
                          margen_post: float = 0.0):
    """
    Recorta el hipnograma al intervalo [primer_no_W, ultimo_no_W] con m√°rgenes,
    usando el PUNTO MEDIO de cada epoch para decidir si se conserva.
    Devuelve (hyp_recortado, t_ini, t_fin) en EJE DE TIEMPO DEL HYPNOGRAMA.
    """
    # Epochs que NO son Wake
    mask_sueno = hyp_df['etapa'] != etiqueta_wake

    if mask_sueno.sum() == 0:
        # No hay sue√±o: devolvemos todo tal cual
        return hyp_df.copy(), None, None

    # Onset y offset del sue√±o (en segundos, eje hypnograma)
    t_on = hyp_df.loc[mask_sueno, 'inicio'].min()
    t_off = (hyp_df.loc[mask_sueno, 'inicio'] + hyp_df.loc[mask_sueno, 'duracion']).max()

    # Aplicar m√°rgenes
    t_ini = max(0.0, t_on - margen_pre)
    t_fin = t_off + margen_post

    # Usar el punto medio de cada epoch para decidir si queda dentro
    mid = hyp_df['inicio'] + 0.5 * hyp_df['duracion']
    mask_keep = (mid >= t_ini) & (mid <= t_fin)

    hyp_rec = hyp_df.loc[mask_keep].reset_index(drop=True)

    return hyp_rec, t_ini, t_fin


print("  Funci√≥n de recorte de sue√±o lista")

# ================================================================
print("\n Todas las funciones auxiliares cargadas")
print("="*80)


# Procesamiento principal

In [None]:
# ======================================================
# DETECTAR N√öMEROS DE WORKERS DEL EQUIPO
# ======================================================

import multiprocessing as mp

TOTAL_CORES = mp.cpu_count()
N_WORKERS =  max(1, int(TOTAL_CORES * 0.75))

print(" Informaci√≥n del sistema:")
print(f"   ‚Ä¢ N√∫cleos disponibles : {TOTAL_CORES}")
print(f"   ‚Ä¢ Workers a utilizar  : {N_WORKERS}  (~75%)")


In [None]:
# ============================================================
# TEST MULTICANAL ‚Äî 1 PACIENTE COMPLETO 
# ============================================================

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

print("Buscando pares PSG-Hypnogram...")
pares = encontrar_pares(RAW_DIR)

if len(pares) == 0:
    print("No se encontraron pares en RAW_DIR")
    raise SystemExit

# ------------------------------------------------------------------
# Elegimos un paciente para probar.
# Si se quiere espec√≠ficamente el √≠ndice , dejamos un "fallback" a 0
# por si el dataset no tiene tantos.
# ------------------------------------------------------------------
idx_paciente = 40
if len(pares) <= idx_paciente:
    idx_paciente = 0

key, psg_path, hyp_path = pares[idx_paciente]

print(f"\nProbando todos los canales para paciente: {key}")
print(f"   PSG: {psg_path.name}")
print(f"   HYP: {hyp_path.name}")

# ============================================================
# 1) Leer hypnograma
# ============================================================
hyp_df = leer_hypnograma_mne(hyp_path)

# ============================================================
# 2) Recorte a sue√±o (mismos m√°rgenes que el procesamiento global)
# ============================================================
MARGEN_PRE  = 0* 60   # 10 min antes del onset
MARGEN_POST = 0* 60   # 10 min despu√©s del offset

hyp_df_rec, t_ini_hyp, t_fin_hyp = recortar_hyp_a_sueno(
    hyp_df,
    etiqueta_wake="W",
    margen_pre=MARGEN_PRE,
    margen_post=MARGEN_POST
)

# Si no hubo sue√±o, usar todo el hypnograma
if t_ini_hyp is None:
    t_ini_hyp = float(hyp_df["inicio"].min())
    t_fin_hyp = float((hyp_df["inicio"] + hyp_df["duracion"]).max())
    hyp_df_uso = hyp_df
else:
    hyp_df_uso = hyp_df_rec

# Alinear hypnograma al PSG
offset = calcular_offset_segundos(psg_path, hyp_path)
t_ini_psg = t_ini_hyp + offset
t_fin_psg = t_fin_hyp + offset

print(f"\nIntervalo de sue√±o en eje PSG: [{t_ini_psg:.1f} s, {t_fin_psg:.1f} s]")

# ============================================================
# 3) Analizar canales
# ============================================================
canales = analizar_canales(psg_path)
canales_validos = [c for c, meta in canales.items() if meta["tipo"] != "EVENTO"]

print(f"\nCanales detectados ({len(canales_validos)}):")
for c in canales_validos:
    print("   ‚Ä¢", c)

resultados = {}  # guardamos X, y, t por canal

# ============================================================
# 4) Procesar canal por canal
# ============================================================
for canal in canales_validos:
    print(f"\nProcesando canal: {canal}")

    try:
        res = extraer_ventanas_por_canal(
            psg_path=psg_path,
            hyp_df=hyp_df_uso,
            canal_nombre=canal,
            window_size=WINDOW_SIZE,
            stride=STRIDE,
            t_ini_psg=t_ini_psg,
            t_fin_psg=t_fin_psg,
            hyp_offset=offset
        )
    except Exception as e:
        print("   Error:", e)
        continue

    X = res["ventanas"]
    y = np.array(res["etiquetas"])
    t0 = np.array(res["tiempos_inicio"])

    print(f"   Ventanas: {X.shape}")
    if len(y) > 0:
        vc_abs = pd.Series(y).value_counts().to_dict()
        vc_rel = (pd.Series(y).value_counts(normalize=True)
                  .round(3).to_dict())
        print(f"   Distribuci√≥n absoluta:   {vc_abs}")
        print(f"   Distribuci√≥n proporcional: {vc_rel}")
    else:
        print("   Sin ventanas extra√≠das para este canal")

    # guardar para revisi√≥n
    resultados[canal] = {"X": X, "y": y, "t": t0}

    # guardar archivo de prueba
    test_file = WINDOWS_DIR / f"{key}_{canal.replace(' ','_')}_TEST.npz"
    np.savez_compressed(
        test_file,
        ventanas=X,
        etiquetas=y,
        tiempos_inicio=t0,
        freq_muestreo=res["freq_muestreo"],
        nombre_canal=res["nombre_canal"]
    )
    print(f"   Guardado test ‚Üí {test_file.name}")

# ============================================================
# 5) Comparar alineaci√≥n entre canales
# ============================================================
print("\n" + "="*60)
print("COMPARACI√ìN ENTRE CANALES (N¬∫ DE VENTANAS)")
print("="*60)

if not resultados:
    print("No se obtuvieron resultados en ning√∫n canal.")
else:
    # Canal base para comparar n¬∫ de ventanas
    can_base = list(resultados.keys())[0]
    t_base = resultados[can_base]["t"]
    print(f"Canal base para comparaci√≥n: {can_base}")

    todo_consistente = True
    for canal, data in resultados.items():
        dt = len(data["t"]) - len(t_base)
        if dt == 0:
            print(f"{canal}: mismo n√∫mero de ventanas que {can_base} ({len(data['t'])})")
        else:
            todo_consistente = False
            print(f"{canal}: {len(data['t'])} vs {len(t_base)} ventanas (diferencia {dt})")

    if todo_consistente:
        print("\n TODOS los canales de este paciente tienen el MISMO n√∫mero de ventanas.")
    else:
        print("\n Hay canales con distinto n√∫mero de ventanas. Revisar arriba.")

# ============================================================
# 6) GRAFICAR TIMELINES PARA TODOS LOS CANALES
# ============================================================
if resultados:
    colores = {
        "W":"#f4d03f","1":"#e67e22","2":"#3498db",
        "N1":"#e67e22","N2":"#3498db",
        "3":"#2ecc71","4":"#27ae60","N3":"#2ecc71",
        "R":"#9b59b6","REM":"#9b59b6"
    }

    n_channels = len(resultados)
    fig, axes = plt.subplots(n_channels, 1, figsize=(16, 2*n_channels), sharex=True)

    if n_channels == 1:
        axes = [axes]

    for ax, (canal, data) in zip(axes, resultados.items()):
        y = data["y"]
        cols = [colores.get(e, "gray") for e in y]
        ax.bar(range(len(y)), np.ones(len(y)), color=cols, width=1.0)
        ax.set_title(canal)
        ax.set_yticks([])

    # leyenda compacta
    labels_usadas = sorted({et for data in resultados.values() for et in data["y"]})
    patches = []
    for l in labels_usadas:
        c = colores.get(l, "gray")
        patches.append(mpatches.Patch(color=c, label=l))
    axes[0].legend(handles=patches, loc="upper right", bbox_to_anchor=(1.15, 1))

    plt.tight_layout()
    plt.show()
else:
    print("\n No hay resultados para graficar.")


In [None]:
# ============================================================
# PROCESAMIENTO PRINCIPAL 
# ============================================================

import warnings
import numpy as np
from tqdm.notebook import tqdm
from collections import defaultdict

warnings.filterwarnings("ignore", category=RuntimeWarning)

print("Iniciando procesamiento de registros PSG + Hypnogram")
print("="*80)
print("Nota: Este script SOBRESCRIBE los archivos .npz existentes y el resumen_global_mas_corto.csv")

# Par√°metros de recorte (en segundos)
MARGEN_PRE  = 0 * 60   #  min antes del onset de sue√±o
MARGEN_POST = 0 * 60   #  min despu√©s del offset de sue√±o

print(f"    Recorte de sue√±o activado:")
print(f"      ‚Ä¢ Margen antes  : {MARGEN_PRE/60:.1f} min")
print(f"      ‚Ä¢ Margen despu√©s: {MARGEN_POST/60:.1f} min")

pares = encontrar_pares(RAW_DIR)
print(f"\n Pares PSG-Hypnogram encontrados: {len(pares)}")

resumen_rows = []
errores = []

# Para chequear consistencia por paciente (n¬∫ de ventanas por canal)
patient_channel_windows = defaultdict(dict)  # {paciente: {canal: n_ventanas}}

# Para distribuci√≥n global de etapas: contamos solo UN canal por paciente
stage_counts = defaultdict(int)             # {etapa: conteo}
pacientes_etapas_contadas = set()          # pacientes ya considerados en stage_counts

# Bucle principal: por paciente (sin multiproceso)
for key, psg_path, hyp_path in tqdm(
    pares, desc="Procesando pacientes", unit="pac", leave=True
):
    try:
        # 1) Leer hypnograma
        hyp_df = leer_hypnograma_mne(hyp_path)

        # 2) Recortar a sue√±o (en eje del HYPNOGRAMA)
        hyp_df_rec, t_ini_hyp, t_fin_hyp = recortar_hyp_a_sueno(
            hyp_df,
            etiqueta_wake="W",
            margen_pre=MARGEN_PRE,
            margen_post=MARGEN_POST
        )

        if t_ini_hyp is None:
            # Caso raro: no hay sue√±o, usar todo el hypnograma
            hyp_df_uso = hyp_df
            t_ini_hyp = float(hyp_df['inicio'].min())
            t_fin_hyp = float((hyp_df['inicio'] + hyp_df['duracion']).max())
        else:
            hyp_df_uso = hyp_df_rec

        # 3) Alinear tiempos PSG‚ÄìHypnograma
        offset = calcular_offset_segundos(psg_path, hyp_path)
        # Pasar a eje de tiempo del PSG
        t_ini_psg = t_ini_hyp + offset
        t_fin_psg = t_fin_hyp + offset

        # 4) Analizar canales
        canales = analizar_canales(psg_path)

        # 5) Procesar por canal
        desc_inner = f"[{key}] Procesando canales"
        for canal, meta in tqdm(
            canales.items(), desc=desc_inner, unit="canal", leave=False
        ):
            # Saltar canales de eventos/marcadores
            if meta["tipo"] == "EVENTO":
                continue

            try:
                resultado = extraer_ventanas_por_canal(
                    psg_path=psg_path,
                    hyp_df=hyp_df_uso,
                    canal_nombre=canal,
                    window_size=WINDOW_SIZE,
                    stride=STRIDE,
                    t_ini_psg=t_ini_psg,
                    t_fin_psg=t_fin_psg,
                    hyp_offset=offset
                )

                ventanas = resultado["ventanas"]
                n_vent = len(ventanas)

                # Registrar n¬∫ de ventanas por canal para el chequeo de consistencia
                patient_channel_windows[key][canal] = n_vent

                # Contar distribuci√≥n de etapas SOLO una vez por paciente
                if key not in pacientes_etapas_contadas:
                    for et in resultado["etiquetas"]:
                        stage_counts[str(et)] += 1
                    pacientes_etapas_contadas.add(key)

                # ================================
                # GUARDAR EN NPZ (COMPRIMIDO)
                # (sobrescribe si ya existe)
                # ================================
                out_file = WINDOWS_DIR / f"{key}_{canal.replace(' ', '_')}.npz"
                np.savez_compressed(
                    out_file,
                    ventanas=ventanas,
                    etiquetas=np.array(resultado["etiquetas"], dtype=object),
                    tiempos_inicio=np.array(resultado["tiempos_inicio"]),
                    freq_muestreo=resultado["freq_muestreo"],
                    nombre_canal=resultado["nombre_canal"]
                )

                resumen_rows.append({
                    "Paciente": key,
                    "Canal": canal,
                    "Tipo": meta["tipo"],
                    "Fs": meta["freq"],
                    "N_Ventanas": n_vent,
                    "Archivo": str(out_file)
                })

            except Exception as e:
                errores.append((key, canal, str(e)))

    except Exception as e:
        errores.append((key, "__paciente__", str(e)))


# ============================================================
# RESUMEN GLOBAL
# ============================================================
df_resumen = pd.DataFrame(resumen_rows)
out_csv = ANALYSIS_DIR / "resumen_global_mas_corto.csv"
df_resumen.to_csv(out_csv, index=False)

print("\n" + "="*80)
print("Resumen guardado (SOBREESCRITO) en:", out_csv)
print("Totales:")
print("  ‚Ä¢ Pacientes procesados:", len(set([r["Paciente"] for r in resumen_rows])) if resumen_rows else 0)
print("  ‚Ä¢ Canales procesados:  ", len(resumen_rows))
print("  ‚Ä¢ Ventanas totales:    ", int(df_resumen["N_Ventanas"].sum()) if not df_resumen.empty else 0)

if not df_resumen.empty:
    display(df_resumen.head(10))
    try:
        print("\n Ventanas por tipo de canal:")
        display(df_resumen.groupby("Tipo")["N_Ventanas"].describe())
    except Exception:
        pass

# ============================================================
# CONSISTENCIA POR PACIENTE (N¬∫ VENTANAS POR CANAL)
# ============================================================
print("\n" + "="*80)
print(" Consistencia interna por paciente (n¬∫ de ventanas por canal)")
consistent_keys = []
inconsistent_info = []

for pac, canales_dict in patient_channel_windows.items():
    counts = list(canales_dict.values())
    if len(counts) == 0:
        continue
    if len(set(counts)) == 1:
        consistent_keys.append(pac)
    else:
        inconsistent_info.append((pac, canales_dict))

print(f"  Pacientes con todos los canales consistentes: {len(consistent_keys)}")

if inconsistent_info:
    print(f"  ‚Ä¢ Pacientes con inconsistencias: {len(inconsistent_info)}")
    print("    (mostrando hasta 5 ejemplos)")
    for pac, canales_dict in inconsistent_info[:5]:
        print(f"    - {pac}:")
        for canal, n_vent in canales_dict.items():
            print(f"        {canal}: {n_vent} ventanas")
else:
    print("   Todos los pacientes tienen el MISMO n√∫mero de ventanas en todos los canales (ignorando EVENTO).")

# ============================================================
# DISTRIBUCI√ìN GLOBAL DE ETAPAS (W, N1, N2, N3, REM)
# ============================================================
if stage_counts:
    print("\n" + "="*80)
    print(" Distribuci√≥n global de etapas de sue√±o (por ventana,")
    print("   contando SOLO un canal por paciente):\n")

    etapa_order = ["W", "1", "2", "3", "R"]
    etapa_name = {
        "W": "Wake",
        "1": "N1",
        "2": "N2",
        "3": "N3",   # N3 (ya unifica N3+N4 en el preprocesamiento)
        "R": "REM"
    }

    total_windows_stages = sum(stage_counts.values())
    for et in etapa_order:
        if et in stage_counts:
            cnt = stage_counts[et]
            pct = 100.0 * cnt / total_windows_stages if total_windows_stages > 0 else 0.0
            print(f"   ‚Ä¢ {etapa_name[et]:>3} ({et}): {cnt:5d} ventanas  ‚Üí {pct:5.1f}%")

else:
    print("\n No se pudo calcular la distribuci√≥n de etapas (sin datos de ventanas).")

# ============================================================
#  ERRORES (SI LOS HUBO)
# ============================================================
if errores:
    print("\n" + "="*80)
    print(f" Errores encontrados ({len(errores)}): (mostrando hasta 10)")
    for i, (k, c, msg) in enumerate(errores[:10], 1):
        print(f"  {i:02d}. {k} | {c} -> {msg}")
    if len(errores) > 10:
        print(f"  ... y {len(errores)-10} m√°s")
else:
    print("\n Sin errores reportados. Todo OK.")


# Cargar datos

In [None]:
# ================================================
# Cargar y visualizar ventanas (.npz / .pkl)
# ================================================
import numpy as np
import pickle
from collections import Counter
import matplotlib.pyplot as plt
from pathlib import Path

# Mapeos de etiquetas
ID2LABEL = {0: "W", 1: "N1", 2: "N2", 3: "N3", 4: "REM"}
LABEL_MAP = {
    "W": 0, "0": 0,
    "1": 1, "N1": 1,
    "2": 2, "N2": 2,
    "3": 3, "4": 3, "N3": 3,
    "R": 4, "REM": 4
}

def cargar_ventanas(
    paciente: str,
    canal: str,
    return_ids: bool = False,
    mmap_npz: bool = True,
    windows_dir: Path | None = None
):
    """
    Carga ventanas de un paciente y canal espec√≠fico desde:
      - .npz nuevo o viejo
      - .pkl antiguo (fallback)

    windows_dir:
      - Si es None -> usa la variable global WINDOWS_DIR.
      - Si es Path -> usa esa carpeta (sirve para tener carpeta 'larga'
        y carpeta 'corta' a la vez).

    Devuelve un dict con:
      - 'ventanas': array (N, L) o (N, L, C)
      - 'etiquetas': array de IDs 0-4 o strings, seg√∫n return_ids
      - 'tiempos_inicio': array (N,)
      - 'freq_muestreo': float
      - 'nombre_canal': str
    """
    base_dir = Path(windows_dir) if windows_dir is not None else WINDOWS_DIR
    base = base_dir / f"{paciente}_{canal.replace(' ', '_')}"
    npz_path = base.with_suffix(".npz")
    pkl_path = base.with_suffix(".pkl")

    # --------------------
    # Preferir .npz
    # --------------------
    if npz_path.exists():
        d = np.load(
            npz_path,
            allow_pickle=True,   # <- IMPORTANTE para etiquetas dtype=object
            mmap_mode='r' if mmap_npz else None
        )
        files = set(d.files)

        # Formato viejo: X, y, t, fs, canal
        if {"X", "y", "t", "fs", "canal"}.issubset(files):
            X = d["X"].astype(np.float32, copy=False)
            labels_raw = d["y"]
            t = d["t"].astype(np.float32, copy=False)
            fs = float(d["fs"])
            canal_name = str(d["canal"])

        # Formato nuevo: ventanas, etiquetas, tiempos_inicio, freq_muestreo, nombre_canal
        elif {"ventanas", "etiquetas", "tiempos_inicio",
              "freq_muestreo", "nombre_canal"}.issubset(files):
            X = d["ventanas"].astype(np.float32, copy=False)
            labels_raw = d["etiquetas"]
            t = d["tiempos_inicio"].astype(np.float32, copy=False)
            fs = float(d["freq_muestreo"])
            canal_name = str(d["nombre_canal"])

        else:
            print(f" Formato .npz desconocido: {npz_path.name}, claves={d.files}")
            return None

        labels_raw = np.asarray(labels_raw)
        if labels_raw.dtype.kind in {"U", "S", "O"}:
            y = np.array([LABEL_MAP.get(str(s), 255) for s in labels_raw],
                         dtype=np.uint8)
        else:
            y = labels_raw.astype(np.uint8, copy=False)
        fmt = ".npz"

    # --------------------
    # Fallback: .pkl antiguo
    # --------------------
    elif pkl_path.exists():
        with open(pkl_path, "rb") as f:
            raw = pickle.load(f)

        labels_raw = raw.get("etiquetas", [])
        if labels_raw and isinstance(labels_raw[0], str):
            y = np.array([LABEL_MAP.get(s, 255) for s in labels_raw], dtype=np.uint8)
        else:
            y = np.asarray(labels_raw, dtype=np.uint8)

        X = np.asarray(raw.get("ventanas", []), dtype=np.float32)
        t = np.asarray(raw.get("tiempos_inicio", []), dtype=np.float32)
        fs = float(raw.get("freq_muestreo", 100.0))
        canal_name = raw.get("nombre_canal", "CANAL")
        fmt = ".pkl"

    else:
        print(f" No se encontr√≥ ni {npz_path.name} ni {pkl_path.name}")
        return None

    # --------------------
    # Preparar etiquetas para salida
    # --------------------
    n = X.shape[0]
    total = max(1, n)

    if return_ids:
        y_out = y
        dist_keys = [ID2LABEL.get(int(v), f"id{int(v)}") for v in y]
    else:
        y_out = np.array([ID2LABEL.get(int(v), f"id{int(v)}") for v in y],
                         dtype=object)
        dist_keys = y_out

    dist = Counter(dist_keys)
    print(f" Cargado ({fmt}): {paciente} - {canal_name}")
    print(f"   ‚Ä¢ Ventanas: {X.shape} | Fs: {fs} Hz")
    print(f"   ‚Ä¢ Distribuci√≥n:")
    for k, c in sorted(dist.items()):
        print(f"     {k}: {c} ({100.0*c/total:.1f}%)")

    return {
        "ventanas": X,
        "etiquetas": y_out,
        "tiempos_inicio": t,
        "freq_muestreo": fs,
        "nombre_canal": canal_name
    }


In [None]:
# ================================================
# Recuento global de etiquetas 
# ================================================
import pandas as pd
import numpy as np
from collections import Counter
from contextlib import redirect_stdout
import io


from pathlib import Path

# Ruta base 
BASE_DIR = Path(
    r"C:\Users\shipa\OneDrive\Escritorio\Inteligencia Computacional"
    r"\sleep-edf-database-expanded-1.0.0\sleep-edf-database-expanded-1.0.0"
    r"\sleep-cassette\ventanas_out_mas_corto"
)

# Carpeta donde est√°n los .npz nuevos (las ventanas recortadas)
WINDOWS_DIR  = BASE_DIR / "ventanas_extraidas_mas_corto"

# Carpeta donde est√° el resumen "mas corto"
ANALYSIS_DIR = BASE_DIR / "analisis_canales_mas_corto"

RESUMEN_CSV = ANALYSIS_DIR / "resumen_global_mas_corto.csv"


df = pd.read_csv(RESUMEN_CSV)

# Obtener pares √∫nicos PACIENTE‚ÄìCANAL
pares = (
    df[["Paciente", "Canal"]]
    .dropna()
    .drop_duplicates()
    .reset_index(drop=True)
)

conteo_global = Counter()

def extraer_etiquetas(data: dict):
    """Extrae etiquetas en str desde el dict de cargar_ventanas()."""
    y = np.asarray(data.get("etiquetas", []))
    return [str(e) for e in y]


for _, row in pares.iterrows():
    paciente, canal = row["Paciente"], row["Canal"]

    buf = io.StringIO()
    try:
        # Silenciar prints de cargar_ventanas
        with redirect_stdout(buf):
            data = cargar_ventanas(paciente, canal, return_ids=False)
    except Exception:
        continue

    if data is None:
        continue

    etiquetas = extraer_etiquetas(data)
    conteo_global.update(etiquetas)


# ---- Imprimir resultados ----
total = sum(conteo_global.values())

if total == 0:
    print("No se encontraron etiquetas.")
else:
    orden = ["W", "N1", "N2", "N3", "REM"]
    orden += [e for e in sorted(conteo_global) if e not in orden]

    print(" | ".join(
        f"{etapa}: {conteo_global.get(etapa,0) / total * 100:.2f}%"
        for etapa in orden
    ))


# Gr√°fico de cada se√±al

In [None]:
from pathlib import Path

WINDOWS_DIR_CORTO = WINDOWS_DIR
WINDOWS_DIR_LARGO = Path(
    r"C:\Users\shipa\OneDrive\Escritorio\Inteligencia Computacional\sleep-edf-database-expanded-1.0.0\sleep-edf-database-expanded-1.0.0\sleep-cassette\ventanas_out\ventanas_extraidas"
)

In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import numpy as np
import mne

def plot_examen_y_ventanas(
    idx_paciente=0,
    canal_preferido="EEG Fpz-Cz",
    win_sec=30.0,
    decim_factor=50
):
    """
    Panel 1: PSG completo (se√±al EEG decimada).
    Panel 2: Hipnograma (etapas por epoch).
    Panel 3: Ventanas:
        - Capa superior: ventanas 'largas' (todo el examen).
        - Capa inferior: ventanas 'cortas' (dataset recortado).
    Todo alineado en tiempo (horas desde inicio PSG).
    """
    # -----------------------
    # 0) Buscar par PSG/HYP
    # -----------------------
    pares = encontrar_pares(RAW_DIR)
    if not pares:
        print("No hay pares PSG/Hyp en RAW_DIR")
        return
    if idx_paciente >= len(pares):
        print(f"idx_paciente={idx_paciente} fuera de rango (hay {len(pares)} pares)")
        return

    paciente, psg_path, hyp_path = pares[idx_paciente]
    print(f"Paciente: {paciente}")
    print(f"  PSG:  {psg_path.name}")
    print(f"  HYP:  {hyp_path.name}")

    # -----------------------
    # 1) Leer PSG y canal EEG
    # -----------------------
    raw = mne.io.read_raw_edf(str(psg_path), preload=True, verbose=False)

    if canal_preferido in raw.ch_names:
        canal = canal_preferido
    else:
        eeg_cands = [ch for ch in raw.ch_names if "EEG" in ch.upper()]
        if not eeg_cands:
            print("No se encontraron canales EEG en este archivo.")
            return
        canal = eeg_cands[0]

    print(f"Canal usado: {canal}")

    data, _ = raw[canal, :]
    x_full = data.flatten()
    t_full = raw.times  # segundos

    if decim_factor > 1:
        x_plot = x_full[::decim_factor]
        t_plot = t_full[::decim_factor]
    else:
        x_plot = x_full
        t_plot = t_full

    # -----------------------
    # 2) Hipnograma
    # -----------------------
    hyp_df = leer_hypnograma_mne(hyp_path)
    offset = calcular_offset_segundos(psg_path, hyp_path)

    starts = hyp_df["inicio"].to_numpy(float) + offset
    ends   = (hyp_df["inicio"] + hyp_df["duracion"]).to_numpy(float) + offset
    etapas = hyp_df["etapa"].tolist()

    stage_order = ["W", "1", "2", "3", "4", "R"]
    stage_to_y = {st: i for i, st in enumerate(stage_order)}
    stage_to_label = {
        "W": "Wake (W)",
        "1": "N1",
        "2": "N2",
        "3": "N3",
        "4": "N3/N4",
        "R": "REM",
    }
    colors_stages = {
        "W":  "#f4d03f",
        "1":  "#e67e22",
        "2":  "#3498db",
        "3":  "#2ecc71",
        "4":  "#27ae60",
        "R":  "#9b59b6",
    }

    # -----------------------
    # 3) Ventanas largas y cortas
    # -----------------------
    # Dataset "corto" (carpeta actual WINDOWS_DIR_CORTO)
    data_corto = cargar_ventanas(
        paciente, canal, return_ids=False, windows_dir=WINDOWS_DIR_CORTO
    )
    if data_corto is None:
        print("No se pudieron cargar ventanas 'cortas'")
        return

    etiquetas_corto = np.asarray(data_corto["etiquetas"])
    t_corto = np.asarray(data_corto["tiempos_inicio"])  # segundos

    # Dataset "largo" (carpeta WINDOWS_DIR_LARGO)
    data_largo = cargar_ventanas(
        paciente, canal, return_ids=False, windows_dir=WINDOWS_DIR_LARGO
    )
    if data_largo is None:
        print("No se pudieron cargar ventanas 'largas' (examen completo)")
        return

    etiquetas_largo = np.asarray(data_largo["etiquetas"])
    t_largo = np.asarray(data_largo["tiempos_inicio"])

    # Mismos colores pero con claves ya fusionadas (W, N1, N2, N3, REM)
    colors_labels = {
        "W":   "#f4d03f",
        "N1":  "#e67e22",
        "N2":  "#3498db",
        "N3":  "#2ecc71",
        "REM": "#9b59b6",
    }

    # -----------------------
    # 4) Figuras
    # -----------------------
    fig, (ax_sig, ax_hyp, ax_win) = plt.subplots(
        3, 1,
        figsize=(16, 10),
        sharex=False,
        gridspec_kw={"height_ratios": [3, 1.5, 1.8]}
    )

    # Panel 1: se√±al completa
    ax_sig.plot(t_plot / 3600.0, x_plot, linewidth=0.4, color="black", alpha=0.7)
    ax_sig.set_ylabel("Amplitud")
    ax_sig.set_title(f"{paciente} ‚Äî {canal}")
    ax_sig.grid(True, alpha=0.3)

    # Panel 2: hipnograma
    for s, e, st in zip(starts, ends, etapas):
        if st not in stage_to_y:
            continue
        y = stage_to_y[st]
        ax_hyp.hlines(
            y,
            s / 3600.0,
            e / 3600.0,
            colors=colors_stages.get(st, "gray"),
            linewidth=8,
            alpha=0.9,
        )

    ax_hyp.set_yticks([stage_to_y[st] for st in stage_order])
    ax_hyp.set_yticklabels([stage_to_label.get(st, st) for st in stage_order])
    ax_hyp.set_ylabel("Etapas\n(hipnograma)")
    ax_hyp.grid(True, axis="x", alpha=0.2)

    # Leyenda hipnograma
    legend_handles = [
        mpatches.Patch(color=colors_stages[st], label=stage_to_label.get(st, st))
        for st in stage_order
        if st in colors_stages
    ]
    ax_hyp.legend(handles=legend_handles, title="Etapas de Sue√±o",
                  loc="upper right", bbox_to_anchor=(1.02, 1))

    # Panel 3: ventanas largas vs cortas
    ax_win.set_title("Ventanas recortadas vs completas (alineadas en tiempo)")
    ax_win.set_xlabel("Tiempo (horas desde inicio PSG)")
    ax_win.set_yticks([0.25, 0.75])
    ax_win.set_yticklabels(["CORTO", "LARGO"])

    # Helper para dibujar barras de ventanas
    def _plot_windows(ax, t_starts, labels, y_center, height, label):
        for ti, lab in zip(t_starts, labels):
            lab_str = str(lab)
            # etiquetas vienen como 'W', 'N1', etc (ya mapeado)
            color = colors_labels.get(lab_str, "gray")
            ax.add_patch(
                mpatches.Rectangle(
                    (ti / 3600.0, y_center - height / 2),
                    width=win_sec / 3600.0,
                    height=height,
                    color=color,
                    alpha=0.9 if label == "corto" else 0.3,  # largas m√°s transparentes
                    linewidth=0,
                )
            )

    _plot_windows(ax_win, t_largo, etiquetas_largo, y_center=0.75, height=0.35, label="largo")
    _plot_windows(ax_win, t_corto, etiquetas_corto, y_center=0.25, height=0.35, label="corto")

    # Leyenda para etapas (reutilizamos colors_labels)
    leg_handles2 = [
        mpatches.Patch(color=c, label=l) for l, c in colors_labels.items()
    ]
    ax_win.legend(handles=leg_handles2, title="Etapas (ventanas)",
                  loc="upper right", bbox_to_anchor=(1.02, 1))

    # Limites X: todo el PSG
    ax_win.set_xlim(t_plot[0] / 3600.0, t_plot[-1] / 3600.0)

    plt.tight_layout()
    plt.show()

# Ejemplo de uso:
plot_examen_y_ventanas(idx_paciente=0, canal_preferido="EEG Fpz-Cz")



In [None]:
# ============================================================
# RECONSTRUIR resumen_global_mas_corto.csv DESDE LOS .NPZ NUEVOS
# ============================================================
import numpy as np
import pandas as pd
from pathlib import Path

rows = []

for path in sorted(WINDOWS_DIR.glob("*.npz")):
    stem = path.stem  # ej: "SC4001E_EEG_Fpz-Cz"
    paciente, canal_stub = stem.split("_", 1)
    canal = canal_stub.replace("_", " ")          # "EEG Fpz-Cz"

    data = cargar_ventanas(paciente, canal, return_ids=True, mmap_npz=True)
    if data is None:
        continue

    X = data["ventanas"]
    fs = float(data["freq_muestreo"])
    n_ventanas = X.shape[0]

    # Tipo de canal r√°pido
    cname_up = canal.upper()
    if "EEG" in cname_up:
        tipo = "EEG"
    elif "EOG" in cname_up:
        tipo = "EOG"
    elif "EMG" in cname_up:
        tipo = "EMG"
    else:
        tipo = "OTRO"

    rows.append({
        "Paciente": paciente,
        "Canal": canal,
        "Tipo": tipo,
        "Fs": fs,
        "N_Ventanas": n_ventanas,
        "Archivo": str(path)
    })

df_new = pd.DataFrame(rows)
out_csv = ANALYSIS_DIR / "resumen_global_mas_corto.csv"
df_new.to_csv(out_csv, index=False)

print("Nuevo resumen reconstruido desde los .npz cortos")
print("Guardado en:", out_csv)
display(df_new.head())


In [None]:
from collections import Counter
import numpy as np

data = cargar_ventanas("SC4001E", "EEG Fpz-Cz", return_ids=False)
y = np.asarray(data["etiquetas"])

dist = Counter(y)
total = len(y)
print("Distribuci√≥n sobre ESTE .npz:")
for k, v in sorted(dist.items()):
    print(f"  {k}: {v} ({100*v/total:.2f}%)")


# Separaci√≥n de datos

In [None]:
# ============================================================
# DATASETS CoSleepNet: RAW + DCT 
# - Genera NPZ por paciente y un √≠ndice CSV por dataset
# ============================================================

import numpy as np
import pandas as pd
import gc
from pathlib import Path
from scipy.fft import dct   # DCT tipo-II

# ---------- CONFIG GENERAL ----------
ANALYSIS_FILE = ANALYSIS_DIR / "resumen_global_mas_corto.csv"

# EDF: W, 1, 2, 3, 4, R  -->  W, N1, N2, N3 (3+4), REM
LABEL2ID = {"W":0, "N1":1, "N2":2, "N3":3, "REM":4}
ID2LABEL = {v:k for k, v in LABEL2ID.items()}

def _etiquetas_to_ids(etiquetas_raw):
    """['W','1','2','3','4','R'] -> [0..4] con N3 = {3,4}."""
    etiquetas_raw = np.array(etiquetas_raw, dtype=str)
    y = np.empty(len(etiquetas_raw), dtype=np.uint8)
    for i, s in enumerate(etiquetas_raw):
        if s == "W":
            y[i] = LABEL2ID["W"]
        elif s == "1":
            y[i] = LABEL2ID["N1"]
        elif s == "2":
            y[i] = LABEL2ID["N2"]
        elif s in ("3", "4"):
            y[i] = LABEL2ID["N3"]
        elif s == "R":
            y[i] = LABEL2ID["REM"]
        else:
            y[i] = 255   # etiqueta rara
    return y

# Aliases de canales
CHANNEL_PATTERNS = {
    "EEG1": ["EEG Fpz-Cz", "Fpz-Cz"],
    "EEG2": ["EEG Pz-Oz", "Pz-Oz"],
    "EOG" : ["EOG horizontal", "EOG", "EOG horizontal derivation"],
}

# Datasets CoSleepNet
# 1: EEG1 + EEG2   (paper)
# 2: EEG1 + EOG
# 3: EEG1 + EEG2 + EOG  
RECIPES = {
    1: (["EEG1", "EEG2"],              "cosleep_ds1_eeg1_eeg2"),
    2: (["EEG1", "EOG"],               "cosleep_ds2_eeg1_eog"),
    3: (["EEG1", "EEG2", "EOG"],       "cosleep_ds3_eeg1_eeg2_eog"),
}

LIMIT_PATIENTS   = None     # None o int para debug (limitar n¬∫ pacientes)
DATASET_DIR      = Path("datasets_cosleepnet_on_disk")
DATASET_DIR.mkdir(parents=True, exist_ok=True)

# ---------- Utils ----------

def pick_channel_name(dfp: pd.DataFrame, aliases: list[str]) -> str | None:
    """Elige el nombre real de canal que matchee alg√∫n alias."""
    names = list(dfp["Canal"].unique())
    u_names = [n.upper() for n in names]
    for alias in aliases:
        alias_u = alias.upper()
        # match exacto
        for n, u in zip(names, u_names):
            if u == alias_u:
                return n
        # substring
        for n, u in zip(names, u_names):
            if alias_u in u:
                return n
    return None

def load_windows_npz(path: Path):
    """Carga el .npz creado en el preprocesamiento principal."""
    d = np.load(path, allow_pickle=True)
    X = d["ventanas"].astype(np.float32)          # (N, L)
    etiquetas = d["etiquetas"].astype(str)        # (N,)
    t = d["tiempos_inicio"].astype(np.float32)    # (N,)
    fs = float(d["freq_muestreo"])
    canal = str(d["nombre_canal"])
    return X, etiquetas, t, fs, canal

def compute_dct_per_channel(X_raw):
    """
    X_raw: (N, L, C)  -> DCT por canal en eje temporal (L).
    Devuelve X_dct con misma shape.
    """
    X_ncL = np.transpose(X_raw, (0, 2, 1))          # (N, C, L)
    X_dct_ncL = dct(X_ncL, type=2, axis=2, norm="ortho")
    X_dct = np.transpose(X_dct_ncL, (0, 2, 1)).astype(np.float32)
    return X_dct

# ---------- Cargar resumen y lista de pacientes ----------
df = pd.read_csv(ANALYSIS_FILE)
patients_all = sorted(df["Paciente"].unique())
if LIMIT_PATIENTS is not None:
    patients_all = patients_all[:LIMIT_PATIENTS]

print(f" Pacientes disponibles en resumen_global_mas_corto: {len(patients_all)}")

# ============================================================
# CONSTRUCCI√ìN *ON DISK* (NPZ POR PACIENTE + √çNDICE CSV)
# ============================================================
for ds_id, (ch_keys, tag) in RECIPES.items():
    print("\n" + "="*70)
    print(f" Construyendo dataset {ds_id}: canales {ch_keys}  [{tag}]")
    print("="*70)

    index_rows = []   # para index CSV
    total_by_class = np.zeros(5, dtype=np.int64)

    for p in patients_all:
        dfp = df[df["Paciente"] == p]

        # 1) nombres reales de canales
        chosen = {}
        for ck in ch_keys:
            ch_name = pick_channel_name(dfp, CHANNEL_PATTERNS[ck])
            if ch_name is None:
                chosen = None
                break
            chosen[ck] = ch_name

        if chosen is None:
            continue  # paciente no tiene todos los canales requeridos

        # 2) cargar ventanas por canal y alinear por tiempo
        times_rounded = {}
        X_by_ck = {}
        y_raw_by_ck = {}

        for ck, ch_name in chosen.items():
            row = dfp[dfp["Canal"] == ch_name].iloc[0]
            npz_path = Path(row["Archivo"])
            X, etiquetas_raw, t, fs, canal_real = load_windows_npz(npz_path)

            X_by_ck[ck] = X
            y_raw_by_ck[ck] = np.array(etiquetas_raw, dtype=str)
            times_rounded[ck] = np.round(t.astype(np.float32), 4)

        # intersecci√≥n de tiempos
        common = set(times_rounded[ch_keys[0]])
        for ck in ch_keys[1:]:
            common &= set(times_rounded[ck])
        if not common:
            continue

        common_sorted = np.array(sorted(list(common)), dtype=np.float32)

        idx_maps = {}
        for ck in ch_keys:
            t = times_rounded[ck]
            t2idx = {float(tt): i for i, tt in enumerate(t)}
            idx_maps[ck] = [t2idx[float(tt)] for tt in common_sorted]

        # 3) recortar y chequear longitudes
        X_aligned_list = []
        y_ref = None
        for ck in ch_keys:
            X_ch = X_by_ck[ck][idx_maps[ck]]  # (n, L_ck)
            if y_ref is None:
                y_ref = y_raw_by_ck[ck][idx_maps[ck]]
            X_aligned_list.append(X_ch)

        L_min = min(X.shape[1] for X in X_aligned_list)
        X_aligned_list = [X[:, :L_min] for X in X_aligned_list]

        # 4) tensor RAW (n, L_min, C)
        X_raw_p = np.stack(X_aligned_list, axis=-1).astype(np.float32)

        #  Normalizaci√≥n por paciente y canal (z-score)
        mu = X_raw_p.mean(axis=(0, 1), keepdims=True)          # (1,1,C)
        sigma = X_raw_p.std(axis=(0, 1), keepdims=True) + 1e-6
        X_raw_p = (X_raw_p - mu) / sigma

        # 5) DCT por canal
        X_dct_p = compute_dct_per_channel(X_raw_p)

        # 6) etiquetas -> ids + filtro 255
        y_p = _etiquetas_to_ids(y_ref)
        valid_mask = y_p != 255
        X_raw_p = X_raw_p[valid_mask]
        X_dct_p = X_dct_p[valid_mask]
        y_p     = y_p[valid_mask]

        n_p = y_p.shape[0]
        if n_p == 0:
            continue

        # 7) Guardar ESTE paciente en disco
        out_npz = DATASET_DIR / f"{tag}_{p}.npz"
        np.savez_compressed(
            out_npz,
            X_raw=X_raw_p.astype(np.float32),
            X_dct=X_dct_p.astype(np.float32),
            y=y_p.astype(np.uint8),
        )

        # 8) Actualizar √≠ndices y conteos
        counts_by_class = np.bincount(y_p, minlength=5)
        total_by_class += counts_by_class

        index_rows.append({
            "dataset_id": ds_id,
            "tag": tag,
            "paciente": p,
            "n_ventanas": int(n_p),
            "file": str(out_npz),
        })

        # liberar RAM de este paciente
        del X_raw_p, X_dct_p, y_p, X_by_ck, y_raw_by_ck
        gc.collect()

    # ---- Guardar √≠ndice CSV de este dataset ----
    index_df = pd.DataFrame(index_rows)
    index_csv = DATASET_DIR / f"{tag}_index.csv"
    index_df.to_csv(index_csv, index=False)

    print(f"\n Dataset {ds_id} listo (on-disk):")
    print(f"   ‚Ä¢ Pacientes con datos: {len(index_rows)}")
    print(f"   ‚Ä¢ Index CSV: {index_csv}")
    print(f"   ‚Ä¢ Ventanas por clase:")
    for cls_id, count in enumerate(total_by_class):
        print(f"      - {ID2LABEL[cls_id]}: {int(count)}")

print("\n Estructura generada en disco (por dataset):")
print("   ‚Ä¢ NPZ por paciente:  cosleep_dsX_..._SCxxxxE.npz")
print("   ‚Ä¢ √çndice:            cosleep_dsX_..._index.csv")


# Creaci√≥n de datasets

In [None]:
# ============================================================
# SPLIT POR PACIENTE para datasets 
#  - Reporta n¬∫ de ventanas y distribuci√≥n de etapas por split
# ============================================================

import numpy as np
import pandas as pd
from pathlib import Path
from sklearn.model_selection import train_test_split

DATASET_DIR = Path("datasets_cosleepnet_on_disk")

RECIPES = {
    1: (["EEG1", "EEG2"],              "cosleep_ds1_eeg1_eeg2"),
    2: (["EEG1", "EOG"],               "cosleep_ds2_eeg1_eog"),
    3: (["EEG1", "EEG2", "EOG"],       "cosleep_ds3_eeg1_eeg2_eog"), 
}

# Si no est√°n en el entorno, definimos el mapeo de etiquetas
try:
    ID2LABEL
except NameError:
    ID2LABEL = {0: "W", 1: "N1", 2: "N2", 3: "N3", 4: "REM"}

# ------------------------------------------------------------------
# 1) Cargar √≠ndice de un dataset (usa columnas reales: paciente, file)
# ------------------------------------------------------------------
def load_index(tag):
    """Carga el CSV index de un dataset y normaliza rutas a NPZ."""
    index_csv = DATASET_DIR / f"{tag}_index.csv"
    if not index_csv.exists():
        raise FileNotFoundError(f"No existe: {index_csv}")

    df = pd.read_csv(index_csv)

    if "paciente" not in df.columns or "file" not in df.columns:
        raise ValueError("El √≠ndice debe contener columnas 'paciente' y 'file'.")

    # normalizar rutas
    corrected_paths = []
    for fp in df["file"]:
        fp = Path(fp)
        if fp.is_absolute():
            corrected_paths.append(fp)
        else:
            if str(fp).startswith(str(DATASET_DIR.name)):
                corrected_paths.append(DATASET_DIR / fp.name)
            else:
                corrected_paths.append(DATASET_DIR / fp)

    df["npz_path"] = corrected_paths
    return df

# ------------------------------------------------------------------
# 2) Split por paciente 
# ------------------------------------------------------------------
def split_patients(patient_list, test_size=0.2, val_size=0.2, random_state=42):
    patients = np.array(sorted(set(patient_list)))  # √∫nicos

    idx = np.arange(len(patients))
    train_val_idx, test_idx = train_test_split(
        idx, test_size=test_size, random_state=random_state, shuffle=True
    )

    rel_val = val_size / (1.0 - test_size)
    train_idx, val_idx = train_test_split(
        train_val_idx, test_size=rel_val, random_state=random_state + 1, shuffle=True
    )

    return {
        "train_patients": patients[train_idx].tolist(),
        "val_patients":   patients[val_idx].tolist(),
        "test_patients":  patients[test_idx].tolist(),
    }

# ------------------------------------------------------------------
# 3) Resumen de pacientes + ventanas + etapas
# ------------------------------------------------------------------
def print_patient_split_with_stages(df_index: pd.DataFrame, splits: dict, name: str):
    train_p = set(splits["train_patients"])
    val_p   = set(splits["val_patients"])
    test_p  = set(splits["test_patients"])

    inter_tv = train_p & val_p
    inter_tt = train_p & test_p
    inter_vt = val_p   & test_p

    print(f"\n===== SPLIT {name} =====")
    print(f"Pacientes -> Train: {len(train_p)} | Val: {len(val_p)} | Test: {len(test_p)}")
    print("Intersecciones pacientes (deben ser 0): "
          f"Train‚à©Val={len(inter_tv)}, Train‚à©Test={len(inter_tt)}, Val‚à©Test={len(inter_vt)}")

    for split_name, pats in [("train", train_p), ("val", val_p), ("test", test_p)]:
        sub = df_index[df_index["paciente"].isin(pats)]

        total_ventanas = int(sub["n_ventanas"].sum())
        counts = np.zeros(len(ID2LABEL), dtype=np.int64)

        for _, row in sub.iterrows():
            npz_path = Path(row["npz_path"])
            try:
                with np.load(npz_path, allow_pickle=True) as d:
                    if "y" in d.files:
                        y = d["y"].astype(int)
                    elif "labels" in d.files:
                        y = d["labels"].astype(int)
                    else:
                        continue
                uniq, c = np.unique(y, return_counts=True)
                for k, v in zip(uniq, c):
                    if 0 <= int(k) < len(counts):
                        counts[int(k)] += int(v)
            except Exception as e:
                print(f"  Error leyendo {npz_path.name}: {e}")

        total_etapas = int(counts.sum())

        print(f"\n‚Üí {split_name.upper()}:")
        print(f"   Pacientes: {len(pats)}")
        print(f"   Ventanas (seg√∫n √≠ndice): {total_ventanas}")
        if total_etapas > 0:
            print(f"   Ventanas (sumando y):  {total_etapas}")
            for cls_id, c in enumerate(counts):
                etiqueta = ID2LABEL.get(cls_id, str(cls_id))
                perc = c / total_etapas * 100.0
                print(f"   - {etiqueta}: {c} ({perc:.1f}%)")
        else:
            print("   (No se pudieron leer etiquetas 'y' en los NPZ)")

# ============================================================
#             SPLITS DS1, DS2 y DS3
# ============================================================

splits = {}

for ds_id, (ch_keys, tag) in RECIPES.items():
    print("\n" + "="*70)
    print(f"Creando SPLIT para dataset {ds_id}: {tag}")
    print("="*70)

    df_index = load_index(tag)
    patients = df_index["paciente"].unique().tolist()

    s = split_patients(patients, test_size=0.20, val_size=0.20, random_state=42)
    splits[ds_id] = s

    print_patient_split_with_stages(df_index, s, name=tag)

print("\n Splits listos. Puedes usarlos para cargar NPZ por paciente durante entrenamiento.")


In [None]:
print((DATASET_DIR / "cosleep_ds1_eeg1_eeg2_index.csv").resolve())

# DataLoader de los datasets

In [None]:
# ============================================================
# Dataset PyTorch para CoSleepNet 
#   - Devuelve (C, T, 1) por muestra
#   - Batch queda (N, C, T, 1)
#   - Compatible con Conv2d kernel=(3,1)
# ============================================================

import numpy as np
import pandas as pd
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
import torch

DATASET_DIR = Path("datasets_cosleepnet_on_disk")
ID2LABEL = {0: "W", 1: "N1", 2: "N2", 3: "N3", 4: "REM"}

# ------------------------------------------------------------
# Cargar index y normalizar rutas
# ------------------------------------------------------------
def load_index_pytorch(tag: str) -> pd.DataFrame:
    index_csv = DATASET_DIR / f"{tag}_index.csv"
    if not index_csv.exists():
        raise FileNotFoundError(f"No existe: {index_csv}")

    df = pd.read_csv(index_csv)
    corrected_paths = []
    for fp in df["file"]:
        fp = Path(fp)
        if fp.is_absolute():
            corrected_paths.append(fp)
        else:
            if str(fp).startswith(str(DATASET_DIR.name)):
                corrected_paths.append(DATASET_DIR / fp.name)
            else:
                corrected_paths.append(DATASET_DIR / fp)
    df["npz_path"] = corrected_paths
    return df


# ------------------------------------------------------------
# Dataset RAM
# ------------------------------------------------------------
class CoSleepNPZDataset(Dataset):
    """
    Dataset que PRECARGA TODO en RAM al inicializarse.

    Devuelve:
      xr: (C, T, 1)
      xd: (C, T, 1)
      y : escalar long
    """
    def __init__(self, tag: str, split_info: dict, split_name: str):
        super().__init__()
        self.tag = tag
        self.split_name = split_name

        df_index = load_index_pytorch(tag)
        key_patients = f"{split_name}_patients"
        patients_split = set(split_info[key_patients])
        sub = df_index[df_index["paciente"].isin(patients_split)].copy()

        if sub.empty:
            raise ValueError(f"No hay pacientes para split '{split_name}'")

        file_paths = sub["npz_path"].tolist()
        n_files = len(file_paths)

        print(f"Cargando {split_name} en RAM ({n_files} archivos)...")

        all_X_raw, all_X_dct, all_y = [], [], []

        for i, path in enumerate(file_paths):
            if (i + 1) % 20 == 0 or i == 0 or (i + 1) == n_files:
                print(f"   Archivo {i+1}/{n_files}...", end="\r")

            d = np.load(path, allow_pickle=False)
            all_X_raw.append(d["X_raw"].astype(np.float32))  # (n_win, T, C)
            all_X_dct.append(d["X_dct"].astype(np.float32))  # (n_win, T, C)
            all_y.append(d["y"].astype(np.int64))            # (n_win,)

        print()

        self.X_raw = np.concatenate(all_X_raw, axis=0)  # (N_total, T, C)
        self.X_dct = np.concatenate(all_X_dct, axis=0)  # (N_total, T, C)
        self.y = np.concatenate(all_y, axis=0)          # (N_total,)

        del all_X_raw, all_X_dct, all_y

        self.total_windows = len(self.y)

        mem_gb = (self.X_raw.nbytes + self.X_dct.nbytes + self.y.nbytes) / 1e9
        print(f" {split_name}: {self.total_windows} ventanas cargadas ({mem_gb:.2f} GB RAM)")

        # ======================================================
        # DEBUG: distribuci√≥n de clases en este split
        # ======================================================
        uniq, cnts = np.unique(self.y, return_counts=True)
        print(f"   Distribuci√≥n de clases en split '{self.split_name}':")
        for k, c in zip(uniq, cnts):
            nombre = ID2LABEL.get(int(k), f"id{int(k)}")
            pct = 100.0 * c / max(1, self.total_windows)
            print(f"      - {nombre} (id={int(k)}): {c} ventanas ({pct:.1f}%)")
        print("-" * 60)

    def __len__(self):
        return self.total_windows

    def __getitem__(self, idx):
        xr = self.X_raw[idx]   # (T, C)
        xd = self.X_dct[idx]   # (T, C)
        y  = self.y[idx]

        # --------------------------------------------------
        # (T, C) -> (C, T, 1)
        # Batch final: (N, C, T, 1)
        # --------------------------------------------------
        xr = np.transpose(xr, (1, 0))[:, :, np.newaxis]   # (C, T, 1)
        xd = np.transpose(xd, (1, 0))[:, :, np.newaxis]   # (C, T, 1)

        return (
            torch.from_numpy(xr),
            torch.from_numpy(xd),
            torch.tensor(int(y), dtype=torch.long)
        )


# ------------------------------------------------------------
# Loader factory
# ------------------------------------------------------------
def make_pytorch_loaders(
    dataset_id: int,
    splits: dict,
    batch_size: int = 128,
    num_workers: int = 0,
    pin_memory: bool = True,
):
    if dataset_id == 1:
        tag = "cosleep_ds1_eeg1_eeg2"
        DATASET_NAME = "EEG1+EEG2"

    elif dataset_id == 2:
        tag = "cosleep_ds2_eeg1_eog"
        DATASET_NAME = "EEG1+EOG"

    elif dataset_id == 3:
        tag = "cosleep_ds3_eeg1_eeg2_eog"
        DATASET_NAME = "EEG1+EEG2+EOG"

    else:
        raise ValueError("dataset_id debe ser 1, 2 o 3")

    # ------------------------------------
    # cargar split correcto
    # ------------------------------------
    split_info = splits[dataset_id]

    print(f"\n{'='*50}")
    print(f"Precargando dataset {DATASET_NAME} en RAM...")
    print(f"{'='*50}\n")

    train_set = CoSleepNPZDataset(tag, split_info, "train")
    val_set   = CoSleepNPZDataset(tag, split_info, "val")
    test_set  = CoSleepNPZDataset(tag, split_info, "test")

    train_loader = DataLoader(
        train_set,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=pin_memory,
    )
    val_loader = DataLoader(
        val_set,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=pin_memory,
    )
    test_loader = DataLoader(
        test_set,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=pin_memory,
    )

    xr0, xd0, y0 = train_set[0]
    input_shape = tuple(xr0.shape)  # (C, T, 1)

    print(f"\n Loaders listos para {DATASET_NAME}:")
    print(f"   Input shape por muestra (C,T,1): {input_shape}")
    print(f"   Train: {len(train_set)} | Val: {len(val_set)} | Test: {len(test_set)}")

    return train_loader, val_loader, test_loader, input_shape, DATASET_NAME



# Modelo

In [None]:
# ============================================================
# Modelo CoSleepNet en PyTorch
#  - Entrada: (N, C, T, 1)
# ============================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import (
    accuracy_score, precision_recall_fscore_support, 
    confusion_matrix, cohen_kappa_score, classification_report
)
import seaborn as sns
import pandas as pd

class CNNBlock(nn.Module):
    def __init__(self, in_ch, out_ch, dropout=0.3):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size=(3,1), padding=(1,0))
        self.bn1   = nn.BatchNorm2d(out_ch)

        self.conv2 = nn.Conv2d(out_ch, out_ch, kernel_size=(3,1), padding=(1,0))
        self.bn2   = nn.BatchNorm2d(out_ch)

        self.pool  = nn.MaxPool2d(kernel_size=(3,1), stride=(3,1))
        self.drop  = nn.Dropout(dropout)

    def forward(self, x):
        # x: (N, C, T, 1)
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.pool(x)       # reduce T
        x = self.drop(x)
        return x


class CoSleepBranch(nn.Module):
    def __init__(self, in_ch, lstm_units=64, dropout=0.3):
        super().__init__()
        self.block1 = CNNBlock(in_ch, 32, dropout)
        self.block2 = CNNBlock(32, 64, dropout)
        self.block3 = CNNBlock(64, 64, dropout)
        self.block4 = CNNBlock(64, 128, dropout)

        self.lstm = nn.LSTM(
            input_size=128,
            hidden_size=lstm_units,
            batch_first=True
        )

    def forward(self, x):
        # x: (N, C, T, 1)
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)

        # ‚Üí (N, 128, T', 1)
        N, C_feat, T_, _ = x.shape

        # ‚Üí (N, T', 128)
        x = x.squeeze(3).permute(0,2,1)

        out,_ = self.lstm(x)
        return out[:,-1,:]


class CoSleepNetTorch(nn.Module):
    def __init__(self, in_ch, num_classes=5, lstm_units=64, dropout=0.3, use_dct=True):
        super().__init__()
        self.use_dct = use_dct
        self.raw_branch = CoSleepBranch(in_ch, lstm_units, dropout)
        if use_dct:
            self.dct_branch = CoSleepBranch(in_ch, lstm_units, dropout)

        self.fc1 = nn.Linear(lstm_units*(2 if use_dct else 1), 128)
        self.drop = nn.Dropout(0.3)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, xr, xd=None):
        h_raw = self.raw_branch(xr)
        if self.use_dct:
            h_dct = self.dct_branch(xd)
            h = torch.cat([h_raw, h_dct], dim=1)
        else:
            h = h_raw

        x = F.relu(self.fc1(h))
        x = self.drop(x)
        return self.fc2(x)


# Entrenamiento

In [None]:
# ============================================================
# FUNCIONES DE ENTRENAMIENTO
# ============================================================

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from sklearn.metrics import f1_score

# =====================================================
#  Focal Loss 
# =====================================================
class FocalLoss(nn.Module):
    def __init__(self, gamma=2.0, alpha=None):
        super().__init__()
        self.gamma = gamma
        self.alpha = alpha

    def forward(self, logits, targets):
        probs = F.softmax(logits, dim=1)
        ce_loss = F.cross_entropy(logits, targets, reduction='none', weight=self.alpha)
        p_t = probs.gather(1, targets.unsqueeze(1)).squeeze(1)
        focal_weight = (1 - p_t) ** self.gamma
        loss = focal_weight * ce_loss
        return loss.mean()

# =====================================================
#  Train con progreso + Macro F1
# =====================================================
def train_one_epoch_pytorch(model, loader, optimizer, criterion, device, grad_clip=1.0):
    model.train()
    total_loss, total_samples = 0.0, 0
    all_preds, all_labels = [], []
    n_batches = len(loader)
    
    for i, (xr, xd, y) in enumerate(loader):
        xr = xr.to(device, non_blocking=True)
        xd = xd.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)
        logits = model(xr, xd)
        loss = criterion(logits, y)
        loss.backward()
        
        if grad_clip:
            nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
        optimizer.step()

        bs = y.size(0)
        total_samples += bs
        total_loss += loss.item() * bs
        
        preds = logits.argmax(dim=1)
        all_preds.append(preds.cpu().numpy())
        all_labels.append(y.cpu().numpy())

        if (i + 1) % 100 == 0 or (i + 1) == n_batches:
            print(f"  [Train] {i+1}/{n_batches} ({100*(i+1)/n_batches:.0f}%)", end="\r")
    
    print()
    
    all_preds = np.concatenate(all_preds)
    all_labels = np.concatenate(all_labels)
    
    avg_loss = total_loss / total_samples
    macro_f1 = f1_score(all_labels, all_preds, average='macro', zero_division=0)
    acc = (all_preds == all_labels).mean()
    
    return avg_loss, acc, macro_f1

# =====================================================
#  Eval con progreso + Macro F1
# =====================================================
@torch.no_grad()
def eval_pytorch(model, loader, criterion, device):
    model.eval()
    total_loss, total_samples = 0.0, 0
    all_preds, all_labels = [], []
    n_batches = len(loader)
    
    for i, (xr, xd, y) in enumerate(loader):
        xr = xr.to(device, non_blocking=True)
        xd = xd.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)

        logits = model(xr, xd)
        loss = criterion(logits, y)

        bs = y.size(0)
        total_samples += bs
        total_loss += loss.item() * bs
        
        preds = logits.argmax(dim=1)
        all_preds.append(preds.cpu().numpy())
        all_labels.append(y.cpu().numpy())

        if (i + 1) % 100 == 0 or (i + 1) == n_batches:
            print(f"  [Eval]  {i+1}/{n_batches} ({100*(i+1)/n_batches:.0f}%)", end="\r")

    print()
    
    all_preds = np.concatenate(all_preds)
    all_labels = np.concatenate(all_labels)
    
    avg_loss = total_loss / total_samples
    macro_f1 = f1_score(all_labels, all_preds, average='macro', zero_division=0)
    acc = (all_preds == all_labels).mean()
    
    return avg_loss, acc, macro_f1, all_labels, all_preds

# =====================================================
#  Class weights
# =====================================================
def get_class_weights(loader, n_classes=5, method='sqrt_inverse', device='cpu'):
    print("  Calculando class weights...")
    all_labels = []
    for batch in loader:
        all_labels.append(batch[-1].numpy())
    all_labels = np.concatenate(all_labels)
    
    _, counts = np.unique(all_labels, return_counts=True)
    
    if method == 'sqrt_inverse':
        weights = 1.0 / np.sqrt(counts)
    elif method == 'inverse':
        weights = 1.0 / counts
    else:
        weights = np.ones(n_classes)
    
    weights = weights / weights.sum() * n_classes
    print(f"  Distribuci√≥n: {counts}")
    return torch.tensor(weights, dtype=torch.float32, device=device)

print(" Celda 1 cargada: Modelo y funciones de entrenamiento")

# Ejecuciones

# DCT ON   

## Set 1: EEG1 + EEG2

In [None]:
# ============================================================
# CELDA: MULTI-RUN (3‚Äì5 corridas) + MEDIA ¬± DESV. EST√ÅNDAR
# + VISUALIZACI√ìN DE LA √öLTIMA CORRIDA
# ============================================================

import numpy as np
import pandas as pd
from sklearn.metrics import (
    precision_recall_fscore_support, confusion_matrix,
    cohen_kappa_score
)
import torch
import random
import matplotlib.pyplot as plt
import seaborn as sns

# ---------- Config ----------
DATASET_ID = 1        # 1, 2 o 3 
LR = 5e-5
BATCH_SIZE = 128
EPOCHS = 50
PATIENCE = 5
USE_FOCAL = True
FOCAL_GAMMA = 1.5
USE_CLASS_WEIGHTS = False

N_RUNS = 5            
BASE_SEED = 42
labels = ["W", "N1", "N2", "N3", "REM"]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(" Device:", device)
if device.type == 'cuda':
    print(f"   GPU: {torch.cuda.get_device_name(0)}")

# ---------- Utils ----------
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def kappa_per_class(y_true, y_pred, n_classes=5):
    cm = confusion_matrix(y_true, y_pred, labels=range(n_classes))
    N = cm.sum()
    kappas = np.zeros(n_classes, dtype=np.float64)

    for k in range(n_classes):
        TP = cm[k, k]
        FN = cm[k, :].sum() - TP
        FP = cm[:, k].sum() - TP
        TN = N - (TP + FN + FP)

        obs = (TP + TN) / max(1, N)

        p_yes_true = (TP + FN) / max(1, N)
        p_yes_pred = (TP + FP) / max(1, N)
        p_no_true  = (FP + TN) / max(1, N)
        p_no_pred  = (FN + TN) / max(1, N)
        exp = p_yes_true * p_yes_pred + p_no_true * p_no_pred

        kappas[k] = (obs - exp) / (1 - exp + 1e-12)

    return cm, kappas

# ---------- Loaders fijos (siempre los mismos splits) ----------
train_loader, val_loader, test_loader, input_shape, DATASET_NAME = make_pytorch_loaders(
    dataset_id=DATASET_ID,
    splits=splits,
    batch_size=BATCH_SIZE,
    num_workers=0,
    pin_memory=True,
)

C_in, H, L = input_shape
print(f"\n Dataset: {DATASET_NAME} | Shape: C={C_in}, H={H}, L={L}")
print(f" Ejecutando {N_RUNS} corridas independientes...\n")

# ---------- Loop de runs ----------
run_summaries = []
per_class_tables = []

# Variables para guardar historia de la √∫ltima corrida
last_run_history = None
last_run_cm = None
last_run_test_f1 = None
last_run_kappa = None
last_run_best_val_f1 = None

for run in range(N_RUNS):
    seed = BASE_SEED + run
    set_seed(seed)
    print(f"\n================ RUN {run+1}/{N_RUNS} ‚Äî seed={seed} ================")

    # Modelo nuevo
    model = CoSleepNetTorch(
        in_ch=C_in,
        num_classes=5,
        lstm_units=64,
        dropout=0.3,
        use_dct=True,
    ).to(device)

    # Criterion
    class_weights = None
    if USE_CLASS_WEIGHTS:
        class_weights = get_class_weights(train_loader, device=device)
        print(f"  Class weights: {class_weights.cpu().numpy().round(3)}")

    if USE_FOCAL:
        criterion = FocalLoss(gamma=FOCAL_GAMMA, alpha=class_weights)
        print(f" Loss: Focal (Œ≥={FOCAL_GAMMA})")
    else:
        criterion = nn.CrossEntropyLoss(weight=class_weights)
        print(f" Loss: CrossEntropy")

    optimizer = torch.optim.Adam(model.parameters(), lr=LR)

    # Historia de entrenamiento
    history = {
        'train_loss': [], 'val_loss': [],
        'train_acc': [], 'val_acc': [],
        'train_f1': [], 'val_f1': []
    }
    
    # Entrenamiento
    best_val_f1 = 0.0
    best_state = None
    patience_counter = 0

    for epoch in range(1, EPOCHS + 1):
        print(f"‚îÅ‚îÅ‚îÅ Epoch {epoch}/{EPOCHS} ‚îÅ‚îÅ‚îÅ")
        tr_loss, tr_acc, tr_f1 = train_one_epoch_pytorch(
            model, train_loader, optimizer, criterion, device
        )
        val_loss, val_acc, val_f1, _, _ = eval_pytorch(
            model, val_loader, criterion, device
        )
        
        # Guardar historia
        history['train_loss'].append(tr_loss)
        history['val_loss'].append(val_loss)
        history['train_acc'].append(tr_acc)
        history['val_acc'].append(val_acc)
        history['train_f1'].append(tr_f1)
        history['val_f1'].append(val_f1)

        improved = val_f1 > best_val_f1
        if improved:
            best_val_f1 = val_f1
            best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
            patience_counter = 0
        else:
            patience_counter += 1

        star = " ‚≠ê" if improved else ""
        print(f"  Train | Loss: {tr_loss:.4f} | Acc: {tr_acc:.4f} | F1: {tr_f1:.4f}")
        print(f"  Val   | Loss: {val_loss:.4f} | Acc: {val_acc:.4f} | F1: {val_f1:.4f}{star}\n")

        if patience_counter >= PATIENCE:
            print(f" Early stopping en epoch {epoch}")
            break

    # Mejor estado
    if best_state is not None:
        model.load_state_dict(best_state)
        model = model.to(device)

    # Test
    print("‚îÅ‚îÅ‚îÅ Evaluando en TEST ‚îÅ‚îÅ‚îÅ")
    test_loss, test_acc, test_f1, y_true, y_pred = eval_pytorch(
        model, test_loader, criterion, device
    )
    kappa_global = cohen_kappa_score(y_true, y_pred)
    cm, kappas = kappa_per_class(y_true, y_pred, n_classes=5)

    prec, rec, f1, support = precision_recall_fscore_support(
        y_true, y_pred, labels=range(5), zero_division=0
    )

    df_metrics = pd.DataFrame({
        "Etapa": labels,
        "Precision": prec,
        "Recall": rec,
        "F1-Score": f1,
        "Kappa": kappas,
        "Soporte": support
    })

    print(f"\n RUN {run+1} ‚Äî Resultados globales:")
    print(f"   Test Loss:     {test_loss:.4f}")
    print(f"   Test Acc:      {test_acc:.4f}")
    print(f"   Test Macro F1: {test_f1:.4f}")
    print(f"   Kappa global:  {kappa_global:.4f}")
    print(f"   Best Val F1:   {best_val_f1:.4f}")

    run_summaries.append({
        "loss": test_loss,
        "acc": test_acc,
        "macro_f1": test_f1,
        "kappa": kappa_global,
        "best_val_f1": best_val_f1
    })
    per_class_tables.append(df_metrics)
    
    # Guardar datos de la √∫ltima corrida
    if run == N_RUNS - 1:
        last_run_history = history
        last_run_cm = cm
        last_run_test_f1 = test_f1
        last_run_kappa = kappa_global
        last_run_best_val_f1 = best_val_f1

    #  Liberar GPU entre runs (por si acaso)
    del model, optimizer, criterion
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

# ---------- Agregados: media ¬± std ----------
print("\n" + "="*70)
print(f" {N_RUNS} corridas completadas ‚Äî {DATASET_NAME}")
print("="*70)

losses  = np.array([r["loss"] for r in run_summaries])
accs    = np.array([r["acc"] for r in run_summaries])
f1s     = np.array([r["macro_f1"] for r in run_summaries])
kappasg = np.array([r["kappa"] for r in run_summaries])

def ms(x):  # mean ¬± std string
    return f"{x.mean():.4f} ¬± {x.std():.4f}"

print(f"Test Loss:      {ms(losses)}")
print(f"Test Accuracy:  {ms(accs)}")
print(f"Test Macro F1:  {ms(f1s)}")
print(f"Kappa global:   {ms(kappasg)}")

# ---------- Agregado por etapa ----------
prec_arr  = np.stack([df["Precision"].values for df in per_class_tables], axis=0)
rec_arr   = np.stack([df["Recall"].values    for df in per_class_tables], axis=0)
f1_arr    = np.stack([df["F1-Score"].values  for df in per_class_tables], axis=0)
kappa_arr = np.stack([df["Kappa"].values     for df in per_class_tables], axis=0)

df_agg = pd.DataFrame({"Etapa": labels})
for name, arr in [("Precision", prec_arr),
                  ("Recall", rec_arr),
                  ("F1-Score", f1_arr),
                  ("Kappa", kappa_arr)]:
    means = arr.mean(axis=0)
    stds  = arr.std(axis=0)
    df_agg[name] = [f"{m:.3f} ¬± {s:.3f}" for m, s in zip(means, stds)]

print("\n M√âTRICAS AGREGADAS POR ETAPA (media ¬± std):")
display(df_agg)

# ---------- VISUALIZACI√ìN DE LA √öLTIMA CORRIDA ----------
fig = plt.figure(figsize=(16, 10))
gs = fig.add_gridspec(2, 2, hspace=0.3, wspace=0.3)

# 1. Loss
ax1 = fig.add_subplot(gs[0, 0])
epochs_range = range(1, len(last_run_history['train_loss']) + 1)
ax1.plot(epochs_range, last_run_history['train_loss'], 'b-', label='Train', linewidth=2)
ax1.plot(epochs_range, last_run_history['val_loss'], 'r-', label='Val', linewidth=2)
ax1.set_xlabel('Epoch', fontsize=11)
ax1.set_ylabel('Loss', fontsize=11)
ax1.set_title('Loss', fontsize=12, fontweight='bold')
ax1.legend(fontsize=10)
ax1.grid(True, alpha=0.3)

# 2. Macro F1 Score
ax2 = fig.add_subplot(gs[0, 1])
ax2.plot(epochs_range, last_run_history['train_f1'], 'b-', label='Train F1', linewidth=2)
ax2.plot(epochs_range, last_run_history['val_f1'], 'r-', label='Val F1', linewidth=2)
ax2.axhline(y=last_run_best_val_f1, color='g', linestyle='--', 
            label=f'Best Val F1: {last_run_best_val_f1:.3f}', linewidth=1.5, alpha=0.7)
ax2.set_xlabel('Epoch', fontsize=11)
ax2.set_ylabel('Macro F1 Score', fontsize=11)
ax2.set_title('Macro F1 Score (m√©trica principal)', fontsize=12, fontweight='bold')
ax2.legend(fontsize=10)
ax2.grid(True, alpha=0.3)
ax2.set_ylim([0, 1])

# 3. Accuracy
ax3 = fig.add_subplot(gs[1, 0])
ax3.plot(epochs_range, last_run_history['train_acc'], 'b-', label='Train Acc', linewidth=2)
ax3.plot(epochs_range, last_run_history['val_acc'], 'r-', label='Val Acc', linewidth=2)
ax3.set_xlabel('Epoch', fontsize=11)
ax3.set_ylabel('Accuracy', fontsize=11)
ax3.set_title('Accuracy (secundaria)', fontsize=12, fontweight='bold')
ax3.legend(fontsize=10)
ax3.grid(True, alpha=0.3)
ax3.set_ylim([0, 1])

# 4. Matriz de Confusi√≥n Normalizada
ax4 = fig.add_subplot(gs[1, 1])
cm_normalized = last_run_cm.astype('float') / last_run_cm.sum(axis=1)[:, np.newaxis]
sns.heatmap(cm_normalized, annot=True, fmt='.2f', cmap='Blues', 
            xticklabels=labels, yticklabels=labels, 
            cbar_kws={'label': 'Proporci√≥n'}, ax=ax4, vmin=0, vmax=1)
ax4.set_xlabel('Predicho', fontsize=11)
ax4.set_ylabel('Real', fontsize=11)
ax4.set_title('Matriz de Confusi√≥n (normalizada)', fontsize=12, fontweight='bold')

fig.suptitle(
    f'CoSleepNet ‚Äî {DATASET_NAME}\n'
    f'Test F1: {last_run_test_f1:.3f} | Kappa: {last_run_kappa:.3f}',
    fontsize=14, fontweight='bold', y=0.98
)

plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.show()


## Set 2: EEG1 + EOG

In [None]:
# ============================================================
# CELDA: MULTI-RUN (3‚Äì5 corridas) + MEDIA ¬± DESV. EST√ÅNDAR
# + VISUALIZACI√ìN DE LA √öLTIMA CORRIDA
# ============================================================

import numpy as np
import pandas as pd
from sklearn.metrics import (
    precision_recall_fscore_support, confusion_matrix,
    cohen_kappa_score
)
import torch
import random
import matplotlib.pyplot as plt
import seaborn as sns

# ---------- Config ----------
DATASET_ID = 2        # 1, 2 o 3 
LR = 5e-5
BATCH_SIZE = 128
EPOCHS = 50
PATIENCE = 5
USE_FOCAL = True
FOCAL_GAMMA = 1.5
USE_CLASS_WEIGHTS = False

N_RUNS = 5            
BASE_SEED = 42
labels = ["W", "N1", "N2", "N3", "REM"]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(" Device:", device)
if device.type == 'cuda':
    print(f"   GPU: {torch.cuda.get_device_name(0)}")

# ---------- Utils ----------
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def kappa_per_class(y_true, y_pred, n_classes=5):
    cm = confusion_matrix(y_true, y_pred, labels=range(n_classes))
    N = cm.sum()
    kappas = np.zeros(n_classes, dtype=np.float64)

    for k in range(n_classes):
        TP = cm[k, k]
        FN = cm[k, :].sum() - TP
        FP = cm[:, k].sum() - TP
        TN = N - (TP + FN + FP)

        obs = (TP + TN) / max(1, N)

        p_yes_true = (TP + FN) / max(1, N)
        p_yes_pred = (TP + FP) / max(1, N)
        p_no_true  = (FP + TN) / max(1, N)
        p_no_pred  = (FN + TN) / max(1, N)
        exp = p_yes_true * p_yes_pred + p_no_true * p_no_pred

        kappas[k] = (obs - exp) / (1 - exp + 1e-12)

    return cm, kappas

# ---------- Loaders fijos (siempre los mismos splits) ----------
train_loader, val_loader, test_loader, input_shape, DATASET_NAME = make_pytorch_loaders(
    dataset_id=DATASET_ID,
    splits=splits,
    batch_size=BATCH_SIZE,
    num_workers=0,
    pin_memory=True,
)

C_in, H, L = input_shape
print(f"\n Dataset: {DATASET_NAME} | Shape: C={C_in}, H={H}, L={L}")
print(f" Ejecutando {N_RUNS} corridas independientes...\n")

# ---------- Loop de runs ----------
run_summaries = []
per_class_tables = []

# Variables para guardar historia de la √∫ltima corrida
last_run_history = None
last_run_cm = None
last_run_test_f1 = None
last_run_kappa = None
last_run_best_val_f1 = None

for run in range(N_RUNS):
    seed = BASE_SEED + run
    set_seed(seed)
    print(f"\n================ RUN {run+1}/{N_RUNS} ‚Äî seed={seed} ================")

    # Modelo nuevo
    model = CoSleepNetTorch(
        in_ch=C_in,
        num_classes=5,
        lstm_units=64,
        dropout=0.3,
        use_dct=True,
    ).to(device)

    # Criterion
    class_weights = None
    if USE_CLASS_WEIGHTS:
        class_weights = get_class_weights(train_loader, device=device)
        print(f"  Class weights: {class_weights.cpu().numpy().round(3)}")

    if USE_FOCAL:
        criterion = FocalLoss(gamma=FOCAL_GAMMA, alpha=class_weights)
        print(f" Loss: Focal (Œ≥={FOCAL_GAMMA})")
    else:
        criterion = nn.CrossEntropyLoss(weight=class_weights)
        print(f" Loss: CrossEntropy")

    optimizer = torch.optim.Adam(model.parameters(), lr=LR)

    # Historia de entrenamiento
    history = {
        'train_loss': [], 'val_loss': [],
        'train_acc': [], 'val_acc': [],
        'train_f1': [], 'val_f1': []
    }
    
    # Entrenamiento
    best_val_f1 = 0.0
    best_state = None
    patience_counter = 0

    for epoch in range(1, EPOCHS + 1):
        print(f"‚îÅ‚îÅ‚îÅ Epoch {epoch}/{EPOCHS} ‚îÅ‚îÅ‚îÅ")
        tr_loss, tr_acc, tr_f1 = train_one_epoch_pytorch(
            model, train_loader, optimizer, criterion, device
        )
        val_loss, val_acc, val_f1, _, _ = eval_pytorch(
            model, val_loader, criterion, device
        )
        
        # Guardar historia
        history['train_loss'].append(tr_loss)
        history['val_loss'].append(val_loss)
        history['train_acc'].append(tr_acc)
        history['val_acc'].append(val_acc)
        history['train_f1'].append(tr_f1)
        history['val_f1'].append(val_f1)

        improved = val_f1 > best_val_f1
        if improved:
            best_val_f1 = val_f1
            best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
            patience_counter = 0
        else:
            patience_counter += 1

        star = " ‚≠ê" if improved else ""
        print(f"  Train | Loss: {tr_loss:.4f} | Acc: {tr_acc:.4f} | F1: {tr_f1:.4f}")
        print(f"  Val   | Loss: {val_loss:.4f} | Acc: {val_acc:.4f} | F1: {val_f1:.4f}{star}\n")

        if patience_counter >= PATIENCE:
            print(f" Early stopping en epoch {epoch}")
            break

    # Mejor estado
    if best_state is not None:
        model.load_state_dict(best_state)
        model = model.to(device)

    # Test
    print("‚îÅ‚îÅ‚îÅ Evaluando en TEST ‚îÅ‚îÅ‚îÅ")
    test_loss, test_acc, test_f1, y_true, y_pred = eval_pytorch(
        model, test_loader, criterion, device
    )
    kappa_global = cohen_kappa_score(y_true, y_pred)
    cm, kappas = kappa_per_class(y_true, y_pred, n_classes=5)

    prec, rec, f1, support = precision_recall_fscore_support(
        y_true, y_pred, labels=range(5), zero_division=0
    )

    df_metrics = pd.DataFrame({
        "Etapa": labels,
        "Precision": prec,
        "Recall": rec,
        "F1-Score": f1,
        "Kappa": kappas,
        "Soporte": support
    })

    print(f"\n RUN {run+1} ‚Äî Resultados globales:")
    print(f"   Test Loss:     {test_loss:.4f}")
    print(f"   Test Acc:      {test_acc:.4f}")
    print(f"   Test Macro F1: {test_f1:.4f}")
    print(f"   Kappa global:  {kappa_global:.4f}")
    print(f"   Best Val F1:   {best_val_f1:.4f}")

    run_summaries.append({
        "loss": test_loss,
        "acc": test_acc,
        "macro_f1": test_f1,
        "kappa": kappa_global,
        "best_val_f1": best_val_f1
    })
    per_class_tables.append(df_metrics)
    
    # Guardar datos de la √∫ltima corrida
    if run == N_RUNS - 1:
        last_run_history = history
        last_run_cm = cm
        last_run_test_f1 = test_f1
        last_run_kappa = kappa_global
        last_run_best_val_f1 = best_val_f1

    #  Liberar GPU entre runs (por si acaso)
    del model, optimizer, criterion
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

# ---------- Agregados: media ¬± std ----------
print("\n" + "="*70)
print(f" {N_RUNS} corridas completadas ‚Äî {DATASET_NAME}")
print("="*70)

losses  = np.array([r["loss"] for r in run_summaries])
accs    = np.array([r["acc"] for r in run_summaries])
f1s     = np.array([r["macro_f1"] for r in run_summaries])
kappasg = np.array([r["kappa"] for r in run_summaries])

def ms(x):  # mean ¬± std string
    return f"{x.mean():.4f} ¬± {x.std():.4f}"

print(f"Test Loss:      {ms(losses)}")
print(f"Test Accuracy:  {ms(accs)}")
print(f"Test Macro F1:  {ms(f1s)}")
print(f"Kappa global:   {ms(kappasg)}")

# ---------- Agregado por etapa ----------
prec_arr  = np.stack([df["Precision"].values for df in per_class_tables], axis=0)
rec_arr   = np.stack([df["Recall"].values    for df in per_class_tables], axis=0)
f1_arr    = np.stack([df["F1-Score"].values  for df in per_class_tables], axis=0)
kappa_arr = np.stack([df["Kappa"].values     for df in per_class_tables], axis=0)

df_agg = pd.DataFrame({"Etapa": labels})
for name, arr in [("Precision", prec_arr),
                  ("Recall", rec_arr),
                  ("F1-Score", f1_arr),
                  ("Kappa", kappa_arr)]:
    means = arr.mean(axis=0)
    stds  = arr.std(axis=0)
    df_agg[name] = [f"{m:.3f} ¬± {s:.3f}" for m, s in zip(means, stds)]

print("\n M√âTRICAS AGREGADAS POR ETAPA (media ¬± std):")
display(df_agg)

# ---------- VISUALIZACI√ìN DE LA √öLTIMA CORRIDA ----------
fig = plt.figure(figsize=(16, 10))
gs = fig.add_gridspec(2, 2, hspace=0.3, wspace=0.3)

# 1. Loss
ax1 = fig.add_subplot(gs[0, 0])
epochs_range = range(1, len(last_run_history['train_loss']) + 1)
ax1.plot(epochs_range, last_run_history['train_loss'], 'b-', label='Train', linewidth=2)
ax1.plot(epochs_range, last_run_history['val_loss'], 'r-', label='Val', linewidth=2)
ax1.set_xlabel('Epoch', fontsize=11)
ax1.set_ylabel('Loss', fontsize=11)
ax1.set_title('Loss', fontsize=12, fontweight='bold')
ax1.legend(fontsize=10)
ax1.grid(True, alpha=0.3)

# 2. Macro F1 Score
ax2 = fig.add_subplot(gs[0, 1])
ax2.plot(epochs_range, last_run_history['train_f1'], 'b-', label='Train F1', linewidth=2)
ax2.plot(epochs_range, last_run_history['val_f1'], 'r-', label='Val F1', linewidth=2)
ax2.axhline(y=last_run_best_val_f1, color='g', linestyle='--', 
            label=f'Best Val F1: {last_run_best_val_f1:.3f}', linewidth=1.5, alpha=0.7)
ax2.set_xlabel('Epoch', fontsize=11)
ax2.set_ylabel('Macro F1 Score', fontsize=11)
ax2.set_title('Macro F1 Score (m√©trica principal)', fontsize=12, fontweight='bold')
ax2.legend(fontsize=10)
ax2.grid(True, alpha=0.3)
ax2.set_ylim([0, 1])

# 3. Accuracy
ax3 = fig.add_subplot(gs[1, 0])
ax3.plot(epochs_range, last_run_history['train_acc'], 'b-', label='Train Acc', linewidth=2)
ax3.plot(epochs_range, last_run_history['val_acc'], 'r-', label='Val Acc', linewidth=2)
ax3.set_xlabel('Epoch', fontsize=11)
ax3.set_ylabel('Accuracy', fontsize=11)
ax3.set_title('Accuracy (secundaria)', fontsize=12, fontweight='bold')
ax3.legend(fontsize=10)
ax3.grid(True, alpha=0.3)
ax3.set_ylim([0, 1])

# 4. Matriz de Confusi√≥n Normalizada
ax4 = fig.add_subplot(gs[1, 1])
cm_normalized = last_run_cm.astype('float') / last_run_cm.sum(axis=1)[:, np.newaxis]
sns.heatmap(cm_normalized, annot=True, fmt='.2f', cmap='Blues', 
            xticklabels=labels, yticklabels=labels, 
            cbar_kws={'label': 'Proporci√≥n'}, ax=ax4, vmin=0, vmax=1)
ax4.set_xlabel('Predicho', fontsize=11)
ax4.set_ylabel('Real', fontsize=11)
ax4.set_title('Matriz de Confusi√≥n (normalizada)', fontsize=12, fontweight='bold')

fig.suptitle(
    f'CoSleepNet ‚Äî {DATASET_NAME}\n'
    f'Test F1: {last_run_test_f1:.3f} | Kappa: {last_run_kappa:.3f}',
    fontsize=14, fontweight='bold', y=0.98
)

plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.show()


## Set 3: EEG1 + EEG2 + EOG

In [None]:
# ============================================================
# CELDA: MULTI-RUN (3‚Äì5 corridas) + MEDIA ¬± DESV. EST√ÅNDAR
# + VISUALIZACI√ìN DE LA √öLTIMA CORRIDA
# ============================================================

import numpy as np
import pandas as pd
from sklearn.metrics import (
    precision_recall_fscore_support, confusion_matrix,
    cohen_kappa_score
)
import torch
import random
import matplotlib.pyplot as plt
import seaborn as sns

# ---------- Config ----------
DATASET_ID = 3        # 1, 2 o 3 
LR = 5e-5
BATCH_SIZE = 128
EPOCHS = 50
PATIENCE = 5
USE_FOCAL = True
FOCAL_GAMMA = 1.5
USE_CLASS_WEIGHTS = False

N_RUNS = 5            
BASE_SEED = 42
labels = ["W", "N1", "N2", "N3", "REM"]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(" Device:", device)
if device.type == 'cuda':
    print(f"   GPU: {torch.cuda.get_device_name(0)}")

# ---------- Utils ----------
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def kappa_per_class(y_true, y_pred, n_classes=5):
    cm = confusion_matrix(y_true, y_pred, labels=range(n_classes))
    N = cm.sum()
    kappas = np.zeros(n_classes, dtype=np.float64)

    for k in range(n_classes):
        TP = cm[k, k]
        FN = cm[k, :].sum() - TP
        FP = cm[:, k].sum() - TP
        TN = N - (TP + FN + FP)

        obs = (TP + TN) / max(1, N)

        p_yes_true = (TP + FN) / max(1, N)
        p_yes_pred = (TP + FP) / max(1, N)
        p_no_true  = (FP + TN) / max(1, N)
        p_no_pred  = (FN + TN) / max(1, N)
        exp = p_yes_true * p_yes_pred + p_no_true * p_no_pred

        kappas[k] = (obs - exp) / (1 - exp + 1e-12)

    return cm, kappas

# ---------- Loaders fijos (siempre los mismos splits) ----------
train_loader, val_loader, test_loader, input_shape, DATASET_NAME = make_pytorch_loaders(
    dataset_id=DATASET_ID,
    splits=splits,
    batch_size=BATCH_SIZE,
    num_workers=0,
    pin_memory=True,
)

C_in, H, L = input_shape
print(f"\n Dataset: {DATASET_NAME} | Shape: C={C_in}, H={H}, L={L}")
print(f" Ejecutando {N_RUNS} corridas independientes...\n")

# ---------- Loop de runs ----------
run_summaries = []
per_class_tables = []

# Variables para guardar historia de la √∫ltima corrida
last_run_history = None
last_run_cm = None
last_run_test_f1 = None
last_run_kappa = None
last_run_best_val_f1 = None

for run in range(N_RUNS):
    seed = BASE_SEED + run
    set_seed(seed)
    print(f"\n================ RUN {run+1}/{N_RUNS} ‚Äî seed={seed} ================")

    # Modelo nuevo
    model = CoSleepNetTorch(
        in_ch=C_in,
        num_classes=5,
        lstm_units=64,
        dropout=0.3,
        use_dct=True,
    ).to(device)

    # Criterion
    class_weights = None
    if USE_CLASS_WEIGHTS:
        class_weights = get_class_weights(train_loader, device=device)
        print(f"  Class weights: {class_weights.cpu().numpy().round(3)}")

    if USE_FOCAL:
        criterion = FocalLoss(gamma=FOCAL_GAMMA, alpha=class_weights)
        print(f" Loss: Focal (Œ≥={FOCAL_GAMMA})")
    else:
        criterion = nn.CrossEntropyLoss(weight=class_weights)
        print(f" Loss: CrossEntropy")

    optimizer = torch.optim.Adam(model.parameters(), lr=LR)

    # Historia de entrenamiento
    history = {
        'train_loss': [], 'val_loss': [],
        'train_acc': [], 'val_acc': [],
        'train_f1': [], 'val_f1': []
    }
    
    # Entrenamiento
    best_val_f1 = 0.0
    best_state = None
    patience_counter = 0

    for epoch in range(1, EPOCHS + 1):
        print(f"‚îÅ‚îÅ‚îÅ Epoch {epoch}/{EPOCHS} ‚îÅ‚îÅ‚îÅ")
        tr_loss, tr_acc, tr_f1 = train_one_epoch_pytorch(
            model, train_loader, optimizer, criterion, device
        )
        val_loss, val_acc, val_f1, _, _ = eval_pytorch(
            model, val_loader, criterion, device
        )
        
        # Guardar historia
        history['train_loss'].append(tr_loss)
        history['val_loss'].append(val_loss)
        history['train_acc'].append(tr_acc)
        history['val_acc'].append(val_acc)
        history['train_f1'].append(tr_f1)
        history['val_f1'].append(val_f1)

        improved = val_f1 > best_val_f1
        if improved:
            best_val_f1 = val_f1
            best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
            patience_counter = 0
        else:
            patience_counter += 1

        star = " ‚≠ê" if improved else ""
        print(f"  Train | Loss: {tr_loss:.4f} | Acc: {tr_acc:.4f} | F1: {tr_f1:.4f}")
        print(f"  Val   | Loss: {val_loss:.4f} | Acc: {val_acc:.4f} | F1: {val_f1:.4f}{star}\n")

        if patience_counter >= PATIENCE:
            print(f" Early stopping en epoch {epoch}")
            break

    # Mejor estado
    if best_state is not None:
        model.load_state_dict(best_state)
        model = model.to(device)

    # Test
    print("‚îÅ‚îÅ‚îÅ Evaluando en TEST ‚îÅ‚îÅ‚îÅ")
    test_loss, test_acc, test_f1, y_true, y_pred = eval_pytorch(
        model, test_loader, criterion, device
    )
    kappa_global = cohen_kappa_score(y_true, y_pred)
    cm, kappas = kappa_per_class(y_true, y_pred, n_classes=5)

    prec, rec, f1, support = precision_recall_fscore_support(
        y_true, y_pred, labels=range(5), zero_division=0
    )

    df_metrics = pd.DataFrame({
        "Etapa": labels,
        "Precision": prec,
        "Recall": rec,
        "F1-Score": f1,
        "Kappa": kappas,
        "Soporte": support
    })

    print(f"\n RUN {run+1} ‚Äî Resultados globales:")
    print(f"   Test Loss:     {test_loss:.4f}")
    print(f"   Test Acc:      {test_acc:.4f}")
    print(f"   Test Macro F1: {test_f1:.4f}")
    print(f"   Kappa global:  {kappa_global:.4f}")
    print(f"   Best Val F1:   {best_val_f1:.4f}")

    run_summaries.append({
        "loss": test_loss,
        "acc": test_acc,
        "macro_f1": test_f1,
        "kappa": kappa_global,
        "best_val_f1": best_val_f1
    })
    per_class_tables.append(df_metrics)
    
    # Guardar datos de la √∫ltima corrida
    if run == N_RUNS - 1:
        last_run_history = history
        last_run_cm = cm
        last_run_test_f1 = test_f1
        last_run_kappa = kappa_global
        last_run_best_val_f1 = best_val_f1

    #  Liberar GPU entre runs (por si acaso)
    del model, optimizer, criterion
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

# ---------- Agregados: media ¬± std ----------
print("\n" + "="*70)
print(f" {N_RUNS} corridas completadas ‚Äî {DATASET_NAME}")
print("="*70)

losses  = np.array([r["loss"] for r in run_summaries])
accs    = np.array([r["acc"] for r in run_summaries])
f1s     = np.array([r["macro_f1"] for r in run_summaries])
kappasg = np.array([r["kappa"] for r in run_summaries])

def ms(x):  # mean ¬± std string
    return f"{x.mean():.4f} ¬± {x.std():.4f}"

print(f"Test Loss:      {ms(losses)}")
print(f"Test Accuracy:  {ms(accs)}")
print(f"Test Macro F1:  {ms(f1s)}")
print(f"Kappa global:   {ms(kappasg)}")

# ---------- Agregado por etapa ----------
prec_arr  = np.stack([df["Precision"].values for df in per_class_tables], axis=0)
rec_arr   = np.stack([df["Recall"].values    for df in per_class_tables], axis=0)
f1_arr    = np.stack([df["F1-Score"].values  for df in per_class_tables], axis=0)
kappa_arr = np.stack([df["Kappa"].values     for df in per_class_tables], axis=0)

df_agg = pd.DataFrame({"Etapa": labels})
for name, arr in [("Precision", prec_arr),
                  ("Recall", rec_arr),
                  ("F1-Score", f1_arr),
                  ("Kappa", kappa_arr)]:
    means = arr.mean(axis=0)
    stds  = arr.std(axis=0)
    df_agg[name] = [f"{m:.3f} ¬± {s:.3f}" for m, s in zip(means, stds)]

print("\n M√âTRICAS AGREGADAS POR ETAPA (media ¬± std):")
display(df_agg)

# ---------- VISUALIZACI√ìN DE LA √öLTIMA CORRIDA ----------
fig = plt.figure(figsize=(16, 10))
gs = fig.add_gridspec(2, 2, hspace=0.3, wspace=0.3)

# 1. Loss
ax1 = fig.add_subplot(gs[0, 0])
epochs_range = range(1, len(last_run_history['train_loss']) + 1)
ax1.plot(epochs_range, last_run_history['train_loss'], 'b-', label='Train', linewidth=2)
ax1.plot(epochs_range, last_run_history['val_loss'], 'r-', label='Val', linewidth=2)
ax1.set_xlabel('Epoch', fontsize=11)
ax1.set_ylabel('Loss', fontsize=11)
ax1.set_title('Loss', fontsize=12, fontweight='bold')
ax1.legend(fontsize=10)
ax1.grid(True, alpha=0.3)

# 2. Macro F1 Score
ax2 = fig.add_subplot(gs[0, 1])
ax2.plot(epochs_range, last_run_history['train_f1'], 'b-', label='Train F1', linewidth=2)
ax2.plot(epochs_range, last_run_history['val_f1'], 'r-', label='Val F1', linewidth=2)
ax2.axhline(y=last_run_best_val_f1, color='g', linestyle='--', 
            label=f'Best Val F1: {last_run_best_val_f1:.3f}', linewidth=1.5, alpha=0.7)
ax2.set_xlabel('Epoch', fontsize=11)
ax2.set_ylabel('Macro F1 Score', fontsize=11)
ax2.set_title('Macro F1 Score (m√©trica principal)', fontsize=12, fontweight='bold')
ax2.legend(fontsize=10)
ax2.grid(True, alpha=0.3)
ax2.set_ylim([0, 1])

# 3. Accuracy
ax3 = fig.add_subplot(gs[1, 0])
ax3.plot(epochs_range, last_run_history['train_acc'], 'b-', label='Train Acc', linewidth=2)
ax3.plot(epochs_range, last_run_history['val_acc'], 'r-', label='Val Acc', linewidth=2)
ax3.set_xlabel('Epoch', fontsize=11)
ax3.set_ylabel('Accuracy', fontsize=11)
ax3.set_title('Accuracy (secundaria)', fontsize=12, fontweight='bold')
ax3.legend(fontsize=10)
ax3.grid(True, alpha=0.3)
ax3.set_ylim([0, 1])

# 4. Matriz de Confusi√≥n Normalizada
ax4 = fig.add_subplot(gs[1, 1])
cm_normalized = last_run_cm.astype('float') / last_run_cm.sum(axis=1)[:, np.newaxis]
sns.heatmap(cm_normalized, annot=True, fmt='.2f', cmap='Blues', 
            xticklabels=labels, yticklabels=labels, 
            cbar_kws={'label': 'Proporci√≥n'}, ax=ax4, vmin=0, vmax=1)
ax4.set_xlabel('Predicho', fontsize=11)
ax4.set_ylabel('Real', fontsize=11)
ax4.set_title('Matriz de Confusi√≥n (normalizada)', fontsize=12, fontweight='bold')

fig.suptitle(
    f'CoSleepNet ‚Äî {DATASET_NAME}\n'
    f'Test F1: {last_run_test_f1:.3f} | Kappa: {last_run_kappa:.3f}',
    fontsize=14, fontweight='bold', y=0.98
)

plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.show()


# DCT OFF

## Set 1: EEG1 + EEG2

In [None]:
# ============================================================
# CELDA: MULTI-RUN (3‚Äì5 corridas) + MEDIA ¬± DESV. EST√ÅNDAR
# + VISUALIZACI√ìN DE LA √öLTIMA CORRIDA
# ============================================================

import numpy as np
import pandas as pd
from sklearn.metrics import (
    precision_recall_fscore_support, confusion_matrix,
    cohen_kappa_score
)
import torch
import random
import matplotlib.pyplot as plt
import seaborn as sns

# ---------- Config ----------
DATASET_ID = 1        # 1, 2 o 3 
LR = 5e-5
BATCH_SIZE = 128
EPOCHS = 50
PATIENCE = 5
USE_FOCAL = True
FOCAL_GAMMA = 1.5
USE_CLASS_WEIGHTS = False

N_RUNS = 5            
BASE_SEED = 42
labels = ["W", "N1", "N2", "N3", "REM"]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(" Device:", device)
if device.type == 'cuda':
    print(f"   GPU: {torch.cuda.get_device_name(0)}")

# ---------- Utils ----------
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def kappa_per_class(y_true, y_pred, n_classes=5):
    cm = confusion_matrix(y_true, y_pred, labels=range(n_classes))
    N = cm.sum()
    kappas = np.zeros(n_classes, dtype=np.float64)

    for k in range(n_classes):
        TP = cm[k, k]
        FN = cm[k, :].sum() - TP
        FP = cm[:, k].sum() - TP
        TN = N - (TP + FN + FP)

        obs = (TP + TN) / max(1, N)

        p_yes_true = (TP + FN) / max(1, N)
        p_yes_pred = (TP + FP) / max(1, N)
        p_no_true  = (FP + TN) / max(1, N)
        p_no_pred  = (FN + TN) / max(1, N)
        exp = p_yes_true * p_yes_pred + p_no_true * p_no_pred

        kappas[k] = (obs - exp) / (1 - exp + 1e-12)

    return cm, kappas

# ---------- Loaders fijos (siempre los mismos splits) ----------
train_loader, val_loader, test_loader, input_shape, DATASET_NAME = make_pytorch_loaders(
    dataset_id=DATASET_ID,
    splits=splits,
    batch_size=BATCH_SIZE,
    num_workers=0,
    pin_memory=True,
)

C_in, H, L = input_shape
print(f"\n Dataset: {DATASET_NAME} | Shape: C={C_in}, H={H}, L={L}")
print(f" Ejecutando {N_RUNS} corridas independientes...\n")

# ---------- Loop de runs ----------
run_summaries = []
per_class_tables = []

# Variables para guardar historia de la √∫ltima corrida
last_run_history = None
last_run_cm = None
last_run_test_f1 = None
last_run_kappa = None
last_run_best_val_f1 = None

for run in range(N_RUNS):
    seed = BASE_SEED + run
    set_seed(seed)
    print(f"\n================ RUN {run+1}/{N_RUNS} ‚Äî seed={seed} ================")

    # Modelo nuevo
    model = CoSleepNetTorch(
        in_ch=C_in,
        num_classes=5,
        lstm_units=64,
        dropout=0.3,
        use_dct=False,
    ).to(device)

    # Criterion
    class_weights = None
    if USE_CLASS_WEIGHTS:
        class_weights = get_class_weights(train_loader, device=device)
        print(f"  Class weights: {class_weights.cpu().numpy().round(3)}")

    if USE_FOCAL:
        criterion = FocalLoss(gamma=FOCAL_GAMMA, alpha=class_weights)
        print(f" Loss: Focal (Œ≥={FOCAL_GAMMA})")
    else:
        criterion = nn.CrossEntropyLoss(weight=class_weights)
        print(f" Loss: CrossEntropy")

    optimizer = torch.optim.Adam(model.parameters(), lr=LR)

    # Historia de entrenamiento
    history = {
        'train_loss': [], 'val_loss': [],
        'train_acc': [], 'val_acc': [],
        'train_f1': [], 'val_f1': []
    }
    
    # Entrenamiento
    best_val_f1 = 0.0
    best_state = None
    patience_counter = 0

    for epoch in range(1, EPOCHS + 1):
        print(f"‚îÅ‚îÅ‚îÅ Epoch {epoch}/{EPOCHS} ‚îÅ‚îÅ‚îÅ")
        tr_loss, tr_acc, tr_f1 = train_one_epoch_pytorch(
            model, train_loader, optimizer, criterion, device
        )
        val_loss, val_acc, val_f1, _, _ = eval_pytorch(
            model, val_loader, criterion, device
        )
        
        # Guardar historia
        history['train_loss'].append(tr_loss)
        history['val_loss'].append(val_loss)
        history['train_acc'].append(tr_acc)
        history['val_acc'].append(val_acc)
        history['train_f1'].append(tr_f1)
        history['val_f1'].append(val_f1)

        improved = val_f1 > best_val_f1
        if improved:
            best_val_f1 = val_f1
            best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
            patience_counter = 0
        else:
            patience_counter += 1

        star = " ‚≠ê" if improved else ""
        print(f"  Train | Loss: {tr_loss:.4f} | Acc: {tr_acc:.4f} | F1: {tr_f1:.4f}")
        print(f"  Val   | Loss: {val_loss:.4f} | Acc: {val_acc:.4f} | F1: {val_f1:.4f}{star}\n")

        if patience_counter >= PATIENCE:
            print(f" Early stopping en epoch {epoch}")
            break

    # Mejor estado
    if best_state is not None:
        model.load_state_dict(best_state)
        model = model.to(device)

    # Test
    print("‚îÅ‚îÅ‚îÅ Evaluando en TEST ‚îÅ‚îÅ‚îÅ")
    test_loss, test_acc, test_f1, y_true, y_pred = eval_pytorch(
        model, test_loader, criterion, device
    )
    kappa_global = cohen_kappa_score(y_true, y_pred)
    cm, kappas = kappa_per_class(y_true, y_pred, n_classes=5)

    prec, rec, f1, support = precision_recall_fscore_support(
        y_true, y_pred, labels=range(5), zero_division=0
    )

    df_metrics = pd.DataFrame({
        "Etapa": labels,
        "Precision": prec,
        "Recall": rec,
        "F1-Score": f1,
        "Kappa": kappas,
        "Soporte": support
    })

    print(f"\n RUN {run+1} ‚Äî Resultados globales:")
    print(f"   Test Loss:     {test_loss:.4f}")
    print(f"   Test Acc:      {test_acc:.4f}")
    print(f"   Test Macro F1: {test_f1:.4f}")
    print(f"   Kappa global:  {kappa_global:.4f}")
    print(f"   Best Val F1:   {best_val_f1:.4f}")

    run_summaries.append({
        "loss": test_loss,
        "acc": test_acc,
        "macro_f1": test_f1,
        "kappa": kappa_global,
        "best_val_f1": best_val_f1
    })
    per_class_tables.append(df_metrics)
    
    # Guardar datos de la √∫ltima corrida
    if run == N_RUNS - 1:
        last_run_history = history
        last_run_cm = cm
        last_run_test_f1 = test_f1
        last_run_kappa = kappa_global
        last_run_best_val_f1 = best_val_f1

    #  Liberar GPU entre runs (por si acaso)
    del model, optimizer, criterion
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

# ---------- Agregados: media ¬± std ----------
print("\n" + "="*70)
print(f" {N_RUNS} corridas completadas ‚Äî {DATASET_NAME}")
print("="*70)

losses  = np.array([r["loss"] for r in run_summaries])
accs    = np.array([r["acc"] for r in run_summaries])
f1s     = np.array([r["macro_f1"] for r in run_summaries])
kappasg = np.array([r["kappa"] for r in run_summaries])

def ms(x):  # mean ¬± std string
    return f"{x.mean():.4f} ¬± {x.std():.4f}"

print(f"Test Loss:      {ms(losses)}")
print(f"Test Accuracy:  {ms(accs)}")
print(f"Test Macro F1:  {ms(f1s)}")
print(f"Kappa global:   {ms(kappasg)}")

# ---------- Agregado por etapa ----------
prec_arr  = np.stack([df["Precision"].values for df in per_class_tables], axis=0)
rec_arr   = np.stack([df["Recall"].values    for df in per_class_tables], axis=0)
f1_arr    = np.stack([df["F1-Score"].values  for df in per_class_tables], axis=0)
kappa_arr = np.stack([df["Kappa"].values     for df in per_class_tables], axis=0)

df_agg = pd.DataFrame({"Etapa": labels})
for name, arr in [("Precision", prec_arr),
                  ("Recall", rec_arr),
                  ("F1-Score", f1_arr),
                  ("Kappa", kappa_arr)]:
    means = arr.mean(axis=0)
    stds  = arr.std(axis=0)
    df_agg[name] = [f"{m:.3f} ¬± {s:.3f}" for m, s in zip(means, stds)]

print("\n M√âTRICAS AGREGADAS POR ETAPA (media ¬± std):")
display(df_agg)

# ---------- VISUALIZACI√ìN DE LA √öLTIMA CORRIDA ----------
fig = plt.figure(figsize=(16, 10))
gs = fig.add_gridspec(2, 2, hspace=0.3, wspace=0.3)

# 1. Loss
ax1 = fig.add_subplot(gs[0, 0])
epochs_range = range(1, len(last_run_history['train_loss']) + 1)
ax1.plot(epochs_range, last_run_history['train_loss'], 'b-', label='Train', linewidth=2)
ax1.plot(epochs_range, last_run_history['val_loss'], 'r-', label='Val', linewidth=2)
ax1.set_xlabel('Epoch', fontsize=11)
ax1.set_ylabel('Loss', fontsize=11)
ax1.set_title('Loss', fontsize=12, fontweight='bold')
ax1.legend(fontsize=10)
ax1.grid(True, alpha=0.3)

# 2. Macro F1 Score
ax2 = fig.add_subplot(gs[0, 1])
ax2.plot(epochs_range, last_run_history['train_f1'], 'b-', label='Train F1', linewidth=2)
ax2.plot(epochs_range, last_run_history['val_f1'], 'r-', label='Val F1', linewidth=2)
ax2.axhline(y=last_run_best_val_f1, color='g', linestyle='--', 
            label=f'Best Val F1: {last_run_best_val_f1:.3f}', linewidth=1.5, alpha=0.7)
ax2.set_xlabel('Epoch', fontsize=11)
ax2.set_ylabel('Macro F1 Score', fontsize=11)
ax2.set_title('Macro F1 Score (m√©trica principal)', fontsize=12, fontweight='bold')
ax2.legend(fontsize=10)
ax2.grid(True, alpha=0.3)
ax2.set_ylim([0, 1])

# 3. Accuracy
ax3 = fig.add_subplot(gs[1, 0])
ax3.plot(epochs_range, last_run_history['train_acc'], 'b-', label='Train Acc', linewidth=2)
ax3.plot(epochs_range, last_run_history['val_acc'], 'r-', label='Val Acc', linewidth=2)
ax3.set_xlabel('Epoch', fontsize=11)
ax3.set_ylabel('Accuracy', fontsize=11)
ax3.set_title('Accuracy (secundaria)', fontsize=12, fontweight='bold')
ax3.legend(fontsize=10)
ax3.grid(True, alpha=0.3)
ax3.set_ylim([0, 1])

# 4. Matriz de Confusi√≥n Normalizada
ax4 = fig.add_subplot(gs[1, 1])
cm_normalized = last_run_cm.astype('float') / last_run_cm.sum(axis=1)[:, np.newaxis]
sns.heatmap(cm_normalized, annot=True, fmt='.2f', cmap='Blues', 
            xticklabels=labels, yticklabels=labels, 
            cbar_kws={'label': 'Proporci√≥n'}, ax=ax4, vmin=0, vmax=1)
ax4.set_xlabel('Predicho', fontsize=11)
ax4.set_ylabel('Real', fontsize=11)
ax4.set_title('Matriz de Confusi√≥n (normalizada)', fontsize=12, fontweight='bold')

fig.suptitle(
    f'CoSleepNet ‚Äî {DATASET_NAME}\n'
    f'Test F1: {last_run_test_f1:.3f} | Kappa: {last_run_kappa:.3f}',
    fontsize=14, fontweight='bold', y=0.98
)

plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.show()


## Set 2: EEG1 + EOG

In [None]:
# ============================================================
# CELDA: MULTI-RUN (3‚Äì5 corridas) + MEDIA ¬± DESV. EST√ÅNDAR
# + VISUALIZACI√ìN DE LA √öLTIMA CORRIDA
# ============================================================

import numpy as np
import pandas as pd
from sklearn.metrics import (
    precision_recall_fscore_support, confusion_matrix,
    cohen_kappa_score
)
import torch
import random
import matplotlib.pyplot as plt
import seaborn as sns

# ---------- Config ----------
DATASET_ID = 2        # 1, 2 o 3 
LR = 5e-5
BATCH_SIZE = 128
EPOCHS = 50
PATIENCE = 5
USE_FOCAL = True
FOCAL_GAMMA = 1.5
USE_CLASS_WEIGHTS = False

N_RUNS = 5            
BASE_SEED = 42
labels = ["W", "N1", "N2", "N3", "REM"]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(" Device:", device)
if device.type == 'cuda':
    print(f"   GPU: {torch.cuda.get_device_name(0)}")

# ---------- Utils ----------
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def kappa_per_class(y_true, y_pred, n_classes=5):
    cm = confusion_matrix(y_true, y_pred, labels=range(n_classes))
    N = cm.sum()
    kappas = np.zeros(n_classes, dtype=np.float64)

    for k in range(n_classes):
        TP = cm[k, k]
        FN = cm[k, :].sum() - TP
        FP = cm[:, k].sum() - TP
        TN = N - (TP + FN + FP)

        obs = (TP + TN) / max(1, N)

        p_yes_true = (TP + FN) / max(1, N)
        p_yes_pred = (TP + FP) / max(1, N)
        p_no_true  = (FP + TN) / max(1, N)
        p_no_pred  = (FN + TN) / max(1, N)
        exp = p_yes_true * p_yes_pred + p_no_true * p_no_pred

        kappas[k] = (obs - exp) / (1 - exp + 1e-12)

    return cm, kappas

# ---------- Loaders fijos (siempre los mismos splits) ----------
train_loader, val_loader, test_loader, input_shape, DATASET_NAME = make_pytorch_loaders(
    dataset_id=DATASET_ID,
    splits=splits,
    batch_size=BATCH_SIZE,
    num_workers=0,
    pin_memory=True,
)

C_in, H, L = input_shape
print(f"\n Dataset: {DATASET_NAME} | Shape: C={C_in}, H={H}, L={L}")
print(f" Ejecutando {N_RUNS} corridas independientes...\n")

# ---------- Loop de runs ----------
run_summaries = []
per_class_tables = []

# Variables para guardar historia de la √∫ltima corrida
last_run_history = None
last_run_cm = None
last_run_test_f1 = None
last_run_kappa = None
last_run_best_val_f1 = None

for run in range(N_RUNS):
    seed = BASE_SEED + run
    set_seed(seed)
    print(f"\n================ RUN {run+1}/{N_RUNS} ‚Äî seed={seed} ================")

    # Modelo nuevo
    model = CoSleepNetTorch(
        in_ch=C_in,
        num_classes=5,
        lstm_units=64,
        dropout=0.3,
        use_dct=False,
    ).to(device)

    # Criterion
    class_weights = None
    if USE_CLASS_WEIGHTS:
        class_weights = get_class_weights(train_loader, device=device)
        print(f"  Class weights: {class_weights.cpu().numpy().round(3)}")

    if USE_FOCAL:
        criterion = FocalLoss(gamma=FOCAL_GAMMA, alpha=class_weights)
        print(f" Loss: Focal (Œ≥={FOCAL_GAMMA})")
    else:
        criterion = nn.CrossEntropyLoss(weight=class_weights)
        print(f" Loss: CrossEntropy")

    optimizer = torch.optim.Adam(model.parameters(), lr=LR)

    # Historia de entrenamiento
    history = {
        'train_loss': [], 'val_loss': [],
        'train_acc': [], 'val_acc': [],
        'train_f1': [], 'val_f1': []
    }
    
    # Entrenamiento
    best_val_f1 = 0.0
    best_state = None
    patience_counter = 0

    for epoch in range(1, EPOCHS + 1):
        print(f"‚îÅ‚îÅ‚îÅ Epoch {epoch}/{EPOCHS} ‚îÅ‚îÅ‚îÅ")
        tr_loss, tr_acc, tr_f1 = train_one_epoch_pytorch(
            model, train_loader, optimizer, criterion, device
        )
        val_loss, val_acc, val_f1, _, _ = eval_pytorch(
            model, val_loader, criterion, device
        )
        
        # Guardar historia
        history['train_loss'].append(tr_loss)
        history['val_loss'].append(val_loss)
        history['train_acc'].append(tr_acc)
        history['val_acc'].append(val_acc)
        history['train_f1'].append(tr_f1)
        history['val_f1'].append(val_f1)

        improved = val_f1 > best_val_f1
        if improved:
            best_val_f1 = val_f1
            best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
            patience_counter = 0
        else:
            patience_counter += 1

        star = " ‚≠ê" if improved else ""
        print(f"  Train | Loss: {tr_loss:.4f} | Acc: {tr_acc:.4f} | F1: {tr_f1:.4f}")
        print(f"  Val   | Loss: {val_loss:.4f} | Acc: {val_acc:.4f} | F1: {val_f1:.4f}{star}\n")

        if patience_counter >= PATIENCE:
            print(f" Early stopping en epoch {epoch}")
            break

    # Mejor estado
    if best_state is not None:
        model.load_state_dict(best_state)
        model = model.to(device)

    # Test
    print("‚îÅ‚îÅ‚îÅ Evaluando en TEST ‚îÅ‚îÅ‚îÅ")
    test_loss, test_acc, test_f1, y_true, y_pred = eval_pytorch(
        model, test_loader, criterion, device
    )
    kappa_global = cohen_kappa_score(y_true, y_pred)
    cm, kappas = kappa_per_class(y_true, y_pred, n_classes=5)

    prec, rec, f1, support = precision_recall_fscore_support(
        y_true, y_pred, labels=range(5), zero_division=0
    )

    df_metrics = pd.DataFrame({
        "Etapa": labels,
        "Precision": prec,
        "Recall": rec,
        "F1-Score": f1,
        "Kappa": kappas,
        "Soporte": support
    })

    print(f"\n RUN {run+1} ‚Äî Resultados globales:")
    print(f"   Test Loss:     {test_loss:.4f}")
    print(f"   Test Acc:      {test_acc:.4f}")
    print(f"   Test Macro F1: {test_f1:.4f}")
    print(f"   Kappa global:  {kappa_global:.4f}")
    print(f"   Best Val F1:   {best_val_f1:.4f}")

    run_summaries.append({
        "loss": test_loss,
        "acc": test_acc,
        "macro_f1": test_f1,
        "kappa": kappa_global,
        "best_val_f1": best_val_f1
    })
    per_class_tables.append(df_metrics)
    
    # Guardar datos de la √∫ltima corrida
    if run == N_RUNS - 1:
        last_run_history = history
        last_run_cm = cm
        last_run_test_f1 = test_f1
        last_run_kappa = kappa_global
        last_run_best_val_f1 = best_val_f1

    #  Liberar GPU entre runs (por si acaso)
    del model, optimizer, criterion
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

# ---------- Agregados: media ¬± std ----------
print("\n" + "="*70)
print(f" {N_RUNS} corridas completadas ‚Äî {DATASET_NAME}")
print("="*70)

losses  = np.array([r["loss"] for r in run_summaries])
accs    = np.array([r["acc"] for r in run_summaries])
f1s     = np.array([r["macro_f1"] for r in run_summaries])
kappasg = np.array([r["kappa"] for r in run_summaries])

def ms(x):  # mean ¬± std string
    return f"{x.mean():.4f} ¬± {x.std():.4f}"

print(f"Test Loss:      {ms(losses)}")
print(f"Test Accuracy:  {ms(accs)}")
print(f"Test Macro F1:  {ms(f1s)}")
print(f"Kappa global:   {ms(kappasg)}")

# ---------- Agregado por etapa ----------
prec_arr  = np.stack([df["Precision"].values for df in per_class_tables], axis=0)
rec_arr   = np.stack([df["Recall"].values    for df in per_class_tables], axis=0)
f1_arr    = np.stack([df["F1-Score"].values  for df in per_class_tables], axis=0)
kappa_arr = np.stack([df["Kappa"].values     for df in per_class_tables], axis=0)

df_agg = pd.DataFrame({"Etapa": labels})
for name, arr in [("Precision", prec_arr),
                  ("Recall", rec_arr),
                  ("F1-Score", f1_arr),
                  ("Kappa", kappa_arr)]:
    means = arr.mean(axis=0)
    stds  = arr.std(axis=0)
    df_agg[name] = [f"{m:.3f} ¬± {s:.3f}" for m, s in zip(means, stds)]

print("\n M√âTRICAS AGREGADAS POR ETAPA (media ¬± std):")
display(df_agg)

# ---------- VISUALIZACI√ìN DE LA √öLTIMA CORRIDA ----------
fig = plt.figure(figsize=(16, 10))
gs = fig.add_gridspec(2, 2, hspace=0.3, wspace=0.3)

# 1. Loss
ax1 = fig.add_subplot(gs[0, 0])
epochs_range = range(1, len(last_run_history['train_loss']) + 1)
ax1.plot(epochs_range, last_run_history['train_loss'], 'b-', label='Train', linewidth=2)
ax1.plot(epochs_range, last_run_history['val_loss'], 'r-', label='Val', linewidth=2)
ax1.set_xlabel('Epoch', fontsize=11)
ax1.set_ylabel('Loss', fontsize=11)
ax1.set_title('Loss', fontsize=12, fontweight='bold')
ax1.legend(fontsize=10)
ax1.grid(True, alpha=0.3)

# 2. Macro F1 Score
ax2 = fig.add_subplot(gs[0, 1])
ax2.plot(epochs_range, last_run_history['train_f1'], 'b-', label='Train F1', linewidth=2)
ax2.plot(epochs_range, last_run_history['val_f1'], 'r-', label='Val F1', linewidth=2)
ax2.axhline(y=last_run_best_val_f1, color='g', linestyle='--', 
            label=f'Best Val F1: {last_run_best_val_f1:.3f}', linewidth=1.5, alpha=0.7)
ax2.set_xlabel('Epoch', fontsize=11)
ax2.set_ylabel('Macro F1 Score', fontsize=11)
ax2.set_title('Macro F1 Score (m√©trica principal)', fontsize=12, fontweight='bold')
ax2.legend(fontsize=10)
ax2.grid(True, alpha=0.3)
ax2.set_ylim([0, 1])

# 3. Accuracy
ax3 = fig.add_subplot(gs[1, 0])
ax3.plot(epochs_range, last_run_history['train_acc'], 'b-', label='Train Acc', linewidth=2)
ax3.plot(epochs_range, last_run_history['val_acc'], 'r-', label='Val Acc', linewidth=2)
ax3.set_xlabel('Epoch', fontsize=11)
ax3.set_ylabel('Accuracy', fontsize=11)
ax3.set_title('Accuracy (secundaria)', fontsize=12, fontweight='bold')
ax3.legend(fontsize=10)
ax3.grid(True, alpha=0.3)
ax3.set_ylim([0, 1])

# 4. Matriz de Confusi√≥n Normalizada
ax4 = fig.add_subplot(gs[1, 1])
cm_normalized = last_run_cm.astype('float') / last_run_cm.sum(axis=1)[:, np.newaxis]
sns.heatmap(cm_normalized, annot=True, fmt='.2f', cmap='Blues', 
            xticklabels=labels, yticklabels=labels, 
            cbar_kws={'label': 'Proporci√≥n'}, ax=ax4, vmin=0, vmax=1)
ax4.set_xlabel('Predicho', fontsize=11)
ax4.set_ylabel('Real', fontsize=11)
ax4.set_title('Matriz de Confusi√≥n (normalizada)', fontsize=12, fontweight='bold')

fig.suptitle(
    f'CoSleepNet ‚Äî {DATASET_NAME}\n'
    f'Test F1: {last_run_test_f1:.3f} | Kappa: {last_run_kappa:.3f}',
    fontsize=14, fontweight='bold', y=0.98
)

plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.show()


## Set 3: EEG1 + EEG2 + EOG

In [None]:
# ============================================================
# CELDA: MULTI-RUN (3‚Äì5 corridas) + MEDIA ¬± DESV. EST√ÅNDAR
# + VISUALIZACI√ìN DE LA √öLTIMA CORRIDA
# ============================================================

import numpy as np
import pandas as pd
from sklearn.metrics import (
    precision_recall_fscore_support, confusion_matrix,
    cohen_kappa_score
)
import torch
import random
import matplotlib.pyplot as plt
import seaborn as sns

# ---------- Config ----------
DATASET_ID = 3        # 1, 2 o 3 
LR = 5e-5
BATCH_SIZE = 128
EPOCHS = 50
PATIENCE = 5
USE_FOCAL = True
FOCAL_GAMMA = 1.5
USE_CLASS_WEIGHTS = False

N_RUNS = 5            
BASE_SEED = 42
labels = ["W", "N1", "N2", "N3", "REM"]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(" Device:", device)
if device.type == 'cuda':
    print(f"   GPU: {torch.cuda.get_device_name(0)}")

# ---------- Utils ----------
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def kappa_per_class(y_true, y_pred, n_classes=5):
    cm = confusion_matrix(y_true, y_pred, labels=range(n_classes))
    N = cm.sum()
    kappas = np.zeros(n_classes, dtype=np.float64)

    for k in range(n_classes):
        TP = cm[k, k]
        FN = cm[k, :].sum() - TP
        FP = cm[:, k].sum() - TP
        TN = N - (TP + FN + FP)

        obs = (TP + TN) / max(1, N)

        p_yes_true = (TP + FN) / max(1, N)
        p_yes_pred = (TP + FP) / max(1, N)
        p_no_true  = (FP + TN) / max(1, N)
        p_no_pred  = (FN + TN) / max(1, N)
        exp = p_yes_true * p_yes_pred + p_no_true * p_no_pred

        kappas[k] = (obs - exp) / (1 - exp + 1e-12)

    return cm, kappas

# ---------- Loaders fijos (siempre los mismos splits) ----------
train_loader, val_loader, test_loader, input_shape, DATASET_NAME = make_pytorch_loaders(
    dataset_id=DATASET_ID,
    splits=splits,
    batch_size=BATCH_SIZE,
    num_workers=0,
    pin_memory=True,
)

C_in, H, L = input_shape
print(f"\n Dataset: {DATASET_NAME} | Shape: C={C_in}, H={H}, L={L}")
print(f" Ejecutando {N_RUNS} corridas independientes...\n")

# ---------- Loop de runs ----------
run_summaries = []
per_class_tables = []

# Variables para guardar historia de la √∫ltima corrida
last_run_history = None
last_run_cm = None
last_run_test_f1 = None
last_run_kappa = None
last_run_best_val_f1 = None

for run in range(N_RUNS):
    seed = BASE_SEED + run
    set_seed(seed)
    print(f"\n================ RUN {run+1}/{N_RUNS} ‚Äî seed={seed} ================")

    # Modelo nuevo
    model = CoSleepNetTorch(
        in_ch=C_in,
        num_classes=5,
        lstm_units=64,
        dropout=0.3,
        use_dct=False,
    ).to(device)

    # Criterion
    class_weights = None
    if USE_CLASS_WEIGHTS:
        class_weights = get_class_weights(train_loader, device=device)
        print(f"  Class weights: {class_weights.cpu().numpy().round(3)}")

    if USE_FOCAL:
        criterion = FocalLoss(gamma=FOCAL_GAMMA, alpha=class_weights)
        print(f" Loss: Focal (Œ≥={FOCAL_GAMMA})")
    else:
        criterion = nn.CrossEntropyLoss(weight=class_weights)
        print(f" Loss: CrossEntropy")

    optimizer = torch.optim.Adam(model.parameters(), lr=LR)

    # Historia de entrenamiento
    history = {
        'train_loss': [], 'val_loss': [],
        'train_acc': [], 'val_acc': [],
        'train_f1': [], 'val_f1': []
    }
    
    # Entrenamiento
    best_val_f1 = 0.0
    best_state = None
    patience_counter = 0

    for epoch in range(1, EPOCHS + 1):
        print(f"‚îÅ‚îÅ‚îÅ Epoch {epoch}/{EPOCHS} ‚îÅ‚îÅ‚îÅ")
        tr_loss, tr_acc, tr_f1 = train_one_epoch_pytorch(
            model, train_loader, optimizer, criterion, device
        )
        val_loss, val_acc, val_f1, _, _ = eval_pytorch(
            model, val_loader, criterion, device
        )
        
        # Guardar historia
        history['train_loss'].append(tr_loss)
        history['val_loss'].append(val_loss)
        history['train_acc'].append(tr_acc)
        history['val_acc'].append(val_acc)
        history['train_f1'].append(tr_f1)
        history['val_f1'].append(val_f1)

        improved = val_f1 > best_val_f1
        if improved:
            best_val_f1 = val_f1
            best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
            patience_counter = 0
        else:
            patience_counter += 1

        star = " ‚≠ê" if improved else ""
        print(f"  Train | Loss: {tr_loss:.4f} | Acc: {tr_acc:.4f} | F1: {tr_f1:.4f}")
        print(f"  Val   | Loss: {val_loss:.4f} | Acc: {val_acc:.4f} | F1: {val_f1:.4f}{star}\n")

        if patience_counter >= PATIENCE:
            print(f" Early stopping en epoch {epoch}")
            break

    # Mejor estado
    if best_state is not None:
        model.load_state_dict(best_state)
        model = model.to(device)

    # Test
    print("‚îÅ‚îÅ‚îÅ Evaluando en TEST ‚îÅ‚îÅ‚îÅ")
    test_loss, test_acc, test_f1, y_true, y_pred = eval_pytorch(
        model, test_loader, criterion, device
    )
    kappa_global = cohen_kappa_score(y_true, y_pred)
    cm, kappas = kappa_per_class(y_true, y_pred, n_classes=5)

    prec, rec, f1, support = precision_recall_fscore_support(
        y_true, y_pred, labels=range(5), zero_division=0
    )

    df_metrics = pd.DataFrame({
        "Etapa": labels,
        "Precision": prec,
        "Recall": rec,
        "F1-Score": f1,
        "Kappa": kappas,
        "Soporte": support
    })

    print(f"\n RUN {run+1} ‚Äî Resultados globales:")
    print(f"   Test Loss:     {test_loss:.4f}")
    print(f"   Test Acc:      {test_acc:.4f}")
    print(f"   Test Macro F1: {test_f1:.4f}")
    print(f"   Kappa global:  {kappa_global:.4f}")
    print(f"   Best Val F1:   {best_val_f1:.4f}")

    run_summaries.append({
        "loss": test_loss,
        "acc": test_acc,
        "macro_f1": test_f1,
        "kappa": kappa_global,
        "best_val_f1": best_val_f1
    })
    per_class_tables.append(df_metrics)
    
    # Guardar datos de la √∫ltima corrida
    if run == N_RUNS - 1:
        last_run_history = history
        last_run_cm = cm
        last_run_test_f1 = test_f1
        last_run_kappa = kappa_global
        last_run_best_val_f1 = best_val_f1

    #  Liberar GPU entre runs (por si acaso)
    del model, optimizer, criterion
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

# ---------- Agregados: media ¬± std ----------
print("\n" + "="*70)
print(f" {N_RUNS} corridas completadas ‚Äî {DATASET_NAME}")
print("="*70)

losses  = np.array([r["loss"] for r in run_summaries])
accs    = np.array([r["acc"] for r in run_summaries])
f1s     = np.array([r["macro_f1"] for r in run_summaries])
kappasg = np.array([r["kappa"] for r in run_summaries])

def ms(x):  # mean ¬± std string
    return f"{x.mean():.4f} ¬± {x.std():.4f}"

print(f"Test Loss:      {ms(losses)}")
print(f"Test Accuracy:  {ms(accs)}")
print(f"Test Macro F1:  {ms(f1s)}")
print(f"Kappa global:   {ms(kappasg)}")

# ---------- Agregado por etapa ----------
prec_arr  = np.stack([df["Precision"].values for df in per_class_tables], axis=0)
rec_arr   = np.stack([df["Recall"].values    for df in per_class_tables], axis=0)
f1_arr    = np.stack([df["F1-Score"].values  for df in per_class_tables], axis=0)
kappa_arr = np.stack([df["Kappa"].values     for df in per_class_tables], axis=0)

df_agg = pd.DataFrame({"Etapa": labels})
for name, arr in [("Precision", prec_arr),
                  ("Recall", rec_arr),
                  ("F1-Score", f1_arr),
                  ("Kappa", kappa_arr)]:
    means = arr.mean(axis=0)
    stds  = arr.std(axis=0)
    df_agg[name] = [f"{m:.3f} ¬± {s:.3f}" for m, s in zip(means, stds)]

print("\n M√âTRICAS AGREGADAS POR ETAPA (media ¬± std):")
display(df_agg)

# ---------- VISUALIZACI√ìN DE LA √öLTIMA CORRIDA ----------
fig = plt.figure(figsize=(16, 10))
gs = fig.add_gridspec(2, 2, hspace=0.3, wspace=0.3)

# 1. Loss
ax1 = fig.add_subplot(gs[0, 0])
epochs_range = range(1, len(last_run_history['train_loss']) + 1)
ax1.plot(epochs_range, last_run_history['train_loss'], 'b-', label='Train', linewidth=2)
ax1.plot(epochs_range, last_run_history['val_loss'], 'r-', label='Val', linewidth=2)
ax1.set_xlabel('Epoch', fontsize=11)
ax1.set_ylabel('Loss', fontsize=11)
ax1.set_title('Loss', fontsize=12, fontweight='bold')
ax1.legend(fontsize=10)
ax1.grid(True, alpha=0.3)

# 2. Macro F1 Score
ax2 = fig.add_subplot(gs[0, 1])
ax2.plot(epochs_range, last_run_history['train_f1'], 'b-', label='Train F1', linewidth=2)
ax2.plot(epochs_range, last_run_history['val_f1'], 'r-', label='Val F1', linewidth=2)
ax2.axhline(y=last_run_best_val_f1, color='g', linestyle='--', 
            label=f'Best Val F1: {last_run_best_val_f1:.3f}', linewidth=1.5, alpha=0.7)
ax2.set_xlabel('Epoch', fontsize=11)
ax2.set_ylabel('Macro F1 Score', fontsize=11)
ax2.set_title('Macro F1 Score (m√©trica principal)', fontsize=12, fontweight='bold')
ax2.legend(fontsize=10)
ax2.grid(True, alpha=0.3)
ax2.set_ylim([0, 1])

# 3. Accuracy
ax3 = fig.add_subplot(gs[1, 0])
ax3.plot(epochs_range, last_run_history['train_acc'], 'b-', label='Train Acc', linewidth=2)
ax3.plot(epochs_range, last_run_history['val_acc'], 'r-', label='Val Acc', linewidth=2)
ax3.set_xlabel('Epoch', fontsize=11)
ax3.set_ylabel('Accuracy', fontsize=11)
ax3.set_title('Accuracy (secundaria)', fontsize=12, fontweight='bold')
ax3.legend(fontsize=10)
ax3.grid(True, alpha=0.3)
ax3.set_ylim([0, 1])

# 4. Matriz de Confusi√≥n Normalizada
ax4 = fig.add_subplot(gs[1, 1])
cm_normalized = last_run_cm.astype('float') / last_run_cm.sum(axis=1)[:, np.newaxis]
sns.heatmap(cm_normalized, annot=True, fmt='.2f', cmap='Blues', 
            xticklabels=labels, yticklabels=labels, 
            cbar_kws={'label': 'Proporci√≥n'}, ax=ax4, vmin=0, vmax=1)
ax4.set_xlabel('Predicho', fontsize=11)
ax4.set_ylabel('Real', fontsize=11)
ax4.set_title('Matriz de Confusi√≥n (normalizada)', fontsize=12, fontweight='bold')

fig.suptitle(
    f'CoSleepNet ‚Äî {DATASET_NAME}\n'
    f'Test F1: {last_run_test_f1:.3f} | Kappa: {last_run_kappa:.3f}',
    fontsize=14, fontweight='bold', y=0.98
)

plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.show()
