In [None]:
# Cell 1: Setup and Parameters
import os
import time

import jax
import jax.numpy as jnp
import numpy as np

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS

from scipy.spatial.distance import cdist
from sklearn.neighbors import kneighbors_graph

import matplotlib.pyplot as plt

# --- Simulation parameters ---
N_sims = 100  # number of simulation repetitions
I = 200       # number of patients
C = 50        # number of conditions
k = 10        # k for k-NN graph
masking_fraction = 0.2  # fraction of grid entries to mask

# --- True parameters for data generation ---
tau_s_true = 1.5       # structured precision
tau_u_true = 2.0       # unstructured precision
sigma_delta_true = 0.75  # std dev for condition effects

# --- Sampler parameters (exposed for runtime control) ---
num_warmup = 500
num_samples = 500
num_chains = 1

print("Simulation parameters set:", {
    'N_sims': N_sims,
    'I': I,
    'C': C,
    'k': k,
    'masking_fraction': masking_fraction,
    'tau_s_true': tau_s_true,
    'tau_u_true': tau_u_true,
    'sigma_delta_true': sigma_delta_true,
    'num_warmup': num_warmup,
    'num_samples': num_samples,
    'num_chains': num_chains,
})


In [None]:
# Cell 2: Helper Functions
from numpyro import handlers


def get_laplacian_from_locs(locations: np.ndarray, k: int) -> jnp.ndarray:
    """
    Build an (I x I) graph Laplacian from 2D locations using k-NN adjacency.
    Returns a dense JAX array L = D - W.
    """
    I = locations.shape[0]
    # k-NN adjacency (exclude self-loops by setting include_self=False)
    A = kneighbors_graph(locations, n_neighbors=k, mode='connectivity', include_self=False)
    # Symmetrize
    W = (A + A.T).astype(bool).astype(np.float32)
    # Degree and Laplacian
    degrees = np.asarray(W.sum(axis=1)).ravel()
    D = np.diag(degrees)
    L = D - W.toarray()
    return jnp.asarray(L, dtype=jnp.float32)


def bym_model(L_pat: jnp.ndarray, y: jnp.ndarray, I: int, C: int):
    """
    NumPyro BYM model with NaN-masked observations.
    L_pat: (I x I) dense Laplacian
    y: (I x C) with NaNs indicating masked entries
    """
    # Hyperpriors
    tau_s = numpyro.sample("tau_s", dist.HalfCauchy(2.0))
    tau_u = numpyro.sample("tau_u", dist.HalfCauchy(2.0))
    sigma_delta = numpyro.sample("sigma_delta", dist.HalfCauchy(1.0))

    # Latents
    delta = numpyro.sample("delta", dist.Normal(0, sigma_delta).expand([C]))
    phi = numpyro.sample("phi", dist.Normal(0, 1.0).expand([I]))

    # Structured (ICAR-like) energy and unstructured energy
    U_structured = tau_s * (phi @ (L_pat @ phi))
    numpyro.factor("structured_effect", -0.5 * U_structured)

    U_unstructured = tau_u * jnp.sum(phi ** 2)
    numpyro.factor("unstructured_effect", -0.5 * U_unstructured)

    # Linear predictor
    Lambda = phi[:, None] + delta[None, :]

    # Likelihood with masking (mask True for observed entries)
    obs_mask = ~jnp.isnan(y)
    with handlers.mask(mask=obs_mask):
        numpyro.sample("obs", dist.Bernoulli(logits=Lambda), obs=y)
