# BYM Prediction-Based FDR Simulation

This notebook runs a simulation study to estimate prediction-based False Discovery Rate (FDR) for a BYM model. In each simulation, we:
- Generate ground-truth patient and condition effects
- Mask a random fraction of observations
- Fit a BYM model on masked data (handling NaNs)
- Predict masked entries and compute FDR = FP / (FP + TP)

The final section summarizes the distribution of FDR across runs and visualizes it.


## Cell 1: Setup and Parameters

Imports core libraries (JAX, NumPy, NumPyro, SciPy/sklearn, Matplotlib) and defines:
- Simulation controls: number of simulations `N_sims`, patients `I`, conditions `C`, k-NN parameter `k`, masking fraction.
- Ground-truth parameters for data generation (`tau_s_true`, `tau_u_true`, `sigma_delta_true`).
- Sampler configuration (`num_warmup`, `num_samples`, `num_chains`).


In [1]:
# Cell 1: Setup and Parameters
import os
import time

import jax
import jax.numpy as jnp
import numpy as np

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS

from scipy.spatial.distance import cdist
from sklearn.neighbors import kneighbors_graph

import matplotlib.pyplot as plt

# --- Simulation parameters ---
N_sims = 100  # number of simulation repetitions
I = 200       # number of patients
C = 50        # number of conditions
k = 10        # k for k-NN graph
masking_fraction = 0.2  # fraction of grid entries to mask

# --- True parameters for data generation ---
tau_s_true = 1.5       # structured precision
tau_u_true = 2.0       # unstructured precision
sigma_delta_true = 0.75  # std dev for condition effects

# --- Sampler parameters (exposed for runtime control) ---
num_warmup = 500
num_samples = 500
num_chains = 1

print("Simulation parameters set:", {
    'N_sims': N_sims,
    'I': I,
    'C': C,
    'k': k,
    'masking_fraction': masking_fraction,
    'tau_s_true': tau_s_true,
    'tau_u_true': tau_u_true,
    'sigma_delta_true': sigma_delta_true,
    'num_warmup': num_warmup,
    'num_samples': num_samples,
    'num_chains': num_chains,
})


Simulation parameters set: {'N_sims': 100, 'I': 200, 'C': 50, 'k': 10, 'masking_fraction': 0.2, 'tau_s_true': 1.5, 'tau_u_true': 2.0, 'sigma_delta_true': 0.75, 'num_warmup': 500, 'num_samples': 500, 'num_chains': 1}


## Cell 2: Helper Functions

- `get_laplacian_from_locs`: builds the patient k-NN graph and returns the graph Laplacian `L = D - W` as a JAX array.
- `bym_model`: a NumPyro BYM model with:
  - Structured component via `numpyro.factor` using the Laplacian energy `phi^T L phi` scaled by `tau_s`.
  - Unstructured component via `||phi||^2` scaled by `tau_u`.
  - Condition effects `delta` with prior scale `sigma_delta`.
  - Likelihood `Bernoulli(logits=phi[:,None]+delta[None,:])` masked to ignore `NaN` entries in `y`. 


In [None]:
# Cell 3: Main Simulation Loop

def sigmoid(x):
    return 1.0 / (1.0 + jnp.exp(-x))

fdr_results = []
key = jax.random.PRNGKey(0)

for sim in range(N_sims):
    t0 = time.time()
    # --- Generate Ground Truth Data ---
    # Random 2D locations in unit square
    key, k_locs = jax.random.split(key)
    locations = jax.random.uniform(k_locs, shape=(I, 2), minval=0.0, maxval=1.0)
    locations_np = np.array(locations)

    # Build true patient Laplacian
    L_pat_true = get_laplacian_from_locs(locations_np, k)

    # Construct true precision Q_true = tau_s_true * L_pat_true + tau_u_true * I
    Q_true = tau_s_true * L_pat_true + (tau_u_true + 1e-5) * jnp.eye(I)

    # Sample phi_true ~ N(0, Q_true^{-1}) via Cholesky of Q_true
    # Solve Q_true^{-1/2} * z, with z ~ N(0, I)
    # Use symmetric PD assumption with jitter
    L_chol = jnp.linalg.cholesky(Q_true)
    key, k_z = jax.random.split(key)
    z = jax.random.normal(k_z, shape=(I,))
    # Solve L_chol x = z then L_chol^T phi = x => phi = (L_chol^{-T}) (L_chol^{-1}) z
    x = jax.scipy.linalg.solve_triangular(L_chol, z, lower=True)
    phi_true = jax.scipy.linalg.solve_triangular(L_chol.T, x, lower=False)

    # Condition effects
    key, k_delta = jax.random.split(key)
    delta_true = jax.random.normal(k_delta, shape=(C,)) * sigma_delta_true

    # Linear predictor and probabilities
    Lambda_true = phi_true[:, None] + delta_true[None, :]
    p_true = sigmoid(Lambda_true)

    # Sample binary outcomes
    key, k_y = jax.random.split(key)
    y_true = jax.random.bernoulli(k_y, p=p_true).astype(jnp.float32)

    # --- Masked Training Data ---
    key, k_mask = jax.random.split(key)
    total_entries = I * C
    num_mask = int(masking_fraction * total_entries)
    flat_indices = jax.random.choice(k_mask, total_entries, shape=(num_mask,), replace=False)
    mask_rows = (flat_indices // C).astype(int)
    mask_cols = (flat_indices % C).astype(int)

    y_train = y_true.copy()
    y_train = y_train.at[mask_rows, mask_cols].set(jnp.nan)

    true_mask_values = y_true[mask_rows, mask_cols]

    # --- Fit BYM Model ---
    def model():
        return bym_model(L_pat_true, y_train, I, C)

    nuts = NUTS(model)
    mcmc = MCMC(nuts, num_warmup=num_warmup, num_samples=num_samples, num_chains=num_chains, progress_bar=False)
    mcmc.run(key)
    samples = mcmc.get_samples()

    # --- Predictions for Masked Entries ---
    phi_mean = samples['phi'].mean(axis=0)
    delta_mean = samples['delta'].mean(axis=0)
    Lambda_pred = phi_mean[:, None] + delta_mean[None, :]
    p_pred = sigmoid(Lambda_pred)

    pred_mask_probs = p_pred[mask_rows, mask_cols]

    # Discoveries: p > 0.5
    discoveries = pred_mask_probs > 0.5

    # Compute FP and TP
    tp = jnp.sum((discoveries) & (true_mask_values == 1))
    fp = jnp.sum((discoveries) & (true_mask_values == 0))

    denom = tp + fp
    fdr = jnp.where(denom > 0, fp / denom, 0.0)
    fdr_results.append(float(fdr))

    t1 = time.time()
    print(f"Sim {sim+1}/{N_sims}: FDR={float(fdr):.3f} (elapsed {t1 - t0:.1f}s, masked={num_mask})")


In [None]:
# Cell 4: Final Analysis and Visualization

fdr_arr = np.asarray(fdr_results, dtype=float)
mean_fdr = float(np.nanmean(fdr_arr)) if fdr_arr.size else float('nan')
std_fdr = float(np.nanstd(fdr_arr)) if fdr_arr.size else float('nan')

print(f"Prediction-based FDR across {len(fdr_arr)} simulations:")
print(f"  Mean: {mean_fdr:.4f}")
print(f"  Std:  {std_fdr:.4f}")

plt.figure(figsize=(7,4))
plt.hist(fdr_arr, bins=20, alpha=0.75, color='steelblue', edgecolor='white')
plt.axvline(mean_fdr, color='red', linestyle='--', linewidth=2, label=f"Mean = {mean_fdr:.3f}")
plt.xlabel('FDR')
plt.ylabel('Frequency')
plt.title('Distribution of Prediction-Based FDR')
plt.legend()
plt.tight_layout()
plt.show()


## Cell 4: Interpretation

This section aggregates the FDR across all simulation runs and visualizes its distribution. A lower mean FDR indicates fewer false discoveries among positive predictions on masked entries. The histogram helps assess variability across runs.


In [2]:
# Cell 2: Helper Functions
from numpyro import handlers


def get_laplacian_from_locs(locations: np.ndarray, k: int) -> jnp.ndarray:
    """
    Build an (I x I) graph Laplacian from 2D locations using k-NN adjacency.
    Returns a dense JAX array L = D - W.
    """
    I = locations.shape[0]
    # k-NN adjacency (exclude self-loops by setting include_self=False)
    A = kneighbors_graph(locations, n_neighbors=k, mode='connectivity', include_self=False)
    # Symmetrize
    W = (A + A.T).astype(bool).astype(np.float32)
    # Degree and Laplacian
    degrees = np.asarray(W.sum(axis=1)).ravel()
    D = np.diag(degrees)
    L = D - W.toarray()
    return jnp.asarray(L, dtype=jnp.float32)


def bym_model(L_pat: jnp.ndarray, y: jnp.ndarray, I: int, C: int):
    """
    NumPyro BYM model with NaN-masked observations.
    L_pat: (I x I) dense Laplacian
    y: (I x C) with NaNs indicating masked entries
    """
    # Hyperpriors
    tau_s = numpyro.sample("tau_s", dist.HalfCauchy(2.0))
    tau_u = numpyro.sample("tau_u", dist.HalfCauchy(2.0))
    sigma_delta = numpyro.sample("sigma_delta", dist.HalfCauchy(1.0))

    # Latents
    delta = numpyro.sample("delta", dist.Normal(0, sigma_delta).expand([C]))
    phi = numpyro.sample("phi", dist.Normal(0, 1.0).expand([I]))

    # Structured (ICAR-like) energy and unstructured energy
    U_structured = tau_s * (phi @ (L_pat @ phi))
    numpyro.factor("structured_effect", -0.5 * U_structured)

    U_unstructured = tau_u * jnp.sum(phi ** 2)
    numpyro.factor("unstructured_effect", -0.5 * U_unstructured)

    # Linear predictor
    Lambda = phi[:, None] + delta[None, :]

    # Likelihood with masking (mask True for observed entries)
    obs_mask = ~jnp.isnan(y)
    with handlers.mask(mask=obs_mask):
        numpyro.sample("obs", dist.Bernoulli(logits=Lambda), obs=y)
