In [2]:
import jax.numpy as jnp
import numpy as np
import torch

from mm_sbi_review.examples.turin import (
    turin,
    compute_turin_summaries,
    compute_turin_summaries_with_max,
)

In [3]:
import types, sys

sys.modules["networks"] = types.ModuleType("networks")
sys.modules["networks.summary_nets"] = types.ModuleType("networks.summary_nets")

from summary_nets import TurinSummary  # local definition

sys.modules["networks.summary_nets"].TurinSummary = TurinSummary

sum_net = torch.load("../data/turin_robust-sbi/sum_net.pkl", map_location="cpu")
sum_net.eval()  # inference mode → disables dropout / BN updates

  sum_net = torch.load("../data/turin_robust-sbi/sum_net.pkl", map_location="cpu")


TurinSummary(
  (lstm): LSTM(1, 4, batch_first=True)
  (conv): Sequential(
    (0): Conv1d(1, 8, kernel_size=(3,), stride=(3,))
    (1): Conv1d(8, 16, kernel_size=(3,), stride=(3,))
    (2): Conv1d(16, 32, kernel_size=(3,), stride=(3,))
    (3): Conv1d(32, 64, kernel_size=(3,), stride=(3,))
    (4): Conv1d(64, 8, kernel_size=(3,), stride=(3,))
    (5): AvgPool1d(kernel_size=(2,), stride=(2,), padding=(0,))
  )
)

In [4]:
@torch.no_grad()
def learnt_stats(x_tensor: torch.Tensor) -> torch.Tensor:
    """
    x_tensor : (N, 801)  or (batch, N, 801)  real‑valued power (dB)
    returns  : (batch, d) summary where d = 8 for TurinSummary
    """
    if x_tensor.ndim == 2:  # single data set
        x_tensor = x_tensor.unsqueeze(0)  # → (1, N, 801)

    _, stat = sum_net(x_tensor)  # (batch, d)
    return stat.squeeze(0)  # (d,) if batch==1

In [5]:
x_obs_tensor = torch.tensor(
    np.load("../data/turin_obs.npy"), dtype=torch.float32
).reshape(100, 801)
x_obs = jnp.array(learnt_stats(x_obs_tensor).numpy())  # (d,)

In [6]:
# get NPE posterior samples
N = 100
B = 4e9
Ns = 801

# x_data_full = (np.load("../data/turin_obs.npy")).reshape(N, 801)
# x_data_full = torch.tensor(x_data_full, dtype=torch.float32)
# x_obs = compute_turin_summaries_with_max(x_data_full, delta_f=(B / (Ns - 1)))
# x_obs = jnp.array(x_obs)

with open("../data/turin_theta_2000_tau0.npy", "rb") as f:
    npe_samples = np.load(f)

In [7]:
import jax.numpy as jnp
import numpy as np
from functools import partial


# --- 1. helpers --------------------------------------------------------------
def rbf_kernel(x, y, ell):
    return jnp.exp(-jnp.sum((x - y) ** 2, axis=-1) / (2 * ell**2))


def median_heuristic(x, batch=1_000):
    """Median of pairwise ℓ2 distances (memory-friendly)."""
    n = x.shape[0]

    def dists(i):
        sl = slice(i * batch, min((i + 1) * batch, n))
        a = x[sl, None, :]
        return jnp.sqrt(jnp.sum((a - x[None, :, :]) ** 2, -1)).ravel()

    return jnp.sqrt(
        jnp.median(jnp.concatenate([dists(i) for i in range((n + batch - 1) // batch)]))
        / 2.0
    )
    # return 1.0


def unbiased_mmd(sims, x_obs, ell):
    l = sims.shape[0]

    k_xx = rbf_kernel(sims[:, None, :], sims[None, :, :], ell)  # (l,l)
    k_xy = rbf_kernel(sims, x_obs[None, :], ell).reshape(l)  # (l,)
    mmd2 = k_xx.sum() / (l**2) - 2.0 * k_xy.sum() / l + 1.0
    # k_xx.sum() / m**2 - 2 * k_xy.sum() / m
    return mmd2


def biased_mmd2(sims, x_obs, ell):
    """
    Squared MMD (biased) between
        P = empirical on `sims` (m ≥ 1)
        Q = point mass at `x_obs`  (n = 1)
    using an RBF kernel with bandwidth `ell`.
    Returns: scalar MMD^2.
    """
    k_xx = rbf_kernel(sims[:, None, :], sims[None, :, :], ell).mean()  # E_P k(X,X')
    k_yy = 1.0  # k(y,y) = 1
    k_xy = rbf_kernel(sims, x_obs[None, :], ell).mean()  # E_P k(X,y)

    return k_xx + k_yy - 2.0 * k_xy


def biased_mmd(sims, x_obs, ell):
    return jnp.sqrt(jnp.maximum(biased_mmd2(sims, x_obs, ell), 0.0))


def mmd_with_median(x, y):
    ell = median_heuristic(jnp.concatenate([x, y], axis=0))
    return biased_mmd2(x, y, ell)


# # convenience wrapper that does the length-scale search once
# def mmd_with_median(x, y):
#     ell = median_heuristic(jnp.concatenate([x, y], axis=0))
#     return
# d_mmd(x, y, ell)

In [10]:
num_post_pred_sims = 25

turin_sim = turin(B=B, Ns=Ns, N=N, tau0=0)


# posterior predictive samples based on NPE

# for i in range(min(npe_samples.shape[0], num_post_pred_sims)):
#     theta_draws = npe_samples[i, :]
#     theta_draws = torch.tensor(theta_draws)
#     data = turin_sim(theta_draws)
#     summaries = compute_turin_summaries_with_max(data)

In [11]:
def posterior_predictive_summaries(theta_array, sim):
    stats = []
    for th in theta_array:
        x_sim = sim(torch.tensor(th))  # (N, 801)
        s = learnt_stats(x_sim).numpy()  # (d,)
        stats.append(s)
    return jnp.stack(stats, axis=0)  # (m, d)

In [9]:
s_rep = posterior_predictive_summaries(
    npe_samples[:num_post_pred_sims], turin_sim
)  # (m, d)

In [10]:
s_rep.shape

(25, 8)

In [11]:
ell = median_heuristic(jnp.concatenate([s_rep, x_obs[None, :]], axis=0))

In [12]:
mmd_obs = biased_mmd(s_rep, x_obs, ell)

In [13]:
s_rep.shape[0]

25

In [14]:
ell

Array(3.4201233, dtype=float32)

In [15]:
import jax.random as random

B_null = 100  # size of null distribution

rng_key = random.PRNGKey(0)  # for reproducibility

m = s_rep.shape[0]  # same size you used in the observed test


def mmd_under_null(key):
    key_theta, key_val, key_obs = random.split(key, 3)

    # 1. draw *one* parameter θ*
    idx = random.randint(key_theta, (), 0, npe_samples.shape[0])
    theta_star = npe_samples[idx]

    # 2. simulate m validation data sets and 1 "observed" data set
    val_summaries = []
    for _ in range(m):
        x_val = turin_sim(torch.tensor(theta_star))
        val_summaries.append(jnp.array(learnt_stats(x_val)))
    x_obs_sim = turin_sim(torch.tensor(theta_star))
    s_obs_sim = learnt_stats(x_obs_sim)
    s_obs_sim = jnp.array(s_obs_sim)

    s_val = jnp.stack(val_summaries, axis=0)  # shape (m,d)
    s_obs_1 = s_obs_sim[None, :]  # shape (1,d)

    return biased_mmd(s_val, s_obs_1, ell)  # same estimator you used before


keys = random.split(rng_key, num=B_null)
mmd_null = jnp.array([mmd_under_null(k) for k in keys])

crit = jnp.quantile(mmd_null, 0.95)  # 5 % test
pval = (mmd_null >= mmd_obs).mean()

print(f"MMD̂ = {mmd_obs:.4f}  |  critical 5 % = {crit:.4f}  |  p = {pval:.3f}")

MMD̂ = 0.9324  |  critical 5 % = 0.3517  |  p = 0.000


In [16]:
mmd_null

Array([0.30802575, 0.28400996, 0.13446763, 0.09258758, 0.1934053 ,
       0.12388543, 0.05393685, 0.2165058 , 0.18614213, 0.1962845 ,
       0.24966933, 0.18680985, 0.10587656, 0.14259276, 0.16471262,
       0.16560115, 0.18295273, 0.15027446, 0.11998314, 0.5646854 ,
       0.06623639, 0.0687905 , 0.26325855, 0.2393258 , 0.0974167 ,
       0.20745835, 0.46021914, 0.10612173, 0.03889889, 0.05787214,
       0.29299173, 0.15531643, 0.16331421, 0.14880526, 0.05029652,
       0.16297881, 0.17082901, 0.27280942, 0.07253072, 0.0591012 ,
       0.1432109 , 0.12280346, 0.16192138, 0.13050681, 0.07665238,
       0.07636491, 0.13040721, 0.07866861, 0.08689074, 0.09557786,
       0.07862768, 0.06823108, 0.22895913, 0.09558472, 0.10791457,
       0.12474655, 0.18222772, 0.30661982, 0.15387717, 0.2298651 ,
       0.18580273, 0.197709  , 0.07605127, 0.14670312, 0.29901814,
       0.10876349, 0.28155777, 0.5534229 , 0.12415505, 0.06927057,
       0.08272585, 0.07554328, 0.32259375, 0.53173053, 0.12898

In [17]:
mmd_obs

Array(0.93239707, dtype=float32)

In [18]:
mmd_null

Array([0.30802575, 0.28400996, 0.13446763, 0.09258758, 0.1934053 ,
       0.12388543, 0.05393685, 0.2165058 , 0.18614213, 0.1962845 ,
       0.24966933, 0.18680985, 0.10587656, 0.14259276, 0.16471262,
       0.16560115, 0.18295273, 0.15027446, 0.11998314, 0.5646854 ,
       0.06623639, 0.0687905 , 0.26325855, 0.2393258 , 0.0974167 ,
       0.20745835, 0.46021914, 0.10612173, 0.03889889, 0.05787214,
       0.29299173, 0.15531643, 0.16331421, 0.14880526, 0.05029652,
       0.16297881, 0.17082901, 0.27280942, 0.07253072, 0.0591012 ,
       0.1432109 , 0.12280346, 0.16192138, 0.13050681, 0.07665238,
       0.07636491, 0.13040721, 0.07866861, 0.08689074, 0.09557786,
       0.07862768, 0.06823108, 0.22895913, 0.09558472, 0.10791457,
       0.12474655, 0.18222772, 0.30661982, 0.15387717, 0.2298651 ,
       0.18580273, 0.197709  , 0.07605127, 0.14670312, 0.29901814,
       0.10876349, 0.28155777, 0.5534229 , 0.12415505, 0.06927057,
       0.08272585, 0.07554328, 0.32259375, 0.53173053, 0.12898

In [14]:
# --- MMD misspecification diagnostic (paper-compliant) -----------------------
# Prior-predictive summaries -> whiten (fit on validation) -> multi-kernel RBF
# Null via independent draws from p(z|M); first sample kept large and fixed.

import numpy as np, torch

# 0) Make sure nets/tensors live on same device as the simulator
_dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
try:
    sum_net.to(_dev)
except NameError:
    pass
try:
    x_obs_tensor = x_obs_tensor.to(_dev)
except NameError:
    raise RuntimeError("x_obs_tensor not found; load your observed data first.")

# 1) Prior from run_turin.py (BoxUniform bounds)
_PRIOR_LOW = np.array([1e-9, 1e-9, 1e7, 1e-10], dtype=np.float64)
_PRIOR_HIGH = np.array([1e-8, 1e-8, 5e9, 1e-9], dtype=np.float64)
_rng = np.random.default_rng(0)


def prior_sample(n: int) -> np.ndarray:
    return _rng.uniform(_PRIOR_LOW, _PRIOR_HIGH, size=(n, 4))


# 2) Summaries via your learned network (expects dB power)
@torch.no_grad()
def sim_summaries(num: int) -> np.ndarray:
    out = []
    thetas = prior_sample(num)
    for th in thetas:
        x = turin_sim(torch.tensor(th, dtype=torch.float32, device=_dev))  # (N,801)
        s = learnt_stats(x).detach().to("cpu").numpy()
        out.append(s)
    return np.asarray(out, dtype=np.float64)  # (num, d)


# 3) Whitening fit on validation only
def whiten_fit(Z: np.ndarray, eps: float = 1e-6):
    mu = Z.mean(0)
    C = np.cov(Z, rowvar=False) + eps * np.eye(Z.shape[1])
    L = np.linalg.cholesky(C)
    return mu, L


def whiten_apply(Z: np.ndarray, mu: np.ndarray, L: np.ndarray) -> np.ndarray:
    return np.linalg.solve(L, (Z - mu).T).T


# 4) Multi-kernel RBF (biased MMD^2, valid for N=1)
def rbf_sum(X: np.ndarray, Y: np.ndarray, sigmas: np.ndarray) -> np.ndarray:
    # X:(n,d), Y:(m,d)
    d2 = ((X[:, None, :] - Y[None, :, :]) ** 2).sum(-1)  # (n,m)
    K = np.exp(-0.5 * d2[:, :, None] / (sigmas[None, None, :] ** 2)).sum(-1)
    return K


def mmd2_biased(X: np.ndarray, Y: np.ndarray, sigmas: np.ndarray) -> float:
    kxx = rbf_sum(X, X, sigmas).mean()
    kyy = rbf_sum(Y, Y, sigmas).mean()
    kxy = rbf_sum(X, Y, sigmas).mean()
    return float(kxx + kyy - 2.0 * kxy)


# 5) Fixed kernel widths from validation only (no peeking at observed)
def median_pairwise(Z: np.ndarray, cap: int = 2000) -> float:
    if Z.shape[0] > cap:
        Z = Z[_rng.choice(Z.shape[0], size=cap, replace=False)]
    d = np.sqrt(((Z[:, None, :] - Z[None, :, :]) ** 2).sum(-1)).ravel()
    return np.median(d[d > 0])


def make_sigmas(Z_val_w: np.ndarray) -> np.ndarray:
    med = max(median_pairwise(Z_val_w), 1e-12)
    return med * np.array([0.5, 1.0, 2.0, 4.0, 8.0], dtype=np.float64)


# ------------------------------ Run the test ---------------------------------
M_val = 100  # prior-predictive validation size
B_null = 25  # null replicates
alpha = 0.05

# A) Validation summaries from p(z|M)
Z_val = sim_summaries(M_val)  # (M,d)
mu, L = whiten_fit(Z_val)
Z_val_w = whiten_apply(Z_val, mu, L)

# B) Observed summaries (N may be 1)
z_obs = learnt_stats(x_obs_tensor).detach().to("cpu").numpy()[None, ...]
Z_obs_w = whiten_apply(z_obs, mu, L)

# C) Multi-kernel widths from validation only
sigmas = make_sigmas(Z_val_w)

# D) Observed statistic
mmd2_obs = mmd2_biased(Z_val_w, Z_obs_w, sigmas)

# E) Null: independent second sample from p(z|M), size N, using the fixed large first set
N = Z_obs_w.shape[0]
Kxx_const = rbf_sum(Z_val_w, Z_val_w, sigmas).mean()  # one-time cost
mmd2_null = np.empty(B_null, dtype=np.float64)
for b in range(B_null):
    print(f"b: {b}")
    Z_b = whiten_apply(sim_summaries(N), mu, L)
    kyy = rbf_sum(Z_b, Z_b, sigmas).mean()
    kxy = rbf_sum(Z_val_w, Z_b, sigmas).mean()
    mmd2_null[b] = Kxx_const + kyy - 2.0 * kxy

crit = np.quantile(mmd2_null, 1.0 - alpha)
pval = float(np.mean(mmd2_null >= mmd2_obs))

print(
    f"MMD^2_obs={mmd2_obs:.6g}  MMD_obs={np.sqrt(mmd2_obs):.6g}  "
    f"crit@{1-alpha:.2f}={crit:.6g}  p={pval:.3f}"
)

b: 0
b: 1
b: 2
b: 3
b: 4
b: 5
b: 6
b: 7
b: 8
b: 9
b: 10
b: 11
b: 12
b: 13
b: 14
b: 15
b: 16
b: 17
b: 18
b: 19
b: 20
b: 21
b: 22
b: 23
b: 24
MMD^2_obs=2.70869  MMD_obs=1.64581  crit@0.95=2.13876  p=0.000


In [15]:
mmd2_null

array([1.91406661, 1.65363886, 1.02260941, 1.80480186, 0.9922519 ,
       1.91380868, 0.90085011, 0.70903236, 1.51323068, 1.89239709,
       2.19312863, 0.91823997, 1.89758444, 1.7354356 , 2.20557598,
       1.20414721, 1.81721313, 1.89032247, 1.92128648, 1.55331989,
       0.81432782, 1.02096912, 1.31130663, 1.5232953 , 1.64831147])