
# TC‑MultiGAN — Training Notebook (Projection‑Style D, with Fixes)

This notebook contains a **drop‑in training script** with the following fixes applied:

- **Projection‑style conditioning in D**: replaces the auxiliary emotion head.
- **Safe CE handling**: `ce_optional(...)` avoids crashes when the aux head is `None` (keeps code future‑proof).
- **Feature‑matching fix**: handles the tuple return `(pooled, pooled_proj)` from `D.extract_features`.
- **AntiAliasUp1D freeze**: ensures the fixed smoothing kernels are not trainable (safety freeze even if model file wasn’t edited).
- **Config guard**: `lambda_tc` set to `0.0` (classification loss disabled when using projection‑style D).
- **Sanity**: shape asserts use your configured `seq_length_low` / `seq_length_ecg`.

> **Before you run**: make sure your project `src` folder is reachable (next cell shows how).


In [1]:

# --- Project path setup (edit PROJECT_SRC if needed) ---
import os, sys
from pathlib import Path
import matplotlib
import matplotlib.pyplot as plt

# Try to infer a 'src' folder by walking up parents
nb_dir = Path.cwd()
auto_src = None
for p in [nb_dir, *nb_dir.parents]:
    if (p / "src").exists():
        auto_src = (p / "src").resolve()
        break

# If auto-detect fails, set this to your absolute path:
# Example on Windows:
# PROJECT_SRC = r"C:\Users\Joseph\generative-health-models\src"
PROJECT_SRC = os.environ.get("PROJECT_SRC", r"C:\Users\Joseph\generative-health-models\src")

chosen = auto_src if auto_src and auto_src.exists() else Path(PROJECT_SRC)

if not chosen.exists():
    raise FileNotFoundError(
        f"Could not find a 'src' folder at {chosen}.\n"
        "Set PROJECT_SRC to your absolute path or place this notebook inside the repo."
    )

if str(chosen) not in sys.path:
    sys.path.insert(0, str(chosen))

print("Using src path:", chosen)

Using src path: C:\Users\Joseph\generative-health-models\src


In [2]:

# --- Imports ---
import os
import math
import json
from pathlib import Path
from contextlib import contextmanager

import numpy as np
import torch
import torch.nn as nn
import pandas as pd
import torch.optim as optim
from torch.nn import functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR
from evaluation.evaluator import run_epoch_evaluations
from datasets.wesad import make_loader
from models.tc_multigan import create_tc_multigan, boundary_loss
from utils.config import _build_parser, _load_json_defaults
from utils.config import parse_args



print("Torch:", torch.__version__, "| CUDA available:", torch.cuda.is_available())

Torch: 2.7.1+cu118 | CUDA available: True


In [5]:
@contextmanager
def _clean_argv():
    old = sys.argv
    try:
        sys.argv = [old[0]]  # keep program name only
        yield
    finally:
        sys.argv = old

# --- Find repo root automatically (fallback to your known path) ---
def find_project_root(start: Path | None = None) -> Path:
    start = Path(start or Path.cwd()).resolve()
    for p in [start, *start.parents]:
        if (p / "src").exists() and (p / "data" / "processed").exists():
            return p
    # fallback (edit if your path changes)
    return Path(r"C:\Users\Joseph\generative-health-models").resolve()

with _clean_argv():
    cfg = parse_args()  # safe in notebooks now


Configuration:
  aug_jitter: 0.01
  aug_scale: 0.1
  batch_size: 16
  boundary_margin_low: 0.92
  ckpt_dir: results\checkpoints
  ckpt_interval: 5
  condition_dim: 4
  config_json: None
  d_steps: 1
  data_root: C:\Users\Joseph\generative-health-models\src\train\data\processed
  device: cuda
  ecg_boundary_margin: 0.95
  ecg_margin: 0.98
  ema_decay: 0.999
  epochs_ae: 20
  epochs_gan: 50
  fake_label: 0.2
  fm_warmup_epochs: 15
  fold: tc_multigan_fold_S10
  fs_ecg: 175
  fs_low: 4
  g_steps: 2
  gen_label: 0.8
  hidden_dim: 256
  inst_noise_std: 0.01
  inst_noise_warm_epochs: 20
  lambda_adv: 1.0
  lambda_boundary_ecg: 0.08
  lambda_boundary_low: 0.003
  lambda_boundary_low_eda: None
  lambda_boundary_low_resp: None
  lambda_fm: 10.0
  lambda_mismatch: 0.5
  lambda_mm: 0.0
  lambda_rec: 100.0
  lambda_spec_ecg: 1.0
  lambda_spec_low: 0.5
  lambda_spike: 0.4
  lambda_tc: 2.0
  lambda_tv_ecg: 0.0
  lambda_tv_eda: None
  lambda_tv_low: 0.0
  lambda_tv_resp: None
  log_dir: results\logs


In [6]:
# --- Anchor paths to the repo root ---
project_root = find_project_root()
cfg.data_root = str((project_root / "data" / "processed").resolve())
cfg.ckpt_dir  = str((project_root / "results" / "checkpoints").resolve())
cfg.sample_dir= str((project_root / "results" / "samples").resolve())
cfg.log_dir   = str((project_root / "results" / "logs").resolve())

# --- Pick the fold you want here ---
cfg.fold = "tc_multigan_fold_S10"


In [None]:

# --- Utilities ---

class EMA:
    def __init__(self, model, decay=0.999):
        self.decay = float(decay)
        self.shadow = {n: p.clone().detach()
                       for n, p in model.named_parameters() if p.requires_grad}
        self.backup = {}

    @torch.no_grad()
    def update(self, model):
        for n, p in model.named_parameters():
            if p.requires_grad and n in self.shadow:
                self.shadow[n].mul_(self.decay).add_(p.detach(), alpha=1.0 - self.decay)

    def apply_to(self, model):
        self.backup = {}
        for n, p in model.named_parameters():
            if p.requires_grad and n in self.shadow:
                self.backup[n] = p.data.clone()
                p.data.copy_(self.shadow[n].data)

    def restore(self, model):
        for n, p in model.named_parameters():
            if p.requires_grad and n in self.backup:
                p.data.copy_(self.backup[n])
        self.backup = {}


def _next_pow2(n: int) -> int:
    return 1 << (n - 1).bit_length()


def spectral_l1(x, y, fs, nfft=0, fmin=None, fmax=None, shape_only=False, eps=1e-8):
    """
    x,y: (B,T) waveforms. fs in Hz.
    shape_only=True: compare normalized spectral shapes (sum=1), not raw magnitudes.
    """
    B, T = x.shape
    use_nfft = _next_pow2(T) if (nfft is None or nfft == 0) else max(nfft, T)

    X = torch.fft.rfft(x, n=use_nfft, dim=1)
    Y = torch.fft.rfft(y, n=use_nfft, dim=1)
    magX = X.abs()
    magY = Y.abs()

    if (fmin is not None) or (fmax is not None):
        freqs = torch.fft.rfftfreq(use_nfft, d=1.0/fs).to(x.device)
        mask = torch.ones_like(freqs, dtype=torch.bool)
        if fmin is not None: mask &= (freqs >= float(fmin))
        if fmax is not None: mask &= (freqs <= float(fmax))
        magX = magX[:, mask]
        magY = magY[:, mask]

    if shape_only:
        magX = magX / (magX.sum(dim=1, keepdim=True) + eps)
        magY = magY / (magY.sum(dim=1, keepdim=True) + eps)

    return F.l1_loss(torch.log1p(magX), torch.log1p(magY))


def set_requires_grad(net, flag: bool):
    for p in net.parameters():
        p.requires_grad_(flag)


def d_hinge(real_logits, fake_logits):
    # real wants > 1, fake wants < -1
    loss_real = F.relu(1.0 - real_logits).mean()
    loss_fake = F.relu(1.0 + fake_logits).mean()
    return loss_real + loss_fake


def g_hinge(fake_logits):
    # generator wants logits >> 1
    return -fake_logits.mean()


def ce_optional(logits, y, criterion, device):
    """Cross entropy that safely returns 0 if logits or y are None."""
    if (logits is None) or (y is None):
        return torch.zeros((), device=device)
    return criterion(logits, y)

In [None]:

# --- Sampling ---

def generate_and_save_samples(G, noise, cond_low, cfg, epoch, n_plot=4, save_denorm=True):
    outdir = Path(cfg.sample_dir)
    outdir.mkdir(parents=True, exist_ok=True)

    G.eval()
    with torch.no_grad():
        fake_low, fake_ecg = G(noise, cond_low)
        fake_low_np = fake_low.detach().cpu().float().numpy()
        fake_ecg_np = fake_ecg.detach().cpu().float().numpy()

    # -------- shape checks & safe slicing --------
    print(f"[sample dbg] fake_low={fake_low_np.shape}  fake_ecg={fake_ecg_np.shape}  cond={tuple(cond_low.shape)}")
    if fake_low_np.ndim != 3:
        raise ValueError(f"[bad shape] fake_low ndim={fake_low_np.ndim}, expected 3 (N,T,C)")
    N, T_low, C_low = fake_low_np.shape
    if T_low != cfg.seq_length_low:
        print(f"[warn] fake_low T={T_low}, expected {cfg.seq_length_low}")
    if C_low < 2:
        raise ValueError(f"[bad shape] fake_low has {C_low} channel(s); need ≥2 (EDA, RESP)")
    if C_low > 2:
        print(f"[warn] fake_low has {C_low} channels; plotting ONLY the first two (EDA/RESP)." )
    low_plot = fake_low_np[:, :, :2]

    if fake_ecg_np.ndim == 2:
        fake_ecg_np = fake_ecg_np[..., None]  # (N,T)->(N,T,1)
    elif fake_ecg_np.ndim != 3:
        raise ValueError(f"[bad shape] fake_ecg ndim={fake_ecg_np.ndim}, expected 3")
    if fake_ecg_np.shape[-1] < 1:
        raise ValueError("[bad shape] fake_ecg has 0 channels")
    ecg_plot = fake_ecg_np[:, :, 0:1]

    # -------- save normalized arrays --------
    np.save(outdir / f"fake_low_epoch_{epoch:03d}.npy", fake_low_np)
    np.save(outdir / f"fake_ecg_epoch_{epoch:03d}.npy", fake_ecg_np)

    # -------- optional: denormalize for visualization --------
    low_denorm = None
    ecg_denorm = None
    if save_denorm:
        try:
            fold_dir = Path(cfg.data_root) / cfg.fold
            st_low = np.load(fold_dir / "norm_low.npz")
            st_ecg = np.load(fold_dir / "norm_ecg.npz")
            mu_low, sd_low = st_low["mean"].astype(np.float32), st_low["std"].astype(np.float32)   # (2,), (2,)
            mu_ecg, sd_ecg = st_ecg["mean"].astype(np.float32), st_ecg["std"].astype(np.float32)   # (1,), (1,)

            # broadcast only up to available stats
            k_low = min(fake_low_np.shape[-1], mu_low.shape[0])
            k_ecg = min(fake_ecg_np.shape[-1], mu_ecg.shape[0])

            low_denorm = fake_low_np.copy()
            low_denorm[..., :k_low] = low_denorm[..., :k_low] * sd_low[:k_low][None, None, :] + mu_low[:k_low][None, None, :]

            ecg_denorm = fake_ecg_np.copy()
            ecg_denorm[..., :k_ecg] = ecg_denorm[..., :k_ecg] * sd_ecg[:k_ecg][None, None, :] + mu_ecg[:k_ecg][None, None, :]

            np.save(outdir / f"fake_low_epoch_{epoch:03d}_DENORM.npy", low_denorm)
            np.save(outdir / f"fake_ecg_epoch_{epoch:03d}_DENORM.npy", ecg_denorm)
        except Exception as e:
            print(f"[warn] denorm skipped: {e}")

    # -------- quick plots (normalized) --------
    n = int(min(n_plot, low_plot.shape[0], ecg_plot.shape[0]))
    for i in range(n):
        # EDA/RESP
        fig, axes = plt.subplots(2, 1, figsize=(8, 4), sharex=True)
        axes[0].plot(low_plot[i, :, 0]); axes[0].set_title(f"Low-rate (EDA) — sample {i}")
        axes[1].plot(low_plot[i, :, 1]); axes[1].set_title("Low-rate (RESP)")
        axes[1].set_xlabel("Time (4 Hz steps)")
        fig.tight_layout()
        fig.savefig(outdir / f"fake_low_{epoch:03d}_sample_{i}.png")
        plt.close(fig)

        # ECG
        fig2, ax2 = plt.subplots(1, 1, figsize=(8, 2.5))
        ax2.plot(ecg_plot[i, :, 0]); ax2.set_title(f"ECG — sample {i}")
        ax2.set_xlabel("Time (175 Hz steps)")
        fig2.tight_layout()
        fig2.savefig(outdir / f"fake_ecg_{epoch:03d}_sample_{i}.png")
        plt.close(fig2)

        # Optional denorm plots for eyeballing (if available)
        if save_denorm and (low_denorm is not None) and (ecg_denorm is not None):
            fig3, axes3 = plt.subplots(2, 1, figsize=(8, 4), sharex=True)
            axes3[0].plot(low_denorm[i, :, 0]); axes3[0].set_title(f"[DENORM] EDA — sample {i}")
            if low_denorm.shape[-1] > 1:
                axes3[1].plot(low_denorm[i, :, 1]); axes3[1].set_title("[DENORM] RESP")
            axes3[1].set_xlabel("Time (4 Hz steps)")
            fig3.tight_layout()
            fig3.savefig(outdir / f"fake_low_{epoch:03d}_sample_{i}_DENORM.png")
            plt.close(fig3)

            fig4, ax4 = plt.subplots(1, 1, figsize=(8, 2.5))
            ax4.plot(ecg_denorm[i, :, 0]); ax4.set_title(f"[DENORM] ECG — sample {i}")
            ax4.set_xlabel("Time (175 Hz steps)")
            fig4.tight_layout()
            fig4.savefig(outdir / f"fake_ecg_{epoch:03d}_sample_{i}_DENORM.png")
            plt.close(fig4)

    plt.close("all")

In [None]:

# --- One training epoch (with projection D fixes) ---
def train_one_epoch(
    G, D, data_loader, opt_g, opt_d, device, cfg,
    criterion_aux, epoch, ema=None, use_ema=False
):
    G.train(); D.train()
    total_loss_g = 0.0
    total_loss_d = 0.0

    # ----- epoch-scheduled instance noise -----
    noise0 = float(getattr(cfg, "inst_noise_std", 0.0))
    warm   = int(getattr(cfg, "inst_noise_warm_epochs", 0))
    noise_std = float(noise0 * max(0.0, 1.0 - (epoch - 1) / max(1, warm)))
    def add_noise(x):
        return x + noise_std * torch.randn_like(x) if noise_std > 0 else x

    # ----- feature-matching warmup ramp -----
    fm_base = float(getattr(cfg, "lambda_fm", 0.0))
    fm_warm = int(getattr(cfg, "fm_warmup_epochs", 0))
    ramp = min(1.0, max(0.0, epoch / fm_warm)) if fm_warm > 0 else 1.0
    fm_w_curr = fm_base * ramp

    # ----- spectral warmup (ECG) -----
    spec_warm = int(getattr(cfg, "spec_warmup_epochs", 12))
    spec_ramp = min(1.0, max(0.0, (epoch - 1) / float(max(1, spec_warm))))
    lambda_spec_ecg_eff = float(cfg.lambda_spec_ecg) * spec_ramp

    # ---- symmetric observation margins ----
    margin_low = float(getattr(cfg, "boundary_margin_low", 0.92))
    ecg_margin = float(getattr(cfg, "ecg_margin", 0.98))  # clamp used for the D (OBS view)
    # Boundary penalty can use the same or a slightly tighter margin; both tunable via cfg
    ecg_boundary_margin = float(getattr(cfg, "ecg_boundary_margin", ecg_margin))
    lambda_boundary_ecg = float(getattr(cfg, "lambda_boundary_ecg", 0.0))

    # ---- steps schedule ----
    g_steps = int(getattr(cfg, "g_steps", 1))
    d_steps = int(getattr(cfg, "d_steps", 1))
    if epoch <= 15:
        g_steps, d_steps = 1, 2
    else:
        g_steps, d_steps = 1, 1

    did_label_check = False

    for batch in data_loader:
        # --------- unpack ---------
        sig_low  = batch["signal_low"].to(device).float()   # (B, T_low, 2)
        sig_ecg  = batch["signal_ecg"].to(device).float()   # (B, T_ecg, 1)
        cond_low = batch["condition"].to(device).float()    # (B, T_low, K)

        labels = batch.get("label")
        if labels is not None:
            labels = labels.to(device)
            y = labels if labels.ndim == 1 else labels.argmax(dim=-1)
            y = y.long()
            if y.min() >= 1 and y.max() == cfg.condition_dim:
                y = y - 1
            if not did_label_check:
                if y.min() < 0 or y.max() >= cfg.condition_dim:
                    raise ValueError(f"Label indices out of range: {y.min().item()}..{y.max().item()} (K={cfg.condition_dim})")
                did_label_check = True
        else:
            y = None

        # ===================== D STEP(S) =====================
        set_requires_grad(D, True)
        opt_d.zero_grad()
        batch_loss_d = 0.0

        for _ in range(d_steps):
            # --- REAL (clamp -> noise -> D) ---
            real_low_obs = sig_low.clamp(-margin_low, margin_low)
            real_ecg_obs = sig_ecg.clamp(-ecg_margin, ecg_margin)

            real_low_in = add_noise(real_low_obs).detach().requires_grad_(True)
            real_ecg_in = add_noise(real_ecg_obs).detach().requires_grad_(True)

            real_logits, real_aux, _ = D(real_low_in, real_ecg_in, cond_low)
            real_adv_loss = torch.relu(1.0 - real_logits).mean()  # hinge real
            real_aux_loss = ce_optional(real_aux, y, criterion_aux, device)

            # --- R1 on observed inputs to D ---
            r1 = torch.zeros((), device=device)
            if getattr(cfg, "use_r1", False):
                grads = torch.autograd.grad(real_logits.sum(),
                                            [real_low_in, real_ecg_in],
                                            create_graph=True, retain_graph=True, only_inputs=True)
                r1 = sum(g.reshape(g.size(0), -1).pow(2).sum(dim=1).mean() for g in grads)

            # --- FAKE for D (clamp -> noise -> D) ---
            with torch.no_grad():
                z_d = torch.randn(sig_low.size(0), cfg.z_dim, device=device)
                fake_low_d, fake_ecg_d = G(z_d, cond_low)

            fake_low_obs_d = fake_low_d.clamp(-margin_low, margin_low)
            fake_ecg_obs_d = fake_ecg_d.clamp(-ecg_margin, ecg_margin)

            fake_logits_d, fake_aux_d, _ = D(add_noise(fake_low_obs_d), add_noise(fake_ecg_obs_d), cond_low)
            fake_adv_loss = torch.relu(1.0 + fake_logits_d).mean()  # hinge fake
            fake_aux_loss = ce_optional(fake_aux_d, y, criterion_aux, device)

            # --- mismatch (real signals, wrong condition) on observed view ---
            mismatch_w = float(getattr(cfg, "lambda_mismatch", 0.5))
            B = sig_low.size(0)
            perm = torch.randperm(B, device=device)
            if B > 1 and torch.all(perm == torch.arange(B, device=device)):
                perm = torch.roll(perm, 1)
            cond_wrong = cond_low[perm]
            wrong_logits, _, _ = D(add_noise(real_low_obs.detach()), add_noise(real_ecg_obs.detach()), cond_wrong)
            wrong_adv_loss = torch.relu(1.0 + wrong_logits).mean()

            # --- total D loss ---
            loss_d = (
                cfg.lambda_adv * (real_adv_loss + fake_adv_loss + mismatch_w * wrong_adv_loss)
                + cfg.lambda_tc * (real_aux_loss + fake_aux_loss)
                + 0.5 * float(getattr(cfg, "r1_gamma", 0.0)) * r1
            )
            loss_d.backward()
            batch_loss_d += loss_d.item()

        opt_d.step()
        total_loss_d += batch_loss_d / d_steps

        # ===================== G STEP(S) =====================
        set_requires_grad(D, False)
        opt_g.zero_grad()
        batch_loss_g = 0.0

        for _ in range(g_steps):
            z_g = torch.randn(sig_low.size(0), cfg.z_dim, device=device)
            fake_low_g, fake_ecg_g = G(z_g, cond_low)

            # Observed (clamped) views for D/FM/spec
            real_low_obs = sig_low.clamp(-margin_low, margin_low)
            real_ecg_obs = sig_ecg.clamp(-ecg_margin, ecg_margin)
            fake_low_obs = fake_low_g.clamp(-margin_low, margin_low)
            fake_ecg_obs = fake_ecg_g.clamp(-ecg_margin, ecg_margin)

            gen_logits, gen_aux, fake_feat = D(fake_low_obs, fake_ecg_obs, cond_low)
            adv_g = -gen_logits.mean()  # hinge generator
            gen_aux_loss = ce_optional(gen_aux, y, criterion_aux, device)

            # ----- Feature Matching on observed view -----
            fm_loss = torch.zeros((), device=device)
            if fm_w_curr > 0.0:
                with torch.no_grad():
                    rf_low = add_noise(real_low_obs)
                    rf_ecg = add_noise(real_ecg_obs)
                    real_feat = D.extract_features(rf_low, rf_ecg, cond_low)
                    if isinstance(real_feat, tuple):
                        real_feat = real_feat[0]
                if isinstance(fake_feat, tuple):
                    fake_feat = fake_feat[0]
                fm_loss = F.l1_loss(fake_feat.mean(dim=0), real_feat.mean(dim=0))

            # ----- Moment Matching (tanh-aware on observed view) -----
            lam_mm_low = float(getattr(cfg, "lambda_mm", 0.0))
            lam_mm_ecg = float(getattr(cfg, "lambda_mm_ecg", 0.0))  # optional; default 0.0

            mm_low = torch.zeros((), device=device)
            mm_ecg = torch.zeros((), device=device)

            if lam_mm_low > 0.0:
                def mm_stats_low(x):
                    x = x.clamp(-margin_low, margin_low)   # same observed view as D/spec
                    mu = x.mean(dim=(0,1))                 # per-channel (EDA, RESP)
                    sd = x.std (dim=(0,1))
                    return mu, sd

                mu_r_low, sd_r_low = mm_stats_low(sig_low)
                mu_f_low, sd_f_low = mm_stats_low(fake_low_g)
                mm_low = F.l1_loss(mu_f_low, mu_r_low) + F.l1_loss(sd_f_low, sd_r_low)

            if lam_mm_ecg > 0.0:
                def mm_stats_ecg(x):                       # x: (B,T,1)
                    x = x.clamp(-ecg_margin, ecg_margin)   # observed view
                    return x.mean(), x.std()
                mu_re, sd_re = mm_stats_ecg(sig_ecg)
                mu_fe, sd_fe = mm_stats_ecg(fake_ecg_g)
                mm_ecg = F.l1_loss(mu_fe, mu_re) + F.l1_loss(sd_fe, sd_re)

            # ----- temporal TV penalties -----
            tv_w_low_default = float(getattr(cfg, "lambda_tv_low", 0.0))
            tv_eda  = (fake_low_g[:, 1:, 0] - fake_low_g[:, :-1, 0]).abs().mean()
            tv_resp = (fake_low_g[:, 1:, 1] - fake_low_g[:, :-1, 1]).abs().mean()
            tv_w_eda  = float(getattr(cfg, "lambda_tv_eda",  tv_w_low_default or 0.0) or tv_w_low_default)
            tv_w_resp = float(getattr(cfg, "lambda_tv_resp", tv_w_low_default or 0.0) or tv_w_low_default)
            tv_low = tv_w_eda * tv_eda + tv_w_resp * tv_resp

            tv_ecg = (fake_ecg_g[:, 1:, :] - fake_ecg_g[:, :-1, :]).abs().mean()

            # ----- Spike penalty (hinged TV on Δ) -----
            tau = float(getattr(cfg, "spike_tau", 2.0))
            d_ecg = fake_ecg_g[:, 1:, :] - fake_ecg_g[:, :-1, :]
            spike_loss = torch.relu(d_ecg.abs() - tau).pow(2).mean()

            # ---- ECG boundary penalty on RAW generator output (not the clamped view) ----
            boundary_pen_ecg = lambda_boundary_ecg * boundary_loss(fake_ecg_g, ecg_boundary_margin)

            # ---- Spectral losses on observed view (shape to (B, T)) ----
            real_low_eda  = real_low_obs[..., 0]
            real_low_resp = real_low_obs[..., 1]
            fake_low_eda  = fake_low_obs[..., 0]
            fake_low_resp = fake_low_obs[..., 1]

            spec_low_eda  = spectral_l1(fake_low_eda,  real_low_eda,  fs=cfg.fs_low,
                                        nfft=cfg.spec_nfft_low,
                                        fmin=getattr(cfg, "spec_eda_fmin", None),
                                        fmax=getattr(cfg, "spec_eda_fmax", None))
            spec_low_resp = spectral_l1(fake_low_resp, real_low_resp, fs=cfg.fs_low,
                                        nfft=cfg.spec_nfft_low,
                                        fmin=getattr(cfg, "spec_resp_fmin", None),
                                        fmax=getattr(cfg, "spec_resp_fmax", None),
                                        shape_only=bool(getattr(cfg, "spec_resp_shape_only", True)))
            spec_low = 0.5 * spec_low_eda + 0.5 * spec_low_resp

            real_ecg_1c = real_ecg_obs.squeeze(-1)
            fake_ecg_1c = fake_ecg_obs.squeeze(-1)
            spec_ecg = spectral_l1(fake_ecg_1c, real_ecg_1c, fs=cfg.fs_ecg,
                                   nfft=cfg.spec_nfft_ecg,
                                   fmin=cfg.spec_ecg_fmin, fmax=cfg.spec_ecg_fmax)

            # ---- Boundary penalty for LOW tanh head (on raw generator output) ----
            low_unit = fake_low_g  # generator output
            margin = margin_low
            pen_eda  = torch.relu(low_unit[..., 0].abs() - margin).pow(2).mean()
            pen_resp = torch.relu(low_unit[..., 1].abs() - margin).pow(2).mean()
            lam_b_base  = float(getattr(cfg, "lambda_boundary_low", 0.0))
            lam_b_eda   = float(getattr(cfg, "lambda_boundary_low_eda",  None) or lam_b_base)
            lam_b_resp  = float(getattr(cfg, "lambda_boundary_low_resp", None) or lam_b_base)
            boundary_pen = lam_b_eda * pen_eda + lam_b_resp * pen_resp

            # ----- final generator loss -----
            loss_g = (
                cfg.lambda_adv * adv_g
                + cfg.lambda_tc * gen_aux_loss
                + fm_w_curr * fm_loss
                + tv_low
                + float(cfg.lambda_tv_ecg) * tv_ecg
                + float(cfg.lambda_spec_low) * spec_low
                + float(lambda_spec_ecg_eff) * spec_ecg
                + float(getattr(cfg, "lambda_spike", 0.0)) * spike_loss
                + boundary_pen
                + boundary_pen_ecg
                + lam_mm_low * mm_low
                + lam_mm_ecg * mm_ecg
            )

            loss_g.backward()
            opt_g.step()
            if use_ema and (ema is not None):
                ema.update(G)
            batch_loss_g += loss_g.item()

        total_loss_g += batch_loss_g / g_steps

    return total_loss_g / len(data_loader), total_loss_d / len(data_loader)

In [None]:

# --- Validation (hinge + optional FM) ---
def validate(G, D, data_loader, device, cfg, criterion_aux):
    G.eval(); D.eval()
    val_loss_g = 0.0; val_loss_d = 0.0

    with torch.no_grad():
        for batch in data_loader:
            sig_low  = batch["signal_low"].to(device).float()
            sig_ecg  = batch["signal_ecg"].to(device).float()
            cond_low = batch["condition"].to(device).float()

            labels = batch.get("label")
            y = None
            if labels is not None:
                y = (labels.to(device).long()
                     if labels.ndim == 1 else labels.to(device).argmax(dim=-1).long())
                if y.min() >= 1 and y.max() == cfg.condition_dim:
                    y = y - 1

            # --- symmetric observation margins (same as train) ---
            margin_low = float(getattr(cfg, "boundary_margin_low", 0.92))
            ecg_margin = float(getattr(cfg, "ecg_margin", 0.98))

            # Clamp real before the discriminator / metrics
            real_low_obs = sig_low.clamp(-margin_low,  margin_low)
            real_ecg_obs = sig_ecg.clamp(-ecg_margin,   ecg_margin)

            # Make fake, then clamp before the discriminator / metrics
            z = torch.randn(sig_low.size(0), cfg.z_dim, device=device)
            fake_low, fake_ecg = G(z, cond_low)
            fake_low_obs = fake_low.clamp(-margin_low, margin_low)
            fake_ecg_obs = fake_ecg.clamp(-ecg_margin,  ecg_margin)

            # -------------------- D loss (hinge, no R1 in val) --------------------
            real_logits, real_aux, _ = D(real_low_obs, real_ecg_obs, cond_low)
            fake_logits, fake_aux, _ = D(fake_low_obs, fake_ecg_obs, cond_low)

            adv_d = d_hinge(real_logits, fake_logits)
            aux_d = ce_optional(real_aux, y, criterion_aux, device) + ce_optional(fake_aux, y, criterion_aux, device)
            loss_d = cfg.lambda_adv * adv_d + cfg.lambda_tc * aux_d
            val_loss_d += loss_d.item()

            # -------------------- G loss (hinge + FM) --------------------
            gen_logits, gen_aux, fake_feat = D(fake_low_obs, fake_ecg_obs, cond_low)
            adv_g = g_hinge(gen_logits)
            aux_g = ce_optional(gen_aux, y, criterion_aux, device)

            # Feature matching uses clamped real features
            if float(getattr(cfg, "lambda_fm", 0.0)) > 0.0:
                real_feat = D.extract_features(real_low_obs, real_ecg_obs, cond_low)
                if isinstance(real_feat, tuple):
                    real_feat = real_feat[0]
                if isinstance(fake_feat, tuple):
                    fake_feat = fake_feat[0]
                fm_loss = F.l1_loss(fake_feat.mean(dim=0), real_feat.mean(dim=0))
            else:
                fm_loss = torch.zeros((), device=device)

            loss_g = cfg.lambda_adv * adv_g + cfg.lambda_tc * aux_g + float(getattr(cfg, "lambda_fm", 0.0)) * fm_loss
            val_loss_g += loss_g.item()

    n = len(data_loader)
    return val_loss_g / n, val_loss_d / n

In [None]:
fold_dir = Path(cfg.data_root) / cfg.fold
fold_dir.mkdir(parents=True, exist_ok=True)
print("Using fold_dir:", fold_dir)

required = ["train_X_low.npy", "train_X_ecg.npy", "train_m1_seq.npy"]
missing = [f for f in required if not (fold_dir / f).exists()]
if missing:
    raise FileNotFoundError(f"Missing in {fold_dir}: {missing}")

In [None]:
# # Load RAW train arrays (the same files your dataset uses)
# X_low = np.load(fold_dir / f"{cfg.train_split}_X_low.npy").astype(np.float32)   # (N, T_low, 2) [EDA, RESP]
# X_ecg = np.load(fold_dir / f"{cfg.train_split}_X_ecg.npy").astype(np.float32)   # (N, T_ecg, 1)

# # Compute per-channel mean/std across N and T
# mu_low = X_low.reshape(-1, X_low.shape[-1]).mean(axis=0).astype(np.float32)  # (2,)
# sd_low = X_low.reshape(-1, X_low.shape[-1]).std (axis=0).astype(np.float32)  # (2,)
# mu_ecg = X_ecg.reshape(-1, X_ecg.shape[-1]).mean(axis=0).astype(np.float32)  # (1,)
# sd_ecg = X_ecg.reshape(-1, X_ecg.shape[-1]).std (axis=0).astype(np.float32)  # (1,)

# # Optional: print to review
# print("Train stats (LOW)  mean:", mu_low.tolist(), "std:", sd_low.tolist())
# print("Train stats (ECG)  mean:", mu_ecg.tolist(), "std:", sd_ecg.tolist())

# # Save in the exact files your code expects for de-normalization
# np.savez(fold_dir / "norm_low.npz", mean=mu_low, std=sd_low, channels=np.array(["EDA","RESP"]))
# np.savez(fold_dir / "norm_ecg.npz", mean=mu_ecg, std=sd_ecg, channels=np.array(["ECG"]))
# print("Saved:", fold_dir / "norm_low.npz", "and", fold_dir / "norm_ecg.npz")

In [None]:

# --- Data loaders ---
device = torch.device(cfg.device)
# --- after saving or editing norm_*.npz ---
for name in ["train_loader", "val_loader"]:
    if name in globals():
        del globals()[name]
import gc; gc.collect()

train_loader = make_loader(
    root_dir=cfg.data_root, fold=cfg.fold, split=cfg.train_split,
    window_size_low=cfg.seq_length_low, batch_size=cfg.batch_size,
    shuffle=True, num_workers=cfg.num_workers, weighted_sampling=cfg.weighted_sampling,
    condition_dim=cfg.condition_dim,
    augment=True,
    normalize=True, normalize_ecg=True,
    stats_low_path=str(fold_dir / "norm_low.npz"),
    stats_ecg_path=str(fold_dir / "norm_ecg.npz"),
    expected_ecg_len=cfg.seq_length_ecg,
    force_use_stats=True,
    use_split_stats_if_needed=False,
)

val_loader = make_loader(
    root_dir=cfg.data_root, fold=cfg.fold, split=cfg.val_split,
    window_size_low=cfg.seq_length_low, batch_size=cfg.batch_size,
    shuffle=False, num_workers=cfg.num_workers, condition_dim=cfg.condition_dim,
    augment=False,
    normalize=True, normalize_ecg=True,
    stats_low_path=str(fold_dir / "norm_low.npz"),
    stats_ecg_path=str(fold_dir / "norm_ecg.npz"),
    expected_ecg_len=cfg.seq_length_ecg,
    force_use_stats=True,
    use_split_stats_if_needed=False,
)

print("Train batches:", len(train_loader), "Val batches:", len(val_loader))

# Quick sanity: shapes should match your cfg lengths
b = next(iter(val_loader))
assert b["signal_low"].shape[1] == cfg.seq_length_low and b["signal_ecg"].shape[1] == cfg.seq_length_ecg, (
    f"Expected low/ecg lengths {cfg.seq_length_low}/{cfg.seq_length_ecg}; got "
    f"{b['signal_low'].shape[1]}/{b['signal_ecg'].shape[1]}"
)
print("[sanity] val low mean/std:", b["signal_low"].mean().item(), b["signal_low"].std().item())
print("[sanity] val ecg mean/std:", b["signal_ecg"].mean().item(),  b["signal_ecg"].std().item())
del b

# Global val stats (optional)
sum_low = sumsq_low = count_low = 0.0
sum_ecg = sumsq_ecg = count_ecg = 0.0
for batch in val_loader:
    x_low = batch["signal_low"].double()
    x_ecg = batch["signal_ecg"].double()
    sum_low   += x_low.sum().item()
    sumsq_low += (x_low * x_low).sum().item()
    count_low += x_low.numel()
    sum_ecg   += x_ecg.sum().item()
    sumsq_ecg += (x_ecg * x_ecg).sum().item()
    count_ecg += x_ecg.numel()

mean_low = sum_low / count_low
std_low  = math.sqrt(max(sumsq_low / count_low - mean_low * mean_low, 0.0))
mean_ecg = sum_ecg / count_ecg
std_ecg  = math.sqrt(max(sumsq_ecg / count_ecg - mean_ecg * mean_ecg, 0.0))
print(f"[val GLOBAL] low mean/std ~ {mean_low:.4f} / {std_low:.4f}")
print(f"[val GLOBAL] ecg mean/std ~ {mean_ecg:.4f} / {std_ecg:.4f}")

In [None]:
import numpy as np, hashlib, pathlib

fold_dir = Path(cfg.data_root) / cfg.fold
low_stats_path = str(fold_dir / "norm_low.npz")
ecg_stats_path = str(fold_dir / "norm_ecg.npz")

# 1) Confirm paths
print("Using:", low_stats_path)
print("Using:", ecg_stats_path)

# 2) Print the actual arrays we used for the manual transform (and shapes/dtypes)
sl = np.load(low_stats_path)
se = np.load(ecg_stats_path)
print("low mean:", sl["mean"], " std:", sl["std"], " dtype:", sl["mean"].dtype, sl["std"].dtype)
print("ecg mean:", se["mean"], " std:", se["std"], " dtype:", se["mean"].dtype, se["std"].dtype)

# 3) md5 so we can compare to the one you logged earlier
def md5(p): 
    h=hashlib.md5(); h.update(pathlib.Path(p).read_bytes()); return h.hexdigest()
print("low npz md5:", md5(low_stats_path))
print("ecg npz md5:", md5(ecg_stats_path))

# --- 1) Build two loaders over the same split:
#     a) RAW (no normalization, no aug)
raw_loader = make_loader(
    root_dir=cfg.data_root, fold=cfg.fold, split=cfg.train_split,
    window_size_low=cfg.seq_length_low, batch_size=cfg.batch_size,
    shuffle=False, num_workers=0, condition_dim=cfg.condition_dim,
    augment=False,
    normalize=False, normalize_ecg=False,
    expected_ecg_len=cfg.seq_length_ecg,
)

#     b) NORMED by dataset (use same .npz files, no aug)
norm_loader = make_loader(
    root_dir=cfg.data_root, fold=cfg.fold, split=cfg.train_split,
    window_size_low=cfg.seq_length_low, batch_size=cfg.batch_size,
    shuffle=False, num_workers=0, condition_dim=cfg.condition_dim,
    augment=False,
    normalize=True, normalize_ecg=True,
    stats_low_path=low_stats_path,
    stats_ecg_path=ecg_stats_path,
    expected_ecg_len=cfg.seq_length_ecg,
    force_use_stats=True,
    use_split_stats_if_needed=False,
)

# --- 2) Pull one batch from each
b_raw  = next(iter(raw_loader))
b_norm = next(iter(norm_loader))

x_raw_low = b_raw["signal_low"].to(device).float()     # (B, T_low, 2)
x_raw_ecg = b_raw["signal_ecg"].to(device).float()     # (B, T_ecg, 1)

x_norm_low = b_norm["signal_low"].to(device).float()
x_norm_ecg = b_norm["signal_ecg"].to(device).float()

# --- 3) Manual normalization using the same .npz stats
stats_low = np.load(low_stats_path)
mu_low = torch.as_tensor(stats_low["mean"], device=device, dtype=x_raw_low.dtype).view(1,1,-1)
sd_low = torch.as_tensor(stats_low["std"],  device=device, dtype=x_raw_low.dtype).view(1,1,-1)

stats_ecg = np.load(ecg_stats_path)
mu_ecg = torch.as_tensor(stats_ecg["mean"], device=device, dtype=x_raw_ecg.dtype).view(1,1,1)
sd_ecg = torch.as_tensor(stats_ecg["std"],  device=device, dtype=x_raw_ecg.dtype).view(1,1,1)

x_manual_low = (x_raw_low - mu_low) / sd_low
x_manual_ecg = (x_raw_ecg - mu_ecg) / sd_ecg

# --- 4) Report means/stds and diffs
def ms(x):
    m = x.mean(dim=(0,1))
    s = x.std(dim=(0,1), unbiased=False)
    return m.tolist(), s.tolist()

m_raw_low,    s_raw_low    = ms(x_raw_low)
m_norm_low,   s_norm_low   = ms(x_norm_low)
m_manual_low, s_manual_low = ms(x_manual_low)

m_raw_ecg,    s_raw_ecg    = x_raw_ecg.mean().item(),  x_raw_ecg.std(unbiased=False).item()
m_norm_ecg,   s_norm_ecg   = x_norm_ecg.mean().item(), x_norm_ecg.std(unbiased=False).item()
m_manual_ecg, s_manual_ecg = x_manual_ecg.mean().item(), x_manual_ecg.std(unbiased=False).item()

diff_low   = (x_manual_low - x_norm_low).abs()
diff_ecg   = (x_manual_ecg - x_norm_ecg).abs()

print("--- LOW (EDA/RESP) ---")
print("raw     mean≈", m_raw_low,    " std≈", s_raw_low)
print("manual  mean≈", m_manual_low, " std≈", s_manual_low)
print("loader  mean≈", m_norm_low,   " std≈", s_norm_low)
print("Δ(low)  max_abs≈", diff_low.max().item(),
      " per-chan max_abs≈", diff_low.flatten(0,1).max(dim=0).values.tolist())

print("\n--- ECG ---")
print(f"raw     mean≈ {m_raw_ecg:.4f}  std≈ {s_raw_ecg:.4f}")
print(f"manual  mean≈ {m_manual_ecg:.4f}  std≈ {s_manual_ecg:.4f}")
print(f"loader  mean≈ {m_norm_ecg:.4f}  std≈ {s_norm_ecg:.4f}")
print("Δ(ecg)  max_abs≈", diff_ecg.max().item())

In [None]:
def per_channel_low_stats(loader):
    sum_c = torch.zeros(2, dtype=torch.float64)
    sumsq_c = torch.zeros(2, dtype=torch.float64)
    cnt = 0
    for b in loader:
        x = b["signal_low"].double()  # (B,T,2)
        sum_c   += x.sum(dim=(0,1))
        sumsq_c += (x*x).sum(dim=(0,1))
        cnt     += x.shape[0]*x.shape[1]
    mu = (sum_c / cnt).float()
    var = (sumsq_c / cnt).float() - mu*mu
    std = var.clamp_min(0).sqrt()
    return mu.tolist(), std.tolist()

mu_c, sd_c = per_channel_low_stats(val_loader)
print("Per‑channel low mean:", mu_c, "std:", sd_c)

In [None]:
b = next(iter(val_loader))
print("Shapes:",
      "low", tuple(b["signal_low"].shape),
      "ecg", tuple(b["signal_ecg"].shape),
      "cond", tuple(b["condition"].shape))

print("[sanity] batch low mean/std:",
      b["signal_low"].mean().item(), b["signal_low"].std().item())
print("[sanity] batch ecg mean/std:",
      b["signal_ecg"].mean().item(), b["signal_ecg"].std().item())

In [None]:
import math

def global_mean_std(loader):
    sum_low = sumsq_low = 0.0
    sum_ecg = sumsq_ecg = 0.0
    n_low = n_ecg = 0

    for batch in loader:
        x_low = batch["signal_low"].double()
        x_ecg = batch["signal_ecg"].double()
        sum_low   += x_low.sum().item()
        sumsq_low += (x_low * x_low).sum().item()
        n_low     += x_low.numel()
        sum_ecg   += x_ecg.sum().item()
        sumsq_ecg += (x_ecg * x_ecg).sum().item()
        n_ecg     += x_ecg.numel()

    mu_low = sum_low / n_low
    sd_low = math.sqrt(max(sumsq_low / n_low - mu_low * mu_low, 0.0))
    mu_ecg = sum_ecg / n_ecg
    sd_ecg = math.sqrt(max(sumsq_ecg / n_ecg - mu_ecg * mu_ecg, 0.0))
    return mu_low, sd_low, mu_ecg, sd_ecg

mu_l, sd_l, mu_e, sd_e = global_mean_std(val_loader)
print(f"[val GLOBAL] low mean/std ~ {mu_l:.4f} / {sd_l:.4f}")
print(f"[val GLOBAL] ecg mean/std ~ {mu_e:.4f} / {sd_e:.4f}")

In [None]:
margin_low = float(getattr(cfg, "boundary_margin_low", 0.92))
ecg_margin = float(getattr(cfg, "ecg_margin", 0.98))

def clamp_fraction(loader, margin_low, ecg_margin, n_batches=5):
    frac_low = []; frac_ecg = []
    k = 0
    for b in loader:
        x_low = b["signal_low"]    # (B,T,2)
        x_ecg = b["signal_ecg"]    # (B,T,1)

        f_low = (x_low.abs() >= margin_low).float().mean().item()
        f_ecg = (x_ecg.abs() >= ecg_margin).float().mean().item()

        frac_low.append(f_low); frac_ecg.append(f_ecg)
        k += 1
        if k >= n_batches:
            break
    return np.mean(frac_low), np.mean(frac_ecg)

fr_low, fr_ecg = clamp_fraction(val_loader, margin_low, ecg_margin)
print(f"[clamp fraction] low≈{fr_low:.3f}, ecg≈{fr_ecg:.3f}")

In [None]:
st = np.load(fold_dir / "norm_low.npz")
mu, sd = st["mean"].copy(), st["std"].copy()
mu_sw, sd_sw = mu[[1,0]], sd[[1,0]]
np.savez(fold_dir / "norm_low.npz", mean=mu_sw, std=sd_sw, channels=np.array(["EDA","RESP"]))
print("Swapped channels in norm_low.npz; rebuild loaders and re-check.")

In [None]:

# --- Model init ---
G, D = create_tc_multigan(cfg)
G.to(device); D.to(device)

# Safety freeze for fixed smoothing convs (even if model file wasn't edited)
def _freeze_if_exists(module, attr):
    if hasattr(module, attr):
        m = getattr(module, attr)
        for p in m.parameters():
            p.requires_grad_(False)

for m in getattr(G, "ecg_up", []):
    if hasattr(m, "smooth"):
        for p in m.smooth.parameters():
            p.requires_grad_(False)

for name in ("low_up1", "low_up2"):
    if hasattr(G, name):
        _freeze_if_exists(getattr(G, name), "smooth")

_freeze_if_exists(G, "final_smooth")
print("[init] Fixed smoothing kernels set to requires_grad=False (safety freeze).")

# --- EMA (optional) ---
use_ema = bool(getattr(cfg, "use_ema", False))
ema = EMA(G, decay=float(getattr(cfg, "ema_decay", 0.999))) if use_ema else None

# --- Losses / optimizers / schedulers ---
criterion_aux = nn.CrossEntropyLoss()

opt_g = optim.Adam(G.parameters(), lr=cfg.lr_g, betas=(0.5, 0.999))
opt_d = optim.Adam(D.parameters(), lr=cfg.lr_d, betas=(0.5, 0.999))

for opt in (opt_g, opt_d):
    for pg in opt.param_groups:
        if 'initial_lr' not in pg:
            pg['initial_lr'] = pg['lr']

sched_g = CosineAnnealingLR(opt_g, T_max=cfg.epochs_gan, last_epoch=-1)
sched_d = CosineAnnealingLR(opt_d, T_max=cfg.epochs_gan, last_epoch=-1)

# --- Dirs ---
os.makedirs(cfg.ckpt_dir, exist_ok=True)
os.makedirs(cfg.sample_dir, exist_ok=True)
os.makedirs(cfg.log_dir, exist_ok=True)

# --- Fixed noise / cond for sampling ---
n_samples = cfg.sample_n
g = torch.Generator(device=device).manual_seed(cfg.seed)
fixed_noise = torch.randn(n_samples, cfg.z_dim, generator=g, device=device)

# Build deterministic condition sequences from present classes (or all K)
present = None
try:
    ds = train_loader.dataset
    if getattr(ds, "cond_np", None) is not None:
        present = np.unique(ds.cond_np)
        if present.min() == 1 and present.max() == cfg.condition_dim:
            present = present - 1  # map 1..K -> 0..K-1 if needed
        present = [int(c) for c in present if 0 <= c < cfg.condition_dim]
except Exception:
    present = None
if not present:
    present = list(range(cfg.condition_dim))

fixed_cond_low = torch.zeros(n_samples, cfg.seq_length_low, cfg.condition_dim, device=device)
for i in range(n_samples):
    c = present[i % len(present)]
    fixed_cond_low[i, :, c] = 1.0

print("[init] Model/opt/sched/EMA ready.")

In [None]:
import os, time, json
import torch

best_val_loss = float('inf')

# Sensible defaults if not present in cfg
if not hasattr(cfg, "eval_interval"):    cfg.eval_interval    = 5
if not hasattr(cfg, "eval_n_batches"):   cfg.eval_n_batches   = 8
if not hasattr(cfg, "eval_bootstrap"):   cfg.eval_bootstrap   = 0      # set 200 for CI95
if not hasattr(cfg, "eval_channels"):    cfg.eval_channels    = ("eda","resp","ecg")
if not hasattr(cfg, "eval_viz_n"):       cfg.eval_viz_n       = 4
if not hasattr(cfg, "eval_use_clamp"):   cfg.eval_use_clamp   = True   # match D's observed view
if not hasattr(cfg, "eval_standardize"): cfg.eval_standardize = False  # compare in native units
if not hasattr(cfg, "probe_n_train"):    cfg.probe_n_train    = 2000
if not hasattr(cfg, "probe_n_val"):      cfg.probe_n_val      = 1000

epoch_times = []

for epoch in range(1, cfg.epochs_gan + 1):
    t0 = time.time()

    # ---- Train ----
    loss_g, loss_d = train_one_epoch(
        G, D, train_loader, opt_g, opt_d,
        device, cfg, criterion_aux, epoch,
        ema=ema, use_ema=bool(cfg.use_ema)
    )

    # ---- Validate (with EMA symmetry if enabled) ----
    if bool(cfg.use_ema) and (ema is not None):
        ema.apply_to(G)
    # NOTE: validate signature: (..., _criterion_adv_unused, criterion_aux)
    val_loss_g, val_loss_d = validate(G, D, val_loader, device, cfg, criterion_aux)
    if bool(cfg.use_ema) and (ema is not None):
        ema.restore(G)

    print(f"Epoch {epoch:03d} | Train G: {loss_g:.4f}, D: {loss_d:.4f} | "
          f"Val G: {val_loss_g:.4f}, D: {val_loss_d:.4f}")

    # ---- Distribution + Label-Structure Evaluation (on schedule) ----
    if (epoch % int(cfg.eval_interval) == 0) or (epoch == cfg.epochs_gan):
        try:
            summary = run_epoch_evaluations(
                G=G, D=D,
                train_loader=train_loader,
                val_loader=val_loader,
                cfg=cfg, device=device,
                epoch=epoch, ema=ema
            )

            # compact headline for logs
            parts = []
            for k in ("ks_eda", "w1_ecg", "js_resp",
                    "probe_acc_real", "probe_acc_fake",
                    "probe_bacc_real", "probe_bacc_fake",
                    "probe_f1_real", "probe_f1_fake"):
                v = summary.get(k)
                if v is not None:
                    parts.append(f"{k}={v:.3f}")
            if parts:
                print("[eval] " + " | ".join(parts))
        except Exception as e:
            print(f"[eval] Skipped this epoch due to error: {e}")

    # ---- Sampling (fixed grid) ----
    if epoch % int(cfg.sample_interval) == 0:
        if bool(cfg.use_ema) and (ema is not None):
            ema.apply_to(G)
        generate_and_save_samples(G, fixed_noise, fixed_cond_low, cfg, epoch)
        if bool(cfg.use_ema) and (ema is not None):
            ema.restore(G)

    # ---- Schedulers ----
    sched_g.step()
    sched_d.step()

    # ---- Checkpointing ----
    if (epoch % int(cfg.ckpt_interval) == 0) or (val_loss_d < best_val_loss):
        checkpoint = {
            'epoch': epoch,
            'G_state_dict': G.state_dict(),
            'D_state_dict': D.state_dict(),
            'opt_g_state_dict': opt_g.state_dict(),
            'opt_d_state_dict': opt_d.state_dict(),
            'ema_state_dict': (ema.shadow if (ema is not None) else None),
        }
        ckpt_path = os.path.join(cfg.ckpt_dir, f"ckpt_epoch_{epoch:03d}.pt")
        torch.save(checkpoint, ckpt_path)
        best_val_loss = min(best_val_loss, val_loss_d)

    # ---- Epoch timing (optional) ----
    epoch_times.append(time.time() - t0)
