In [33]:

import os
import numpy as np
import torch
import torch.nn as nn

# --- Paths (edit) ---
DATA_PATH       = "data/pendulum/processed_data.pkl"  # dict with 'train' and 'test' as numpy arrays [S,T,28,28]
PRED_X_TEST     = "checkpoints/pixel_pendulum/di/di_baseline_x_test.pkl"     # [S,T,28,28]
Z_TEST_PATH     = "checkpoints/pixel_pendulum/di/di_baseline_z_test.pkl"     # [S,T,2]
G_STATE_PATH    = "checkpoints/pixel_pendulum/di/di_baseline_generative.pkl" # 2->784 MLP state_dict
EXTRAP_MASK_PATH = None  # optional mask [S,T] of booleans for extrapolation-only evaluation

H, W = 28, 28
PAPER_ZERO_L1_X1E3 = 147.0  # target zero-baseline mean L1 ×10³ from the paper (pixel pendulum)

def to_torch_img5d(np_arr):
    """Convert [S,T,H,W] numpy to torch [S,T,1,H,W] float32."""
    x = torch.from_numpy(np_arr).float()
    assert x.dim() == 4 and x.shape[2] == H and x.shape[3] == W, f"Expected [S,T,{H},{W}], got {tuple(x.shape)}"
    return x.unsqueeze(2)

def per_frame_mae(xh, xg, mask=None):
    """xh, xg: [S,T,1,H,W] -> returns [S*T_masked] per-frame MAE."""
    diff = (xh - xg).abs().view(xg.shape[0], xg.shape[1], -1).mean(dim=2)  # [S,T]
    if mask is not None:
        diff = diff[mask]
    return diff.reshape(-1)

def summarize_x1e3(vals):
    """Return (mean, std, sem) × 1000 for paper-style reporting."""
    n = vals.numel()
    mean = vals.mean().item()
    std  = vals.std(unbiased=True).item() if n > 1 else 0.0
    sem  = (std / (n ** 0.5)) if n > 1 else 0.0
    return 1000.0 * mean, 1000.0 * std, 1000.0 * sem

def print_stats(name, x):
    print(f"[{name}] min={x.min().item():.6g} max={x.max().item():.6g} mean={x.mean().item():.6g} std={x.std(unbiased=True).item():.6g}")

# --- Load data ---
data = torch.load(DATA_PATH, weights_only=False)
X_train_np = data["train"]  # numpy [S_train,T_train,28,28]
X_test_np  = data["test"]   # numpy [S_test,T_test,28,28]

X_train = to_torch_img5d(X_train_np)  # [S,T,1,28,28]
X_test  = to_torch_img5d(X_test_np)   # [S,T,1,28,28]
S, T = X_test.shape[:2]

mask = None
if EXTRAP_MASK_PATH and os.path.isfile(EXTRAP_MASK_PATH):
    m = torch.load(EXTRAP_MASK_PATH, weights_only=False)
    if isinstance(m, np.ndarray):
        m = torch.from_numpy(m.astype(np.bool_))
    assert m.shape == (S, T), f"Mask must be [S,T], got {tuple(m.shape)}"
    mask = m.bool()

# --- Inspect raw scale ---
print("=== SCALE INSPECTION (raw) ===")
print_stats("train(raw)", X_train)
print_stats(" test(raw)", X_test)

# --- Compute zero baseline in raw scale ---
zero = torch.zeros_like(X_test)
mae_zero_raw = per_frame_mae(zero, X_test, mask=mask)
z_mean_raw, z_std_raw, z_sem_raw = summarize_x1e3(mae_zero_raw)
print(f"[Zero | raw] mean={z_mean_raw:.3f} ×10^3, std={z_std_raw:.3f}, sem={z_sem_raw:.3f}")

# --- Evaluate saved DI preds in raw scale ---
X_hat_saved = torch.load(PRED_X_TEST, weights_only=False)  # [S,T,28,28]
X_hat_saved = X_hat_saved.unsqueeze(2)  # -> [S,T,1,28,28]
mae_saved_raw = per_frame_mae(X_hat_saved, X_test, mask=mask)
s_mean_raw, s_std_raw, s_sem_raw = summarize_x1e3(mae_saved_raw)
print(f"[Saved DI | raw] mean={s_mean_raw:.3f} ×10^3, std={s_std_raw:.3f}, sem={s_sem_raw:.3f}")

# --- Approximate rescale to [0,1] using TRAIN data (min-max or max-only) ---
# If pixels are already 0/1, this will have little effect. If they're tiny, it lifts them.
train_min = X_train.min().item()
train_max = X_train.max().item()
if train_max > train_min:
    # min-max to [0,1]
    X_train_01 = (X_train - train_min) / (train_max - train_min)
    X_test_01  = (X_test  - train_min) / (train_max - train_min)
    X_hat_01   = (X_hat_saved - train_min) / (train_max - train_min)
else:
    # fallback: max-only scale
    scale = train_max if train_max > 0 else 1.0
    X_train_01 = X_train / scale
    X_test_01  = X_test  / scale
    X_hat_01   = X_hat_saved / scale

print("=== SCALE INSPECTION ([0,1] approx) ===")
print_stats("train([0,1])", X_train_01)
print_stats(" test([0,1])", X_test_01)

mae_zero_01  = per_frame_mae(torch.zeros_like(X_test_01), X_test_01, mask=mask)
mae_saved_01 = per_frame_mae(X_hat_01, X_test_01, mask=mask)
z_mean_01, z_std_01, z_sem_01 = summarize_x1e3(mae_zero_01)
s_mean_01, s_std_01, s_sem_01 = summarize_x1e3(mae_saved_01)
print(f"[Zero | ~[0,1]] mean={z_mean_01:.1f} ×10^3, std={z_std_01:.1f}, sem={z_sem_01:.1f}")
print(f"[Saved DI | ~[0,1]] mean={s_mean_01:.1f} ×10^3, std={s_std_01:.1f}, sem={s_sem_01:.1f}")

# --- Calibrate scale to match paper's zero baseline (linear rescale) ---
# We want: mean_L1_zero_calibrated ≈ 147 ×10³.
calib_factor = (PAPER_ZERO_L1_X1E3 / z_mean_raw) if z_mean_raw > 0 else 1.0
X_test_cal  = X_test * calib_factor
X_hat_cal   = X_hat_saved * calib_factor

mae_zero_cal  = per_frame_mae(torch.zeros_like(X_test_cal), X_test_cal, mask=mask)
mae_saved_cal = per_frame_mae(X_hat_cal, X_test_cal, mask=mask)
z_mean_cal, z_std_cal, z_sem_cal = summarize_x1e3(mae_zero_cal)
s_mean_cal, s_std_cal, s_sem_cal = summarize_x1e3(mae_saved_cal)
print("=== CALIBRATED TO PAPER ZERO BASELINE ===")
print(f"[Zero | calib]     mean={z_mean_cal:.1f} ×10^3 (target {PAPER_ZERO_L1_X1E3}), std={z_std_cal:.1f}, sem={z_sem_cal:.1f}")
print(f"[Saved DI | calib] mean={s_mean_cal:.1f} ×10^3, std={s_std_cal:.1f}, sem={s_sem_cal:.1f}")

# --- Recompute g(Z) with bounded output (sigmoid), then evaluate in calibrated scale ---
state = torch.load(G_STATE_PATH, weights_only=False)
latent_dim = state['first_layer.weight'].shape[1]
h1         = state['first_layer.weight'].shape[0]
h2         = state['second_layer.weight'].shape[0]
h3         = state['third_layer.weight'].shape[0]
x_dim      = state['fourth_layer.weight'].shape[0]
assert latent_dim == 2 and x_dim == H * W

class G(nn.Module):
    def __init__(self, d, x, h1,h2,h3):
        super().__init__()
        self.first_layer  = nn.Linear(d, h1)
        self.second_layer = nn.Linear(h1, h2)
        self.third_layer  = nn.Linear(h2, h3)
        self.fourth_layer = nn.Linear(h3, x)
        self.act = nn.ReLU()
        self.out_act = nn.Sigmoid()  # bound outputs to [0,1] for images
    def forward(self, z):
        x = self.act(self.first_layer(z))
        x = self.act(self.second_layer(x))
        x = self.act(self.third_layer(x))
        x = self.fourth_layer(x)
        return self.out_act(x)

g = G(latent_dim, x_dim, h1,h2,h3).eval()
missing, unexpected = g.load_state_dict(state, strict=True)
if missing or unexpected:
    raise RuntimeError(f"State load mismatch: missing={missing}, unexpected={unexpected}")

Z_test = torch.load(Z_TEST_PATH, weights_only=False)  # [S,T,2]
Z_flat = Z_test.reshape(S*T, latent_dim)
with torch.no_grad():
    X_hat_flat_sig = g(Z_flat)                     # [S*T,784] in [0,1]
    X_hat_img_sig  = X_hat_flat_sig.view(S, T, 1, H, W)

# Evaluate recomputed in calibrated scale (multiply by same factor)
X_gt_cal_sig   = X_test * calib_factor
X_hat_cal_sig  = X_hat_img_sig * calib_factor
mae_re_cal     = per_frame_mae(X_hat_cal_sig, X_gt_cal_sig, mask=mask)
r_mean_cal, r_std_cal, r_sem_cal = summarize_x1e3(mae_re_cal)
print(f"[Recomputed g(Z) | calib+sigmoid] mean={r_mean_cal:.1f} ×10^3, std={r_std_cal:.1f}, sem={r_sem_cal:.1f}")


=== SCALE INSPECTION (raw) ===
[train(raw)] min=0 max=0.00195309 mean=0.0002069 std=0.000546321
[ test(raw)] min=0 max=0.00195309 mean=0.00020686 std=0.000546319
[Zero | raw] mean=0.207 ×10^3, std=0.003, sem=0.000
[Saved DI | raw] mean=0.208 ×10^3, std=0.003, sem=0.000
=== SCALE INSPECTION ([0,1] approx) ===
[train([0,1])] min=0 max=1 mean=0.105934 std=0.279721
[ test([0,1])] min=0 max=1 mean=0.105914 std=0.27972
[Zero | ~[0,1]] mean=105.9 ×10^3, std=1.4, sem=0.0
[Saved DI | ~[0,1]] mean=106.3 ×10^3, std=1.4, sem=0.0
=== CALIBRATED TO PAPER ZERO BASELINE ===
[Zero | calib]     mean=147.0 ×10^3 (target 147.0), std=2.0, sem=0.0
[Saved DI | calib] mean=147.6 ×10^3, std=2.0, sem=0.0
[Recomputed g(Z) | calib+sigmoid] mean=438.5 ×10^3, std=17.8, sem=0.3
