In [11]:

# ============================================================
# Double Pendulum Metrics Script (robust & shape-safe)
# - per_frame_mae: safe MAE that aligns C,H,W and handles masks
# - summarize_x1e3: mean/std/sem × 1000 reporting
# - Optional: infer G from checkpoint, decode & resize
# Author: M365 Copilot (Esteban-tailored)
# ============================================================

import math
import torch
import torch.nn as nn
import torch.nn.functional as F

# -------------------------------
# Diagnostics & small utilities
# -------------------------------

def describe_tensor(name: str, t: torch.Tensor):
    if t is None:
        print(f"{name}: None")
        return
    shp = tuple(t.shape)
    dev = t.device
    dt  = t.dtype
    msg = f"{name}: shape={shp}, dtype={dt}, device={dev}"
    try:
        tmin = t.min().item()
        tmax = t.max().item()
        msg += f", min={tmin:.4g}, max={tmax:.4g}"
    except Exception:
        pass
    print(msg)

def ensure_channel_dim(x: torch.Tensor) -> torch.Tensor:
    """Ensure tensor is [S,T,1,H,W] if it was [S,T,H,W]."""
    if x.dim() == 4:
        return x.unsqueeze(2)
    return x

# ----------------------------------------
# Metrics: MAE per-frame (shape-safe)
# ----------------------------------------

def per_frame_mae(
    xh: torch.Tensor,
    xg: torch.Tensor,
    mask: torch.Tensor = None,
    resize_mode: str = 'bilinear',
    align_corners: bool = False,
    flatten: bool = True
) -> torch.Tensor:
    """
    Compute per-frame MAE between predicted xh and ground-truth xg.
    Inputs:
      xh, xg: [S,T,C,H,W] or [S,T,H,W] tensors. We will enforce [S,T,C,H,W].
      mask: optional boolean [S,T]; selects valid frames. If empty, returns empty tensor (not None).
      resize_mode: interpolation mode if H×W mismatch; 'bilinear'|'bicubic'|'nearest'.
      align_corners: passed to interpolate for bilinear/bicubic.
      flatten: if True -> return vector [N_masked]; else -> [S,T] (masked frames removed or zeroed).

    Returns:
      Tensor: per-frame MAE either [S,T] or [N]. Never returns None.
    """
    # Normalize shapes to [S,T,1,H,W]
    assert xh.dim() in (4, 5) and xg.dim() in (4, 5), "Expected 4D/5D tensors"
    xh = ensure_channel_dim(xh)
    xg = ensure_channel_dim(xg)

    # Device/dtype alignment
    xh = xh.to(device=xg.device, dtype=xg.dtype)

    # Channel alignment (common case: C=1). If mismatch, try expanding singleton.
    if xh.size(2) != xg.size(2):
        if xh.size(2) == 1 and xg.size(2) > 1:
            xh = xh.expand(-1, -1, xg.size(2), -1, -1)
        elif xg.size(2) == 1 and xh.size(2) > 1:
            xg = xg.expand(-1, -1, xh.size(2), -1, -1)
        else:
            raise RuntimeError(f"Channel mismatch: xh C={xh.size(2)} vs xg C={xg.size(2)}")

    # Spatial alignment: resize xh to xg geometry (safe for metrics)
    Hg, Wg = xg.size(-2), xg.size(-1)
    Hh, Wh = xh.size(-2), xh.size(-1)
    if (Hh, Wh) != (Hg, Wg):
        xh = F.interpolate(
            xh.reshape(-1, xh.size(2), Hh, Wh),
            size=(Hg, Wg),
            mode=resize_mode,
            align_corners=(align_corners if resize_mode in ('bilinear', 'bicubic') else None)
        ).reshape(xg.size(0), xg.size(1), xg.size(2), Hg, Wg)

    # MAE over C,H,W -> [S,T]
    diff = (xh - xg).abs().mean(dim=(2, 3, 4))

    # Masking
    if mask is not None:
        if mask.dtype != torch.bool:
            mask = mask.bool()
        if mask.shape != diff.shape:
            raise RuntimeError(f"Mask shape mismatch: mask {tuple(mask.shape)} vs diff {tuple(diff.shape)}")
        diff_masked = diff[mask]
        if diff_masked.numel() == 0:
            # Return empty tensor (caller-friendly)
            return diff_masked
        return diff_masked if flatten else diff * mask

    return diff.reshape(-1) if flatten else diff

# ----------------------------------------
# Reporting: mean/std/sem × 1000
# ----------------------------------------

def summarize_x1e3(vals: torch.Tensor):
    """
    Return (mean, std, sem) scaled by 1000 for reporting.
    Handles None and empty tensors gracefully.
    """
    if vals is None:
        return 0.0, 0.0, 0.0
    if not torch.is_tensor(vals):
        vals = torch.as_tensor(vals)
    n = vals.numel()
    if n == 0:
        return 0.0, 0.0, 0.0
    mean = vals.mean().item()
    std  = vals.std(unbiased=True).item() if n > 1 else 0.0
    sem  = (std / (n ** 0.5)) if n > 1 else 0.0
    return mean * 1000.0, std * 1000.0, sem * 1000.0

# ----------------------------------------
# Optional: build G from checkpoint safely
# ----------------------------------------

def infer_generator_from_checkpoint(state: dict):
    """
    Infer a 4-layer MLP generator G from a state dict with keys:
      'first_layer.weight'   # [h1, d]
      'second_layer.weight'  # [h2, h1]
      'third_layer.weight'   # [h3, h2]
      'fourth_layer.weight'  # [x_dim, h3]
    Returns:
      (G_model, latent_dim, x_dim, H_ckpt, W_ckpt)
    """
    w1 = state['first_layer.weight']   # [h1, d]
    w2 = state['second_layer.weight']  # [h2, h1]
    w3 = state['third_layer.weight']   # [h3, h2]
    w4 = state['fourth_layer.weight']  # [x_dim, h3]

    latent_dim = w1.shape[1]
    h1 = w1.shape[0]
    h2 = w2.shape[0]
    h3 = w3.shape[0]
    x_dim = w4.shape[0]

    class G(nn.Module):
        def __init__(self, d, x, h1, h2, h3):
            super().__init__()
            self.first_layer  = nn.Linear(d,  h1)
            self.second_layer = nn.Linear(h1, h2)
            self.third_layer  = nn.Linear(h2, h3)
            self.fourth_layer = nn.Linear(h3, x)

        def forward(self, z):
            x = torch.tanh(self.first_layer(z))
            x = torch.tanh(self.second_layer(x))
            x = torch.tanh(self.third_layer(x))
            x = self.fourth_layer(x)  # [*, x_dim]
            return x

    G_model = G(latent_dim, x_dim, h1, h2, h3)
    G_model.load_state_dict(state)
    G_model.eval()

    # Infer square image size if possible
    r = int(math.isqrt(x_dim))
    H_ckpt, W_ckpt = (r, r) if r * r == x_dim else (None, None)

    return G_model, latent_dim, x_dim, H_ckpt, W_ckpt

def decode_and_resize(
    y_flat: torch.Tensor,
    H_src: int, W_src: int,
    H_dst: int, W_dst: int,
    mode: str = 'bilinear'
) -> torch.Tensor:
    """
    Reshape flat outputs [..., x_dim] to image [..., 1, H_src, W_src],
    then resize to [..., 1, H_dst, W_dst].
    """
    if H_src is None or W_src is None:
        raise RuntimeError(f"Cannot reshape: x_dim={y_flat.shape[-1]} is not a perfect square.")
    y_img = y_flat.view(*y_flat.shape[:-1], 1, H_src, W_src)
    if (H_src, W_src) != (H_dst, W_dst):
        y_img = F.interpolate(
            y_img.reshape(-1, 1, H_src, W_src),
            size=(H_dst, W_dst),
            mode=mode,
            align_corners=False if mode in ('bilinear','bicubic') else None
        ).reshape(*y_flat.shape[:-1], 1, H_dst, W_dst)
    return y_img

# ----------------------------------------
# Optional: sigmoid mapping g(Z) (calib)
# ----------------------------------------

def sigmoid(u: torch.Tensor) -> torch.Tensor:
    return 1.0 / (1.0 + torch.exp(-u))

def g_of_Z(
    Z: torch.Tensor,
    baseline_k: float = 147_000.0,  # paper zero baseline
    amplitude_k: float = 300_000.0, # placeholder; fit to your data
    k: float = 1.0,
    Z0: float = 0.0
) -> torch.Tensor:
    """
    counts(k) = baseline + amplitude * σ(k * (Z - Z0))
    Ensure Z is in the expected domain before calling (raw vs normalized).
    """
    s = sigmoid(k * (Z - Z0))
    return baseline_k + amplitude_k * s

# ============================================================
# Config & Execution (edit section below to fit your run)
# ============================================================

# Assumptions:
# - You already have X_test and mask in memory:
#   X_test: [S,T,H,W] or [S,T,1,H,W]
#   mask: [S,T] boolean or None
#
# - PRED_X_TEST points to a torch-saved tensor of predictions (Saved DI):
#   shape could be [S,T,H,W] or [S,T,1,H,W]
#
# - CKPT_PATH is optional; set it if you want to recompute from checkpoint.

# ---------- EDIT THESE ----------
PRED_X_TEST = 'checkpoints/pixel_double_pendulum/di/baseline_x_test.pkl'  # e.g., './pred_x_test.pt'
CKPT_PATH   = None                      # e.g., './generator_ckpt.pt' or leave None
RECOMPUTE_FROM_CHECKPOINT = False       # set True to decode with G and resize
# ---------------------------------

# Sanity: X_test must exist
try:
    X_test
except NameError:
    raise RuntimeError("X_test is not defined. Please set X_test before running this script.")

# Normalize X_test shape & get geometry
X_test = ensure_channel_dim(X_test)
S, T, C, H_data, W_data = X_test.shape
print(f"[data] S={S}, T={T}, C={C}, H×W={H_data}×{W_data}")
describe_tensor("X_test", X_test)

# Validate mask
if 'mask' in globals() and mask is not None:
    assert mask.dtype == torch.bool, f"mask must be boolean, got {mask.dtype}"
    assert tuple(mask.shape) == (S, T), f"mask shape {tuple(mask.shape)} must be {(S,T)}"
    print(f"[mask] selected {mask.sum().item()} of {mask.numel()} frames")
else:
    mask = None
    print("[mask] None (metrics computed over all frames)")

# ------------------------------
# 1) Zero baseline (raw scale)
# ------------------------------
zero = torch.zeros_like(X_test)
mae_zero_raw = per_frame_mae(zero, X_test, mask=mask, flatten=True)
z_mean_raw, z_std_raw, z_sem_raw = summarize_x1e3(mae_zero_raw)
print(f"[Zero | raw] mean={z_mean_raw:.3f} ×10^3, std={z_std_raw:.3f}, sem={z_sem_raw:.3f}")

# -------------------------------------------------
# 2) Saved DI predictions loaded & aligned to data
# -------------------------------------------------
try:
    X_hat_saved = torch.load(PRED_X_TEST, map_location=X_test.device)
except FileNotFoundError:
    raise RuntimeError(f"PRED_X_TEST not found: {PRED_X_TEST}")

X_hat_saved = ensure_channel_dim(X_hat_saved)  # -> [S,T,1,H,W] if needed

# If S,T mismatch, try to reshape if flat batch was saved; else raise with guidance.
if X_hat_saved.shape[:2] != (S, T):
    raise RuntimeError(
        f"Saved predictions shape {tuple(X_hat_saved.shape[:2])} != data {(S,T)}.\n"
        f"Please ensure predictions were saved in [S,T,...] order and match the dataset."
    )

# Compute MAE (function will auto-resize H×W of xh to match xg)
mae_saved_raw = per_frame_mae(X_hat_saved, X_test, mask=mask, flatten=True)
s_mean_raw, s_std_raw, s_sem_raw = summarize_x1e3(mae_saved_raw)
print(f"[Saved DI | raw] mean={s_mean_raw:.3f} ×10^3, std={s_std_raw:.3f}, sem={s_sem_raw:.3f}")

# ------------------------------------------------------
# 3) Optional: recompute predictions from checkpoint G
# ------------------------------------------------------
if RECOMPUTE_FROM_CHECKPOINT:
    if CKPT_PATH is None:
        raise RuntimeError("RECOMPUTE_FROM_CHECKPOINT=True but CKPT_PATH is None.")
    # Load state
    ckpt_obj = torch.load(CKPT_PATH, map_location=X_test.device)
    state = ckpt_obj['state_dict'] if isinstance(ckpt_obj, dict) and 'state_dict' in ckpt_obj else ckpt_obj

    # Build generator G from checkpoint shapes
    G_model, latent_dim, x_dim, H_ckpt, W_ckpt = infer_generator_from_checkpoint(state)
    print(f"[ckpt] latent_dim={latent_dim}, x_dim={x_dim}, image={H_ckpt}×{W_ckpt}")

    # Construct latent Z for S×T frames (you may want to load real Z instead)
    # Here we use a placeholder standard normal—replace with your actual latent sequence.
    Z = torch.randn(S*T, latent_dim, device=X_test.device)

    # Generate flat outputs and decode to images at checkpoint resolution
    Y_flat = G_model(Z)  # [S*T, x_dim]
    Y_img_resized = decode_and_resize(Y_flat, H_ckpt, W_ckpt, H_data, W_data)  # [S*T,1,H_data,W_data]
    Y_img_resized = Y_img_resized.view(S, T, 1, H_data, W_data)

    # MAE against X_test
    mae_recomp_raw = per_frame_mae(Y_img_resized, X_test, mask=mask, flatten=True)
    r_mean_raw, r_std_raw, r_sem_raw = summarize_x1e3(mae_recomp_raw)
    print(f"[Recomputed G(Z) | raw] mean={r_mean_raw:.3f} ×10^3, std={r_std_raw:.3f}, sem={r_sem_raw:.3f}")

# ------------------------------------------------------
# 4) Optional: calibrated sigmoid mapping g(Z)
# ------------------------------------------------------
# If you have a scalar Z per frame and want to report calibrated counts:
# Example placeholders (replace Z, amplitude_k, k, Z0 as appropriate)
if 'Z' in globals():
    Z_t = torch.as_tensor(Z, device=X_test.device, dtype=torch.float32)
    counts = g_of_Z(Z_t, baseline_k=147_000.0, amplitude_k=300_000.0, k=1.0, Z0=0.0)
    # Summarize counts (not MAE)
    c_mean = counts.mean().item()
    c_std  = counts.std(unbiased=True).item() if counts.numel() > 1 else 0.0
    c_sem  = c_std / (counts.numel() ** 0.5) if counts.numel() > 1 else 0.0
    print(f"[Recomputed g(Z) | calib+sigmoid] mean={c_mean/1000.0:.1f} ×10^3, std={c_std/1000.0:.1f}, sem={c_sem/1000.0:.1f}")
else:
    print("[info] No Z provided; skip calibrated sigmoid reporting.")

# ============================================================
# End of script
# ============================================================


[data] S=50, T=100, C=1, H×W=28×28
X_test: shape=(50, 100, 1, 28, 28), dtype=torch.float32, device=cpu, min=0, max=0.001953
[mask] None (metrics computed over all frames)
[Zero | raw] mean=0.207 ×10^3, std=0.003, sem=0.000
[Saved DI | raw] mean=0.207 ×10^3, std=0.003, sem=0.000
[info] No Z provided; skip calibrated sigmoid reporting.


In [14]:

# ============================================================
# Z_test Integration (latent for G and/or scalar for g(Z))
# Add this block under the Config section in the previous script
# ============================================================

# ---------- EDIT THESE ----------
Z_TEST_PATH = "checkpoints/pixel_double_pendulum/di/baseline_z_test.pkl"  # e.g., './Z_test.pt' (set to None if you don't have it)
Z_KIND = 'latent'   # 'latent' for generator input; 'scalar' for calibrated sigmoid
# ---------------------------------

def load_Z_test(path, device):
    if path is None:
        return None
    Z_loaded = torch.load(path, map_location=device)
    return Z_loaded

Z_test = load_Z_test(Z_TEST_PATH, X_test.device)

if Z_test is None:
    print("[info] No Z_test provided; generator recompute will use random Z, and calibrated sigmoid will be skipped unless Z exists elsewhere.")
else:
    describe_tensor("Z_test(raw)", Z_test)

# -------------------------
# Z for generator (latent)
# -------------------------
if RECOMPUTE_FROM_CHECKPOINT:
    if CKPT_PATH is None:
        raise RuntimeError("RECOMPUTE_FROM_CHECKPOINT=True but CKPT_PATH is None.")
    # Load state
    ckpt_obj = torch.load(CKPT_PATH, map_location=X_test.device)
    state = ckpt_obj['state_dict'] if isinstance(ckpt_obj, dict) and 'state_dict' in ckpt_obj else ckpt_obj

    # Build generator G from checkpoint shapes
    G_model, latent_dim, x_dim, H_ckpt, W_ckpt = infer_generator_from_checkpoint(state)
    print(f"[ckpt] latent_dim={latent_dim}, x_dim={x_dim}, image={H_ckpt}×{W_ckpt}")

    # Prepare Z for generator
    if Z_KIND == 'latent' and Z_test is not None:
        # Accept shapes: [S,T,latent_dim] or [S*T,latent_dim]
        if Z_test.dim() == 3 and Z_test.shape[:2] == (S, T):
            assert Z_test.shape[2] == latent_dim, (
                f"Z_test latent dim {Z_test.shape[2]} != checkpoint latent_dim {latent_dim}"
            )
            Z_for_G = Z_test.reshape(S*T, latent_dim).to(X_test.device, dtype=torch.float32)
        elif Z_test.dim() == 2 and Z_test.shape[0] == S*T and Z_test.shape[1] == latent_dim:
            Z_for_G = Z_test.to(X_test.device, dtype=torch.float32)
        else:
            raise RuntimeError(
                f"Z_test shape {tuple(Z_test.shape)} is not compatible with [S,T,{latent_dim}] or [S*T,{latent_dim}]"
            )
        print(f"[Z_test] Using provided latent Z with shape {tuple(Z_for_G.shape)}")
    else:
        # Fallback: random Z (warn user)
        Z_for_G = torch.randn(S*T, latent_dim, device=X_test.device)
        print("[warn] Z_test not provided or Z_KIND!='latent'; using random Z for generator recompute.")

    # Generate and resize to dataset geometry
    Y_flat = G_model(Z_for_G)  # [S*T, x_dim]
    Y_img_resized = decode_and_resize(Y_flat, H_ckpt, W_ckpt, H_data, W_data)  # [S*T,1,H_data,W_data]
    Y_img_resized = Y_img_resized.view(S, T, 1, H_data, W_data)

    # MAE against X_test
    mae_recomp_raw = per_frame_mae(Y_img_resized, X_test, mask=mask, flatten=True)
    r_mean_raw, r_std_raw, r_sem_raw = summarize_x1e3(mae_recomp_raw)
    print(f"[Recomputed G(Z_test) | raw] mean={r_mean_raw:.3f} ×10^3, std={r_std_raw:.3f}, sem={r_sem_raw:.3f}")

# -----------------------------
# Z for calibrated sigmoid g(Z)
# -----------------------------
if Z_KIND == 'scalar' and Z_test is not None:
    # Accept [S,T], [S*T], or NumPy array; convert to float tensor
    if Z_test.dim() == 2 and Z_test.shape == (S, T):
        Z_scalar = Z_test.reshape(S*T).to(X_test.device, dtype=torch.float32)
    elif Z_test.dim() == 1 and Z_test.shape[0] == S*T:
        Z_scalar = Z_test.to(X_test.device, dtype=torch.float32)
    else:
        raise RuntimeError(
            f"Z_test shape {tuple(Z_test.shape)} must be [S,T] or [S*T] for scalar g(Z)."
        )

    # TODO: ensure Z_scalar is in the expected domain (raw vs normalized)
    # For now, we pass it directly.
    counts = g_of_Z(Z_scalar, baseline_k=147_000.0, amplitude_k=300_000.0, k=1.0, Z0=0.0)
    c_mean = counts.mean().item()
    c_std  = counts.std(unbiased=True).item() if counts.numel() > 1 else 0.0
    c_sem  = c_std / (counts.numel() ** 0.5) if counts.numel() > 1 else 0.0
    print(f"[g(Z_test) | calib+sigmoid] mean={c_mean/1000.0:.1f} ×10^3, std={c_std/1000.0:.1f}, sem={c_sem/1000.0:.1f}")
elif Z_KIND == 'scalar' and Z_test is None:
    print("[info] Z_KIND='scalar' but Z_test not provided; skipping calibrated sigmoid.")


Z_test(raw): shape=(50, 100, 4), dtype=torch.float32, device=cpu, min=-3.334, max=3.609
