# Estimating ln(Z) with NumPyro + MorphZ

This notebook demonstrates how to compute the Bayesian evidence (log marginal likelihood, **ln Z**) using:

- **NumPyro** for posterior sampling (via NUTS)
- **MorphZ** for evidence estimation from posterior samples

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/EL-MZ/MorphZ/blob/main/examples/numpyro_morphz_lnz.ipynb)

## 1. What We Are Computing

Given a model with parameters $\theta$ and data $y$:

Posterior:
$$
p(\theta \mid y) \propto p(y \mid \theta)\, p(\theta)
$$

Evidence (marginal likelihood):
$$
Z = p(y) = \int p(y \mid \theta)\, p(\theta)\, d\theta
$$

We want $\ln Z$. NumPyro gives posterior samples, and MorphZ estimates $Z$ from those samples plus log-posterior evaluations.

## 2. Colab Setup

In [None]:
%pip -q install "jax[cpu]" numpyro morphz

In [None]:
import numpy as np
import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist

from numpyro.infer import MCMC, NUTS
from numpyro.infer.util import log_density
from morphZ import evidence

print('jax:', jax.__version__)
print('numpyro:', numpyro.__version__)

## 3. Minimal Toy Model: Gaussian Mean Inference

We infer an unknown mean $\mu$ with known noise $\sigma$.

- Prior: $\mu \sim \mathcal{N}(0, \tau_0^2)$
- Likelihood: $y_i \sim \mathcal{N}(\mu, \sigma^2)$

In [None]:
rng = np.random.default_rng(0)

# Synthetic dataset
n_obs = 40
mu_true = 1.25
sigma = 0.7
tau0 = 2.0
y = rng.normal(mu_true, sigma, size=n_obs)

print('n_obs =', n_obs)
print('sample mean =', y.mean())

In [None]:
def gaussian_mean_model(y, sigma, tau0):
    mu = numpyro.sample('mu', dist.Normal(0.0, tau0))
    numpyro.sample('obs', dist.Normal(mu, sigma), obs=y)

## 4. Run NUTS and Collect Posterior Samples

In [None]:
nuts = NUTS(gaussian_mean_model)
mcmc = MCMC(
    nuts,
    num_warmup=800,
    num_samples=2000,
    num_chains=1,
    progress_bar=True,
)

rng_key = jax.random.PRNGKey(42)
mcmc.run(rng_key, y=jnp.asarray(y), sigma=sigma, tau0=tau0)
samples = mcmc.get_samples(group_by_chain=False)

print('posterior keys:', samples.keys())
print('mu sample shape:', np.asarray(samples['mu']).shape)

## 5. Build a Log Posterior Callable

MorphZ needs `lp_fn(sample_vector)` that returns the same log posterior used for sampling.

In [None]:
def build_log_density_fn(model, model_kwargs):
    model_kwargs = jax.tree_util.tree_map(jnp.asarray, model_kwargs)

    def _logpost(params):
        log_prob, _ = log_density(model, (), model_kwargs, params)
        return log_prob

    return jax.jit(_logpost)

model_kwargs = {
    'y': jnp.asarray(y),
    'sigma': sigma,
    'tau0': tau0,
}

logpost_fn = build_log_density_fn(gaussian_mean_model, model_kwargs)

## 6. Pack Samples into a 2D Array

MorphZ expects `shape = (n_draws, n_parameters)`. Here we have one parameter (`mu`).

In [None]:
mu_samples = np.asarray(samples['mu'])
post_smp = mu_samples[:, None]  # (n_draws, 1)

def lp_fn(sample_vec):
    params = {'mu': jnp.asarray(sample_vec[0])}
    return float(logpost_fn(params))

lp = np.array([lp_fn(v) for v in post_smp])

print('post_smp shape:', post_smp.shape)
print('lp shape:', lp.shape)

## 7. Sanity Check: `lp` vs `lp_fn`

In [None]:
for i in range(5):
    print(f'i={i}  precomputed={lp[i]: .6f}  callable={lp_fn(post_smp[i]): .6f}')

## 8. Estimate ln(Z) with MorphZ

In [None]:
results = evidence(
    post_samples=post_smp,
    log_posterior_values=lp,
    log_posterior_function=lp_fn,
    n_resamples=1000,
    thin=2,
    kde_fraction=0.6,
    bridge_start_fraction=0.5,
    max_iter=2000,
    tol=1e-4,
    morph_type='indep',
    kde_bw='silverman',
    param_names=['mu'],
    output_path='morphz_numpyro_demo',
    n_estimations=3,
    verbose=False,
    plot=False,
    show_progress=False,
)

results = np.asarray(results)
results

## 9. Compare to Analytic ln(Z)

For this conjugate Gaussian setup, we can compute exact ln(Z) for validation.

In [None]:
def analytic_lnz(y, sigma, tau0):
    n = y.size
    C = (sigma ** 2) * np.eye(n) + (tau0 ** 2) * np.ones((n, n))
    sign, logdet = np.linalg.slogdet(C)
    if sign <= 0:
        raise RuntimeError('Covariance matrix is not positive definite.')
    quad = y @ np.linalg.solve(C, y)
    return -0.5 * (n * np.log(2 * np.pi) + logdet + quad)

lnz_true = analytic_lnz(y, sigma, tau0)
lnz_est = results[:, 0]
lnz_err = results[:, 1]

print('MorphZ ln(Z) per run:', lnz_est)
print('MorphZ reported errors:', lnz_err)
print(f'MorphZ mean ln(Z): {lnz_est.mean():.6f}')
print(f'Analytic ln(Z):     {lnz_true:.6f}')
print(f'Absolute difference: {abs(lnz_est.mean() - lnz_true):.6f}')

## 10. Summary

The MorphZ inputs are:

- `post_smp`: posterior sample matrix
- `lp`: log posterior values at those samples
- `lp_fn`: callable log posterior

As long as `lp[i] == lp_fn(post_smp[i])`, your evidence estimate is consistent with your sampled posterior.