# WESAD Diffusion — Smoke Test (via `generate/core.py`)

This notebook uses your reusable generation layer (`WESADGenerator`) to:

- Resolve the right milestone JSON and normalization files.
- Load a chosen checkpoint (EMA weights preferred if present).
- Generate **one fused window** with output shape **(1, T, 3)** in channel order **[ECG, Resp, EDA]**.
- Run sanity checks (shape, dtype, finite, non-flat).
- Save quick preview plots.
- (Optional) Generate a small batch and write NPZ/CSV for your evaluation layer.

> **Edit the parameters in the next cell** to match your paths and choices.


In [1]:
from pathlib import Path

# --- Parameters (EDIT THESE) ---
REPO_ROOT = Path(r"C:\Users\Joseph\generative-health-models")
MILESTONES = Path(r"C:\Users\Joseph\generative-health-models\results\checkpoints\diffusion\milestones")
CKPT = Path(r"C:\Users\Joseph\generative-health-models\results\checkpoints\diffusion\ckpt_epoch_136_WEIGHTS.pt")  # known-good weights

CONDITION = "baseline"   # "baseline" | "stress" | "amusement"
DURATION = 5250           # 5250 (ECG native) or 120 (low-rate native)
SEED = 123

OVERRIDE_STEPS = 100      # None to use manifest's sampling_steps
OVERRIDE_GUIDANCE = 1.0   # None to use manifest's cfg_scale

PLOT_DIR = Path(r"C:\Users\Joseph\generative-health-models\smoke_plots")
N_SAMPLES_FOR_DATASET = 8
OUT_FORMAT = "npz"       # "npz" or "csv"


In [2]:
import sys, numpy as np

# Put your repo's src/ on sys.path so imports like `from generate.core import WESADGenerator` work
sys.path.insert(0, str((REPO_ROOT / "src").resolve()))

print("Python path head:", sys.path[0])
print("Repo root exists:", REPO_ROOT.exists())
print("Milestones dir:", MILESTONES, "exists:", MILESTONES.exists())
print("Checkpoint:", CKPT, "exists:", CKPT.exists())
PLOT_DIR.mkdir(parents=True, exist_ok=True)


Python path head: C:\Users\Joseph\generative-health-models\src
Repo root exists: True
Milestones dir: C:\Users\Joseph\generative-health-models\results\checkpoints\diffusion\milestones exists: True
Checkpoint: C:\Users\Joseph\generative-health-models\results\checkpoints\diffusion\ckpt_epoch_136_WEIGHTS.pt exists: True


In [3]:
from generate.core import WESADGenerator

gen = WESADGenerator(milestones_dir=MILESTONES, ckpt_path=CKPT)
b = gen.bundle

print("Chosen checkpoint:", b.ckpt_path)
print("Chosen milestone JSON:", b.milestone_json)
print("Norm low:", b.norm_low)
print("Norm ecg:", b.norm_ecg)
print("\nSHA-256")
print("ckpt:", b.sha_ckpt)
print("json:", b.sha_manifest)
print("norm_low:", b.sha_norm_low)
print("norm_ecg:", b.sha_norm_ecg)
print("\nNative lengths -> ECG:", gen.ecg_len, "LOW:", gen.low_len)
print("Manifest sampling:", b.manifest.get("sampling_method","ddim"),
      b.manifest.get("sampling_steps",50),
      "cfg_scale=", b.manifest.get("cfg_scale",0.0))


Chosen checkpoint: C:\Users\Joseph\generative-health-models\results\checkpoints\diffusion\ckpt_epoch_136_WEIGHTS.pt
Chosen milestone JSON: C:\Users\Joseph\generative-health-models\results\checkpoints\diffusion\milestones\milestone_e136_val0.1183_abL5e-04_abE5e-04_cfg0.5.json
Norm low: C:\Users\Joseph\generative-health-models\results\checkpoints\diffusion\milestones\norm_low.npz
Norm ecg: C:\Users\Joseph\generative-health-models\results\checkpoints\diffusion\milestones\norm_ecg.npz

SHA-256
ckpt: f6e8bccaf24555f5f6df646695ce35419deeb76a145cafa0dcb283c7b46f50a4
json: 33eb0fdd2ce163ffd28255f1bb98154e5cae0a5293264f8f77bbe6a602966957
norm_low: e972f51e574ef7226fb638399df3d74903c21b52e76c655d7d71ac18dcf3c0b9
norm_ecg: e9e28c218f48ebf7f76f6cfe9fcc65348d0ec194a30429445c90fd110569704a

Native lengths -> ECG: 5250 LOW: 120
Manifest sampling: ddim 100 cfg_scale= 0.5


In [6]:
# One-sample smoke test
x = gen.sample_one(
    condition=CONDITION,
    T=int(DURATION),
    steps=None if OVERRIDE_STEPS is None else int(OVERRIDE_STEPS),
    guidance=None if OVERRIDE_GUIDANCE is None else float(OVERRIDE_GUIDANCE),
    seed=int(SEED),
)

print("Output shape:", x.shape, "| dtype:", x.dtype, "| finite:", np.isfinite(x).all())
print("Per-channel std [ECG, Resp, EDA]:", x.std(axis=(0,1)))

x_vis = x.copy()  # keep raw x untouched for eval/saving
scale = np.array([1.5, 2.0, 1.2], dtype=np.float32).reshape(1, 1, 3)  # ECG, Resp, EDA
x_vis *= scale

# now plot using x_vis instead of x
L = min(1000, x_vis.shape[1])
t = np.arange(L) / (175.0 if x_vis.shape[1] == 5250 else 4.0)

import matplotlib.pyplot as plt
outdir = Path("smoke_plots"); outdir.mkdir(parents=True, exist_ok=True)

plt.figure()
plt.plot(t, x_vis[0, :L, 0])  # ECG
plt.title("ECG (scaled for visualization)")
plt.xlabel("seconds"); plt.ylabel("a.u.")
plt.tight_layout(); plt.savefig(outdir / "smoke_ecg_scaled.png", dpi=140); plt.close()

plt.figure()
plt.plot(t, x_vis[0, :L, 1])  # Resp
plt.title("Resp (scaled for visualization)")
plt.xlabel("seconds"); plt.ylabel("a.u.")
plt.tight_layout(); plt.savefig(outdir / "smoke_resp_scaled.png", dpi=140); plt.close()

plt.figure()
plt.plot(t, x_vis[0, :L, 2])  # EDA
plt.title("EDA (scaled for visualization)")
plt.xlabel("seconds"); plt.ylabel("a.u.")
plt.tight_layout(); plt.savefig(outdir / "smoke_eda_scaled.png", dpi=140); plt.close()


[sampler:init:low] shape=(1, 2, 120) std=1.009682 min=-2.938 max=2.797
[ddim] head-skip=2 (ᾱ_start=5.000e-04 -> keep t=979, ᾱ=9.712e-04)
[ddim:step1_dbg:low] eps_std=1.007813  alpha_bar_t=9.711878e-04  sqrt_ab_t=3.116389e-02  sqrt_1m_ab_t=9.995143e-01
[sampler:step1:low] std=1.010525 min=-2.941 max=2.818
[sampler:return:low] std=0.535631 thr=3.000 min=0.147 max=2.781
[sampler:init:ecg] shape=(1, 1, 5250) std=1.008110 min=-3.682 max=3.681
[ddim] head-skip=2 (ᾱ_start=5.000e-04 -> keep t=979, ᾱ=9.712e-04)
[ddim:step1_dbg:ecg] eps_std=0.997012  alpha_bar_t=9.711878e-04  sqrt_ab_t=3.116389e-02  sqrt_1m_ab_t=9.995143e-01
[sampler:step1:ecg] std=1.013307 min=-3.682 max=3.717
[sampler:return:ecg] std=0.678416 thr=3.000 min=-1.355 max=3.000
Output shape: (1, 5250, 3) | dtype: float32 | finite: True
Per-channel std [ECG, Resp, EDA]: [0.18256988 0.0132799  1.8191347 ]


In [5]:
# Small dataset generation (optional)
signals = gen.sample_batch(
    condition=CONDITION,
    T=int(DURATION),
    n_samples=int(N_SAMPLES_FOR_DATASET),
    base_seed=int(SEED),
    steps=None if OVERRIDE_STEPS is None else int(OVERRIDE_STEPS),
    guidance=None if OVERRIDE_GUIDANCE is None else float(OVERRIDE_GUIDANCE),
)
print("Batch shape:", signals.shape)

if OUT_FORMAT.lower() == "npz":
    path = gen.write_npz(signals, condition=CONDITION, base_seed=int(SEED), out_dir=PLOT_DIR)
else:
    path = gen.write_csv(signals, condition=CONDITION, base_seed=int(SEED), out_dir=PLOT_DIR, time_mode="index")

print("Wrote dataset to:", path)


[sampler:init:low] shape=(1, 2, 120) std=1.009682 min=-2.938 max=2.797
[ddim] head-skip=2 (ᾱ_start=5.000e-04 -> keep t=979, ᾱ=9.712e-04)
[ddim:step1_dbg:low] eps_std=1.007813  alpha_bar_t=9.711878e-04  sqrt_ab_t=3.116389e-02  sqrt_1m_ab_t=9.995143e-01
[sampler:step1:low] std=1.010525 min=-2.941 max=2.818
[sampler:return:low] std=0.535631 thr=3.000 min=0.147 max=2.781
[sampler:init:ecg] shape=(1, 1, 5250) std=1.008110 min=-3.682 max=3.681
[ddim] head-skip=2 (ᾱ_start=5.000e-04 -> keep t=979, ᾱ=9.712e-04)
[ddim:step1_dbg:ecg] eps_std=0.997012  alpha_bar_t=9.711878e-04  sqrt_ab_t=3.116389e-02  sqrt_1m_ab_t=9.995143e-01
[sampler:step1:ecg] std=1.013307 min=-3.682 max=3.717
[sampler:return:ecg] std=0.678416 thr=3.000 min=-1.355 max=3.000
[sampler:init:low] shape=(1, 2, 120) std=0.977380 min=-2.487 max=2.575
[ddim] head-skip=2 (ᾱ_start=5.000e-04 -> keep t=979, ᾱ=9.712e-04)
[ddim:step1_dbg:low] eps_std=0.989942  alpha_bar_t=9.711878e-04  sqrt_ab_t=3.116389e-02  sqrt_1m_ab_t=9.995143e-01


In [6]:
# Optional: low-rate variant (T=120)
if gen.low_len != gen.ecg_len:
    x_lowT = gen.sample_one(
        condition=CONDITION,
        T=int(gen.low_len),
        steps=None if OVERRIDE_STEPS is None else int(OVERRIDE_STEPS),
        guidance=None if OVERRIDE_GUIDANCE is None else float(OVERRIDE_GUIDANCE),
        seed=int(SEED + 1),
    )
    print("Low-rate output shape:", x_lowT.shape)


[sampler:init:low] shape=(1, 2, 120) std=0.977380 min=-2.487 max=2.575
[ddim] head-skip=2 (ᾱ_start=5.000e-04 -> keep t=979, ᾱ=9.712e-04)
[ddim:step1_dbg:low] eps_std=0.985435  alpha_bar_t=9.711878e-04  sqrt_ab_t=3.116389e-02  sqrt_1m_ab_t=9.995143e-01
[sampler:step1:low] std=0.974440 min=-2.500 max=2.544
[sampler:return:low] std=0.619839 thr=3.000 min=-2.953 max=0.626
[sampler:init:ecg] shape=(1, 1, 5250) std=0.999190 min=-3.359 max=3.749
[ddim] head-skip=2 (ᾱ_start=5.000e-04 -> keep t=979, ᾱ=9.712e-04)
[ddim:step1_dbg:ecg] eps_std=0.998149  alpha_bar_t=9.711878e-04  sqrt_ab_t=3.116389e-02  sqrt_1m_ab_t=9.995143e-01
[sampler:step1:ecg] std=0.999336 min=-3.365 max=3.751
[sampler:return:ecg] std=0.509254 thr=3.000 min=-1.132 max=3.000
Low-rate output shape: (1, 120, 3)
