In [1]:
# Find repo root (directory that contains src/)
from pathlib import Path
import sys, contextlib, io, json
import numpy as np,  pandas as pd
import torch
from tqdm.auto import tqdm
from datetime import datetime
from shutil import copy2

# Resolve REPO root and make src importable
here = Path.cwd()
REPO = next((p for p in (here, *here.parents) if (p / "src").exists()), None)
assert REPO is not None, "Couldn't find repo root (folder with src/)"
sys.path.insert(0, str(REPO / "src"))

# Core imports from the repo
from generate.core import WESADGenerator, _denorm, _interp_to_len, sha256_file
from evaluation.wesad_real import RealWESADPreparer
from evaluation.wesad_eval import WESADEvaluator, EvalConfig
from evaluation.calibration import WESADCalibrator



  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))
    x = torch.randn(1, device="cuda")
    print("Sample tensor device:", x.device)

CUDA available: True
GPU: NVIDIA GeForce RTX 3080
Sample tensor device: cuda:0


In [3]:
# === Paths you’ll most often tweak ===
MILESTONES   = REPO / "results/checkpoints/diffusion/milestones"
CKPT         = REPO / "results/checkpoints/diffusion/ckpt_epoch_130_WEIGHTS.pt"  # swap to _041.pt if needed
FOLD_DIR     = REPO / "data/processed/two_stream/fold_S10"                      # pick the fold you want
OUT_DIR      = REPO / "results/evaluation/eval_ckpt130"                          # outputs go here
CALIB_JSON   = OUT_DIR / "calibration_targets.json"
NORM_LOW_PATH = FOLD_DIR / "norm_low.npz"
NORM_ECG_PATH = FOLD_DIR / "norm_ecg.npz"

# Where the per-class real/synth npz live
REAL_SPLIT_DIR = REPO / "results/evaluation/real_3class_split"
SYN_DIR        = REPO / "data/generated/3class_calibrated"

# (Legacy scratch eval dir; we’ll create run-tagged dirs later)
EVAL_DIR = REPO / "results" / "evaluation" / "ckpt130_cls_run_3class"
EVAL_DIR.mkdir(parents=True, exist_ok=True)

maybe_real_base   = FOLD_DIR / "test_baseline.npz"
maybe_real_stress = FOLD_DIR / "test_stress.npz"
maybe_real_amuse  = FOLD_DIR / "test_amusement.npz"

# ---------- Generation knobs ----------
CONDITION          = "baseline"   # or "stress", "amusement" ... (must be supported by your model)
BASE_SEED          = 42
OVERRIDE_STEPS     = 150          # e.g., 100 to override manifest sampling_steps
OVERRIDE_GUIDANCE  = 0.5          # e.g., 0.5 to override manifest cfg_scale
FORCE_REBUILD_CAL  = True
FORCE_REBUILD_CA   = False

# ---- centralized sampling knobs & norm paths ----
STEPS_ECG = 150     # or 100 if you prefer
STEPS_LOW = 150
GUID_ECG  = 0.5
GUID_LOW  = 0.1
STORE_ECG_QMAP = True     # set True to enable ECG histogram matching
RESP_Q_N = 4001           # finer resolution

# ---------- Mapping (raw test_cond IDs → 3-class labels) ----------
# IMPORTANT for this fold: ID_MAP = {3:0, 2:1, 1:2}
BASELINE_ID  = 3
STRESS_ID    = 2
AMUSEMENT_ID = 1
USE_IDS  = [BASELINE_ID, STRESS_ID, AMUSEMENT_ID]
REMAPPED = {BASELINE_ID: 0, STRESS_ID: 1, AMUSEMENT_ID: 2}
ORDER    = ["baseline", "stress", "amusement"]

# ---------- Sampling rates (used by evaluator for PSD/ACF) ----------
FS_ECG = 175.0
FS_LOW = 4.0

# A run tag we’ll reuse for output folders
RUN_TAG = f"3class_run_{CKPT.stem}_seed{BASE_SEED}_stE{STEPS_ECG}_stL{STEPS_LOW}_gE{GUID_ECG}_gL{GUID_LOW}_cal"
BEST_TAG = "stE200_stL100_gE040_gL010_a80"  # from your sweep summary

OUT_DIR.mkdir(parents=True, exist_ok=True)
REAL_SPLIT_DIR.mkdir(parents=True, exist_ok=True)
SYN_DIR.mkdir(parents=True, exist_ok=True)

print("Repo:", REPO)
print("Milestones:", MILESTONES)
print("Ckpt:", CKPT)
print("Fold:", FOLD_DIR)
print("Out:", OUT_DIR)
print("Real split dir:", REAL_SPLIT_DIR)
print("Synth dir:", SYN_DIR)

Repo: c:\Users\Joseph\generative-health-models
Milestones: c:\Users\Joseph\generative-health-models\results\checkpoints\diffusion\milestones
Ckpt: c:\Users\Joseph\generative-health-models\results\checkpoints\diffusion\ckpt_epoch_130_WEIGHTS.pt
Fold: c:\Users\Joseph\generative-health-models\data\processed\two_stream\fold_S10
Out: c:\Users\Joseph\generative-health-models\results\evaluation\eval_ckpt130
Real split dir: c:\Users\Joseph\generative-health-models\results\evaluation\real_3class_split
Synth dir: c:\Users\Joseph\generative-health-models\data\generated\3class_calibrated


In [4]:
# Build the real set in the standard shape (N, T, 3) with channels [ECG, Resp, EDA]
prep = RealWESADPreparer(FOLD_DIR)
X_real, y_real = prep.prepare(target="ecg")   # T=5250
print("REAL:", X_real.shape, X_real.dtype, "labels:", None if y_real is None else y_real.shape)

assert y_real is not None, "test_cond.npy missing → labels unavailable."
assert len(X_real) == len(y_real), "Label count doesn’t match X count."
assert X_real.shape[2] == 3 and X_real.shape[1] == 5250, f"Expected (N,5250,3), got {X_real.shape}"

# Save the full real set once (handy for other tools)
real_npz = OUT_DIR / "real_test_ecgT.npz"
np.savez_compressed(
    real_npz,
    signals=X_real.astype(np.float32, copy=False),
    channels=np.array(["ECG","Resp","EDA"], dtype=object),
    labels=y_real.astype(np.int32, copy=False),
)
print("Saved real set ->", real_npz)

REAL: (194, 5250, 3) float32 labels: (194,)
Saved real set -> c:\Users\Joseph\generative-health-models\results\evaluation\eval_ckpt130\real_test_ecgT.npz


In [5]:
# ---- Build/refresh the 3-class real split (exclude other IDs, remap with REMAPPED) ----
keep = np.isin(y_real, USE_IDS)
X3 = X_real[keep]
y3_raw = y_real[keep]
y3 = np.vectorize(REMAPPED.get)(y3_raw).astype(np.int32)  # {3:0, 2:1, 1:2} -> {0,1,2}

def _save_one(label_id: int, name: str) -> Path:
    m = (y3 == label_id)
    p = REAL_SPLIT_DIR / f"real_{name}.npz"
    np.savez_compressed(
        p,
        signals=X3[m].astype(np.float32, copy=False),
        channels=np.array(["ECG","Resp","EDA"], dtype=object),
        labels=np.full(int(m.sum()), label_id, dtype=np.int32),
    )
    print(f"{name:<10} -> {p} (N={int(m.sum())})")
    return p

p_baseline = _save_one(0, "baseline")
p_stress   = _save_one(1, "stress")
p_amuse    = _save_one(2, "amusement")

# Quick sanity: label purity & shapes
for name, expected in zip(ORDER, [0,1,2]):
    d = np.load(REAL_SPLIT_DIR / f"real_{name}.npz", allow_pickle=True)
    uu = np.unique(d["labels"])
    print(f"[check] {name}: signals={d['signals'].shape}, labels unique={uu} (expected [{expected}])")

# Keep for later cells
T_target = X_real.shape[1]

baseline   -> c:\Users\Joseph\generative-health-models\results\evaluation\real_3class_split\real_baseline.npz (N=23)
stress     -> c:\Users\Joseph\generative-health-models\results\evaluation\real_3class_split\real_stress.npz (N=46)
amusement  -> c:\Users\Joseph\generative-health-models\results\evaluation\real_3class_split\real_amusement.npz (N=77)
[check] baseline: signals=(23, 5250, 3), labels unique=[0] (expected [0])
[check] stress: signals=(46, 5250, 3), labels unique=[1] (expected [1])
[check] amusement: signals=(77, 5250, 3), labels unique=[2] (expected [2])


In [6]:
CALIB_JSON.parent.mkdir(parents=True, exist_ok=True)

def _build_cal_targets():
    cal_ = WESADCalibrator.from_real(
        X_real,
        store_ecg_qmap=STORE_ECG_QMAP,
        resp_q_n=int(RESP_Q_N),
        # ecg_q_n uses library default unless you want to override
    )
    cal_.save(CALIB_JSON)
    print("Saved calibration targets ->", CALIB_JSON)
    return cal_

if FORCE_REBUILD_CAL or not CALIB_JSON.exists():
    cal = _build_cal_targets()
else:
    cal = WESADCalibrator.load(CALIB_JSON)
    need_rebuild = False
    try:
        # Rebuild if ECG q-map on/off differs from current knob
        if bool(STORE_ECG_QMAP) != bool(cal.has_ecg_qmap()):
            need_rebuild = True
        # Rebuild if resp quantile resolution changed
        rq = getattr(cal.targets, "resp_qs", None)
        if rq is None or len(rq) != int(RESP_Q_N):
            need_rebuild = True
    except Exception:
        need_rebuild = True

    if need_rebuild:
        cal = _build_cal_targets()
    else:
        print("Loaded calibration targets:", CALIB_JSON)

# Quick summary
try:
    rqn = len(getattr(cal.targets, "resp_qs", []))
except Exception:
    rqn = None
print(f"[cal] ecg_qmap={cal.has_ecg_qmap()}  resp_q_n={rqn}")

Saved calibration targets -> c:\Users\Joseph\generative-health-models\results\evaluation\eval_ckpt130\calibration_targets.json
[cal] ecg_qmap=True  resp_q_n=4001


In [7]:
# 1) Load real windows — reuse the already prepared arrays from RealWESADPreparer
X_real_full = X_real.astype(np.float32, copy=False)         # (N, 5250, 3)
y_real_full = y_real.astype(np.int32,  copy=False)          # (N,)

print("X_real_full:", X_real_full.shape, X_real_full.dtype)
print("y_real_full:", y_real_full.shape, y_real_full.dtype)

# Sanity checks
assert X_real_full.ndim == 3 and X_real_full.shape[2] == 3 and X_real_full.shape[1] == 5250, \
    f"Expected (N,5250,3), got {X_real_full.shape}"
assert y_real_full is not None and len(X_real_full) == len(y_real_full), \
    "Label count doesn’t match X count (check you’re using the same fold)."

# Keep target length handy for synth generation
T_target = X_real_full.shape[1]

X_real_full: (194, 5250, 3) float32
y_real_full: (194,) int32


In [8]:
# Keep only the chosen three classes, and remap raw IDs -> {0,1,2}
keep = np.isin(y_real_full, USE_IDS)

remap = np.full_like(y_real_full, fill_value=-1)
for raw, mapped in REMAPPED.items():
    remap[y_real_full == raw] = mapped

X3 = X_real_full[keep]
y3 = remap[keep].astype(np.int32)

assert X3.shape[1:] == (5250, 3), f"Expected (N,5250,3), got {X3.shape}"
assert set(np.unique(y3).tolist()).issubset({0,1,2}), f"Bad remap: {np.unique(y3)}"

# Save class-specific real sets for the evaluator / classifier
REAL_SPLIT_DIR.mkdir(parents=True, exist_ok=True)
channels = np.array(["ECG","Resp","EDA"], dtype=object)

def _save_one(label_id: int, name: str) -> Path:
    m = (y3 == label_id)
    p = REAL_SPLIT_DIR / f"real_{name}.npz"
    np.savez_compressed(
        p,
        signals=X3[m].astype(np.float32, copy=False),
        channels=channels,
        labels=np.full(int(m.sum()), label_id, dtype=np.int32),
    )
    print(f"{name:<10} -> {p} (N={int(m.sum())})")
    return p

p_baseline = _save_one(0, "baseline")
p_stress   = _save_one(1, "stress")
p_amuse    = _save_one(2, "amusement")

# Quick sanity: label purity
for (p, exp) in [(p_baseline,0),(p_stress,1),(p_amuse,2)]:
    d = np.load(p, allow_pickle=True)
    u = np.unique(d["labels"])
    print(f"[check] {p.name}: labels unique={u} (expected [{exp}])")

baseline   -> c:\Users\Joseph\generative-health-models\results\evaluation\real_3class_split\real_baseline.npz (N=23)
stress     -> c:\Users\Joseph\generative-health-models\results\evaluation\real_3class_split\real_stress.npz (N=46)
amusement  -> c:\Users\Joseph\generative-health-models\results\evaluation\real_3class_split\real_amusement.npz (N=77)
[check] real_baseline.npz: labels unique=[0] (expected [0])
[check] real_stress.npz: labels unique=[1] (expected [1])
[check] real_amusement.npz: labels unique=[2] (expected [2])


In [9]:
def synth_match_real_fast(
    gen,
    condition,
    T,
    N,
    base_seed,
    *,
    # shared defaults (fallback to manifest)
    steps=None,
    guidance=None,
    batch_size=16,
    show_pbar=True,
    suppress_model_logs=True,
    # which normalization stats to use
    norm_source="milestone",
    norm_low_path=None,
    norm_ecg_path=None,
    # NEW: per-head overrides
    steps_ecg=None,
    steps_low=None,
    guidance_ecg=None,
    guidance_low=None,
):
    import io, contextlib
    from tqdm.auto import tqdm
    import numpy as np
    import torch

    # --- labels/condition dim ---
    label_idx = {'baseline':0, 'stress':1, 'amusement':2}[condition]
    K = getattr(gen, "condition_dim", None) or int(gen.bundle.manifest.get("condition_dim", 3))
    device   = gen.device
    manifest = gen.bundle.manifest
    method   = str(manifest.get("sampling_method", "ddim")).lower()

    # --- resolve steps & guidance safely ---
    raw_steps = steps if steps is not None else manifest.get("sampling_steps", 50)
    try:
        base_steps = int(raw_steps)
    except Exception:
        base_steps = 50
    if base_steps < 1:
        base_steps = 50  # avoid DDIM division-by-zero

    se = int(steps_ecg) if steps_ecg is not None else base_steps
    sl = int(steps_low) if steps_low is not None else base_steps
    if se < 1: se = base_steps
    if sl < 1: sl = base_steps

    default_guidance = float(guidance) if guidance is not None else float(manifest.get("cfg_scale", 0.0))
    ge = float(guidance_ecg) if guidance_ecg is not None else default_guidance
    gl = float(guidance_low) if guidance_low is not None else default_guidance

    print(f"Sampling: method={method}  steps(ecg/low)={se}/{sl}  guidance(ecg/low)={ge}/{gl}")

    # --- seeding ---
    torch.manual_seed(int(base_seed))
    np.random.seed(int(base_seed))

    # --- choose norms ---
    if norm_source == "milestone":
        nl_path = gen.bundle.norm_low
        ne_path = gen.bundle.norm_ecg
    elif norm_source == "paths":
        if norm_low_path is None or norm_ecg_path is None:
            raise ValueError("Provide norm_low_path and norm_ecg_path when norm_source='paths'.")
        nl_path = norm_low_path
        ne_path = norm_ecg_path
    else:
        raise ValueError("norm_source must be 'milestone' or 'paths'.")

    nl = np.load(nl_path, allow_pickle=False)
    ne = np.load(ne_path, allow_pickle=False)
    print("Using norm_low from:", nl_path)
    print("Using norm_ecg from:", ne_path)
    print("low mean/std:", nl["mean"], nl["std"])
    print("ecg mean/std:", ne["mean"], ne["std"])

    out = np.empty((N, T, 3), dtype=np.float32)
    pbar = tqdm(total=N, desc="Synth", unit="win", leave=True) if show_pbar else None

    for i in range(0, N, batch_size):
        b = min(batch_size, N - i)
        cond = torch.zeros(b, K, device=device, dtype=torch.float32)
        cond[:, label_idx] = 1.0

        def _sample():
            with torch.no_grad():
                # per-head steps/guidance
                x_low = gen.low.sample(cond, num_steps=sl, method=method, cfg_scale=gl)
                x_ecg = gen.ecg.sample(cond, num_steps=se, method=method, cfg_scale=ge)
            return x_low, x_ecg

        if suppress_model_logs:
            with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(io.StringIO()):
                x_low, x_ecg = _sample()
        else:
            x_low, x_ecg = _sample()

        # de-normalize using the chosen stats
        x_low = _denorm(x_low.cpu().numpy(), nl)
        x_ecg = _denorm(x_ecg.cpu().numpy(), ne)

        # map low-stream channels to [Resp, EDA]
        mapped = False
        if "channels" in nl.files:
            try:
                chs = [str(c).lower() for c in nl["channels"].tolist()]
                resp_idx = chs.index("resp") if "resp" in chs else chs.index("respiration")
                eda_idx  = chs.index("eda")  if "eda"  in chs else chs.index("electrodermal activity")
                x_low = x_low[..., [resp_idx, eda_idx]]
                mapped = True
            except Exception:
                pass
        if not mapped:
            # heuristic: smaller std ~ Resp, larger ~ EDA
            stds = x_low.std(axis=(0, 1))
            idx_small = int(np.argmin(stds))  # Resp
            idx_large = 1 - idx_small         # EDA
            x_low = x_low[..., [idx_small, idx_large]]

        # EDA non-negative
        x_low[..., 1] = np.clip(x_low[..., 1], 0.0, None)

        # fuse to requested length T
        if T == gen.ecg_len:
            fused = np.concatenate([x_ecg.astype(np.float32), _interp_to_len(x_low, gen.ecg_len)], axis=-1)
        elif T == gen.low_len:
            fused = np.concatenate([_interp_to_len(x_ecg, gen.low_len), x_low.astype(np.float32)], axis=-1)
        else:
            raise ValueError(f"T={T} must equal {gen.ecg_len} (ECG) or {gen.low_len} (LOW).")

        out[i:i+b] = fused
        if pbar: pbar.update(b)

    if pbar: pbar.close()
    return out

In [10]:
# Where to save (we'll use SYN_DIR defined earlier)
SAVE_DIR = REPO / "data" / "generated"
SAVE_DIR.mkdir(parents=True, exist_ok=True)

# Make a generator tied to your checkpoint
gen = WESADGenerator(milestones_dir=MILESTONES, ckpt_path=CKPT)

# Show which norm stats we're USING for generation (paths mode -> fold stats)
print("Using norm_low path:", NORM_LOW_PATH)
print("Using norm_ecg path:", NORM_ECG_PATH)
nl = np.load(str(NORM_LOW_PATH), allow_pickle=False)
ne = np.load(str(NORM_ECG_PATH), allow_pickle=False)
print("low mean/std:", nl["mean"], nl["std"])
print("ecg mean/std:", ne["mean"], ne["std"])

# Match N per class to the real split
counts = {}
for name in ORDER:
    d = np.load(REAL_SPLIT_DIR / f"real_{name}.npz", allow_pickle=True)
    counts[name] = int(d["signals"].shape[0])
print("[real counts]", counts)

# Helper: generate & calibrate one class
def _gen_and_calibrate(name: str, N: int, T: int) -> Path:
    out_p = SYN_DIR / f"synth_{name}_calibrated.npz"
    if out_p.exists():
        print("[synth] exists:", out_p)
        return out_p

    print(f"[synth] generating '{name}' (N={N}, T={T}) → {out_p}")
    X_syn = synth_match_real_fast(
        gen, name, T, N, BASE_SEED,
        steps=OVERRIDE_STEPS, guidance=OVERRIDE_GUIDANCE,
        steps_ecg=STEPS_ECG, steps_low=STEPS_LOW,
        guidance_ecg=GUID_ECG, guidance_low=GUID_LOW,
        batch_size=16,
        norm_source="paths",
        norm_low_path=str(NORM_LOW_PATH),
        norm_ecg_path=str(NORM_ECG_PATH),
    )

    # Calibrate in physical units using targets we built earlier
    X_syn_cal = cal.apply(
        X_syn,
        do_ecg=True, do_resp=True, do_eda=True,
        ecg_qmap=cal.has_ecg_qmap(), ecg_qmap_alpha=0.6,
        enforce_resp_std=True,
    )

    # Save with labels {0,1,2} matching ORDER
    label_idx = ORDER.index(name)
    y = np.full(N, label_idx, dtype=np.int32)
    np.savez_compressed(
        out_p,
        signals=X_syn_cal.astype(np.float32, copy=False),
        channels=np.array(["ECG","Resp","EDA"], dtype=object),
        labels=y,
    )
    print(f"[synth] wrote {out_p} (N={N})")
    return out_p

# Generate all three classes to match real counts
assert 'T_target' in globals(), "Define T_target from the real set earlier."
synth_paths = {name: _gen_and_calibrate(name, counts[name], T_target) for name in ORDER}
print("Synth files:", synth_paths)

Using norm_low path: c:\Users\Joseph\generative-health-models\data\processed\two_stream\fold_S10\norm_low.npz
Using norm_ecg path: c:\Users\Joseph\generative-health-models\data\processed\two_stream\fold_S10\norm_ecg.npz
low mean/std: [ 0.00161831 -0.00026734] [0.08802962 3.0933654 ]
ecg mean/std: [-2.3908885e-06] [0.2691641]
[real counts] {'baseline': 23, 'stress': 46, 'amusement': 77}
[synth] exists: c:\Users\Joseph\generative-health-models\data\generated\3class_calibrated\synth_baseline_calibrated.npz
[synth] exists: c:\Users\Joseph\generative-health-models\data\generated\3class_calibrated\synth_stress_calibrated.npz
[synth] exists: c:\Users\Joseph\generative-health-models\data\generated\3class_calibrated\synth_amusement_calibrated.npz
Synth files: {'baseline': WindowsPath('c:/Users/Joseph/generative-health-models/data/generated/3class_calibrated/synth_baseline_calibrated.npz'), 'stress': WindowsPath('c:/Users/Joseph/generative-health-models/data/generated/3class_calibrated/synth_str

In [11]:
# ---- knobs (locked classifier, vary only curation) ----
FRACTIONS = [0.00, 0.03, 0.05]   # amusement fractions to try
EPOCHS    = 40                   # lock classifier epochs here
SEED      = 0

# ---- preconditions ----
assert 'RUN_TAG' in globals()
assert 'ORDER' in globals()
assert 'REAL_SPLIT_DIR' in globals() and Path(REAL_SPLIT_DIR).exists()
assert 'SYN_DIR' in globals() and Path(SYN_DIR).exists()
assert 'T_target' in globals() and isinstance(T_target, int)
assert 'FS_ECG' in globals() and 'FS_LOW' in globals()

rng = np.random.default_rng(SEED)
channels = np.array(["ECG","Resp","EDA"], dtype=object)  # not saved (avoid pickle); here for reference only

def _empty_X(): return np.empty((0, T_target, 3), dtype=np.float32)
def _empty_y(): return np.empty((0,), dtype=np.int32)

# Real class counts (RAW)
real_counts = {}
for name in ORDER:
    with np.load(REAL_SPLIT_DIR / f"real_{name}.npz", allow_pickle=False) as d:
        real_counts[name] = int(d["signals"].shape[0])
majority_n = max(real_counts.values())
print("[real RAW counts]", real_counts, "majority=", majority_n)

summary_rows = []
INTEGRATED_DIR = REPO / "results" / "evaluation" / (RUN_TAG + "_INTEGRATED")
INTEGRATED_DIR.mkdir(parents=True, exist_ok=True)

for frac in FRACTIONS:
    tag = f"A{int(frac*100)}_e{EPOCHS}"
    cur_dir = SYN_DIR.with_name(SYN_DIR.name + f"_curated_{tag}")
    cur_dir.mkdir(parents=True, exist_ok=True)

    # --- curate synth per class on RAW calibrated NPZs ---
    for name in ORDER:
        src = SYN_DIR / f"synth_{name}_calibrated.npz"
        if not src.exists():
            print(f"[curate] missing synth raw: {src} -> writing empty stub")
            np.savez_compressed(cur_dir / f"synth_{name}_calibrated.npz",
                                signals=_empty_X(), labels=_empty_y())
            continue

        with np.load(src, allow_pickle=False) as d:
            Xs = d["signals"].astype(np.float32); ys = d["labels"].astype(np.int32)

        if name == "amusement":
            k = int(np.floor(frac * len(Xs)))
        else:
            need = max(0, majority_n - real_counts[name])
            k = min(need, len(Xs))

        if k > 0:
            idx = rng.choice(len(Xs), size=k, replace=False)
            X_keep = Xs[idx]; y_keep = ys[idx]
        else:
            X_keep = _empty_X(); y_keep = _empty_y()

        # Safety: ensure correct T
        if X_keep.shape[1] != T_target:
            X_keep = _empty_X(); y_keep = _empty_y()

        np.savez_compressed(cur_dir / f"synth_{name}_calibrated.npz",
                            signals=X_keep, labels=y_keep)
        print(f"[curate {tag}] {name:<10} -> kept {k} / {len(Xs)}; shape={X_keep.shape}")

    # --- run classifier with locked recipe (z-score + balanced) ---
    real_files_cls  = {k: REAL_SPLIT_DIR / f"real_{k}.npz"               for k in ORDER}
    synth_files_cls = {k: cur_dir        / f"synth_{k}_calibrated.npz"   for k in ORDER}

    out_dir = REPO / "results" / "evaluation" / (RUN_TAG + f"_CLS_{tag}_zsLocked")
    out_dir.mkdir(parents=True, exist_ok=True)

    cfg_cls = EvalConfig(
        T_target=T_target, fs_ecg=FS_ECG, fs_low=FS_LOW,
        results_dir=out_dir,
        run_classifier=True,
        clf_labels=(0,1,2),
        clf_epochs=EPOCHS,
        clf_batch_size=64,
        clf_seed=0,
        clf_lr=1e-3,
        clf_zscore=True,               # <-- lock: z-score w.r.t REAL
        clf_class_weight="balanced",   # <-- lock: inverse-frequency weights
    )
    ev = WESADEvaluator(cfg_cls)
    res = ev.evaluate_all(real_files_cls, synth_files_cls)

    # Show/collect Table 2
    t2 = Path(res["table2_csv"])
    df2 = pd.read_csv(t2)

    def _pick(rowname: str, col: str):
        # handle potential unicode arrow variants robustly
        m = df2['setting'].astype(str).str.contains("Real\+Synth", regex=True, na=False) if rowname=="Real+Synth→Real" \
            else df2['setting'].astype(str).str.contains(rowname.split('→')[0], na=False)
        if not m.any():
            return np.nan
        return float(df2.loc[m, col].values[0])

    auroc_r   = _pick("Real→Real", "AUROC")
    f1_r      = _pick("Real→Real", "F1")
    auroc_mix = _pick("Real+Synth→Real", "AUROC")
    f1_mix    = _pick("Real+Synth→Real", "F1")

    print(f"\n[{tag}] Table 2 -> {t2}")
    print(df2.to_string(index=False))

    # copy to integrated folder
    dest_csv = INTEGRATED_DIR / f"table2_classifier_metrics_{tag}.csv"
    copy2(t2, dest_csv)

    summary_rows.append({
        "amusement_frac": frac,
        "epochs": EPOCHS,
        "AUROC_Real→Real": auroc_r,
        "F1_Real→Real": f1_r,
        "AUROC_Real+Synth→Real": auroc_mix,
        "F1_Real+Synth→Real": f1_mix,
        "csv": str(dest_csv),
    })

# --- summary across fractions ---
summary = pd.DataFrame(summary_rows)
print("\n=== Summary (locked classifier; vary amusement fraction) ===")
print(summary.to_string(index=False))

print("\nTop by F1 (Real+Synth→Real):")
print(summary.sort_values("F1_Real+Synth→Real", ascending=False).head(3).to_string(index=False))

print("\nTop by AUROC (Real+Synth→Real):")
print(summary.sort_values("AUROC_Real+Synth→Real", ascending=False).head(3).to_string(index=False))

[real RAW counts] {'baseline': 23, 'stress': 46, 'amusement': 77} majority= 77
[curate A0_e40] baseline   -> kept 23 / 23; shape=(23, 5250, 3)
[curate A0_e40] stress     -> kept 31 / 46; shape=(31, 5250, 3)
[curate A0_e40] amusement  -> kept 0 / 77; shape=(0, 5250, 3)

[A0_e40] Table 2 -> c:\Users\Joseph\generative-health-models\results\evaluation\3class_run_ckpt_epoch_130_WEIGHTS_seed42_stE150_stL150_gE0.5_gL0.1_cal_CLS_A0_e40_zsLocked\table2_classifier_metrics.csv
        setting    AUROC       F1
      Real→Real 0.978738 0.943352
     Synth→Real 0.693817 0.364004
Real+Synth→Real 0.962325 0.180791
[curate A3_e40] baseline   -> kept 23 / 23; shape=(23, 5250, 3)
[curate A3_e40] stress     -> kept 31 / 46; shape=(31, 5250, 3)
[curate A3_e40] amusement  -> kept 2 / 77; shape=(2, 5250, 3)

[A3_e40] Table 2 -> c:\Users\Joseph\generative-health-models\results\evaluation\3class_run_ckpt_epoch_130_WEIGHTS_seed42_stE150_stL150_gE0.5_gL0.1_cal_CLS_A3_e40_zsLocked\table2_classifier_metrics.csv
 

In [12]:
# --- Per-class counts from REAL split (physical units) ---
counts = {}
for name in ORDER:
    with np.load(REAL_SPLIT_DIR / f"real_{name}.npz", allow_pickle=True) as d:
        counts[name] = int(d["signals"].shape[0])
print("[real counts]", counts)

# --- Per-head sampling for this run (as requested) ---
STEPS_ECG = 100
STEPS_LOW = 100
GUID_ECG  = 0.5
GUID_LOW  = 0.1

# --- Generate & calibrate helper (idempotent) ---
from pathlib import Path
def _gen_and_calibrate(name: str, N: int, T: int) -> Path:
    out_p = SYN_DIR / f"synth_{name}_calibrated.npz"
    if out_p.exists():
        print("[synth] exists:", out_p)
        return out_p

    print(f"[synth] generating '{name}' (N={N}, T={T}) → {out_p}")
    X_syn = synth_match_real_fast(
        gen, name, T, N, BASE_SEED,
        steps=OVERRIDE_STEPS, guidance=OVERRIDE_GUIDANCE,
        steps_ecg=STEPS_ECG, steps_low=STEPS_LOW,
        guidance_ecg=GUID_ECG, guidance_low=GUID_LOW,
        batch_size=16,
        norm_source="paths",
        norm_low_path=str(NORM_LOW_PATH),
        norm_ecg_path=str(NORM_ECG_PATH),
    )
    # Calibrate in physical units using prebuilt targets
    X_syn_cal = cal.apply(
        X_syn,
        do_ecg=True, do_resp=True, do_eda=True,
        ecg_qmap=cal.has_ecg_qmap(), ecg_qmap_alpha=0.7,
        enforce_resp_std=True,
    )
    # Labels {0,1,2} matching ORDER
    label_idx = ORDER.index(name)
    y = np.full(N, label_idx, dtype=np.int32)
    np.savez_compressed(
        out_p,
        signals=X_syn_cal.astype(np.float32, copy=False),
        channels=np.array(["ECG","Resp","EDA"], dtype=object),
        labels=y,
    )
    print(f"[synth] wrote {out_p} (N={N})")
    return out_p

# --- Generate any missing classes (match T to real) ---
assert 'T_target' in globals(), "T_target not set; run the real-prep cell first."
synth_paths = {name: _gen_and_calibrate(name, counts[name], T_target) for name in ORDER}
print("Synth files:", synth_paths)

[real counts] {'baseline': 23, 'stress': 46, 'amusement': 77}
[synth] exists: c:\Users\Joseph\generative-health-models\data\generated\3class_calibrated\synth_baseline_calibrated.npz
[synth] exists: c:\Users\Joseph\generative-health-models\data\generated\3class_calibrated\synth_stress_calibrated.npz
[synth] exists: c:\Users\Joseph\generative-health-models\data\generated\3class_calibrated\synth_amusement_calibrated.npz
Synth files: {'baseline': WindowsPath('c:/Users/Joseph/generative-health-models/data/generated/3class_calibrated/synth_baseline_calibrated.npz'), 'stress': WindowsPath('c:/Users/Joseph/generative-health-models/data/generated/3class_calibrated/synth_stress_calibrated.npz'), 'amusement': WindowsPath('c:/Users/Joseph/generative-health-models/data/generated/3class_calibrated/synth_amusement_calibrated.npz')}


In [13]:
channels = np.array(["ECG","Resp","EDA"], dtype=object)

if 'X_cal' in globals():
    # Prefer evaluator-friendly schema and include a label vector
    label_idx = ORDER.index(CONDITION) if 'CONDITION' in globals() else 0
    cal_npz = SAVE_DIR / f"synth_{CONDITION}_N{X_cal.shape[0]}_T{X_cal.shape[1]}_seed{BASE_SEED}_calibrated_preview.npz"
    np.savez_compressed(
        cal_npz,
        signals=X_cal.astype(np.float32, copy=False),
        channels=channels,
        labels=np.full(X_cal.shape[0], label_idx, dtype=np.int32),
        condition=str(CONDITION),
        window_ids=np.arange(X_cal.shape[0], dtype=np.int32),
    )
    print("Saved preview calibrated NPZ ->", cal_npz)
else:
    print("Skipping single-condition save; calibrated per-class files already exist in:", SYN_DIR)
    for name in ORDER:
        p = SYN_DIR / f"synth_{name}_calibrated.npz"
        if p.exists():
            with np.load(p, allow_pickle=True) as d:
                X = d["signals"]
                print(f"  {name:<10} {p}  (N={X.shape[0]}, T={X.shape[1]})")
        else:
            print(f"  [missing] {p}")

Skipping single-condition save; calibrated per-class files already exist in: c:\Users\Joseph\generative-health-models\data\generated\3class_calibrated
  baseline   c:\Users\Joseph\generative-health-models\data\generated\3class_calibrated\synth_baseline_calibrated.npz  (N=23, T=5250)
  stress     c:\Users\Joseph\generative-health-models\data\generated\3class_calibrated\synth_stress_calibrated.npz  (N=46, T=5250)
  amusement  c:\Users\Joseph\generative-health-models\data\generated\3class_calibrated\synth_amusement_calibrated.npz  (N=77, T=5250)


In [14]:
# --- Build a consolidated run meta (works with per-class calibrated files) ---
manifest = getattr(gen.bundle, "manifest", {})
sampling_method = str(manifest.get("sampling_method", "ddim"))

base_steps = int(OVERRIDE_STEPS or manifest.get("sampling_steps", 50))
base_guid  = float(OVERRIDE_GUIDANCE if OVERRIDE_GUIDANCE is not None else manifest.get("cfg_scale", 0.0))

def _summ_stats(arr: np.ndarray):
    return {
        "mean": [float(arr[..., i].mean()) for i in range(3)],
        "std":  [float(arr[..., i].std())  for i in range(3)],
    }

# Real stats are always available
stats_real = _summ_stats(X_real)

# Optional single-condition preview stats (only if those vars exist)
stats_pre  = _summ_stats(X_synth) if "X_synth" in globals() else None
stats_post = _summ_stats(X_cal)   if "X_cal"   in globals() else None

# If no single-condition preview, aggregate calibrated synth across the three class files
if stats_post is None:
    try:
        agg = []
        for name in ORDER:
            p = SYN_DIR / f"synth_{name}_calibrated.npz"
            if p.exists():
                with np.load(p, allow_pickle=True) as d:
                    agg.append(d["signals"].astype(np.float32))
        if agg:
            S = np.concatenate(agg, axis=0)  # (N_total, T, 3)
            stats_post = _summ_stats(S)
    except Exception as e:
        print("[meta] warn: could not aggregate synth stats:", e)

# Per-class counts (recompute if not in scope)
try:
    counts  # noqa: F401
except NameError:
    counts = {}
    for name in ORDER:
        with np.load(REAL_SPLIT_DIR / f"real_{name}.npz", allow_pickle=True) as d:
            counts[name] = int(d["signals"].shape[0])

# Optional previous evaluation block (if you already ran an eval and have 'results')
eval_block = None
if "results" in globals() and isinstance(results, dict) and "metrics" in results:
    eval_block = {
        "table1_csv": results.get("table1_csv"),
        "figure_psd": results.get("figure_psd"),
        "figure_acf": results.get("figure_acf"),
        "metrics": results["metrics"],
    }

# Collect synthesized file paths for traceability
synth_files_map = {name: str((SYN_DIR / f"synth_{name}_calibrated.npz").resolve()) for name in ORDER}

meta = {
    "timestamp_utc": datetime.utcnow().isoformat() + "Z",
    "run": {
        "seed": int(BASE_SEED),
        "save_dir": str(SYN_DIR.resolve()),
        "real_split_dir": str(REAL_SPLIT_DIR.resolve()),
        "per_class_counts": {k: int(v) for k, v in counts.items()},
        # single-condition fields (present only if preview was run)
        "single_condition": {
            "condition": str(CONDITION) if "CONDITION" in globals() else None,
            "N": int(X_synth.shape[0]) if "X_synth" in globals() else None,
            "T": int(X_synth.shape[1]) if "X_synth" in globals() else None,
            "npz_raw": str(npz_path) if "npz_path" in globals() else None,
            "npz_calibrated": str(cal_npz) if "cal_npz" in globals() else None,
        },
        "synth_files_calibrated": synth_files_map,
        "real_full_npz": str((OUT_DIR / "real_test_ecgT.npz").resolve()),
    },
    "model": {
        "ckpt_path": str(CKPT),
        "ckpt_sha256": sha256_file(CKPT),
        "milestones_dir": str(MILESTONES),
        "condition_dim": int(getattr(gen, "condition_dim", manifest.get("condition_dim", 3))),
        "sampling_method": sampling_method,
    },
    "sampling": {
        "base_steps": base_steps,
        "base_guidance": base_guid,
        "ecg": {"steps": int(STEPS_ECG), "guidance": float(GUID_ECG)},
        "low": {"steps": int(STEPS_LOW), "guidance": float(GUID_LOW)},
    },
    "norms": {
        "used": {
            "low_path": str(NORM_LOW_PATH),
            "low_sha256": sha256_file(NORM_LOW_PATH),
            "ecg_path": str(NORM_ECG_PATH),
            "ecg_sha256": sha256_file(NORM_ECG_PATH),
        },
        "milestone_defaults": {
            "low_path": str(gen.bundle.norm_low),
            "low_sha256": sha256_file(gen.bundle.norm_low),
            "ecg_path": str(gen.bundle.norm_ecg),
            "ecg_sha256": sha256_file(gen.bundle.norm_ecg),
        },
    },
    "calibration": {
        "ecg_scaled_to_real_std": True,
        "resp_quantile_mapped": True,
        "eda_mean_std_matched": True,
        "ecg_hist_qmap": bool(STORE_ECG_QMAP),
        "enforce_resp_std": True,
        "calibration_targets_path": str(CALIB_JSON),
        "ecg_qmap_alpha": 0.6,  # keep consistent with _gen_and_calibrate
    },
    "stats": {
        "real": stats_real,
        "pre_calibration": stats_pre,    # may be None
        "post_calibration": stats_post,  # aggregated if preview not run
        "channel_order": ["ECG", "Resp", "EDA"],
    },
    "evaluation": eval_block,  # may be None; will be filled after you run the eval cells
}

meta_path = (OUT_DIR / f"{RUN_TAG}_meta.json").resolve()
meta_path.write_text(json.dumps(meta, indent=2))
print("Wrote meta ->", meta_path)

Wrote meta -> C:\Users\Joseph\generative-health-models\results\evaluation\eval_ckpt130\3class_run_ckpt_epoch_130_WEIGHTS_seed42_stE150_stL150_gE0.5_gL0.1_cal_meta.json


In [15]:
# --- Per-class file maps for the evaluator (RAW / physical units) ---
real_files = {
    "baseline": REAL_SPLIT_DIR / "real_baseline.npz",
    "stress":   REAL_SPLIT_DIR / "real_stress.npz",
    "amusement":REAL_SPLIT_DIR / "real_amusement.npz",
}
synth_files = {
    "baseline": SYN_DIR / "synth_baseline_calibrated.npz",
    "stress":   SYN_DIR / "synth_stress_calibrated.npz",
    "amusement":SYN_DIR / "synth_amusement_calibrated.npz",
}

# Reconfirm RUN_TAG (if not already set earlier)
try:
    RUN_TAG
except NameError:
    CKPT_TAG = Path(CKPT).stem
    RUN_TAG  = f"3class_run_{CKPT_TAG}_seed{BASE_SEED}_stE{STEPS_ECG}_stL{STEPS_LOW}_gE{GUID_ECG}_gL{GUID_LOW}_cal"

print("RUN_TAG:", RUN_TAG)
for k, p in real_files.items():   print(f"[real]  {k:<10} -> {p}")
for k, p in synth_files.items():  print(f"[synth] {k:<10} -> {p}")

RUN_TAG: 3class_run_ckpt_epoch_130_WEIGHTS_seed42_stE150_stL150_gE0.5_gL0.1_cal
[real]  baseline   -> c:\Users\Joseph\generative-health-models\results\evaluation\real_3class_split\real_baseline.npz
[real]  stress     -> c:\Users\Joseph\generative-health-models\results\evaluation\real_3class_split\real_stress.npz
[real]  amusement  -> c:\Users\Joseph\generative-health-models\results\evaluation\real_3class_split\real_amusement.npz
[synth] baseline   -> c:\Users\Joseph\generative-health-models\data\generated\3class_calibrated\synth_baseline_calibrated.npz
[synth] stress     -> c:\Users\Joseph\generative-health-models\data\generated\3class_calibrated\synth_stress_calibrated.npz
[synth] amusement  -> c:\Users\Joseph\generative-health-models\data\generated\3class_calibrated\synth_amusement_calibrated.npz


In [16]:
# ---- knobs to try ----
GUID_ECG_LIST        = [0.35, 0.30, 0.40]   # start at 0.35; also test 0.30/0.40
STEPS_ECG            = 200
STEPS_LOW            = 100
GUID_LOW             = 0.10
ECG_QMAP_ALPHA_LIST  = [0.8]                # can add 0.6 or 1.0 later

ORDER = ["baseline", "stress", "amusement"]  # sanity

assert 'gen' in globals(), "Create WESADGenerator as 'gen' earlier."

# Where to write synth variants
SYN_SWEEP_BASE = Path(REPO) / "data" / "generated" / "3class_calibrated_sweep"
SYN_SWEEP_BASE.mkdir(parents=True, exist_ok=True) 

# Real counts + T_target from the split you already built
real_counts = {}
T_target = None
for name in ORDER:
    with np.load(Path(REAL_SPLIT_DIR) / f"real_{name}.npz", allow_pickle=False) as d:
        real_counts[name] = int(d["signals"].shape[0])
        if T_target is None:
            T_target = int(d["signals"].shape[1])
print("[real counts]", real_counts, "T_target:", T_target)

# Load (or build) calibrator with ECG q-map
try:
    cal = WESADCalibrator.load(CALIB_JSON)
except Exception:
    prep = RealWESADPreparer(FOLD_DIR)
    Xr_tmp, _ = prep.prepare(target="ecg")
    cal = WESADCalibrator.from_real(Xr_tmp, store_ecg_qmap=True, resp_q_n=4001)
    cal.save(CALIB_JSON)

if not cal.has_ecg_qmap():
    print("[warn] calibration targets lack ECG q-map; ecg_qmap will be disabled for apply().")

def _gen_and_calibrate_to(out_dir: Path, name: str, N: int, T: int, ge: float, alpha: float) -> Path:
    out_dir.mkdir(parents=True, exist_ok=True)
    out_p = out_dir / f"synth_{name}_calibrated.npz"

    # Reuse if already present at correct shape
    if out_p.exists():
        try:
            with np.load(out_p, allow_pickle=False) as d:
                if d["signals"].shape == (N, T, 3):
                    print(f"[reuse] {out_p}")
                    return out_p
        except Exception:
            pass

    print(f"[gen] {name}  N={N} T={T}  steps_ecg={STEPS_ECG} steps_low={STEPS_LOW}  gE={ge:.2f} gL={GUID_LOW:.2f}  qmap_alpha={alpha:.2f}")
    X_syn = synth_match_real_fast(
        gen, name, T, N, BASE_SEED,
        steps=OVERRIDE_STEPS, guidance=OVERRIDE_GUIDANCE,   # base (unused by per-head but kept)
        steps_ecg=STEPS_ECG, steps_low=STEPS_LOW,
        guidance_ecg=ge, guidance_low=GUID_LOW,
        batch_size=16,
        norm_source="paths",
        norm_low_path=str(Path(FOLD_DIR) / "norm_low.npz"),
        norm_ecg_path=str(Path(FOLD_DIR) / "norm_ecg.npz"),
    )

    X_syn_cal = cal.apply(
        X_syn, do_ecg=True, do_resp=True, do_eda=True,
        ecg_qmap=(cal.has_ecg_qmap() and True),
        ecg_qmap_alpha=float(alpha),
        enforce_resp_std=True
    )

    y = np.full(N, ORDER.index(name), dtype=np.int32)
    np.savez_compressed(
        out_p,
        signals=X_syn_cal.astype(np.float32, copy=False),
        labels=y,
        channels=np.array(["ECG","Resp","EDA"], dtype=object),
    )
    return out_p

summary = []
for ge in GUID_ECG_LIST:
    for alpha in ECG_QMAP_ALPHA_LIST:
        tag = f"stE{STEPS_ECG}_stL{STEPS_LOW}_gE{int(round(ge*100)):03d}_gL{int(round(GUID_LOW*100)):03d}_a{int(round(alpha*100)):02d}"
        out_dir = SYN_SWEEP_BASE / tag

        synth_paths = {}
        for name in ORDER:
            synth_paths[name] = _gen_and_calibrate_to(out_dir, name, real_counts[name], T_target, ge, alpha)

        real_files  = {k: Path(REAL_SPLIT_DIR) / f"real_{k}.npz"                  for k in ["baseline","stress","amusement"]}
        synth_files = {k: SYN_DIR               / f"synth_{k}_calibrated.npz"     for k in ["baseline","stress","amusement"]}

         # RAW eval (Table 1 only) with EDA linear fix
        EVAL_RAW_BAND = REPO / "results" / "evaluation" / (RUN_TAG + f"_RAW_{BEST_TAG}_ecgBandHP")
        EVAL_RAW_BAND.mkdir(parents=True, exist_ok=True)

        cfg_raw = EvalConfig(
            T_target=T_target, fs_ecg=FS_ECG, fs_low=FS_LOW,
            results_dir=EVAL_RAW_BAND,
            run_classifier=False,
            apply_eda_linfix=True,        # keep your EDA correction
            psd_low_max_hz=1.5,           # low streams capped for PSD compare
            psd_ecg_band=(0.5, 40.0),     # ECG band
            ecg_hp_cut_hz=None,           # no HPF (was hurting)
            psd_log=True,                 # <- the key that lifted PSD_sim
            psd_norm=None,                # correlation already scale/offset-invariant
        )

        ev = WESADEvaluator(cfg_raw)
        res = ev.evaluate_all(real_files, synth_files)

        # Read Table 1 and collect ECG PSD_sim
        t1 = Path(res["table1_csv"])
        df1 = pd.read_csv(t1)
        ecg_row  = df1[df1["channel"].str.upper() == "ECG"].iloc[0]
        resp_row = df1[df1["channel"].str.upper() == "RESP"].iloc[0]
        eda_row  = df1[df1["channel"].str.upper() == "EDA"].iloc[0]

        row_ecg = df1[df1["channel"].str.upper()=="ECG"].iloc[0]
        ecg_psd = float(ecg_row["PSD_sim"])
        print(f"\n[ECG PSD_sim after band/HPF] {row_ecg['PSD_sim']:.3f}  (target ≥ 0.70)")

        summary.append({
            "tag": tag,
            "gE": ge,
            "alpha": alpha,
            "ECG_PSD_sim": ecg_psd,
            "ECG_KS": float(ecg_row["KS"]),
            "ECG_W1": float(ecg_row["W1"]),
            "ROW_ECG": float(row_ecg['PSD_sim']),
            "Resp_PSD_sim": float(resp_row["PSD_sim"]),
            "EDA_PSD_sim": float(eda_row["PSD_sim"]),
            "table1_csv": str(t1),
            "figure_psd": res["figure_psd"],
            "figure_acf": res["figure_acf"],
        })

# ---- summary & best pick ----
sum_df = pd.DataFrame(summary).sort_values("ECG_PSD_sim", ascending=False)
print("\n=== ECG-focused sweep summary (higher is better) ===")
print(sum_df[["tag","gE","alpha","ECG_PSD_sim","Resp_PSD_sim","EDA_PSD_sim"]].to_string(index=False))

if len(sum_df):
    best = sum_df.iloc[0]
    print(f"\nBest combo: {best['tag']}  ECG_PSD_sim={best['ECG_PSD_sim']:.3f}")
    print(f"Table 1 -> {best['table1_csv']}")
    print(f"PSD fig -> {best['figure_psd']}")
    print(f"ACF fig -> {best['figure_acf']}")

[real counts] {'baseline': 23, 'stress': 46, 'amusement': 77} T_target: 5250
[reuse] c:\Users\Joseph\generative-health-models\data\generated\3class_calibrated_sweep\stE200_stL100_gE035_gL010_a80\synth_baseline_calibrated.npz
[reuse] c:\Users\Joseph\generative-health-models\data\generated\3class_calibrated_sweep\stE200_stL100_gE035_gL010_a80\synth_stress_calibrated.npz
[reuse] c:\Users\Joseph\generative-health-models\data\generated\3class_calibrated_sweep\stE200_stL100_gE035_gL010_a80\synth_amusement_calibrated.npz

[ECG PSD_sim after band/HPF] 0.729  (target ≥ 0.70)
[reuse] c:\Users\Joseph\generative-health-models\data\generated\3class_calibrated_sweep\stE200_stL100_gE030_gL010_a80\synth_baseline_calibrated.npz
[reuse] c:\Users\Joseph\generative-health-models\data\generated\3class_calibrated_sweep\stE200_stL100_gE030_gL010_a80\synth_stress_calibrated.npz
[reuse] c:\Users\Joseph\generative-health-models\data\generated\3class_calibrated_sweep\stE200_stL100_gE030_gL010_a80\synth_amusement

In [17]:
import torch
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))
    x = torch.randn(1, device="cuda")
    print("Sample tensor device:", x.device)

CUDA available: True
GPU: NVIDIA GeForce RTX 3080
Sample tensor device: cuda:0


In [18]:
rng = np.random.default_rng(123)  # reproducible baseline

def phase_randomize_batch(X: np.ndarray, rng=np.random.default_rng(0)) -> np.ndarray:
    """
    Phase-randomized surrogate: preserves magnitude spectrum (thus PSD),
    destroys temporal phase relationships. Works per sample, per channel.
    X: (N, T, 3) -> returns same shape.
    """
    X = np.asarray(X, dtype=np.float32)
    N, T, C = X.shape
    Y = np.empty_like(X)
    for n in range(N):
        for c in range(C):
            x = X[n, :, c].astype(np.float64)
            Xf = np.fft.rfft(x)           # length T//2 + 1
            mag = np.abs(Xf)
            theta = rng.uniform(-np.pi, np.pi, Xf.shape)

            # keep DC (and Nyquist if even-length) as real to preserve mean/energy
            theta[0] = 0.0
            if (T % 2) == 0:
                theta[-1] = 0.0

            Yf = mag * np.exp(1j * theta)
            y = np.fft.irfft(Yf, n=T)
            Y[n, :, c] = y.astype(np.float32)
    return Y

# Use the same RAW settings you’ve been using for Table 1
cfg_raw_bl = EvalConfig(
    T_target=T_target,
    fs_ecg=FS_ECG,
    fs_low=FS_LOW,
    results_dir=EVAL_DIR,     # put outputs alongside your existing results
    run_classifier=False,
    apply_eda_linfix=True,    # keep your EDA linear correction for fairness
    psd_low_max_hz=1.5,
)

ev_bl = WESADEvaluator(cfg_raw_bl)

# Load aligned real & diffusion arrays via your evaluator
Xr, yr, Xs, ys = ev_bl.load_and_align(real_files, synth_files)

# Build baseline (same shape as real), then compute metrics
Xb = phase_randomize_batch(Xr, rng=rng)

dist_diff = ev_bl.distribution_metrics(Xr, Xs)
psd_diff  = ev_bl.psd_similarity(Xr, Xs)
acf_diff  = ev_bl.acf_similarity(Xr, Xs)

dist_base = ev_bl.distribution_metrics(Xr, Xb)
psd_base  = ev_bl.psd_similarity(Xr, Xb)
acf_base  = ev_bl.acf_similarity(Xr, Xb)

# Write a combined Table 1 with a Model column
table1_with_baseline = Path(EVAL_DIR) / "table1_with_baseline.csv"
with table1_with_baseline.open("w", encoding="utf-8") as f:
    f.write("Channel,Model,KS,W1,JSD,PSD_sim,ACF_sim\n")
    for ch in ["ECG","Resp","EDA"]:
        f.write(f"{ch},Diffusion,{dist_diff[ch]['KS']},{dist_diff[ch]['W1']},{dist_diff[ch]['JSD']},{psd_diff[ch]['PSD_sim']},{acf_diff[ch]['ACF_sim']}\n")
        f.write(f"{ch},Baseline (Phase-Rand),{dist_base[ch]['KS']},{dist_base[ch]['W1']},{dist_base[ch]['JSD']},{psd_base[ch]['PSD_sim']},{acf_base[ch]['ACF_sim']}\n")

# (Optional) save separate PSD/ACF overlays for the baseline
psd_fig_diff = ev_bl.figure_psd_overlay(Xr, Xs)
psd_fig_base = ev_bl.figure_psd_overlay(Xr, Xb)
acf_fig_diff = ev_bl.figure_acf_overlay(Xr, Xs)
acf_fig_base = ev_bl.figure_acf_overlay(Xr, Xb)

print("Wrote Table 1 (+ baseline):", table1_with_baseline)
print("PSD fig (Diffusion):", psd_fig_diff)
print("PSD fig (Baseline) :", psd_fig_base)
print("ACF fig (Diffusion):", acf_fig_diff)
print("ACF fig (Baseline) :", acf_fig_base)

Wrote Table 1 (+ baseline): c:\Users\Joseph\generative-health-models\results\evaluation\ckpt130_cls_run_3class\table1_with_baseline.csv
PSD fig (Diffusion): c:\Users\Joseph\generative-health-models\results\evaluation\ckpt130_cls_run_3class\figure_psd_overlay.png
PSD fig (Baseline) : c:\Users\Joseph\generative-health-models\results\evaluation\ckpt130_cls_run_3class\figure_psd_overlay.png
ACF fig (Diffusion): c:\Users\Joseph\generative-health-models\results\evaluation\ckpt130_cls_run_3class\figure_acf_overlay.png
ACF fig (Baseline) : c:\Users\Joseph\generative-health-models\results\evaluation\ckpt130_cls_run_3class\figure_acf_overlay.png


In [19]:
# --- choose your final setting here ---
CHOSEN_FRAC = 0.05   # 0.05 for best F1; use 0.03 for best AUROC; 0.00 for middle ground
EPOCHS = 40          # keep the locked classifier recipe; bump to 50 if you want

# --- preconditions (already defined earlier in your notebook) ---
assert 'RUN_TAG' in globals()
assert 'ORDER' in globals()
assert 'REAL_SPLIT_DIR' in globals() and Path(REAL_SPLIT_DIR).exists()
assert 'SYN_DIR' in globals() and Path(SYN_DIR).exists()
assert 'T_target' in globals()
assert 'FS_ECG' in globals() and 'FS_LOW' in globals()
from evaluation.wesad_eval import EvalConfig, WESADEvaluator  # uses your patched evaluator

rng = np.random.default_rng(0)
def _empty_X(): return np.empty((0, T_target, 3), dtype=np.float32)
def _empty_y(): return np.empty((0,), dtype=np.int32)

# --- real RAW counts (for topping-up baseline/stress) ---
real_counts = {}
for name in ORDER:
    with np.load(REAL_SPLIT_DIR / f"real_{name}.npz", allow_pickle=False) as d:
        real_counts[name] = int(d["signals"].shape[0])
majority_n = max(real_counts.values())
print("[finalize] real RAW counts:", real_counts, "majority=", majority_n)

# --- curate synth on RAW calibrated files for the chosen fraction ---
tag = f"FINAL_A{int(CHOSEN_FRAC*100)}_e{EPOCHS}"
CUR_SYN_RAW_DIR = SYN_DIR.with_name(SYN_DIR.name + f"_{tag}")
CUR_SYN_RAW_DIR.mkdir(parents=True, exist_ok=True)

for name in ORDER:
    src = SYN_DIR / f"synth_{name}_calibrated.npz"
    if not src.exists():
        print(f"[finalize] missing synth raw: {src} -> writing empty stub")
        np.savez_compressed(CUR_SYN_RAW_DIR / f"synth_{name}_calibrated.npz",
                            signals=_empty_X(), labels=_empty_y())
        continue

    with np.load(src, allow_pickle=False) as d:
        Xs = d["signals"].astype(np.float32); ys = d["labels"].astype(np.int32)

    if name == "amusement":
        k = int(np.floor(CHOSEN_FRAC * len(Xs)))
    else:
        need = max(0, majority_n - real_counts[name])
        k = min(need, len(Xs))

    if k > 0:
        idx = rng.choice(len(Xs), size=k, replace=False)
        X_keep = Xs[idx]; y_keep = ys[idx]
    else:
        X_keep = _empty_X(); y_keep = _empty_y()

    if X_keep.shape[1] != T_target:  # safety
        X_keep = _empty_X(); y_keep = _empty_y()

    np.savez_compressed(CUR_SYN_RAW_DIR / f"synth_{name}_calibrated.npz",
                        signals=X_keep, labels=y_keep)
    print(f"[finalize] {name:<10} -> kept {k} / {len(Xs)}; shape={X_keep.shape}")

# --- build file maps ---
real_files = {k: REAL_SPLIT_DIR  / f"real_{k}.npz"              for k in ORDER}
synth_files= {k: CUR_SYN_RAW_DIR / f"synth_{k}_calibrated.npz"  for k in ORDER}

# --- where to store outputs ---
FINAL_DIR = (REPO / "results" / "evaluation" / (RUN_TAG + f"_{tag}_BUNDLE"))
FINAL_DIR.mkdir(parents=True, exist_ok=True)

# ------------------ PASS 1: RAW (Table 1 + figures) ------------------
cfg_raw = EvalConfig(
    T_target=T_target, fs_ecg=FS_ECG, fs_low=FS_LOW,
    results_dir=FINAL_DIR / "RAW",
    run_classifier=False,
    apply_eda_linfix=True,   # EDA linear mean/std correction for KS/W1/JSD only
    psd_low_max_hz=1.5,
)
ev_raw = WESADEvaluator(cfg_raw)
res_raw = ev_raw.evaluate_all(real_files, synth_files)
print("[FINAL RAW] Table 1:", res_raw["table1_csv"])
print("[FINAL RAW] PSD fig:", res_raw["figure_psd"])
print("[FINAL RAW] ACF fig:", res_raw["figure_acf"])

# ------------------ PASS 2: CLS (Table 2) ------------------
cfg_cls = EvalConfig(
    T_target=T_target, fs_ecg=FS_ECG, fs_low=FS_LOW,
    results_dir=FINAL_DIR / "CLS",
    run_classifier=True,
    clf_labels=(0,1,2),
    clf_epochs=EPOCHS,
    clf_batch_size=64,
    clf_seed=0,
    clf_lr=1e-3,
    clf_zscore=True,               # on-the-fly standardization from REAL stats
    clf_class_weight="balanced",   # inverse-frequency weighting
)
ev_cls = WESADEvaluator(cfg_cls)
res_cls = ev_cls.evaluate_all(real_files, synth_files)
print("[FINAL CLS] Table 2:", res_cls["table2_csv"])

# --- tiny recap (prints the Real+Synth→Real row) ---
t2 = Path(res_cls["table2_csv"])
df2 = pd.read_csv(t2)
row = df2[df2["setting"].astype(str).str.contains("Real\\+Synth", regex=True, na=False)]
if len(row):
    auroc = float(row["AUROC"].values[0]); f1 = float(row["F1"].values[0])
    print(f"[FINAL summary] Real+Synth→Real: AUROC={auroc:.4f}  F1={f1:.4f}")
else:
    print("[FINAL summary] Could not find Real+Synth→Real row in Table 2.")

[finalize] real RAW counts: {'baseline': 23, 'stress': 46, 'amusement': 77} majority= 77
[finalize] baseline   -> kept 23 / 23; shape=(23, 5250, 3)
[finalize] stress     -> kept 31 / 46; shape=(31, 5250, 3)
[finalize] amusement  -> kept 3 / 77; shape=(3, 5250, 3)
[FINAL RAW] Table 1: c:\Users\Joseph\generative-health-models\results\evaluation\3class_run_ckpt_epoch_130_WEIGHTS_seed42_stE150_stL150_gE0.5_gL0.1_cal_FINAL_A5_e40_BUNDLE\RAW\table1_distribution_psd_acf.csv
[FINAL RAW] PSD fig: c:\Users\Joseph\generative-health-models\results\evaluation\3class_run_ckpt_epoch_130_WEIGHTS_seed42_stE150_stL150_gE0.5_gL0.1_cal_FINAL_A5_e40_BUNDLE\RAW\figure_psd_overlay.png
[FINAL RAW] ACF fig: c:\Users\Joseph\generative-health-models\results\evaluation\3class_run_ckpt_epoch_130_WEIGHTS_seed42_stE150_stL150_gE0.5_gL0.1_cal_FINAL_A5_e40_BUNDLE\RAW\figure_acf_overlay.png
[FINAL CLS] Table 2: c:\Users\Joseph\generative-health-models\results\evaluation\3class_run_ckpt_epoch_130_WEIGHTS_seed42_stE150_s

In [20]:
import matplotlib.pyplot as plt
from matplotlib.ticker import ScalarFormatter
from matplotlib.lines import Line2D
from pathlib import Path

# ---- Resolve paths & constants from your notebook (with safe fallbacks) ----
REPO      = Path(globals().get("REPO", r"C:\Users\Joseph\generative-health-models"))
ORDER     = list(globals().get("ORDER", ["baseline","stress","amusement"]))
REAL_DIR  = Path(globals().get("REAL_SPLIT_DIR", REPO / "results/evaluation/real_3class_split"))
SYN_DIR   = Path(globals().get("SYN_DIR",      REPO / "data/generated/3class_calibrated"))
RUN_TAG   = str(globals().get("RUN_TAG",       "signal_snippets"))
FS_ECG    = float(globals().get("FS_ECG",      175.0))   # fused arrays at ECG rate
SECONDS   = 8
OUT_DIR   = REPO / "results" / "evaluation" / f"{RUN_TAG}_PLOTS"
OUT_DIR.mkdir(parents=True, exist_ok=True)

# Display knobs (same look as before)
ECG_DECIM  = 3
REAL_STYLE = dict(color="#1f77b4", lw=1.8, alpha=0.95, zorder=3, label="Real")
SYN_STYLE  = dict(color="#ff7f0e", lw=1.4, alpha=0.85, ls=(0,(6,3)), zorder=2, label="Synth")

def _load_signals(npz_path: Path) -> np.ndarray:
    with np.load(npz_path, allow_pickle=False) as d:
        if "signals" in d:
            return d["signals"].astype(np.float32, copy=False)
        # fallback to any (N,T,3)
        for k in d.files:
            arr = d[k]
            if arr.ndim == 3 and arr.shape[-1] == 3:
                return arr.astype(np.float32, copy=False)
    raise ValueError(f"No (N,T,3) signals found in {npz_path}")

def _pick_snippet(X: np.ndarray, seconds: int, fs: float):
    N, T, C = X.shape
    L = int(round(seconds * fs))
    L = min(L, T)
    # median-ECG-std window for "typical" sample
    ecg_std = X[..., 0].std(axis=1)
    mid_idx = int(np.argsort(ecg_std)[len(ecg_std)//2])
    x = X[mid_idx]
    s = max(0, (T - L)//2); e = s + L
    return x[s:e]  # (L,3)

# ---- Load snippets for each class (real & synth) ----
snips_real, snips_synth = [], []
T_ref = None
for name in ORDER:
    Xr = _load_signals(REAL_DIR / f"real_{name}.npz")
    Xs_path = SYN_DIR / f"synth_{name}_calibrated.npz"
    if T_ref is None: T_ref = Xr.shape[1]
    if Xs_path.exists():
        Xs = _load_signals(Xs_path)
    else:
        Xs = np.zeros((1, T_ref, 3), dtype=np.float32)  # empty placeholder
    snips_real.append(_pick_snippet(Xr, SECONDS, FS_ECG))
    snips_synth.append(_pick_snippet(Xs, SECONDS, FS_ECG))

L = snips_real[0].shape[0]
t = np.arange(L) / FS_ECG

# ---- Robust, row-wise y-limits (percentiles) ----
def _row_limits(channel, q_lo=1.0, q_hi=99.0, pad=0.06):
    vals = np.concatenate([snips_real[i][:,channel] for i in range(len(ORDER))]
                          + [snips_synth[i][:,channel] for i in range(len(ORDER))])
    lo, hi = np.percentile(vals, [q_lo, q_hi])
    if channel == 0:  # ECG symmetric looks nicer
        m = max(abs(lo), abs(hi)); lo, hi = -m, m
    span = hi - lo
    return (lo - pad*span, hi + pad*span)

ylims = {
    0: _row_limits(0, 1, 99, 0.05),  # ECG
    1: _row_limits(1, 1, 99, 0.12),  # Resp (more pad for tiny amplitudes)
    2: _row_limits(2, 1, 99, 0.08),  # EDA
}

# ---- Figure with a dedicated header row for title + legend ----
fig = plt.figure(figsize=(14, 10))
gs  = fig.add_gridspec(nrows=4, ncols=3, height_ratios=[0.18, 1, 1, 1],
                       hspace=0.75, wspace=0.42)

# Header row (axis is hidden). Title + legend live here so they never overlap plots.
hdr = fig.add_subplot(gs[0, :])
hdr.axis("off")
hdr.set_title("Signal snippets — real vs. synth (per channel)",
              fontsize=14, fontweight="semibold", pad=14)

legend_handles = [
    Line2D([0],[0], **REAL_STYLE),
    Line2D([0],[0], **SYN_STYLE),
]
# Place legend centered in the header row
hdr.legend(handles=legend_handles, loc="center", ncol=2, frameon=False,
           columnspacing=1.4, handlelength=2.8, handletextpad=0.6,
           bbox_to_anchor=(0.5, 0.25))  # lower in the header strip, far from titles

# Axes grid (3x3) for ECG/Resp/EDA × classes
axes = {
    "ecg":  [fig.add_subplot(gs[1, c]) for c in range(3)],
    "resp": [fig.add_subplot(gs[2, c]) for c in range(3)],
    "eda":  [fig.add_subplot(gs[3, c]) for c in range(3)],
}
row_labels = ["ECG [a.u.]", "Resp [a.u.]", "EDA [a.u.]"]

for c_idx, name in enumerate(ORDER):
    # Titles on the ECG row
    ax0 = axes["ecg"][c_idx]
    ax0.set_title(name, fontsize=12, pad=6)

    # ECG (decimated for clarity)
    d = max(1, int(ECG_DECIM))
    ax0.plot(t[::d], snips_real[c_idx][::d, 0], **REAL_STYLE)
    ax0.plot(t[::d], snips_synth[c_idx][::d, 0], **SYN_STYLE)
    ax0.set_ylim(*ylims[0]); ax0.grid(alpha=0.25)

    # Resp (scientific offset; tiny variations visible)
    ax1 = axes["resp"][c_idx]
    ax1.plot(t, snips_real[c_idx][:, 1], **REAL_STYLE)
    ax1.plot(t, snips_synth[c_idx][:, 1], **SYN_STYLE)
    ax1.set_ylim(*ylims[1]); ax1.grid(alpha=0.25)
    fmt = ScalarFormatter(useMathText=True); fmt.set_powerlimits((-2, 2))
    ax1.yaxis.set_major_formatter(fmt)
    ax1.yaxis.get_offset_text().set_x(-0.06)

    # EDA
    ax2 = axes["eda"][c_idx]
    ax2.plot(t, snips_real[c_idx][:, 2], **REAL_STYLE)
    ax2.plot(t, snips_synth[c_idx][:, 2], **SYN_STYLE)
    ax2.set_ylim(*ylims[2]); ax2.grid(alpha=0.25)

# Row y-labels & bottom x-labels
for ax in axes["ecg"]:  ax.set_ylabel(row_labels[0])
for ax in axes["resp"]: ax.set_ylabel(row_labels[1])
for ax in axes["eda"]:
    ax.set_ylabel(row_labels[2])
    ax.set_xlabel("time [s]")

# Tidy spines
for ax in [*axes["ecg"], *axes["resp"], *axes["eda"]]:
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)

# Save
out_path = OUT_DIR / "signal_snippets_overlay_clean.png"
fig.savefig(out_path, dpi=220, bbox_inches="tight")
plt.close(fig)
print("Saved:", out_path)

Saved: c:\Users\Joseph\generative-health-models\results\evaluation\3class_run_ckpt_epoch_130_WEIGHTS_seed42_stE150_stL150_gE0.5_gL0.1_cal_PLOTS\signal_snippets_overlay_clean.png


In [21]:
out_path = OUT_DIR / "signal_snippets_overlay_clean.png"

# Save PNG (raster)
fig.savefig(out_path, dpi=220)

# Save PDF (vector) — add this line
fig.savefig(out_path.with_suffix(".pdf"), bbox_inches="tight")

plt.close(fig)
print("Saved PNG:", out_path)
print("Saved PDF:", out_path.with_suffix(".pdf"))

Saved PNG: c:\Users\Joseph\generative-health-models\results\evaluation\3class_run_ckpt_epoch_130_WEIGHTS_seed42_stE150_stL150_gE0.5_gL0.1_cal_PLOTS\signal_snippets_overlay_clean.png
Saved PDF: c:\Users\Joseph\generative-health-models\results\evaluation\3class_run_ckpt_epoch_130_WEIGHTS_seed42_stE150_stL150_gE0.5_gL0.1_cal_PLOTS\signal_snippets_overlay_clean.pdf


In [22]:
import matplotlib.pyplot as plt
from pathlib import Path
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import NearestNeighbors
import pandas as pd

# Resolve paths/vars from your notebook (safe defaults)
REPO     = Path(globals().get("REPO", r"C:\Users\Joseph\generative-health-models"))
ORDER    = list(globals().get("ORDER", ["baseline", "stress", "amusement"]))
REAL_DIR = Path(globals().get("REAL_SPLIT_DIR", REPO / "results/evaluation/real_3class_split"))
SYN_DIR  = Path(globals().get("SYN_DIR", REPO / "data/generated/3class_calibrated"))
RUN_TAG  = str(globals().get("RUN_TAG", "leakage_check"))
OUT_DIR  = REPO / "results" / "evaluation" / f"{RUN_TAG}_LEAKAGE"
OUT_DIR.mkdir(parents=True, exist_ok=True)

rng = np.random.default_rng(0)

def _load_signals(npz_path: Path) -> np.ndarray:
    with np.load(npz_path, allow_pickle=False) as d:
        if "signals" in d:
            return d["signals"].astype(np.float32, copy=False)
        for k in d.files:
            arr = d[k]
            if arr.ndim == 3 and arr.shape[-1] == 3:
                return arr.astype(np.float32, copy=False)
    raise ValueError(f"No (N,T,3) signals in {npz_path}")

def _stack_map(folder: Path, prefix: str) -> np.ndarray:
    xs = []
    for name in ORDER:
        p = (folder / f"{prefix}_{name}.npz")
        if not p.exists() and prefix.startswith("synth"):
            p = folder / f"synth_{name}_calibrated.npz"
        if p.exists():
            try:
                x = _load_signals(p)
                if len(x):
                    xs.append(x)
            except Exception as e:
                print(f"[warn] skipping {p}: {e}")
        else:
            print(f"[warn] missing: {p}")
    if not xs:
        raise RuntimeError(f"No data found in {folder} with prefix='{prefix}'")
    return np.concatenate(xs, axis=0)

def _features12_safe(X: np.ndarray, eps: float = 1e-8) -> np.ndarray:
    """
    12-D per-window features: for each channel c in {ECG,Resp,EDA}:
      mean_c, std_c, skew_c, kurtosis_c(Fisher)
    NaN-safe: when std ~ 0, set skew/kurt to 0.
    """
    N, T, C = X.shape
    feats = []
    for c in range(3):
        x = X[..., c].astype(np.float64)
        m  = x.mean(axis=1)
        xc = x - m[:, None]
        v  = (xc * xc).mean(axis=1)                      # population variance
        sd = np.sqrt(np.maximum(v, 0.0))
        # guarded skew/kurt
        m3 = (xc ** 3).mean(axis=1)
        m4 = (xc ** 4).mean(axis=1)
        sk = m3 / np.maximum(sd ** 3, eps)
        ku = m4 / np.maximum(v ** 2, eps) - 3.0          # Fisher

        const = sd < eps
        sk[const] = 0.0
        ku[const] = 0.0

        feats.extend([m, sd, sk, ku])

    F = np.stack(feats, axis=1).T    # (12, N)
    F = F.T.astype(np.float32)       # (N, 12)
    # Final guard: finite only
    F = np.nan_to_num(F, nan=0.0, posinf=0.0, neginf=0.0)
    return F

# 1) Load real & synth
X_real_all  = _stack_map(REAL_DIR,  "real")
X_synth_all = _stack_map(SYN_DIR,   "synth")
print("[leak] real:", X_real_all.shape, "synth:", X_synth_all.shape)

# 2) Split real into train/val (control distances use val)
idx = np.arange(len(X_real_all))
rng.shuffle(idx)
n_val = max(1, int(0.3 * len(idx)))
val_idx, train_idx = idx[:n_val], idx[n_val:]
X_real_tr, X_real_val = X_real_all[train_idx], X_real_all[val_idx]

# 3) Features
F_tr  = _features12_safe(X_real_tr)
F_val = _features12_safe(X_real_val)
F_syn = _features12_safe(X_synth_all)

# 4) Standardize by REAL-TRAIN stats (and re-guard for finiteness)
scaler = StandardScaler().fit(F_tr)
Z_tr  = np.nan_to_num(scaler.transform(F_tr),  nan=0.0, posinf=0.0, neginf=0.0)
Z_val = np.nan_to_num(scaler.transform(F_val), nan=0.0, posinf=0.0, neginf=0.0)
Z_syn = np.nan_to_num(scaler.transform(F_syn), nan=0.0, posinf=0.0, neginf=0.0)

# 5) NN distances to real-train
nn = NearestNeighbors(n_neighbors=1, metric="euclidean")
nn.fit(Z_tr)
d_val = nn.kneighbors(Z_val, return_distance=True)[0].ravel()
d_syn = nn.kneighbors(Z_syn, return_distance=True)[0].ravel()

# 6) Save CSV + summary
df_dist = pd.DataFrame({
    "group": (["real_val->real_train"] * len(d_val)) + (["synth->real_train"] * len(d_syn)),
    "distance": np.concatenate([d_val, d_syn])
})
csv_path = OUT_DIR / "nn_leakage_distances.csv"
df_dist.to_csv(csv_path, index=False)

def _summ(name, arr):
    return {
        "group": name,
        "n": len(arr),
        "median": float(np.median(arr)),
        "p05": float(np.percentile(arr, 5)),
        "p95": float(np.percentile(arr, 95)),
        "mean": float(np.mean(arr)),
        "std": float(np.std(arr, ddof=1)) if len(arr) > 1 else 0.0,
    }

stats = pd.DataFrame([
    _summ("real_val->real_train", d_val),
    _summ("synth->real_train", d_syn)
])
stats_path = OUT_DIR / "nn_leakage_summary.csv"
stats.to_csv(stats_path, index=False)

print("\n[leakage distances — summary]")
print(stats.to_string(index=False))
print("Saved distances ->", csv_path)
print("Saved summary   ->", stats_path)

# 7) Histogram overlay
plt.figure(figsize=(8.6, 4.8))
bins = "auto"
plt.hist(d_val, bins=bins, alpha=0.55, label=f"Real-val → Real-train (n={len(d_val)})")
plt.hist(d_syn, bins=bins, alpha=0.55, label=f"Synth → Real-train (n={len(d_syn)})")
plt.axvline(np.median(d_val), color="#1f77b4", lw=2, ls="--", alpha=0.9)
plt.axvline(np.median(d_syn), color="#ff7f0e", lw=2, ls="--", alpha=0.9)
plt.xlabel("Nearest-neighbor distance in 12-D feature space (z-scored by real-train)")
plt.ylabel("count")
plt.title("Sanity check for leakage: NN distances to real-train")
plt.legend(frameon=False)
plt.tight_layout()
png_path = OUT_DIR / "nn_leakage_hist.png"
pdf_path = OUT_DIR / "nn_leakage_hist.pdf"
plt.savefig(png_path, dpi=200)
plt.savefig(pdf_path, bbox_inches="tight")
plt.close()
print("Saved plots ->", png_path, "and", pdf_path)

# 8) One-line interpretation
ratio = np.median(d_syn) / max(np.median(d_val), 1e-9)
print(f"\n[quick read] median(synth→real) / median(real-val→real) = {ratio:.2f}")
print("Rule of thumb: >>1 is good (no trivial copying). ≈1 or <1 → investigate further (e.g., DTW, embeddings).")

[leak] real: (146, 5250, 3) synth: (146, 5250, 3)

[leakage distances — summary]
               group   n     median        p05        p95       mean       std
real_val->real_train  43   1.245964   0.404575   2.216059   1.344767  0.714672
   synth->real_train 146 256.582749 229.409842 354.734592 269.335578 45.808333
Saved distances -> c:\Users\Joseph\generative-health-models\results\evaluation\3class_run_ckpt_epoch_130_WEIGHTS_seed42_stE150_stL150_gE0.5_gL0.1_cal_LEAKAGE\nn_leakage_distances.csv
Saved summary   -> c:\Users\Joseph\generative-health-models\results\evaluation\3class_run_ckpt_epoch_130_WEIGHTS_seed42_stE150_stL150_gE0.5_gL0.1_cal_LEAKAGE\nn_leakage_summary.csv
Saved plots -> c:\Users\Joseph\generative-health-models\results\evaluation\3class_run_ckpt_epoch_130_WEIGHTS_seed42_stE150_stL150_gE0.5_gL0.1_cal_LEAKAGE\nn_leakage_hist.png and c:\Users\Joseph\generative-health-models\results\evaluation\3class_run_ckpt_epoch_130_WEIGHTS_seed42_stE150_stL150_gE0.5_gL0.1_cal_LEAKAGE\n

In [23]:
assert 'd_val' in globals() and 'd_syn' in globals(), "Run the leakage distances cell first."

OUT = Path(globals().get('OUT_DIR', '.'))
OUT.mkdir(parents=True, exist_ok=True)

# Quick stats for caption/report
def _summ(arr):
    return dict(n=len(arr),
                median=float(np.median(arr)),
                p05=float(np.percentile(arr,5)),
                p95=float(np.percentile(arr,95)),
                mean=float(np.mean(arr)),
                std=float(np.std(arr, ddof=1)) if len(arr)>1 else 0.0)
stats = pd.DataFrame([
    {'group':'real_val->real_train', **_summ(d_val)},
    {'group':'synth->real_train',    **_summ(d_syn)}
])
print(stats.to_string(index=False))
ratio = np.median(d_syn) / max(np.median(d_val), 1e-9)
print(f"\nMedian distance ratio (synth/real-val): {ratio:.1f}×  "
      "(>>1 implies no trivial copying)")

# Shared bin edges so counts are comparable
xmax = float(max(np.max(d_val), np.max(d_syn)))
bins = np.linspace(0, xmax, 60)

fig, (ax_full, ax_zoom) = plt.subplots(1, 2, figsize=(11.5, 4.6), sharey=True)

# Full range
ax_full.hist(d_val, bins=bins, alpha=0.55, label=f"Real-val → Real-train (n={len(d_val)})")
ax_full.hist(d_syn, bins=bins, alpha=0.55, label=f"Synth → Real-train (n={len(d_syn)})")
ax_full.axvline(np.median(d_val), color="#1f77b4", lw=2, ls="--")
ax_full.axvline(np.median(d_syn), color="#ff7f0e", lw=2, ls="--")
ax_full.set_title("Full range")
ax_full.set_xlabel("NN distance in 12-D (z-scored by real-train)")
ax_full.set_ylabel("count")
ax_full.legend(frameon=False, loc="upper right")

# Zoom panel (focus where blue lives)
xzoom = max(10.0, float(np.percentile(d_val, 99)) * 1.5)
ax_zoom.hist(d_val, bins=bins, alpha=0.55, label="Real-val → Real-train")
ax_zoom.hist(d_syn, bins=bins, alpha=0.25, label="Synth → Real-train")  # lighter so blue stands out
ax_zoom.axvline(np.median(d_val), color="#1f77b4", lw=2, ls="--")
ax_zoom.set_xlim(0, xzoom)
ax_zoom.set_title("Zoom near 0")
ax_zoom.set_xlabel("NN distance (zoomed)")

fig.suptitle("Sanity check for leakage: distances to nearest real-train window", y=1.02, fontsize=13)
fig.tight_layout()
p_png = OUT / "nn_leakage_hist_full_plus_zoom.png"
p_pdf = OUT / "nn_leakage_hist_full_plus_zoom.pdf"
fig.savefig(p_png, dpi=220)
fig.savefig(p_pdf, bbox_inches="tight")
plt.close(fig)
print("Saved:", p_png, "and", p_pdf)

               group   n     median        p05        p95       mean       std
real_val->real_train  43   1.245964   0.404575   2.216059   1.344767  0.714672
   synth->real_train 146 256.582749 229.409842 354.734592 269.335578 45.808333

Median distance ratio (synth/real-val): 205.9×  (>>1 implies no trivial copying)
Saved: c:\Users\Joseph\generative-health-models\results\evaluation\3class_run_ckpt_epoch_130_WEIGHTS_seed42_stE150_stL150_gE0.5_gL0.1_cal_LEAKAGE\nn_leakage_hist_full_plus_zoom.png and c:\Users\Joseph\generative-health-models\results\evaluation\3class_run_ckpt_epoch_130_WEIGHTS_seed42_stE150_stL150_gE0.5_gL0.1_cal_LEAKAGE\nn_leakage_hist_full_plus_zoom.pdf


In [24]:
from sklearn.preprocessing import RobustScaler
from sklearn.covariance import LedoitWolf

# Expect Z_tr, Z_val, Z_syn (12-D features) OR Xr_tr, Xr_val, Xs (N,T,3) to already exist.
# If you only have Z_*, the code will use those directly.

def _feat_12(X):
    """12-D per window: mean/std/skew/kurtosis per channel (ECG,Resp,EDA)."""
    import scipy.stats as st
    m  = X.mean(axis=1)
    sd = X.std(axis=1)
    sk = st.skew(X, axis=1, bias=False)
    ku = st.kurtosis(X, axis=1, fisher=True, bias=False)  # excess kurtosis
    F = np.concatenate([m, sd, sk, ku], axis=1).astype(np.float64)
    return np.nan_to_num(F, copy=False, posinf=0.0, neginf=0.0)

if 'Z_tr' not in globals():
    assert all(k in globals() for k in ['Xr_tr','Xr_val','Xs']), \
        "Provide either Z_tr/Z_val/Z_syn or Xr_tr/Xr_val/Xs."
    Z_tr  = _feat_12(Xr_tr)
    Z_val = _feat_12(Xr_val)
    Z_syn = _feat_12(Xs)

# Robust whitening on *train* only
scaler = RobustScaler(quantile_range=(5,95)).fit(Z_tr)
R_tr  = scaler.transform(Z_tr)
R_val = scaler.transform(Z_val)
R_syn = scaler.transform(Z_syn)

# Optional: light clipping after robust scaling to mute remaining outliers
R_tr  = np.clip(R_tr,  -8, 8)
R_val = np.clip(R_val, -8, 8)
R_syn = np.clip(R_syn, -8, 8)

# Shrinkage covariance (stable inverse)
lw  = LedoitWolf().fit(R_tr)
mu  = lw.location_          # (12,)
prec= lw.precision_         # (12,12)

def _maha_rows(X, mu, prec):
    D = X - mu
    return np.sqrt(np.einsum('ni,ij,nj->n', D, prec, D))  # length-N

d_val_m = _maha_rows(R_val, mu, prec)
d_syn_m = _maha_rows(R_syn, mu, prec)

# Copy-rate at a conservative threshold: τ = 1st percentile of real-val distances
tau = float(np.percentile(d_val_m, 1.0))
copy_rate = float((d_syn_m <= tau).mean())
print(f"Copy-rate@τ (τ = 1st pct of real-val): {copy_rate*100:.2f}%  "
      f"(τ={tau:.3f}, med_val={np.median(d_val_m):.3f}, med_syn={np.median(d_syn_m):.3f})")

# Summary table for the report
stats = pd.DataFrame([
    {"group":"real-val→real-train", "n":len(d_val_m),
     "median":np.median(d_val_m), "p05":np.percentile(d_val_m,5), "p95":np.percentile(d_val_m,95)},
    {"group":"synth→real-train",    "n":len(d_syn_m),
     "median":np.median(d_syn_m), "p05":np.percentile(d_syn_m,5), "p95":np.percentile(d_syn_m,95)},
])
print(stats.to_string(index=False))

# Plot: full range + zoom near 0
OUT = Path(globals().get('OUT_DIR', '.')); OUT.mkdir(parents=True, exist_ok=True)

xmax = float(max(d_val_m.max(), d_syn_m.max()))
bins = np.linspace(0, xmax, 48)

fig, ax = plt.subplots(figsize=(10.2, 4.8))

# Histograms (overlaid)
ax.hist(d_val_m, bins=bins, alpha=0.65, label=f"Real-val → Real-train (n={len(d_val_m)})")
ax.hist(d_syn_m, bins=bins, alpha=0.65, label=f"Synth → Real-train (n={len(d_syn_m)})")

# Medians (dashed)
ax.axvline(np.median(d_val_m), color="#1f77b4", lw=2.2, ls="--")
ax.axvline(np.median(d_syn_m), color="#ff7f0e", lw=2.2, ls="--")

# Cosmetics
ax.set_xlim(0, bins[-1])
ax.set_xlabel("Mahalanobis distance (robust-whitened 12-D)")
ax.set_ylabel("count")
ax.set_title("Sanity check for leakage: distance to nearest real-train window")
ax.spines["top"].set_visible(False); ax.spines["right"].set_visible(False)

# Legend ABOVE the axes (outside the plotting area)
handles, labels = ax.get_legend_handles_labels()   # from the hist(label=...) calls
leg = fig.legend(handles, labels,
                 loc="upper center",
                 bbox_to_anchor=(0.5, 0.80),   # ← move DOWN by lowering this (e.g., 0.92–0.90)
                 bbox_transform=fig.transFigure,  # explicit: coordinates are figure-relative
                 ncol=2, frameon=False)

# Make room for the legend
fig.subplots_adjust(top=0.86)

# Save
p_png = OUT / "nn_leakage_mahalanobis_single.png"
p_pdf = OUT / "nn_leakage_mahalanobis_single.pdf"
fig.savefig(p_png, dpi=220, bbox_inches="tight")
fig.savefig(p_pdf, bbox_inches="tight")
plt.close(fig)
print("Saved:", p_png, "and", p_pdf)

Copy-rate@τ (τ = 1st pct of real-val): 0.00%  (τ=1.433, med_val=2.173, med_syn=33.200)
              group   n    median       p05       p95
real-val→real-train  43  2.173273  1.516632  3.532882
   synth→real-train 146 33.199661 27.013119 33.528175
Saved: c:\Users\Joseph\generative-health-models\results\evaluation\3class_run_ckpt_epoch_130_WEIGHTS_seed42_stE150_stL150_gE0.5_gL0.1_cal_LEAKAGE\nn_leakage_mahalanobis_single.png and c:\Users\Joseph\generative-health-models\results\evaluation\3class_run_ckpt_epoch_130_WEIGHTS_seed42_stE150_stL150_gE0.5_gL0.1_cal_LEAKAGE\nn_leakage_mahalanobis_single.pdf
