In [11]:
# 1) Put your repo's src on sys.path
import sys
from types import SimpleNamespace
from pathlib import Path

SRC = Path(r"C:\Users\Joseph\generative-health-models\src")
if str(SRC) not in sys.path:
    sys.path.insert(0, str(SRC))

print(SRC)

# 2) (Optional) reload the module if you just edited tc_multigan.py
import importlib
import models.tc_multigan as tcm
importlib.reload(tcm)

# 3) Import the parts you need
from models.tc_multigan import Embedder, Supervisor
from datasets.wesad import make_loader
import torch

C:\Users\Joseph\generative-health-models\src


In [12]:
# --- Config: mirror AE pretrain ---
cfg = SimpleNamespace(
    # data
    data_root        = r"C:\Users\Joseph\generative-health-models\data\processed",
    fold             = "tc_multigan_fold_s10",
    train_split      = "train",
    val_split        = "test",          # adjust to your dataset's naming if needed (e.g., "valid" or "dev")
    seq_length_low   = 120,
    condition_dim    = 4,

    # loader
    batch_size       = 32,
    num_workers      = 0,
    weighted_sampling= False,

    # model-ish
    hidden_dim       = 256,

    # device
    device           = "cuda" if torch.cuda.is_available() else "cpu",
)

fold_dir = Path(cfg.data_root) / cfg.fold
stats_low_path = fold_dir / "norm_low.npz"
stats_ecg_path = fold_dir / "norm_ecg.npz"

print("Using stats:", stats_low_path, stats_ecg_path)
assert stats_low_path.exists() and stats_ecg_path.exists(), \
    "Normalization stats not found. Make sure they were saved during AE pretrain."

Using stats: C:\Users\Joseph\generative-health-models\data\processed\tc_multigan_fold_s10\norm_low.npz C:\Users\Joseph\generative-health-models\data\processed\tc_multigan_fold_s10\norm_ecg.npz


In [13]:
def build_loader(split: str, shuffle: bool):
    return make_loader(
        root_dir          = Path(cfg.data_root),
        fold              = cfg.fold,
        split             = split,
        window_size_low   = cfg.seq_length_low,
        batch_size        = cfg.batch_size,
        shuffle           = shuffle,
        num_workers       = cfg.num_workers,
        weighted_sampling = cfg.weighted_sampling,
        condition_dim     = cfg.condition_dim,
        normalize         = True,
        normalize_ecg     = True,
        force_use_stats   = True,
        stats_low_path    = str(stats_low_path),
        stats_ecg_path    = str(stats_ecg_path),
        augment           = False,      # keep off for S pretrain
        # If your loader supports it, you can force ECG length explicitly:
        # window_size_ecg = cfg.seq_length_low * 44,  # e.g., 5280 for integer ratio
    )

In [14]:

train_dl = build_loader(cfg.train_split, shuffle=True)
# If your dataset doesn’t have a "val" split, either use "valid"/"dev" or make a small held-out sampler.
try:
    val_dl = build_loader(cfg.val_split, shuffle=False)
except Exception as e:
    print(f"Couldn't build val loader with split='{cfg.val_split}': {e}\n"
          "Falling back to train split without shuffling for quick checks.")
    val_dl = build_loader(cfg.train_split, shuffle=False)

# --- One-batch shape & ratio sanity check (important) ---
batch = next(iter(train_dl))
x_low  = batch["signal_low"]        # (B, T_low,  C=2)
x_ecg  = batch["signal_ecg"]        # (B, T_ecg, C=1)
cond   = batch["condition"]         # (B, T_low, K=condition_dim)

print("x_low:", tuple(x_low.shape), "| x_ecg:", tuple(x_ecg.shape), "| cond:", tuple(cond.shape))

T_low = x_low.shape[1]
T_ecg = x_ecg.shape[1]
ratio = T_ecg / T_low
print(f"ECG:low ratio = {T_ecg} / {T_low} = {ratio:.4f}")
if abs(ratio - round(ratio)) > 1e-6:
    print("⚠️  Non-integer ratio detected. For best stability, consider re-windowing to an integer multiple, "
          "e.g., 120→5280 (×44) or 120→5400 (×45).")

x_low: (32, 120, 2) | x_ecg: (32, 5250, 1) | cond: (32, 120, 4)
ECG:low ratio = 5250 / 120 = 43.7500
⚠️  Non-integer ratio detected. For best stability, consider re-windowing to an integer multiple, e.g., 120→5280 (×44) or 120→5400 (×45).


In [None]:
# 1) Infer lengths from your loader (keeps dims identical to AE pretrain)
_batch = next(iter(train_dl))
T_low = _batch["signal_low"].shape[1]       # 120
T_ecg = _batch["signal_ecg"].shape[1]       # 5250 (current dataset)
H      = cfg.hidden_dim                     # 256