
# Two‑Stream Class‑Conditional Diffusion — Training Notebook

This notebook mirrors the functionality of `train_diffusion.py` and integrates with your existing project:
- **Data**: uses `datasets.wesad.make_loader(...)`
- **Models**: `DiffusionLow` and `DiffusionECG` from `models/diffusion.py`
- **Evaluation**: calls your unchanged `evaluator.py` via an adapter so the interface matches the GAN.

> **Tip:** Run cell‑by‑cell the first time to confirm paths and environment. Then you can run the final **Training** cell in one go.



## 1) Environment & Imports

Set your repository root (so the notebook can import your modules). Leave as `.` if you're already running from the repo root.


In [1]:
import sys, os
from pathlib import Path

here = Path.cwd()
REPO_ROOT = None
for parent in (here, *here.parents):
    src_dir = parent / "src"
    if (src_dir / "utils" / "config.py").exists():
        sys.path.insert(0, str(src_dir))
        REPO_ROOT = parent.resolve()
        print(f"Using source path: {src_dir}")
        break
else:
    raise RuntimeError("Couldn't find src/config.py up the directory tree")

def as_abs(p: str | os.PathLike):
    p = Path(p)
    return p if p.is_absolute() else (REPO_ROOT / p).resolve()

Using source path: c:\Users\Joseph\generative-health-models\src


In [2]:

import numpy as np
import torch
import math
import torch.nn as nn
import torch
import time
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.optim.lr_scheduler as lrs
import re, json, shutil

# Project imports (must be importable from REPO_ROOT)
from utils.config import parse_args as parse_base_args
from datasets.wesad import make_loader
from evaluation.evaluator import eval_distribution_epoch, eval_label_probe
from models.tc_multigan import TwoStreamDiscriminator
from models.diffusion import DiffusionLow, DiffusionECG
from models.diffusion_adapter import TwoStreamDiffusionAdapter
from torch.optim.lr_scheduler import ReduceLROnPlateau
from collections import deque
from types import SimpleNamespace as NS

print('Imports OK. Torch:', torch.__version__)


Imports OK. Torch: 2.7.1+cu118



## 2) Configuration

We reuse your existing `config.parse_args()` and then apply notebook‑friendly overrides.  
Fill in `overrides` below to point to your fold and set hyperparameters.


In [3]:
def parse_args_notebook(overrides=None):
    # Strip Jupyter argv so argparse won’t see --f=...
    argv_backup = sys.argv
    sys.argv = [argv_backup[0]]
    try:
        cfg = parse_base_args()
    finally:
        sys.argv = argv_backup

    # Diffusion defaults
    def set_default(name, value):
        if not hasattr(cfg, name):
            setattr(cfg, name, value)
    set_default("diffusion_steps", 1000)
    set_default("beta_schedule", "cosine")
    set_default("sampling_steps", 50)
    set_default("sampling_method", "ddim")
    set_default("lr", 2e-4)
    set_default("ema", True)
    set_default("ema_decay", 0.999)
    set_default("eval_interval", 5)
    set_default("eval_n_batches", 8)
    set_default("probe_n_train_batches", 32)
    set_default("probe_n_val_batches", 16)
    set_default("epochs_gan", 50)
    set_default("hidden_dim", 128)
    set_default("sample_n", 16)

    # Core fallbacks
    core_defaults = {
        "data_root": "./data/processed",
        "fold": "default_fold",
        "ckpt_dir": "./results/checkpoints",
        "sample_dir": "./results/samples",
        "device": "cpu",
        "seq_length_low": 120,
        "seq_length_ecg": 5250,
        "condition_dim": 4,
        "batch_size": 8,
        "workers": 0,
        "seeds": 123,
    }
    for k, v in core_defaults.items():
        if not hasattr(cfg, k):
            setattr(cfg, k, v)

    # Compatibility aliases
    if not hasattr(cfg, "workers") and hasattr(cfg, "num_workers"):
        cfg.workers = cfg.num_workers
    if not hasattr(cfg, "num_workers") and hasattr(cfg, "workers"):
        cfg.num_workers = cfg.workers
    if not hasattr(cfg, "seeds") and hasattr(cfg, "seed"):
        cfg.seeds = cfg.seed
    if not hasattr(cfg, "seed") and hasattr(cfg, "seeds"):
        cfg.seed = cfg.seeds
    if not hasattr(cfg, "ema") and hasattr(cfg, "use_ema"):
        cfg.ema = cfg.use_ema
    if not hasattr(cfg, "use_ema") and hasattr(cfg, "ema"):
        cfg.use_ema = cfg.ema

    # Apply overrides
    overrides = overrides or {}
    for k, v in overrides.items():
        setattr(cfg, k, v)

    # Canonicalize to absolute paths relative to repo root
    cfg.data_root = str(as_abs(cfg.data_root))
    for attr in ("ckpt_dir", "sample_dir", "log_dir"):
        if hasattr(cfg, attr):
            setattr(cfg, attr, str(as_abs(getattr(cfg, attr))))

    return cfg

# ---- EDIT THESE OVERRIDES FOR YOUR ENV ----
overrides = {
    "data_root": r"./data/processed",                        # relative to repo root, OK
    "fold": "tc_multigan_fold_S10",
    "device": "cuda",
    "epochs_gan": 182   ,
    "batch_size": 8,
    "workers": 2,
    "seeds": 123,
    "diffusion_steps": 1000,
    "sampling_steps": 100,
    "sampling_method": "ddim",
    "beta_schedule": "cosine",
    "lr": 2e-4,
    "ema": True,
    "ema_decay": 0.999,
    "ckpt_dir": "./results/checkpoints",
    "resume": r"C:\Users\Joseph\generative-health-models\results\checkpoints\diffusion\ckpt_epoch_041.pt",
    "resume_mode": "resume", 
    "sample_dir": "./results/samples",
    "probe_n_train_batches": 32,
    "probe_n_val_batches": 16,
    "cond_drop_prob": 0.15, 
    "cfg_scale": 0.0, 
    "x0_clip_q": 0.995,
    "lambda_low": 1.0,
    "lambda_ecg": 0.5, 

}

cfg = parse_args_notebook(overrides)
device = torch.device("cuda" if (cfg.device == "cuda" and torch.cuda.is_available()) else "cpu")

use_amp = (device.type == "cuda")
if use_amp:
    torch.backends.cudnn.benchmark = True
scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

print('Device:', device)
print('Data root:', cfg.data_root)
print('Fold:', cfg.fold)
print('Epochs:', cfg.epochs_gan)
print('workers:', getattr(cfg, 'workers', None), '| num_workers:', getattr(cfg, 'num_workers', None))
print('seed(s):', getattr(cfg, 'seed', None), '|', getattr(cfg, 'seeds', None))
print('ema/use_ema:', getattr(cfg, 'ema', None), '|', getattr(cfg, 'use_ema', None))

Configuration:
  aug_jitter: 0.01
  aug_scale: 0.1
  batch_size: 16
  boundary_margin_low: 0.92
  ckpt_dir: results\checkpoints
  ckpt_interval: 5
  condition_dim: 4
  config_json: None
  d_steps: 1
  data_root: C:\Users\Joseph\generative-health-models\src\train\data\processed
  device: cuda
  ecg_boundary_margin: 0.95
  ecg_margin: 0.98
  ema_decay: 0.999
  epochs_ae: 20
  epochs_gan: 50
  fake_label: 0.2
  fm_warmup_epochs: 15
  fold: tc_multigan_fold_S10
  fs_ecg: 175
  fs_low: 4
  g_steps: 2
  gen_label: 0.8
  hidden_dim: 256
  inst_noise_std: 0.01
  inst_noise_warm_epochs: 20
  lambda_adv: 1.0
  lambda_boundary_ecg: 0.08
  lambda_boundary_low: 0.003
  lambda_boundary_low_eda: None
  lambda_boundary_low_resp: None
  lambda_fm: 10.0
  lambda_mismatch: 0.5
  lambda_mm: 0.0
  lambda_rec: 100.0
  lambda_spec_ecg: 1.0
  lambda_spec_low: 0.5
  lambda_spike: 0.4
  lambda_tc: 2.0
  lambda_tv_ecg: 0.0
  lambda_tv_eda: None
  lambda_tv_low: 0.0
  lambda_tv_resp: None
  log_dir: results\logs


  scaler = torch.cuda.amp.GradScaler(enabled=use_amp)



## 3) Helpers (EMA, plotting, loaders, models)


In [4]:

# -------------------------
# EMA helper (shadow models)
# -------------------------
import copy

def _copy_model_like(m: nn.Module) -> nn.Module:
    m_ema = copy.deepcopy(m)
    for p in m_ema.parameters():
        p.requires_grad_(False)
    return m_ema

@torch.no_grad()
def update_ema(target: nn.Module, source: nn.Module, decay: float):
    st = target.state_dict()
    ss = source.state_dict()
    for k in st.keys():
        if st[k].dtype.is_floating_point:
            st[k].mul_(decay).add_(ss[k], alpha=1.0 - decay)
        else:
            st[k] = ss[k]
    target.load_state_dict(st)


def _inner(m):
    """Return the actual diffusion object whether you pass a wrapper or the raw model."""
    return m.diff if hasattr(m, "diff") else m


# -------------------------
# Small plotting helpers
# -------------------------
def _plot_low_panel(x_low: np.ndarray, save_path: Path, title: str = ""):
    """x_low: (N, 120, 2) normalized"""
    N = min(8, x_low.shape[0])
    t = np.arange(x_low.shape[1])
    plt.figure(figsize=(10, 6))
    for i in range(N):
        plt.subplot(N, 1, i + 1)
        plt.plot(t, x_low[i, :, 0].astype(float), label="EDA")
        plt.plot(t, x_low[i, :, 1].astype(float), label="RESP")
        if i == 0:
            plt.legend(loc="upper right", fontsize=8)
        plt.xticks([])
        if i == N - 1:
            plt.xticks(ticks=[0, 30, 60, 90, 120], labels=[0, 7.5, 15, 22.5, 30])
    if title:
        plt.suptitle(title, fontsize=10)
    plt.tight_layout()
    plt.savefig(save_path, dpi=120)
    plt.close()

def _plot_ecg_strip(x_ecg: np.ndarray, save_path: Path, title: str = ""):
    """x_ecg: (N, 5250, 1) normalized"""
    N = min(4, x_ecg.shape[0])
    t = np.arange(x_ecg.shape[1]) / 175.0  # seconds
    plt.figure(figsize=(12, 6))
    for i in range(N):
        ax = plt.subplot(N, 1, i + 1)
        ax.plot(t, x_ecg[i, :, 0].astype(float))
        ax.set_xlim([0, t[-1]])
        ax.set_ylabel("ECG (z)")
        if i < N - 1:
            ax.set_xticklabels([])
        else:
            ax.set_xlabel("Time (s)")
    if title:
        plt.suptitle(title, fontsize=10)
    plt.tight_layout()
    plt.savefig(save_path, dpi=120)
    plt.close()


# -------------------------
# Training / Evaluation utils
# -------------------------
def maybe_make_dirs(cfg):
    ckpt_dir = Path(cfg.ckpt_dir) / "diffusion"
    sample_root = Path(cfg.sample_dir) / "diffusion"
    ckpt_dir.mkdir(parents=True, exist_ok=True)
    sample_root.mkdir(parents=True, exist_ok=True)
    return ckpt_dir, sample_root

def build_loaders(cfg):
    data_root = Path(cfg.data_root)            
    fold_dir  = data_root / cfg.fold            
    assert fold_dir.exists(), f"Fold directory not found: {fold_dir}"

    n_workers = int(getattr(cfg, "workers", getattr(cfg, "num_workers", 0)))

    # Common normalization/stat flags (use *train* stats for all splits)
    loader_norm_kwargs = dict(
        normalize=False,
        normalize_ecg=False,
        stats_low_path=str(fold_dir / "norm_low.npz"),
        stats_ecg_path=str(fold_dir / "norm_ecg.npz"),
        force_use_stats=False,
        use_split_stats_if_needed=False,
        expected_ecg_len=int(cfg.seq_length_ecg),
        debug_print=False,
    )

    # Respect your configured split names
    train_split = getattr(cfg, "train_split", "train")
    val_split   = getattr(cfg, "val_split",   "val")  # your config defaults val_split='test'; that's OK too

    # TRAIN loader
    train_loader = make_loader(
        root_dir=str(data_root),                 # <-- parent of the fold
        fold=cfg.fold,                           # <-- fold name
        split=train_split,                       # "train"
        window_size_low=int(cfg.seq_length_low), # 120
        batch_size=int(cfg.batch_size),
        shuffle=True,
        num_workers=n_workers,
        weighted_sampling=bool(getattr(cfg, "weighted_sampling", False)),
        condition_dim=int(cfg.condition_dim),
        **loader_norm_kwargs
    )

    # VAL loader
    val_loader = make_loader(
        root_dir=str(data_root),                 # <-- parent of the fold
        fold=cfg.fold,                           # <-- fold name
        split=val_split,                         # "val" or "test" depending on cfg
        window_size_low=int(cfg.seq_length_low),
        batch_size=int(cfg.batch_size),
        shuffle=False,
        num_workers=n_workers,
        weighted_sampling=False,
        condition_dim=int(cfg.condition_dim),
        **loader_norm_kwargs
    )

    # ---- Sanity check shapes and normalization ----
    batch = next(iter(train_loader))
    xlow, xecg, cond = batch["signal_low"], batch["signal_ecg"], batch["condition"]
    print("train batch shapes:",
          {"signal_low": tuple(xlow.shape), "signal_ecg": tuple(xecg.shape), "condition": tuple(cond.shape)})
    assert xlow.shape[1:] == (cfg.seq_length_low, 2),  f"signal_low shape mismatch: {tuple(xlow.shape)}"
    assert xecg.shape[1:] == (cfg.seq_length_ecg, 1),  f"signal_ecg shape mismatch: {tuple(xecg.shape)}"
    assert cond.shape[1] == cfg.seq_length_low and cond.shape[-1] == cfg.condition_dim, \
        f"condition shape mismatch: {tuple(cond.shape)}"

    x_concat = torch.cat([xlow.reshape(-1), xecg.reshape(-1)], dim=0).float()
    print(f"[sanity] batch mean={x_concat.mean().item():.3f}, std={x_concat.std().item():.3f} (expect ~0/1)")
    return train_loader, val_loader

def build_models(cfg, device):
    diff_low = DiffusionLow(
        condition_dim=cfg.condition_dim,
        base_channels=32,
        num_res_blocks=2,
        downs=(2, 2, 2),
        channel_mults=None,
        diffusion_steps=cfg.diffusion_steps,
        beta_schedule=cfg.beta_schedule,
        device=device,
        cond_drop_prob=float(getattr(cfg, "cond_drop_prob", 0.0)),
        x0_clip_q=float(getattr(cfg, "x0_clip_q", 0.0)),

    ).to(device)

    diff_ecg = DiffusionECG(
        condition_dim=cfg.condition_dim,
        base_channels=32,
        num_res_blocks=2,
        downs=(5, 5, 3),
        channel_mults=None,
        diffusion_steps=cfg.diffusion_steps,
        beta_schedule=cfg.beta_schedule,
        device=device,
        cond_drop_prob=float(getattr(cfg, "cond_drop_prob", 0.0)),
        x0_clip_q=float(getattr(cfg, "x0_clip_q", 0.0)),

    ).to(device)

    params = list(diff_low.parameters()) + list(diff_ecg.parameters())
    opt = optim.AdamW(params, lr=cfg.lr, betas=(0.9, 0.999), weight_decay=0.0)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=int(getattr(cfg,"epochs_gan",50)), eta_min=1e-6)
    ema_low = _copy_model_like(diff_low) if cfg.ema else None
    ema_ecg = _copy_model_like(diff_ecg) if cfg.ema else None

    return diff_low, diff_ecg, opt, ema_low, ema_ecg,scheduler

def get_fixed_conditions(val_loader, condition_dim: int, seq_length_low: int, sample_n: int, device: torch.device):
    try:
        batch = next(iter(val_loader))
        cond = batch["condition"]
        if cond.size(0) >= sample_n:
            cond = cond[:sample_n]
        else:
            reps = math.ceil(sample_n / cond.size(0))
            cond = cond.repeat(reps, 1, 1)[:sample_n]
        cond_low = cond.to(device)
        y = cond_low[:, 0, :].contiguous()
        return cond_low, y
    except Exception:
        classes = list(range(condition_dim))
        reps = math.ceil(sample_n / len(classes))
        ys = []
        for r in range(reps):
            for k in classes:
                vec = torch.zeros(condition_dim)
                vec[k] = 1.0
                ys.append(vec)
        y = torch.stack(ys[:sample_n], dim=0).to(device)
        cond_low = y.unsqueeze(1).repeat(1, seq_length_low, 1).contiguous()
        return cond_low, y



## 4) Build Dirs & Load Data


In [5]:

# Seeds
torch.manual_seed(int(cfg.seeds) if hasattr(cfg, "seeds") else 0)
np.random.seed(int(cfg.seeds) if hasattr(cfg, "seeds") else 0)

ckpt_dir, sample_root = maybe_make_dirs(cfg)
train_loader, val_loader = build_loaders(cfg)

print('Checkpoint dir:', ckpt_dir)
print('Sample root:', sample_root)


train batch shapes: {'signal_low': (8, 120, 2), 'signal_ecg': (8, 5250, 1), 'condition': (8, 120, 4)}
[sanity] batch mean=-0.001, std=1.086 (expect ~0/1)
Checkpoint dir: C:\Users\Joseph\generative-health-models\results\checkpoints\diffusion
Sample root: C:\Users\Joseph\generative-health-models\results\samples\diffusion



## 5) Models, Optimizer, Fixed Conditions, Adapter & Feature Extractor


In [6]:

diff_low, diff_ecg, opt, ema_low, ema_ecg,_ = build_models(cfg, device)

# --- LR scheduler (created immediately after the optimizer) ---
scheduler = ReduceLROnPlateau(
    opt, mode="min", factor=0.5, patience=10, threshold=1e-3,
    threshold_mode="rel", min_lr=1e-6
)

cond_low_fixed, y_fixed = get_fixed_conditions(val_loader, cfg.condition_dim, cfg.seq_length_low, cfg.sample_n, device)

# Frozen feature extractor (TwoStreamDiscriminator)
D = TwoStreamDiscriminator(condition_dim=cfg.condition_dim, hidden_dim=getattr(cfg, "hidden_dim", 128)).to(device)
for p in D.parameters():
    p.requires_grad_(False)
D.eval()

def current_adapter():
    use_ema = bool(getattr(cfg, "ema", True))
    G_low = ema_low if (use_ema and ema_low is not None) else diff_low
    G_ecg = ema_ecg if (use_ema and ema_ecg is not None) else diff_ecg
    return TwoStreamDiffusionAdapter(G_low, G_ecg, cfg, device)

print('Models ready.')


Models ready.


In [7]:
print("torch.cuda.is_available():", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU count:", torch.cuda.device_count())
    print("GPU 0:", torch.cuda.get_device_name(0))
    # Verify model & tensor devices
    print("Selected device:", device)
    x = torch.randn(2, 3, device=device)
    print("Sample tensor on:", x.device)
    print("Low model device:", next(diff_low.parameters()).device)
    print("ECG model device:", next(diff_ecg.parameters()).device)

torch.cuda.is_available(): True
GPU count: 1
GPU 0: NVIDIA GeForce RTX 3080
Selected device: cuda
Sample tensor on: cuda:0
Low model device: cuda:0
ECG model device: cuda:0



## 6) Training, Sampling, NLL, Evaluation


In [8]:

@torch.no_grad()

def sample_and_save(
    epoch,
    model_low, model_ecg,
    cfg, device, sample_root,
    cond_low_fixed, y_fixed,
    *,
    # plotting/sampler knobs
    plot_in_z=True,                 # if False and stats exist, de-normalize to raw units for plots
    fixed_z_ylim=None,              # e.g. 3.0 to force ±3; None = adaptive limits
    use_full_ddpm=False,            # True → ddpm + full steps for debugging
    denorm_for_plots=False          # kept for backward compatibility (alias of plot_in_z=False)
):
    """
    Writes:
      • fake_low_epoch_XXX.npy   (N, 120, 2)
      • fake_ecg_epoch_XXX.npy   (N, 5250, 1)
      • low_panel.png            (N<=8 rows)
      • ecg_strip.png            (N<=4 rows)
    """
    # --- pick sampler ---
    if use_full_ddpm:
        method = "ddpm"
        steps  = int(getattr(cfg, "diffusion_steps", 1000))
    else:
        method = getattr(cfg, "sampling_method", "ddim")
        steps  = int(getattr(cfg, "sampling_steps", 50))

    # --- output dir (avoid double "diffusion") ---
    sample_root = Path(sample_root)
    if sample_root.name.lower() == "diffusion":
        out_dir = sample_root / f"epoch_{epoch:03d}"
    else:
        out_dir = sample_root / "diffusion" / f"epoch_{epoch:03d}"
    out_dir.mkdir(parents=True, exist_ok=True)

    fold_dir = Path(cfg.data_root) / cfg.fold  # e.g. data/processed/tc_multigan_fold_S10

        # paths first (so .exists() is valid)
    stats_low_path = (fold_dir / "norm_low.npz").resolve()
    stats_ecg_path = (fold_dir / "norm_ecg.npz").resolve()

    print(f"[stats] fold_dir:  {fold_dir.resolve()}")
    print(f"[stats] low  path: {stats_low_path}  exists={stats_low_path.exists()}")
    print(f"[stats] ecg  path: {stats_ecg_path}  exists={stats_ecg_path.exists()}")
        
    low_stats = np.load(fold_dir / "norm_low.npz")
    ecg_stats = np.load(fold_dir / "norm_ecg.npz")


    # numpy copies for plotting
    mu_low_np, sd_low_np = low_stats["mean"], low_stats["std"]       # shapes: (2,)
    mu_ecg_np, sd_ecg_np = ecg_stats["mean"], ecg_stats["std"]       # shapes: (1,)

    print(f"[stats] low  μ[:2]={mu_low_np[:2]}, σ[:2]={sd_low_np[:2]}")
    print(f"[stats] ecg  μ[0]={float(mu_ecg_np[0]):.6f}, σ[0]={float(sd_ecg_np[0]):.6f}")

    # torch tensors for fast invert before saving
    MU_LOW = torch.from_numpy(mu_low_np).float().view(1, 1, 2).to(device)
    SD_LOW = torch.from_numpy(sd_low_np).float().view(1, 1, 2).to(device)
    MU_ECG = torch.from_numpy(mu_ecg_np).float().view(1, 1, 1).to(device)
    SD_ECG = torch.from_numpy(sd_ecg_np).float().view(1, 1, 1).to(device)

    eps = 1e-12
    low_dev = cond_low_fixed.to(device=device, dtype=torch.float32)  # (N, T, C_low)
    ecg_dev = y_fixed.to(device=device, dtype=torch.float32)         # (N, T, 1)

    def _per_channel_stats(x: torch.Tensor):
        m = x.mean(dim=(0, 1))
        s = x.std(dim=(0, 1), unbiased=False)
        return m, s

    # Always print shapes so we know what we're dealing with
    print(f"[cond] shapes -> low={tuple(low_dev.shape)}, ecg={tuple(ecg_dev.shape)}")

    C_low = low_dev.shape[-1]

    # As-is stats for ALL low channels (no normalization)
    m_low_as, s_low_as = _per_channel_stats(low_dev)
    print(f"[cond:as_is]   low(all {C_low}) mean={m_low_as.tolist()}, std={s_low_as.tolist()}")

    # If we have at least the first two (EDA, RESP), check them vs MU/SD
    if C_low >= 2:
        low_first2 = low_dev[..., :2]
        m_low2_as, s_low2_as = _per_channel_stats(low_first2)
        low2_z = (low_first2 - MU_LOW) / SD_LOW.clamp_min(eps)
        m_low2_z, s_low2_z = _per_channel_stats(low2_z)
        print(f"[cond:as_is]   low(first2) mean={m_low2_as.tolist()}, std={s_low2_as.tolist()}")
        print(f"[cond:→zspace] low(first2) mean={m_low2_z.tolist()}, std={s_low2_z.tolist()}")

    # If there are extra channels beyond the first two, just report them
    if C_low > 2:
        low_extra = low_dev[..., 2:]
        m_extra, s_extra = _per_channel_stats(low_extra)
        print(f"[cond:extra]   low(extra {C_low-2}) mean={m_extra.tolist()}, std={s_extra.tolist()}")

    # ECG: single channel — safe to compare against MU/SD
    m_ecg_as, s_ecg_as = _per_channel_stats(ecg_dev)
    ecg_z  = (ecg_dev - MU_ECG) / SD_ECG.clamp_min(eps)
    m_ecg_z, s_ecg_z = _per_channel_stats(ecg_z)
    print(f"[cond:as_is]   ecg  mean={m_ecg_as.tolist()}, std={s_ecg_as.tolist()}")
    print(f"[cond:→zspace] ecg  mean={m_ecg_z.tolist()}, std={s_ecg_z.tolist()}")


    # --- inference mode, eval(), no autocast for sampling ---
    was_train_low  = model_low.training
    was_train_ecg  = model_ecg.training
    model_low.eval(); model_ecg.eval()

    with torch.inference_mode():
        # IMPORTANT: no autocast during sampling
        # if you had a global autocast context, explicitly disable:
        # with torch.cuda.amp.autocast(enabled=False):
        fake_low = model_low.sample(
            y_or_seq=cond_low_fixed.to(device),
            num_steps=steps, method=method, cfg_scale=float(getattr(cfg, "cfg_scale", 0.0))
        )
        fake_ecg = model_ecg.sample(
            y_or_seq=y_fixed.to(device),
            num_steps=steps, method=method, cfg_scale=float(getattr(cfg, "cfg_scale", 0.0))
        )

    fake_low_real = fake_low * SD_LOW + MU_LOW      # (N, 120, 2)  [EDA, RESP]
    fake_ecg_real = fake_ecg * SD_ECG + MU_ECG      # (N, 5250, 1) [ECG]

    np.save(out_dir / f"fake_low_epoch_{epoch:03d}.npy", fake_low.detach().cpu().numpy())
    np.save(out_dir / f"fake_ecg_epoch_{epoch:03d}.npy", fake_ecg.detach().cpu().numpy())
    np.save(out_dir / f"fake_low_real_epoch_{epoch:03d}.npy", fake_low_real.detach().cpu().numpy())
    np.save(out_dir / f"fake_ecg_real_epoch_{epoch:03d}.npy", fake_ecg_real.detach().cpu().numpy())

    # ---- print sanity stats in z-space (before any de-norm) ----
    def _stats(x):
        x = x.detach()
        finite = torch.isfinite(x)
        frac_finite = finite.float().mean().item()
        x = x[finite] if frac_finite > 0 else x
        return {
            "finite": frac_finite,
            "mean": x.mean().item() if x.numel() else float("nan"),
            "std":  x.std().item()  if x.numel() else float("nan"),
            "min":  x.min().item()  if x.numel() else float("nan"),
            "max":  x.max().item()  if x.numel() else float("nan"),
            "frac>|3|": (x.abs() > 3).float().mean().item() if x.numel() else float("nan"),
        }

    s_low = _stats(fake_low)
    s_ecg = _stats(fake_ecg)
    print(f"[samples:z] low  stats: {s_low}")
    print(f"[samples:z] ecg  stats: {s_ecg}")

    # ---- prepare arrays for plotting ----
    low_np = fake_low.detach().cpu().numpy()        # (N, 120, 2)
    ecg_np = fake_ecg.detach().cpu().numpy()        # (N, 5250, 1)

    # Optional: if you trained on z/clip in [-1,1], multiply back before plotting in z
    z_clip = float(getattr(cfg, "z_clip", 3.0))
    if plot_in_z and getattr(cfg, "scale_to_unit", False):
        low_np = low_np * z_clip
        ecg_np = ecg_np * z_clip

    # De-normalize for plots if requested
    if (not plot_in_z) or denorm_for_plots:
        # use the *_np arrays loaded above; no reloading
        low_np = low_np * sd_low_np[None, None, :] + mu_low_np[None, None, :]
        ecg_np = ecg_np * sd_ecg_np[None, None, :] + mu_ecg_np[None, None, :]

    # Time axes (defaults 4 Hz low, 175 Hz ECG)
    fs_low = float(getattr(cfg, "fs_low", 4))
    fs_ecg = float(getattr(cfg, "fs_ecg", 175))
    t_low = np.arange(low_np.shape[1]) / fs_low
    t_ecg = np.arange(ecg_np.shape[1]) / fs_ecg

    # ---- Plot LOW panel ----
    n_show_low = int(min(getattr(cfg, "sample_n", 8), low_np.shape[0], 8))
    fig, axes = plt.subplots(n_show_low, 1, figsize=(12, 1.8 * n_show_low), sharex=True)
    if n_show_low == 1: axes = [axes]
    for i in range(n_show_low):
        ax = axes[i]
        ax.plot(t_low, low_np[i, :, 0], linewidth=1.0, label="EDA")
        ax.plot(t_low, low_np[i, :, 1], linewidth=1.0, label="RESP")
        ax.set_ylabel("z" if plot_in_z and not denorm_for_plots else "raw")
        if plot_in_z and (fixed_z_ylim is not None):
            ax.set_ylim(-fixed_z_ylim, fixed_z_ylim)
        elif plot_in_z and (fixed_z_ylim is None):
            # adaptive limits: robust percentiles
            y = low_np[i].reshape(-1)
            lo, hi = np.percentile(y, [0.5, 99.5])
            pad = 0.05 * (hi - lo + 1e-6)
            ax.set_ylim(lo - pad, hi + pad)
        if i == 0:
            ax.legend(loc="upper right", fontsize=9)
    axes[-1].set_xlabel("Time (s)")
    fig.suptitle(f"Low stream samples @ epoch {epoch}")
    fig.tight_layout()
    fig.savefig(out_dir / "low_panel.png", dpi=140)
    plt.close(fig)

    # ---- Plot ECG strip ----
    n_show_ecg = int(min(4, ecg_np.shape[0]))
    fig, axes = plt.subplots(n_show_ecg, 1, figsize=(12, 1.4 * n_show_ecg), sharex=True)
    if n_show_ecg == 1: axes = [axes]
    for i in range(n_show_ecg):
        ax = axes[i]
        ax.plot(t_ecg, ecg_np[i, :, 0], linewidth=0.8)
        ax.set_ylabel("ECG (z)" if plot_in_z and not denorm_for_plots else "ECG")
        if plot_in_z and (fixed_z_ylim is not None):
            ax.set_ylim(-fixed_z_ylim, fixed_z_ylim)
        elif plot_in_z and (fixed_z_ylim is None):
            y = ecg_np[i, :, 0]
            lo, hi = np.percentile(y, [0.5, 99.5])
            pad = 0.05 * (hi - lo + 1e-6)
            ax.set_ylim(lo - pad, hi + pad)
    axes[-1].set_xlabel("Time (s)")
    fig.suptitle(f"ECG samples @ epoch {epoch}")
    fig.tight_layout()
    fig.savefig(out_dir / "ecg_strip.png", dpi=140)
    plt.close(fig)

    if was_train_low: model_low.train()
    if was_train_ecg: model_ecg.train()

    print(f"[samples] saved to {out_dir}")
    return out_dir

def save_checkpoint(ckpt_path: Path, epoch: int, diff_low, diff_ecg, opt, ema_low=None, ema_ecg=None):
    payload = {
        "epoch": epoch,
        "diff_low": diff_low.state_dict(),
        "diff_ecg": diff_ecg.state_dict(),
        "optimizer": opt.state_dict(),
        "ema_low": None if ema_low is None else ema_low.state_dict(),
        "ema_ecg": None if ema_ecg is None else ema_ecg.state_dict(),
    }
    torch.save(payload, ckpt_path)
    print(f"[ckpt] saved to {ckpt_path}")

def train_one_epoch(epoch: int,
                    diff_low, diff_ecg, opt,
                    train_loader, device,
                    *,
                    lam_low: float = 1.0, lam_ecg: float = 1.0,
                    ema_low=None, ema_ecg=None, ema_decay: float = 0.999,
                    log_every: int = 10,                 # <— print every N batches
                    max_batches: int | None = None,      # optional early stop for quick tests
                    scaler=None, use_amp: bool = False):
    diff_low.train()
    diff_ecg.train()

    n_batches = len(train_loader)
    params = list(diff_low.parameters()) + list(diff_ecg.parameters())
    total_loss = 0.0

    t_epoch0 = time.perf_counter()
    t_last_log = t_epoch0

    for b_idx, batch in enumerate(train_loader, start=1):
        x_low = batch["signal_low"].to(device, non_blocking=True).float()
        x_ecg = batch["signal_ecg"].to(device, non_blocking=True).float()
        cond_seq = batch["condition"].to(device, non_blocking=True).float()
        y = cond_seq[:, 0, :]

        # forward
        lam_low = float(getattr(cfg, "lambda_low", 1.0))
        lam_ecg = float(getattr(cfg, "lambda_ecg", 1.0))
        if use_amp:
            with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=use_amp):
                loss_low = diff_low.loss(x_low, cond_seq)
                loss_ecg = diff_ecg.loss(x_ecg, y)
                loss = lam_low * loss_low + lam_ecg * loss_ecg

        else:
            loss_low = diff_low.loss(x_low, cond_seq)
            loss_ecg = diff_ecg.loss(x_ecg, y)
            loss = lam_low * loss_low + lam_ecg * loss_ecg

        # backward + step
        opt.zero_grad(set_to_none=True)
        if use_amp and scaler is not None:
            scaler.scale(loss).backward()
            scaler.unscale_(opt)
            torch.nn.utils.clip_grad_norm_(params, max_norm=1.0)
            scaler.step(opt)
            scaler.update()
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(params, max_norm=1.0)
            opt.step()

        # EMA (optional)
        if ema_low is not None and ema_ecg is not None:
            with torch.no_grad():
                for p_ema, p_src in zip(ema_low.parameters(), diff_low.parameters()):
                    if p_ema.data.dtype.is_floating_point:
                        p_ema.data.mul_(ema_decay).add_(p_src.data, alpha=1.0 - ema_decay)
                for p_ema, p_src in zip(ema_ecg.parameters(), diff_ecg.parameters()):
                    if p_ema.data.dtype.is_floating_point:
                        p_ema.data.mul_(ema_decay).add_(p_src.data, alpha=1.0 - ema_decay)

        # logging
        total_loss += float(loss.item())
        do_log = (b_idx == 1) or (log_every and b_idx % log_every == 0) or (b_idx == n_batches)
        if do_log:
            now = time.perf_counter()
            # average seconds per batch over the last logging window
            window_batches = 1 if b_idx == 1 else log_every
            dt_per_batch = (now - t_last_log) / max(1, window_batches)
            t_last_log = now

            avg_loss = total_loss / b_idx
            bs = x_low.size(0)
            sps = bs / dt_per_batch if dt_per_batch > 0 else float("nan")
            remaining = n_batches - b_idx
            eta_min = (remaining * dt_per_batch) / 60.0

            mem_str = ""
            if device.type == "cuda":
                cur = torch.cuda.memory_allocated() / (1024**3)
                peak = torch.cuda.max_memory_allocated() / (1024**3)
                mem_str = f" | mem {cur:.2f}/{peak:.2f} GB"

            print(
                f"[ep {epoch:03d}] {b_idx:4d}/{n_batches} "
                f"loss={loss.item():.4f} (low={loss_low.item():.4f}, ecg={loss_ecg.item():.4f}) "
                f"| avg={avg_loss:.4f} | {sps:.1f} samp/s | ETA {eta_min:.1f}m{mem_str}",
                flush=True
            )

        if max_batches is not None and b_idx >= max_batches:
            print(f"[ep {epoch:03d}] early stop at {b_idx} batches (debug)", flush=True)
            break

    avg_loss = total_loss / max(1, min(n_batches, b_idx))
    print(f"[epoch {epoch:03d}] train loss: {avg_loss:.4f} ({time.perf_counter() - t_epoch0:.1f}s)", flush=True)
    return avg_loss


@torch.no_grad()
def maybe_compute_nll(epoch: int, diff_low, diff_ecg, val_loader, cfg, device):
    n_batches = min(getattr(cfg, "eval_n_batches", 8), len(val_loader))
    if n_batches <= 0:
        return
    tot_low, tot_ecg = 0.0, 0.0
    count = 0
    for i, batch in enumerate(val_loader):
        if i >= n_batches:
            break
        x_low = batch["signal_low"].to(device, non_blocking=True).float()
        x_ecg = batch["signal_ecg"].to(device,non_blocking=True).float()
        cond_seq = batch["condition"].to(device, non_blocking=True).float()
        y = cond_seq[:, 0, :]

        low_stats = diff_low.nll_bound(x_low, cond_seq, num_steps_eval=min(128, getattr(cfg, "diffusion_steps", 1000)))
        ecg_stats = diff_ecg.nll_bound(x_ecg, y, num_steps_eval=min(128, getattr(cfg, "diffusion_steps", 1000)))
        tot_low += low_stats["bits_per_dim"]
        tot_ecg += ecg_stats["bits_per_dim"]
        count += 1
    if count > 0:
        print(f"[epoch {epoch:03d}] approx bits/dim → low: {tot_low / count:.4f}, ecg: {tot_ecg / count:.4f}")

def maybe_eval(epoch: int, adapter, D_feature, val_loader, cfg, device):
    print(f"[eval] Running evaluator at epoch {epoch} ...")

    # Metrics don't need grads
    with torch.no_grad():
        eval_distribution_epoch(
            G=adapter, val_loader=val_loader, cfg=cfg, device=device,
            epoch=epoch, out_root=cfg.sample_dir
        )

    # The probe DOES need grads
    with torch.enable_grad():
        eval_label_probe(
            G=adapter,
            D=D_feature,
            train_loader=val_loader,  # or your train_loader if you prefer
            val_loader=val_loader,
            cfg=cfg, device=device, epoch=epoch, out_root=cfg.sample_dir,
            n_train_batches=getattr(cfg, "probe_n_train_batches", 32),
            n_val_batches=getattr(cfg, "probe_n_val_batches", 16),
        )

    print(f"[eval] Done epoch {epoch}.")

In [9]:
@torch.no_grad()
def validate_one_epoch(
    diff_low, diff_ecg, val_loader, device, *, lam_low: float = 1.0, lam_ecg: float = 1.0
):
    """Lightweight val pass to drive LR scheduling (no AMP, no EMA update)."""
    diff_low.eval(); diff_ecg.eval()
    tot_low, tot_ecg, n = 0.0, 0.0, 0
    for batch in val_loader:
        x_low = batch["signal_low"].to(device, non_blocking=True).float()
        x_ecg = batch["signal_ecg"].to(device, non_blocking=True).float()
        cond_seq = batch["condition"].to(device, non_blocking=True).float()
        y = cond_seq[:, 0, :]
        l_low = diff_low.loss(x_low, cond_seq)
        l_ecg = diff_ecg.loss(x_ecg, y)
        tot_low += float(l_low.item()); tot_ecg += float(l_ecg.item()); n += 1
    n = max(1, n)
    avg_low, avg_ecg = tot_low / n, tot_ecg / n
    total = lam_low * avg_low + lam_ecg * avg_ecg
    print(f"[val] loss: {total:.4f} (low={avg_low:.4f}, ecg={avg_ecg:.4f})")
    return total, avg_low, avg_ecg

def step_plateau_and_log(scheduler, metric, optimizer):
    before = [pg['lr'] for pg in optimizer.param_groups]
    scheduler.step(metric)  # call this after you compute your val metric
    after = [pg['lr'] for pg in optimizer.param_groups]
    if any(b > a for b, a in zip(before, after)):
        print(f"[LR] reduced: {before} -> {after} (monitored={metric:.6f})")

In [10]:
def _log_first_batch(loader, name):
    batch = next(iter(loader))
    def per_channel_stats(x):
        # x: [B, T, C]
        mu = x.mean(dim=(0,1)).detach().cpu().numpy()      # shape (C,)
        sd = x.std(dim=(0,1)).detach().cpu().numpy()       # shape (C,)
        frac = (x.abs() > 3).float().mean(dim=(0,1)).detach().cpu().numpy()  # per-channel
        return mu, sd, frac

    mu_l, sd_l, f3_l = per_channel_stats(batch["signal_low"])
    mu_e, sd_e, f3_e = per_channel_stats(batch["signal_ecg"])

    print(f"[{name}:first-batch] low  μ={mu_l} σ={sd_l} frac>|3|={f3_l}")
    print(f"[{name}:first-batch] ecg  μ={mu_e} σ={sd_e} frac>|3|={f3_e}")

# call this once after you build loaders
_log_first_batch(train_loader, "train")
if val_loader is not None:
    _log_first_batch(val_loader, "val")

[train:first-batch] low  μ=[-0.0299504  -0.00910311] σ=[0.24292353 1.0678109 ] frac>|3|=[0.         0.02291667]
[train:first-batch] ecg  μ=[-0.00093535] σ=[1.008093] frac>|3|=[0.03138095]
[val:first-batch] low  μ=[0.03340743 0.00134355] σ=[0.4482208 0.8882488] frac>|3|=[0.00833333 0.00833333]
[val:first-batch] ecg  μ=[8.561466e-05] σ=[0.5488843] frac>|3|=[0.]


In [11]:
def save_milestone(ckpt_path: Path,
                   milestones_dir: Path,
                   *,
                   epoch: int,
                   val_total: float, val_low: float, val_ecg: float,
                   cfg,
                   m_low, m_ecg):
    milestones_dir.mkdir(parents=True, exist_ok=True)

    # Read sampler/schedule state off the models and cfg
    ab_low   = float(getattr(m_low, "ddim_alpha_bar_start", 0.0)) if m_low is not None else 0.0
    ab_ecg   = float(getattr(m_ecg, "ddim_alpha_bar_start", 0.0)) if m_ecg is not None else 0.0
    cfg_used = float(getattr(cfg, "cfg_scale", 0.0))

    tag = f"e{epoch:03d}_val{val_total:.4f}_abL{ab_low:.0e}_abE{ab_ecg:.0e}_cfg{cfg_used:g}"

    # Copy the full checkpoint
    dst_ckpt = milestones_dir / f"milestone_{tag}.pt"
    shutil.copy2(ckpt_path, dst_ckpt)

    # Copy normalization files so the milestone is self-contained
    fold_dir = Path(cfg.data_root) / cfg.fold
    for name in ["norm_low.npz", "norm_ecg.npz"]:
        src = fold_dir / name
        if src.exists():
            shutil.copy2(src, milestones_dir / name)

    # Small JSON manifest
    manifest = {
        "epoch": epoch,
        "fold": cfg.fold,
        "data_root": str(cfg.data_root),
        "val_total": float(val_total),
        "val_low": float(val_low),
        "val_ecg": float(val_ecg),
        "sampling_method": getattr(cfg, "sampling_method", "ddim"),
        "sampling_steps": int(getattr(cfg, "sampling_steps", 50)),
        "ddim_alpha_bar_start_low": ab_low,
        "ddim_alpha_bar_start_ecg": ab_ecg,
        "debug_clip_x0_low": bool(getattr(m_low, "debug_clip_x0", False)) if m_low is not None else False,
        "debug_clip_x0_ecg": bool(getattr(m_ecg, "debug_clip_x0", False)) if m_ecg is not None else False,
        "cfg_scale": cfg_used,
    }
    with open(milestones_dir / f"milestone_{tag}.json", "w") as f:
        json.dump(manifest, f, indent=2)

    print(f"[milestone] saved {dst_ckpt.name} (+norms)")

In [12]:
CLEAN_STREAK_TO_RELAX = 2    # epochs
CLAMP_OFF_BAND       = 3e-3  # max allowed clip frac in clean streak

def _hist(state, name, maxlen=16):
    dq = getattr(state, name, None)
    if dq is None:
        dq = deque(maxlen=maxlen)
        setattr(state, name, dq)
    return dq

def _push_hist(dq, val):
    dq.append(float(val))

def _last_n_max(dq, n):
    if len(dq) < n:
        return float("inf")
    return max(list(dq)[-n:])

def _last_n_any(dq, n):
    if len(dq) < n:
        return True  # treat as unsafe until we have enough history
    return any(bool(v) for v in list(dq)[-n:])

def can_relax_clamp(state):
    # Don’t relax in the same epoch (or earlier) as a recovery/cooldown
    if state.epoch <= getattr(state, "recovery_cooldown_until", -1):
        return False

    low_clean = _last_n_max(_hist(state, "hist_clip_frac_low"), CLEAN_STREAK_TO_RELAX) < CLAMP_OFF_BAND
    ecg_clean = _last_n_max(_hist(state, "hist_clip_frac_ecg"), CLEAN_STREAK_TO_RELAX) < CLAMP_OFF_BAND
    low_std   = _last_n_max(_hist(state, "hist_std_low"),       CLEAN_STREAK_TO_RELAX) < 3.0
    ecg_std   = _last_n_max(_hist(state, "hist_std_ecg"),       CLEAN_STREAK_TO_RELAX) < 3.0
    no_nan    = not _last_n_any(_hist(state, "hist_had_nan"),   CLEAN_STREAK_TO_RELAX)
    return low_clean and ecg_clean and low_std and ecg_std and no_nan

In [13]:

def _set_attr_for_sampling(obj, name, value):
    setattr(obj, name, value)
    if hasattr(obj, "diff"):  # your DiffusionLow/ECG wrappers hold the real sampler at .diff
        setattr(obj.diff, name, value)


def _show(m, tag, cfg=None):
    if m is None:
        print(f"[sched:{tag}] <none>")
        return
    obj = m.diff if hasattr(m, "diff") else m
    print(
        f"[sched:{tag}] thr={getattr(obj,'ddim_alpha_bar_start',None)}  "
        f"clip_on={getattr(obj,'debug_clip_x0',None)}  "
        f"x0_clip_q={getattr(obj,'x0_clip_q',None)}  "
        f"x0_clip_value={getattr(obj,'x0_clip_value',None)}"
        + (f"  cfg_scale={getattr(cfg,'cfg_scale',None)}" if cfg is not None else "")
    )


def _last_frac(m):
    obj = m.diff if hasattr(m, "diff") else m
    v = getattr(obj, "_last_frac_clipped", None)
    return None if v is None else float(v)


def _last_std_bad(m):
    obj = m.diff if hasattr(m, "diff") else m
    return getattr(obj, "_last_return_std_bad", None)

def _quiet_from_hist(inner):
    """
    Returns True if the last two *sampling-time* clip fractions are both <1%.
    Supports either [f1, f2, ...] or [(epoch, f1), (epoch, f2), ...].
    """
    hist = getattr(inner, "_clip_hist", [])
    if not hist:
        return False
    # normalize to plain floats
    vals = []
    for h in hist[-2:]:
        if isinstance(h, (list, tuple)) and len(h) >= 2:
            vals.append(float(h[1]))
        else:
            vals.append(float(h))
    return len(vals) == 2 and max(vals) < 0.01

def apply_sampling_schedule(epoch, cfg, m_low, m_ecg, state):
    """
    Runs every epoch *before* sampling.
    """
    models = [m for m in (m_low, m_ecg) if m is not None]

    # ---- update rolling history from last epoch’s sampler stats ----
    _push_hist(_hist(state, "hist_clip_frac_low"), getattr(state, "last_fracclamp_low", 1.0))
    _push_hist(_hist(state, "hist_clip_frac_ecg"), getattr(state, "last_fracclamp_ecg", 1.0))
    _push_hist(_hist(state, "hist_std_low"),       getattr(state, "last_std_low", 10.0))
    _push_hist(_hist(state, "hist_std_ecg"),       getattr(state, "last_std_ecg", 10.0))
    _push_hist(_hist(state, "hist_had_nan"), 1.0 if getattr(state, "had_nan_or_inf", False) else 0.0)

    allow_relax = can_relax_clamp(state)

    # ---------- RECOVERY COOLDOWN ----------
    if getattr(state, "recovery_cooldown_until", -1) > epoch:
        for m in models:
            _set_attr_for_sampling(m, "ddim_alpha_bar_start",
                                max(float(getattr(m, "ddim_alpha_bar_start", 0.0)), 5e-4))
            inn = _inner(m)
            setattr(inn, "debug_clip_x0", True)
            setattr(inn, "x0_clip_value", 3.0)
            # (optional) mirror to wrapper for consistent logs
            if hasattr(m, "diff"):
                m.debug_clip_x0 = True
                m.x0_clip_value = 3.0

        cfg.cfg_scale = 0.0
        return {"recovery_fired": True, "cooldown": True}

    # ---------- HARD RECOVERY GUARD ----------
    bad_std  = (getattr(state, "last_std_low", 0.0) >= 3.0) or (getattr(state, "last_std_ecg", 0.0) >= 3.0)
    bad_clip = (getattr(state, "last_fracclamp_low", 0.0) > 0.02) or (getattr(state, "last_fracclamp_ecg", 0.0) > 0.02)
    had_nan  = bool(getattr(state, "had_nan_or_inf", False))
    blew_up  = bool(getattr(state, "val_loss_spike", False))

    if bad_std or bad_clip or had_nan or blew_up:
        state.recovery_cooldown_until = epoch + 1
        for m in models:
            _set_attr_for_sampling(m, "ddim_alpha_bar_start",
                                max(float(getattr(m, "ddim_alpha_bar_start", 0.0)), 5e-4))
            inn = _inner(m)
            setattr(inn, "debug_clip_x0", True)
            setattr(inn, "x0_clip_value", 3.0)
            if hasattr(m, "diff"):
                m.debug_clip_x0 = True
                m.x0_clip_value = 3.0

        cfg.cfg_scale = 0.0
        
        return {
            "recovery_fired": True,
            "clip_on": True,
            "x0_clip_value": 3.0,
            "cfg_scale": 0.0,
            "ddim_alpha_bar_start": 5e-4,  # optional, for callers that read the dict
        }

    cfg.cfg_scale = 0.0
    # ------------------------------
    # A) Relax ᾱ threshold gradually
    # ------------------------------
    if epoch <= 3:
        thr = 1e-3
    elif epoch <= 5:
        thr = 5e-4
    elif epoch <= 8:
        thr = 1e-4
    else:
        thr = 0.0

    # **Re-apply the floor here**
    thr = max(thr, 5e-4)

    for m in models:
        prev = float(getattr(m, "ddim_alpha_bar_start", float("nan")))
        _set_attr_for_sampling(m, "ddim_alpha_bar_start", float(thr))
        if prev != thr:
            print(f"[schedule] {m.__class__.__name__}: ddim_alpha_bar_start {prev} → {thr}")

    # --------------------------------------------
    # B) Clamp stays ON until quiet for 2 epochs
    # --------------------------------------------
    for m in models:
        inner = _inner(m)

        # ensure flags exist
        if not hasattr(inner, "debug_clip_x0"):
            inner.debug_clip_x0 = True
        if not hasattr(inner, "x0_clip_value"):
            inner.x0_clip_value = 3.0

        frac = getattr(inner, "_last_frac_clipped", None)
        frac_for_hist = 1.0 if frac is None else max(0.0, float(frac))
        hist = getattr(inner, "_clip_hist", [])
        inner._clip_hist = (hist + [frac_for_hist])[-CLEAN_STREAK_TO_RELAX:]

        # local (per-stream) quiet test + global hysteresis gate
        wants_off = (
            len(inner._clip_hist) == CLEAN_STREAK_TO_RELAX
            and max(inner._clip_hist) < CLAMP_OFF_BAND
        )

        if wants_off and allow_relax:
            if inner.debug_clip_x0:
                print(f"[schedule] turning clamp OFF (hist={inner._clip_hist})")
            inner.debug_clip_x0 = False
        else:
            # Either not quiet yet, or hysteresis says “not yet” → keep ON
            inner.debug_clip_x0 = True
            if wants_off and not allow_relax:
                print("[schedule] clamp wants OFF, but hysteresis keeps it ON")

        # mirror to wrapper for your logs
        if hasattr(m, "diff"):
            m.debug_clip_x0 = inner.debug_clip_x0
            m.x0_clip_value = inner.x0_clip_value
            m._clip_hist    = inner._clip_hist
            m._last_frac_clipped = getattr(inner, "_last_frac_clipped", None)

    # -----------------------------------
    # C) Re-introduce CFG slowly
    # -----------------------------------
    if not hasattr(cfg, "cfg_scale"):
        cfg.cfg_scale = 0.0  # init only

    inn_low = _inner(m_low) if m_low is not None else None
    inn_ecg = _inner(m_ecg) if m_ecg is not None else None
    low_quiet = True if inn_low is None else _quiet_from_hist(inn_low)
    ecg_quiet = True if inn_ecg is None else _quiet_from_hist(inn_ecg)

    safe_to_cfg = low_quiet and ecg_quiet and (thr <= 5e-4)

    if safe_to_cfg:
        if cfg.cfg_scale == 0.0 and epoch >= 6:
            cfg.cfg_scale = 0.5
            print("[schedule] CFG -> 0.5")
        elif cfg.cfg_scale == 0.5 and epoch >= 8:
            cfg.cfg_scale = 1.0
            print("[schedule] CFG -> 1.0")
        elif cfg.cfg_scale == 1.0 and epoch >= 10:
            cfg.cfg_scale = 2.0
            print("[schedule] CFG -> 2.0")

    # Auto-rollback on sampler instability
    bad_low = bool(getattr(inn_low, "_last_return_std_bad", False)) if inn_low is not None else False
    bad_ecg = bool(getattr(inn_ecg, "_last_return_std_bad", False)) if inn_ecg is not None else False
    if (bad_low or bad_ecg) and cfg.cfg_scale > 0.0:
        cfg.cfg_scale = 0.0
        print("[schedule] CFG rolled back to 0.0 due to unstable sampler std")

    return {"recovery_fired": False}

In [14]:
def _parse_epoch(p: Path) -> int:
    # works for "ckpt_epoch_123.pt" or "milestone_e123_...pt"
    m = re.search(r'[eE]poch[_\-]?(\d+)|\be(\d{1,5})\b', p.stem)
    if not m: return -1
    g = m.group(1) or m.group(2)
    try: return int(g)
    except: return -1

def find_latest_ckpt(ckpt_dir: Path) -> Path | None:
    cand = sorted(ckpt_dir.glob("ckpt_epoch_*.pt"), key=_parse_epoch)
    return cand[-1] if cand else None

def set_sampling_defaults(m, cfg):
    if m is None:
        return

    # target the inner diffusion if wrapped
    inner = m.diff if hasattr(m, "diff") else m

    # --- config/defaulted knobs ---
    inner.debug_clip_x0       = bool(getattr(inner, "debug_clip_x0", True))
    inner.x0_clip_q           = float(getattr(cfg, "x0_clip_q", getattr(inner, "x0_clip_q", 0.995)))
    inner.x0_clip_value       = float(getattr(cfg, "x0_clip_value", getattr(inner, "x0_clip_value", 3.0)))
    inner.ddim_alpha_bar_start= float(getattr(cfg, "ddim_alpha_bar_start", getattr(inner, "ddim_alpha_bar_start", 1e-3)))
    inner.ret_std_max         = float(getattr(cfg, "ret_std_max", getattr(inner, "ret_std_max", 3.0)))

    # --- runtime diagnostics (fresh each run/resume) ---
    inner._last_frac_clipped      = 1.0     # start conservatively so clamp stays on until we see data
    inner._clip_hist              = []      # history for your schedule
    inner._last_return_std_bad    = False
    inner._last_return_std_value  = None
    if hasattr(inner, "_printed_clip_once"):
        delattr(inner, "_printed_clip_once")  # optional: re-allow the one-time clamp log

    # mirror to wrapper for your existing prints (no-op if not wrapped)
    if hasattr(m, "diff"):
        m.debug_clip_x0            = inner.debug_clip_x0
        m.x0_clip_q                = inner.x0_clip_q
        m.x0_clip_value            = inner.x0_clip_value
        m.ddim_alpha_bar_start     = inner.ddim_alpha_bar_start
        m.ret_std_max              = inner.ret_std_max
        m._last_frac_clipped       = inner._last_frac_clipped
        m._clip_hist               = inner._clip_hist
        m._last_return_std_bad     = inner._last_return_std_bad
        m._last_return_std_value   = inner._last_return_std_value

def _std_val(m):
    if m is None: return None
    inner = m.diff if hasattr(m, "diff") else m
    return getattr(inner, "_last_return_std_value", None)


def load_training_state(ckpt_path: Path,
                        device,
                        *,
                        mode: str,                 # "resume" or "restart"
                        diff_low, diff_ecg,
                        opt=None, scheduler=None, scaler=None,
                        ema_low=None, ema_ecg=None):
    payload = torch.load(ckpt_path, map_location=device)

    # models
    diff_low.load_state_dict(payload["diff_low"], strict=True)
    diff_ecg.load_state_dict(payload["diff_ecg"], strict=True)

    # EMA
    if ema_low is not None and payload.get("ema_low") is not None:
        ema_low.load_state_dict(payload["ema_low"], strict=True)
    if ema_ecg is not None and payload.get("ema_ecg") is not None:
        ema_ecg.load_state_dict(payload["ema_ecg"], strict=True)

    # opt/sched/scaler (only for "resume")
    if mode == "resume":
        if opt is not None and payload.get("optimizer") is not None:
            opt.load_state_dict(payload["optimizer"])
        if scheduler is not None and payload.get("scheduler") is not None:
            try:
                scheduler.load_state_dict(payload["scheduler"])
            except Exception as e:
                print(f"[resume] scheduler state mismatch, continuing fresh: {e}")
        if scaler is not None and payload.get("scaler") is not None:
            try:
                scaler.load_state_dict(payload["scaler"])
            except Exception as e:
                print(f"[resume] scaler state mismatch, continuing fresh: {e}")

    last_epoch = int(payload.get("epoch", 0))
    start_epoch = last_epoch + 1  # continue after the checkpointed epoch
    print(f"[resume] loaded {ckpt_path.name} (epoch={last_epoch}, mode={mode})")
    return start_epoch

def best_val_from_milestones(milestones_dir: Path) -> float:
    best = float("inf")
    for js in milestones_dir.glob("milestone_*.json"):
        try:
            with open(js, "r") as f:
                val = float(json.load(f).get("val_total", float("inf")))
                if val < best: best = val
        except: pass
    return best

In [15]:
def _collect_sampler_stats(state, m, tag):
    if m is None:
        setattr(state, f"last_fracclamp_{tag}", 0.0)
        setattr(state, f"last_std_{tag}", 0.0)
        return
    inn = _inner(m)  # same helper your schedule uses
    # fraction of x0 elements clipped in the last sampling pass (if your sampler tracks it)
    frac = getattr(inn, "_last_frac_clipped", None)
    setattr(state, f"last_fracclamp_{tag}", float(frac) if frac is not None else 0.0)
    # a stability/scale indicator; adapt these attribute names to whatever you record
    std = (getattr(inn, "_last_return_std", None)
           if hasattr(inn, "_last_return_std") else getattr(inn, "_last_latent_std", None))
    setattr(state, f"last_std_{tag}", float(std) if std is not None else 0.0)


## 7) Run Training

The cell below will:
- Train both streams jointly
- Save samples under `sample_dir/diffusion/epoch_XXX/`
- (Optionally) run evaluation every `cfg.eval_interval` epochs
- Save checkpoints under `ckpt_dir/diffusion/ckpt_epoch_XXX.pt`

> **Smoke test**: with the default overrides above (`epochs_gan=2`, CPU, batch size 2), this should run end‑to‑end.


In [17]:
start_epoch = 1
if getattr(cfg, "resume", None):
    start_epoch = load_training_state(
        Path(cfg.resume if cfg.resume != "auto" else find_latest_ckpt(Path(cfg.ckpt_dir) / "diffusion")),
        device,
        mode=str(getattr(cfg, "resume_mode", "resume")).lower(),   # "resume" or "restart"
        diff_low=diff_low, diff_ecg=diff_ecg,
        opt=opt, scheduler=scheduler, scaler=scaler,
        ema_low=ema_low, ema_ecg=ema_ecg
    )

resume_arg = getattr(cfg, "resume", None)
if resume_arg:
    resume_path = (find_latest_ckpt(Path(cfg.ckpt_dir) / "diffusion")
                   if resume_arg == "auto" else Path(resume_arg))
    print(f"[resume] mode={getattr(cfg, 'resume_mode', 'resume')}  path={resume_path}")

# --- 3) Set sampling defaults once (after models exist & after resume) ---
set_sampling_defaults(diff_low, cfg)
set_sampling_defaults(diff_ecg, cfg)
if ema_low is not None: set_sampling_defaults(ema_low, cfg)
if ema_ecg is not None: set_sampling_defaults(ema_ecg, cfg)

print("[defaults] low:", diff_low.debug_clip_x0, diff_low.x0_clip_q, diff_low.x0_clip_value, diff_low.ddim_alpha_bar_start)
print("[defaults] ecg:", diff_ecg.debug_clip_x0, diff_ecg.x0_clip_q, diff_ecg.x0_clip_value, diff_ecg.ddim_alpha_bar_start)

# --- 4) Dirs / milestone tracking (once, before loop) ---
ckpt_dir = Path(cfg.ckpt_dir) / "diffusion"
ckpt_dir.mkdir(parents=True, exist_ok=True)
milestones_dir = ckpt_dir / "milestones"
milestones_dir.mkdir(parents=True, exist_ok=True)   # <-- ensure it exists

best_val   = float("inf")
min_delta  = 1e-6
total_epochs = int(getattr(cfg, "epochs_gan", 50))  # <-- define this for use below

if not hasattr(cfg, "schedule_state"):
    cfg.schedule_state = NS(
        # last-epoch sampler stats (quiet defaults so we don't trigger recovery on epoch 0)
        last_fracclamp_low=0.0,
        last_fracclamp_ecg=0.0,
        last_std_low=0.0,
        last_std_ecg=0.0,
        had_nan_or_inf=False,
        val_loss_spike=False,

        # cooldown gate
        recovery_cooldown_until=-1,

        # rolling histories (if your _hist/_push_hist don’t auto-create)
        hist_clip_frac_low=[],
        hist_clip_frac_ecg=[],
        hist_std_low=[],
        hist_std_ecg=[],
        hist_had_nan=[],
    )

# --- 5) Training loop ---
for epoch in range(start_epoch, total_epochs + 1):
    # ---------------------------
    # Train
    # ---------------------------
    train_one_epoch(
        epoch, diff_low, diff_ecg, opt, train_loader, device,
        lam_low=getattr(cfg, "lambda_low", 1.0),
        lam_ecg=getattr(cfg, "lambda_ecg", 1.0),
        ema_low=ema_low, ema_ecg=ema_ecg,
        ema_decay=getattr(cfg, "ema_decay", 0.999),
        log_every=10,
        scaler=scaler, use_amp=use_amp
    )

    # ---------------------------
    # Validate (no grad)
    # ---------------------------
    with torch.inference_mode():
        val_total, val_low, val_ecg = validate_one_epoch(
            ema_low if ema_low is not None else diff_low,
            ema_ecg if ema_ecg is not None else diff_ecg,
            val_loader, device,
            lam_low=getattr(cfg, "lambda_low", 1.0),
            lam_ecg=getattr(cfg, "lambda_ecg", 1.0),
        )

    # ---------------------------
    # Scheduler (step once, correctly)
    # ---------------------------
    if scheduler is not None:
        if isinstance(scheduler, lrs.ReduceLROnPlateau):
            scheduler.step(val_total)            # pass validation metric
        elif isinstance(scheduler, lrs.CosineAnnealingWarmRestarts):
            scheduler.step(epoch)                # expects epoch (float), not a metric
        else:
            scheduler.step()
        print(f"[lr] epoch {epoch} lr=" + ", ".join(f"{pg['lr']:.6g}" for pg in opt.param_groups), flush=True)

    # ---------------------------
    # Apply sampling schedule to the *active* models you will sample
    # ---------------------------
    active_low = ema_low if ema_low is not None else diff_low
    active_ecg = ema_ecg if ema_ecg is not None else diff_ecg
    state = cfg.schedule_state
    state.epoch = epoch
    apply_sampling_schedule(epoch, cfg, active_low, active_ecg, state)

    if epoch <= 5 or epoch % 5 == 0:
        _show(active_low, "low", cfg)
        _show(active_ecg, "ecg", cfg)
        print(f"[sched:clip before] low={_last_frac(active_low)}  ecg={_last_frac(active_ecg)}")
        print(f"[sched:std flag ] low={_last_std_bad(active_low)}  ecg={_last_std_bad(active_ecg)}")
        print(f"[sched:std last] low={_std_val(active_low)}  ecg={_std_val(active_ecg)}  "
          f"(thr={getattr(_inner(active_low),'ret_std_max',None)})")

    # ---------------------------
    # Sampling & optional NLL (no grad)
    # ---------------------------
    with torch.inference_mode():
        _ = sample_and_save(
            epoch,
            active_low, active_ecg,                 # sample with the active pair
            cfg, device, cfg.sample_dir,
            cond_low_fixed, y_fixed,
            use_full_ddpm=False,
            plot_in_z=True,
            fixed_z_ylim=None
        )
    

        maybe_compute_nll(
            epoch,
            active_low, active_ecg,
            val_loader, cfg, device
        )
    
    _collect_sampler_stats(state, active_low, "low")
    _collect_sampler_stats(state, active_ecg, "ecg")

     # (optional) one-shot flags from numerical/validation checks this epoch
    state.had_nan_or_inf = bool(
        getattr(active_low, "had_nan_or_inf", False) or
        getattr(active_ecg, "had_nan_or_inf", False)
    )

    if epoch <= 5 or epoch % 5 == 0:
        print(f"[sched:clip after ] low={_last_frac(active_low)}  ecg={_last_frac(active_ecg)}")
        print(f"[sched:std flag ] low={_last_std_bad(active_low)}  ecg={_last_std_bad(active_ecg)}")
        print(f"[sched:std last] low={_std_val(active_low)}  ecg={_std_val(active_ecg)}  "
          f"(thr={getattr(_inner(active_low),'ret_std_max',None)})")
        

    # ---------------------------
    # Periodic evaluation
    # ---------------------------
    if (epoch % int(getattr(cfg, "eval_interval", 5)) == 0) or (epoch == total_epochs):
        adapter = current_adapter()
        maybe_eval(epoch, adapter, D_feature=D, val_loader=val_loader, cfg=cfg, device=device)

    # ---------------------------
    # Checkpoint (include scheduler & scaler)
    # ---------------------------
    # --- save WEIGHTS (big file) + STATE (small telemetry) ---
    ckpt_weights = ckpt_dir / f"ckpt_epoch_{epoch:03d}_WEIGHTS.pt"
    ckpt_state   = ckpt_dir / f"ckpt_epoch_{epoch:03d}_STATE.pt"

    # save full model/ema/optimizer (and optionally scheduler/scaler)
    payload = {
        "epoch": epoch,
        "diff_low": diff_low.state_dict(),
        "diff_ecg": diff_ecg.state_dict(),
        "ema_low": None if ema_low is None else ema_low.state_dict(),
        "ema_ecg": None if ema_ecg is None else ema_ecg.state_dict(),
        "optimizer": opt.state_dict(),
        # include these if you want fully faithful resumes:
        "scheduler": None if scheduler is None or not hasattr(scheduler, "state_dict") else scheduler.state_dict(),
        "scaler": None if 'scaler' not in globals() or scaler is None or not hasattr(scaler, "state_dict") else scaler.state_dict(),
    }
    torch.save(payload, ckpt_weights)
    torch.save(cfg.schedule_state, ckpt_state)
    print(f"[ckpt] saved weights -> {ckpt_weights}", flush=True)
    print(f"[ckpt] saved state   -> {ckpt_state}",   flush=True)

    # Milestone only when improved
    improved_total = (val_total + min_delta) < best_val
    if improved_total:
        save_milestone(
            ckpt_path=ckpt_weights,          # <-- was ckpt_path; now point to WEIGHTS
            milestones_dir=milestones_dir,
            epoch=epoch,
            val_total=val_total, val_low=val_low, val_ecg=val_ecg,
            cfg=cfg,
            m_low=active_low, m_ecg=active_ecg
        )
        best_val = float(val_total)

print("[done] Training complete.", flush=True)

[resume] loaded ckpt_epoch_041.pt (epoch=41, mode=resume)
[resume] mode=resume  path=C:\Users\Joseph\generative-health-models\results\checkpoints\diffusion\ckpt_epoch_041.pt
[defaults] low: True 0.995 3.0 0.001
[defaults] ecg: True 0.995 3.0 0.001
[ep 042]    1/330 loss=0.0781 (low=0.0707, ecg=0.0148) | avg=0.0781 | 4.3 samp/s | ETA 10.2m | mem 0.19/0.48 GB
[ep 042]   10/330 loss=0.1067 (low=0.0954, ecg=0.0225) | avg=0.0994 | 73.5 samp/s | ETA 0.6m | mem 0.19/0.51 GB
[ep 042]   20/330 loss=0.1525 (low=0.1432, ecg=0.0186) | avg=0.0964 | 71.3 samp/s | ETA 0.6m | mem 0.19/0.51 GB
[ep 042]   30/330 loss=0.1114 (low=0.0916, ecg=0.0397) | avg=0.0958 | 70.3 samp/s | ETA 0.6m | mem 0.19/0.51 GB
[ep 042]   40/330 loss=0.0812 (low=0.0714, ecg=0.0196) | avg=0.0960 | 68.7 samp/s | ETA 0.6m | mem 0.19/0.51 GB
[ep 042]   50/330 loss=0.0861 (low=0.0749, ecg=0.0225) | avg=0.0941 | 69.6 samp/s | ETA 0.5m | mem 0.19/0.51 GB
[ep 042]   60/330 loss=0.0892 (low=0.0636, ecg=0.0512) | avg=0.0946 | 71.6 samp/

In [None]:
from pathlib import Path
import torch, json, shutil

# 1) Where to save
out_dir = Path(r"C:\Users\Joseph\generative-health-models\results\checkpoints\diffusion")
out_dir.mkdir(parents=True, exist_ok=True)
epoch_tag = 136  # <-- set this to the epoch you want to preserve
ckpt_path = out_dir / f"ckpt_epoch_{epoch_tag:03d}_WEIGHTS.pt"

# 2) Pick the generators you'd want to serve (EMA preferred if available)
G_low = ema_low if ('ema_low' in globals() and ema_low is not None) else diff_low
G_ecg = ema_ecg if ('ema_ecg' in globals() and ema_ecg is not None) else diff_ecg

# 3) Build payload (models + EMA + (opt/sched/scaler if present) + metadata)
payload = {
    "epoch": int(epoch_tag),
    "diff_low": diff_low.state_dict(),
    "diff_ecg": diff_ecg.state_dict(),
    "ema_low": None if ('ema_low' not in globals() or ema_low is None) else ema_low.state_dict(),
    "ema_ecg": None if ('ema_ecg' not in globals() or ema_ecg is None) else ema_ecg.state_dict(),
    "optimizer": None if 'opt' not in globals() else opt.state_dict(),
    "scheduler": None if 'scheduler' not in globals() or not hasattr(scheduler, "state_dict") else scheduler.state_dict(),
    "scaler": None if 'scaler' not in globals() or not hasattr(scaler, "state_dict") else scaler.state_dict(),
    "meta": {
        "condition_dim": int(getattr(cfg, "condition_dim", 4)),
        "diffusion_steps": int(getattr(cfg, "diffusion_steps", 1000)),
        "beta_schedule": str(getattr(cfg, "beta_schedule", "cosine")),
        "sampling_method": str(getattr(cfg, "sampling_method", "ddim")),
        "sampling_steps": int(getattr(cfg, "sampling_steps", 50)),
        "cfg_scale": float(getattr(cfg, "cfg_scale", 0.0)),
        "cond_drop_prob": float(getattr(cfg, "cond_drop_prob", 0.0)),
        "x0_clip_q": float(getattr(cfg, "x0_clip_q", 0.0)),
        # sampler flags from the *active* models
        "ddim_alpha_bar_start_low": float(getattr(getattr(G_low,'diff',G_low), "ddim_alpha_bar_start", 0.0)),
        "ddim_alpha_bar_start_ecg": float(getattr(getattr(G_ecg,'diff',G_ecg), "ddim_alpha_bar_start", 0.0)),
        "debug_clip_x0_low": bool(getattr(getattr(G_low,'diff',G_low), "debug_clip_x0", True)),
        "debug_clip_x0_ecg": bool(getattr(getattr(G_ecg,'diff',G_ecg), "debug_clip_x0", True)),
    }
}

torch.save(payload, ckpt_path)
print("Saved weights to:", ckpt_path, "| size bytes:", ckpt_path.stat().st_size)