In [None]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from scipy.signal import butter, filtfilt
from scipy.io import loadmat
import wfdb
from wfdb import rdrecord

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from train_ecg_vae_reutilice import ECGDataset, ECG_VAE, loss_fn

In [None]:
PTB_DIR = 'data/ptb-xl'
CH_DIR  = 'data/ChapmanShaoxing'
# Hyperparams
batch_size = 16
epochs = 50
lr = 1e-3
z_dim = 16
seq_len = 5000

# Dataset & Loader
ds = ECGDataset(PTB_DIR, CH_DIR, sample_length=seq_len)
loader = DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=4)

# Model, optim
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
vae = ECG_VAE(z_dim=z_dim, seq_len=seq_len).to(device)
opt = optim.Adam(vae.parameters(), lr=lr)

# Train
for ep in range(1, epochs+1):
    vae.train()
    tot_loss = 0
    for batch in loader:
        x = batch.to(device)
        x_hat, mu, logv = vae(x)
        loss, recon, kld = loss_fn(x, x_hat, mu, logv)
        opt.zero_grad(); loss.backward(); opt.step()
        tot_loss += loss.item()
    print(f"Epoch {ep}/{epochs} - Loss: {tot_loss/len(loader):.4f} (Recon {recon:.4f}, KLD {kld:.4f})")
# Guardar
torch.save(vae.state_dict(), 'ecg_vae.pth')
print("Modelo guardado en ecg_vae.pth")
