# Testing Multi-HMC Gibbs

This notebook will run through some examples using the new Multi-HMC Gibbs sampler.

First import sme packages.

In [None]:
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
import arviz
import corner
import matplotlib.pyplot as plt

from numpyro.infer import MCMC, NUTS, HMCGibbs
from jax import random

from MultiHMCGibbs import MultiHMCGibbs

My `custom_gibbs` file is one folder up, so add it to the python path before import (if in the same folder just keep the last line).

## 2D normal distribution

This can be sampled just fine with HMC, and has an analytic Gibbs step, so `NUTS`, `HMCGibbs`, and `MultiHMCGibbs` can all be tested head to head.

In [None]:
def model():
    x = numpyro.sample("x", dist.Normal(0.0, 2.0))
    y = numpyro.sample("y", dist.Normal(0.0, 2.0))
    numpyro.sample("obs", dist.Normal(x + y, 1.0), obs=jnp.array([1.0]))

### NUTS

In [None]:
hmc_kernel = NUTS(model)
mcmc = MCMC(
    hmc_kernel,
    num_warmup=1000,
    num_samples=5000,
    num_chains=4,
    chain_method='vectorized',
    progress_bar=False
)
mcmc.run(random.PRNGKey(0))

In [None]:
inf_data_hmc = arviz.from_numpyro(mcmc)
print(f'divergences per chain: {inf_data_hmc.sample_stats.diverging.values.sum(axis=1)}')
display(arviz.summary(inf_data_hmc))
fig = corner.corner(inf_data_hmc, color='C0')

## MultiHMCGibbs

To use the new `MutliHMCGibbs` you need to create a list of HMC kernels (`NUTS` in this case, each can have their own keywords such as `target_accept_prob` or `max_tree_depth`).  The other argument is a list of lists containing the **free** parameters for each of the inner kernels.

**Important**: All free parameters must be listed **exactly once** for the sampler to work.  I have not implemented any checks for this yet!

In [None]:
inner_kernels = [
    NUTS(model),
    NUTS(model)
]

outer_kernel = MultiHMCGibbs(
    inner_kernels,
    [['y'], ['x']]
)

mcmc_gibbs = MCMC(
    outer_kernel,
    num_warmup=1000,
    num_samples=5000,
    num_chains=4,
    chain_method='vectorized',
    progress_bar=False
)
mcmc_gibbs.run(random.PRNGKey(0))

In [None]:
inf_data_gibbs = arviz.from_numpyro(mcmc_gibbs)
print('HMC (Blue)')
print('MultiHMCGibbs (Orange)')
print(f'divergences per chain per step:\n {inf_data_gibbs.sample_stats.diverging.values.sum(axis=1).T}')
display(arviz.summary(inf_data_gibbs))

fig = corner.corner(inf_data_gibbs, color='C1')
_ = corner.corner(inf_data_hmc, fig=fig, color='C0')

## HMCGibbs

This distribution has an analytic Gibbs step so it can use the built in `HMCGibbs`, let's try that and compare.  We need to use `sequential` as `vectorized` is currently broken for `HMCGibbs`.

In [None]:
def gibbs_fn(rng_key, gibbs_sites, hmc_sites):
    y = hmc_sites['y']
    new_x = dist.Normal(0.8 * (1-y), jnp.sqrt(0.8)).sample(rng_key)
    return {'x': new_x}


kernel_gibbs_fn = HMCGibbs(hmc_kernel, gibbs_fn=gibbs_fn, gibbs_sites=['x'])

mcmc_gibbs_fn = MCMC(
    kernel_gibbs_fn,
    num_warmup=1000,
    num_samples=5000,
    num_chains=4,
    chain_method='sequential',
    progress_bar=False
)

mcmc_gibbs_fn.run(random.PRNGKey(0))

In [None]:
inf_data_gibbs_fn = arviz.from_numpyro(mcmc_gibbs_fn)
print('HMC (Blue)')
print('MultiHMCGibbs (Orange)')
print('HMCGibbs (Green)')
display(arviz.summary(inf_data_gibbs_fn))

fig = corner.corner(inf_data_gibbs, color='C1')
_ = corner.corner(inf_data_hmc, fig=fig, color='C0')
_ = corner.corner(inf_data_gibbs_fn, fig=fig, color='C2')

In all three cases we got the same results!

## Neal's Funnel

Now lets take a distribution where a Gibbs step is needed to get a decent result.

In [None]:
def model(dim=10):
    y = numpyro.sample("y", dist.Normal(0, 3))
    numpyro.sample("x", dist.Normal(jnp.zeros(dim - 1), jnp.exp(y / 2)))

In [None]:
def run_inference(kernel, chain_method, rng_key):
    mcmc = MCMC(
        kernel,
        num_warmup=8000,
        num_samples=5000,
        num_chains=4,
        chain_method=chain_method,
        progress_bar=False
    )
    mcmc.run(rng_key)
    return mcmc

## NUTS

We will use a large `target_accept_prob` to get rid of most divergent samples and use a large number of warmup and samples to get the `r_hat`s down.

In [None]:
funnel_mcmc_hmc = run_inference(NUTS(model, target_accept_prob=0.995), 'vectorized', random.PRNGKey(0))
inf_funnle_hmc = arviz.from_numpyro(funnel_mcmc_hmc)
print(f'divergences per chain: {inf_funnle_hmc.sample_stats.diverging.values.sum(axis=1)}')
display(arviz.summary(inf_funnle_hmc))

In [None]:
x_marginal_true = jnp.linspace(-10, 10, 1000)
y_marginal_true = jnp.exp(dist.Normal(0, 3).log_prob(x_marginal_true))

In [None]:
x_model_hmc = inf_funnle_hmc.posterior.x[..., 0].data.flatten()
y_model_hmc = inf_funnle_hmc.posterior.y.data.flatten()

plt.figure(figsize=(10, 4))
plt.subplot(121)
plt.plot(x_model_hmc, y_model_hmc, '.')
plt.xlabel('x[0]')
plt.ylabel('y')
plt.xlim(-100, 100)
plt.subplot(122)
plt.hist(y_model_hmc, bins=30, histtype='step', density=True, label='HMC')
plt.plot(x_marginal_true, y_marginal_true, color='k', label='True marginal')
plt.xlabel('y')
plt.legend();

We can see that `NUTS` is struggling with this model.  We can see that the `y` marginal is still missing a bit of negative values at the bottom of the funnel.

## MultiHMCGibbs

For the `MultiHMCGibbs` sampler we will only put a large `target_accept_prob` on the `x` values (as these are the difficult ones to draw), but keep the default value for the `y` values.  To keep it on the same footing as the previous run we will use the same number of warm up and sample draws.

In [None]:
funnel_mcmc_gibbs = run_inference(
    MultiHMCGibbs(
        [NUTS(model, target_accept_prob=0.995), NUTS(model, target_accept_prob=0.8)],
        [['x'], ['y']]
    ),
    'vectorized',
    random.PRNGKey(0)
)
inf_funnle_gibbs = arviz.from_numpyro(funnel_mcmc_gibbs)
print(f'divergences per chain per step:\n {inf_funnle_gibbs.sample_stats.diverging.values.sum(axis=1).T}')
display(arviz.summary(inf_funnle_gibbs))

In [None]:
x_model_gibbs = inf_funnle_gibbs.posterior.x[..., 0].data.flatten()
y_model_gibbs = inf_funnle_gibbs.posterior.y.data.flatten()

plt.figure(figsize=(10, 4))
plt.subplot(121)
plt.plot(x_model_hmc, y_model_hmc, '.', label='HMC', zorder=2)
plt.plot(x_model_gibbs, y_model_gibbs, '.', label='Gibbs', zorder=1)
plt.xlabel('x[0]')
plt.ylabel('y')
plt.legend()
plt.xlim(-100, 100)
plt.subplot(122)
plt.hist(y_model_hmc, bins=30, histtype='step', label='HMC', density=True)
plt.hist(y_model_gibbs, bins=30, histtype='step', label='Gibbs', density=True)
plt.plot(x_marginal_true, y_marginal_true, color='k', label='True marginal')
plt.xlabel('y')
plt.legend();

We can see that with the same set up `MultiHMCGibbs` was able to reach deeper into the funnel and pull out the negative `y` values missed by `NUTS`.

## Other notes

- You can use as many `inner_kernels` as you want
- The order the kernels are stepped in is set by the order of the parameter list (in the example above `x` septs first, followed by `y`)
- The order matters!  Typically you want to step the parameters closest to the likelihood first and the hyper-parameters second.  But for some models this might not be so clear, so some experimentation could be needed.