# Pipeline Lead II: Detección de anomalías con CVAE + Clasificador
**Flujo completo:** preprocesado → entrenamiento CVAE → extracción de errores & latentes → entrenamiento clasificador → métricas ROC/PR.

In [ ]:
# Instala dependencias una sola vez
!pip install neurokit2 xgboost torch wfdb

In [ ]:
import os
import numpy as np
import matplotlib.pyplot as plt

# Importa tus scripts
from scripts.preprocess import load_all_signals, preprocess_and_segment
from scripts.model import CVAE, loss_function
from scripts.test_functions import (
    extract_reconstruction_errors,
    train_classifier,
    compute_metrics
)
import torch


In [ ]:
# Parámetros generales
DATA_DIR   = r"C:\Users\anapt\Repositorios\TP-final_ML\TP-final_ML\data"
BEATS_DIR  = os.path.join(DATA_DIR, 'processed', 'beats')
os.makedirs(BEATS_DIR, exist_ok=True)
EPOCHS     = 18
BATCH_SIZE = 32
LR         = 1e-3
LATENT_DIM = 60
DEVICE     = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [ ]:
# 1) Carga y preprocesamiento recursivo
signals = load_all_signals(DATA_DIR)  # devuelve lista de (signal, fs)
print(f"Señales encontradas: {len(signals)}")
beats = []
for sig, fs in signals:
    beat = preprocess_and_segment(sig, fs)
    beats.append(beat)
beats = np.stack(beats, axis=0)  # (N, 1, 2048)
print(f"Beats procesados: {beats.shape}")


In [ ]:
# 2) Train/Test Split y DataLoader
from torch.utils.data import DataLoader, TensorDataset, random_split

dataset = TensorDataset(torch.tensor(beats, dtype=torch.float32))
n_val = int(0.2 * len(dataset))
n_train = len(dataset) - n_val
train_set, val_set = random_split(dataset, [n_train, n_val])
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
val_loader   = DataLoader(val_set,   batch_size=BATCH_SIZE)


In [ ]:
# 3) Definición y entrenamiento del CVAE
model = CVAE(in_channels=1, latent_dim=LATENT_DIM, input_length=2048).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
best_mae = float('inf')

from tqdm.auto import tqdm
for epoch in range(1, EPOCHS+1):
    model.train()
    losses = []
    for (x,) in tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS}"):
        x = x.to(DEVICE)
        recon, mu, logvar = model(x)
        loss = loss_function(recon, x, mu, logvar)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
    # Validación MAE
    model.eval()
    val_err = extract_reconstruction_errors(model, val_loader, DEVICE)
    val_mae = val_err.mean()
    print(f"→ Ep {epoch}: train_loss={np.mean(losses):.4f}, val_mae={val_mae:.4f}")
    if val_mae < best_mae:
        best_mae = val_mae
        torch.save(model.state_dict(), 'best_cvae.pth')


In [ ]:
# 4) Extracción de errores para sanos y anomalías (PTB-XL + Chapman)
model.load_state_dict(torch.load('best_cvae.pth', map_location=DEVICE))
healthy_err = extract_reconstruction_errors(model, val_loader, DEVICE)
print(f"Errores sanos (hold-out): mean={healthy_err.mean():.4f}, std={healthy_err.std():.4f}")
# Ahora lee anomalías: pasa loader=None y subfolders
ptb_err  = extract_reconstruction_errors(model, None, DEVICE, DATA_DIR, 'ptb-xl')
chap_err = extract_reconstruction_errors(model, None, DEVICE, DATA_DIR, 'ChapmanShaoxing')
print(f"PTB-XL anomalías: {len(ptb_err)}, Chapman anomalías: {len(chap_err)}")


In [ ]:
# 5) Entrena clasificador y obtiene probabilidades
probs, y_true = train_classifier(healthy_err, ptb_err, chap_err)
print(f"Clases: {{np.unique(y_true, return_counts=True)}}")


In [ ]:
# 6) Cálculo de métricas ROC/PR y umbral óptimo
metrics = compute_metrics(y_true, probs)
print(f"AUC-ROC: {metrics['auc_roc']:.3f}, AUC-PR: {metrics['auc_pr']:.3f}")

# Curva ROC
from sklearn.metrics import roc_curve, precision_recall_curve
fpr, tpr, _ = roc_curve(y_true, probs)
prec, rec, _ = precision_recall_curve(y_true, probs)
best_thr = (tpr - fpr).argmax()  # Youden

plt.figure(); plt.plot(fpr, tpr, label=f"AUC={metrics['auc_roc']:.3f}"); plt.plot([0,1],[0,1],'k--'); plt.legend(); plt.title('ROC'); plt.show()
plt.figure(); plt.plot(rec, prec, label=f"AUC-PR={metrics['auc_pr']:.3f}"); plt.legend(); plt.title('PR'); plt.show()
print(f"Umbral óptimo (Youden): {best_thr:.4f}")
