In [1]:
import os, sys, time
from types import SimpleNamespace
from pathlib import Path

PROJECT_ROOT = Path(r"C:\Users\Joseph\generative-health-models")
SRC = PROJECT_ROOT / "src"

# Ensure our local 'src' comes before site-packages (avoid clash with HF 'datasets')
os.chdir(PROJECT_ROOT)
if str(SRC) not in sys.path:
    sys.path.insert(0, str(SRC))
print("cwd:", os.getcwd())
print("sys.path[0]:", sys.path[0])

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from datasets.wesad import make_loader
import models.tc_multigan as tcm

cwd: C:\Users\Joseph\generative-health-models
sys.path[0]: C:\Users\Joseph\generative-health-models\src


In [2]:
# ---------------- config & seed ----------------
def set_seed(seed=42):
    import random, numpy as np
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
set_seed(42)

cfg = SimpleNamespace(
    # data
    data_root        = r"C:\Users\Joseph\generative-health-models\data\processed",
    fold             = "tc_multigan_fold_s10",
    train_split      = "train",
    seq_length_low   = 120,
    condition_dim    = 4,

    # loader
    batch_size       = 32,
    num_workers      = 0,
    weighted_sampling= False,

    # model / opt
    hidden_dim       = 256,
    lr_g             = 1e-3,
    epochs_ae        = 25,

    # device
    device           = "cuda" if torch.cuda.is_available() else "cpu",
)
print("Config:", cfg)

root = Path(cfg.data_root)
fold_dir = root / cfg.fold


Config: namespace(data_root='C:\\Users\\Joseph\\generative-health-models\\data\\processed', fold='tc_multigan_fold_s10', train_split='train', seq_length_low=120, condition_dim=4, batch_size=32, num_workers=0, weighted_sampling=False, hidden_dim=256, lr_g=0.001, epochs_ae=25, device='cuda')


In [3]:
# ---------------- loaders (Option A: no dataset normalization) ----------------
train_dl = make_loader(
    root_dir=root, fold=cfg.fold, split="train",
    window_size_low=cfg.seq_length_low,
    batch_size=cfg.batch_size, shuffle=True, num_workers=cfg.num_workers,
    condition_dim=cfg.condition_dim,
    augment=True,
    normalize=False, normalize_ecg=False,     # <-- Option A
    force_use_stats=False,
    use_split_stats_if_needed=False,
    stats_low_path=None, stats_ecg_path=None,
    debug_print=False,                        # set True if you want one-time prints
)

val_dl = make_loader(
    root_dir=root, fold=cfg.fold, split="test",
    window_size_low=cfg.seq_length_low,
    batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers,
    condition_dim=cfg.condition_dim,
    augment=False,
    normalize=False, normalize_ecg=False,     # <-- Option A
    force_use_stats=False,
    use_split_stats_if_needed=False,          # we’ll standardize val on-the-fly below
    stats_low_path=None, stats_ecg_path=None,
    debug_print=False,
)


In [4]:
b_tr = next(iter(train_dl)); b_va = next(iter(val_dl))
xL_tr, xE_tr = b_tr["signal_low"].float(), b_tr["signal_ecg"].float()
xL_va, xE_va = b_va["signal_low"].float(), b_va["signal_ecg"].float()
print("TRAIN  std low:", xL_tr.std(dim=(0,1)).tolist(), "ecg:", xE_tr.std(dim=(0,1)).tolist())
print("VAL    std low:", xL_va.std(dim=(0,1)).tolist(), "ecg:", xE_va.std(dim=(0,1)).tolist())

TRAIN  std low: [0.7463854551315308, 1.0361276865005493] ecg: [0.9888325929641724]
VAL    std low: [0.22556525468826294, 0.7407307624816895] ecg: [0.5809653401374817]


In [5]:
@torch.no_grad()
def compute_stats_from_loader(dl):
    """Mean/std over the whole loader per channel, stable & memory-safe."""
    sumL = torch.zeros(2, dtype=torch.float64)
    sumsqL = torch.zeros(2, dtype=torch.float64)
    nL = 0
    sumE = torch.zeros(1, dtype=torch.float64)
    sumsqE = torch.zeros(1, dtype=torch.float64)
    nE = 0

    for b in dl:
        xL = b["signal_low"].to(dtype=torch.float32)   # (B,T_low,2)
        xE = b["signal_ecg"].to(dtype=torch.float32)   # (B,T_ecg,1)
        sumL   += xL.sum(dim=(0,1)).double()
        sumsqL += (xL**2).sum(dim=(0,1)).double()
        nL     += xL.shape[0] * xL.shape[1]
        sumE   += xE.sum(dim=(0,1)).view(-1).double()
        sumsqE += (xE**2).sum(dim=(0,1)).view(-1).double()
        nE     += xE.shape[0] * xE.shape[1]

    meanL = (sumL / nL).to(torch.float32).numpy()
    varL  = (sumsqL / nL - (sumL / nL)**2).clamp_min(1e-12).to(torch.float32).numpy()
    stdL  = np.sqrt(varL)

    meanE = (sumE / nE).to(torch.float32).numpy()
    varE  = (sumsqE / nE - (sumE / nE)**2).clamp_min(1e-12).to(torch.float32).numpy()
    stdE  = np.sqrt(varE)
    return meanL, stdL, meanE, stdE


In [6]:
val_muL, val_stdL, val_muE, val_stdE = compute_stats_from_loader(val_dl)

device = torch.device(cfg.device)
_EPS = 1e-8

# Make them tensors with shape (1, 1, C) for broadcasting over (B, T, C)
low_mu  = torch.from_numpy(val_muL).to(device).view(1, 1, 2)
low_std = torch.from_numpy(val_stdL).to(device).view(1, 1, 2)
ecg_mu  = torch.from_numpy(val_muE).to(device).view(1, 1, 1)
ecg_std = torch.from_numpy(val_stdE).to(device).view(1, 1, 1)

In [7]:
@torch.no_grad()
def eval_stream_losses(dl, model, device, low_mu, low_std, ecg_mu, ecg_std, eps=1e-8):
    """Evaluate on val/test with inputs standardized by val/test’s own stats."""
    model.eval()
    totL = totE = n = 0
    for batch in dl:
        xL = batch["signal_low"].to(device)   # (B, T_low, 2)
        xE = batch["signal_ecg"].to(device)   # (B, T_ecg, 1)

        xLz = (xL - low_mu) / (low_std + eps)
        xEz = (xE - ecg_mu) / (ecg_std + eps)

        yL, yE = model(xLz, xEz)

        totL += F.mse_loss(yL, xLz).item() * xL.size(0)
        totE += F.mse_loss(yE, xEz).item() * xE.size(0)
        n += xL.size(0)
    return totL / max(1, n), totE / max(1, n)

In [8]:
_probe = next(iter(train_dl))
T_low = _probe["signal_low"].shape[1]
T_ecg = _probe["signal_ecg"].shape[1]

ae = tcm.AutoencoderER(
    hidden_dim=cfg.hidden_dim,
    seq_length_low=T_low,
    seq_length_ecg=T_ecg,
    latent_downsample=4,
    use_ecg=True,
).to(device)

optim = torch.optim.Adam(ae.parameters(), lr=cfg.lr_g, betas=(0.9, 0.99))
mse = nn.MSELoss()

In [12]:
# === QUICK ECG sanity check (run once, right before the training loop) ===
with torch.no_grad():
    ae.eval()
    b = next(iter(val_dl))
    xL = b["signal_low"].to(device)
    xE = b["signal_ecg"].to(device)

    # Standardize val inputs exactly like in eval_stream_losses
    xLz = (xL - low_mu) / (low_std + _EPS)      # (B, T_low, 2)
    xEz = (xE - ecg_mu) / (ecg_std + _EPS)      # (B, T_ecg, 1)

    yL, yE = ae(xLz, xEz)

    print("xEz mean/std:", xEz.mean().item(), xEz.std().item())
    print("yE  mean/std:", yE.mean().item(),  yE.std().item())

xEz mean/std: -0.00010215160000370815 1.0806077718734741
yE  mean/std: -0.0012157309101894498 0.004450276959687471


In [13]:
ae.train()
b = next(iter(train_dl))
xL, xE = b["signal_low"].to(cfg.device), b["signal_ecg"].to(cfg.device)

opt = torch.optim.Adam(ae.parameters(), lr=1e-3)
for t in range(200):
    opt.zero_grad(set_to_none=True)
    yL, yE = ae(xL, xE)
    lossE = F.mse_loss(yE, xE)  # focus on ECG only
    lossE.backward()
    opt.step()
    if (t+1) % 20 == 0:
        print(f"step {t+1:03d} ECG loss: {lossE.item():.4f}")

step 020 ECG loss: 1.0550
step 040 ECG loss: 1.0499
step 060 ECG loss: 1.0444
step 080 ECG loss: 1.0383
step 100 ECG loss: 1.0354
step 120 ECG loss: 1.0425
step 140 ECG loss: 1.0281
step 160 ECG loss: 1.0333
step 180 ECG loss: 1.0197
step 200 ECG loss: 1.0193


In [9]:
out_dir = Path.cwd() / "results" / "logs" / "ae"
out_dir.mkdir(parents=True, exist_ok=True)
best_loss = float("inf")

def run_epoch(dl, train=True):
    ae.train() if train else ae.eval()
    total, n = 0.0, 0
    with torch.set_grad_enabled(train):
        for batch in dl:
            x_low = batch["signal_low"].to(device)
            x_ecg = batch["signal_ecg"].to(device)

            low_hat, ecg_hat = ae(x_low, x_ecg)
            loss_low = mse(low_hat, x_low)
            loss_ecg = mse(ecg_hat, x_ecg)
            loss = 0.5 * loss_low + 0.5 * loss_ecg

            if train:
                optim.zero_grad(set_to_none=True)
                loss.backward()
                optim.step()

            bs = x_low.size(0)
            total += loss.item() * bs
            n += bs
    return total / max(1, n)


In [10]:
# ---------------- training loop ----------------
for epoch in range(1, cfg.epochs_ae + 1):
    t0 = time.time()
    train_loss = run_epoch(train_dl, train=True)

    # Validation on standardized inputs (consistent metric)
    val_low_mse, val_ecg_mse = eval_stream_losses(
        val_dl, ae, device, low_mu, low_std, ecg_mu, ecg_std, eps=_EPS
    )
    val_loss = 0.5 * val_low_mse + 0.5 * val_ecg_mse

    dt = time.time() - t0
    print(f"[{epoch:03}/{cfg.epochs_ae}] train={train_loss:.6f} | "
          f"val={val_loss:.6f} (low={val_low_mse:.6f}, ecg={val_ecg_mse:.6f}) | {dt:.1f}s")

    ckpt = {
        "epoch": epoch,
        "state_dict": ae.state_dict(),
        "optim": optim.state_dict(),
        "val_loss": val_loss,
        "meta": {
            "hidden_dim": cfg.hidden_dim,
            "T_low": T_low, "T_ecg": T_ecg, "latent_downsample": 4,
            "note": "Option A: train pre-standardized on disk; val standardized on-the-fly.",
        },
    }
    torch.save(ckpt, out_dir / "ae_last.pth")
    if val_loss < best_loss:
        best_loss = val_loss
        torch.save(ckpt, out_dir / "ae_best.pth")
        print("  ↳ saved ae_best.pth")

print("Done. Best val recon:", best_loss, "\nSaved to:", out_dir)



[001/25] train=0.730792 | val=0.588896 (low=0.177710, ecg=1.000082) | 2.0s
  ↳ saved ae_best.pth
[002/25] train=0.552198 | val=0.547312 (low=0.094618, ecg=1.000007) | 1.7s
  ↳ saved ae_best.pth
[003/25] train=0.546111 | val=0.535876 (low=0.071755, ecg=0.999998) | 1.7s
  ↳ saved ae_best.pth
[004/25] train=0.524946 | val=0.522702 (low=0.045403, ecg=1.000000) | 1.7s
  ↳ saved ae_best.pth
[005/25] train=0.519193 | val=0.516938 (low=0.033881, ecg=0.999995) | 1.7s
  ↳ saved ae_best.pth
[006/25] train=0.518167 | val=0.514371 (low=0.028746, ecg=0.999996) | 1.7s
  ↳ saved ae_best.pth
[007/25] train=0.524018 | val=0.531066 (low=0.062135, ecg=0.999996) | 1.7s
[008/25] train=0.560235 | val=0.540460 (low=0.080914, ecg=1.000006) | 1.8s
[009/25] train=0.566189 | val=0.614011 (low=0.228021, ecg=1.000002) | 1.7s
[010/25] train=0.575208 | val=0.541975 (low=0.083958, ecg=0.999992) | 1.7s
[011/25] train=0.534368 | val=0.513979 (low=0.027961, ecg=0.999998) | 1.7s
  ↳ saved ae_best.pth
[012/25] train=0.5742

In [11]:
sd = torch.load(out_dir / "ae_best.pth", map_location="cpu")["state_dict"]
top_prefixes = sorted({k.split('.', 1)[0] for k in sd})
print("AE keys top-level prefixes:", top_prefixes)  # expect {'E','R'}

ae.eval()
with torch.no_grad():
    batch = next(iter(val_dl))
    low_hat, ecg_hat = ae(batch["signal_low"].to(device), batch["signal_ecg"].to(device))
print(low_hat.shape, ecg_hat.shape)  # expect (B, 120, 2), (B, 5250, 1)

AE keys top-level prefixes: ['E', 'R']
torch.Size([32, 120, 2]) torch.Size([32, 5250, 1])
