In [1]:
import jax
import jax.numpy as jnp
import numpyro
import numpy as np
import numpyro.distributions as dist

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
J = 8
y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])

In [3]:
def eight_schools():
    mu = numpyro.sample('mu', dist.Normal(2, 5))
    tau = numpyro.sample('tau', dist.HalfCauchy(5))
    theta = numpyro.sample('theta', dist.Normal(jnp.full(J,mu), tau))
    numpyro.sample('obs', dist.Normal(theta, sigma), obs=y)

In [4]:
import blackjax

num_warmup = 500

adapt = blackjax.window_adaptation(
    blackjax.nuts, eight_schools
)
(last_state, parameters), intermediate_states,logdensity_fn  = adapt.run(jax.random.PRNGKey(0), num_warmup)
kernel = blackjax.nuts(logdensity_fn, **parameters).step

TypeError: Value <function window_adaptation.<locals>.logdensity_create.<locals>.<lambda> at 0x13fca7d80> with type <class 'function'> is not a valid JAX type

In [None]:
def inference_loop(rng_key, kernel, initial_state, num_samples):
    @jax.jit
    def one_step(state, rng_key):
        state, info = kernel(rng_key, state)
        return state, (state, info)

    keys = jax.random.split(rng_key, num_samples)
    _, (states, infos) = jax.lax.scan(one_step, initial_state, keys)

    return states, (
        infos.acceptance_rate,
        infos.is_divergent,
        infos.num_integration_steps,
    )

In [None]:
num_sample = 1000

states, infos = inference_loop(jax.random.PRNGKey(0), kernel, last_state, num_sample)
_ = states.position["mu"].block_until_ready()

In [None]:
acceptance_rate = np.mean(infos[0])
num_divergent = np.mean(infos[1])

print(f"\Average acceptance rate: {acceptance_rate:.2f}")
print(f"There were {100*num_divergent:.2f}% divergent transitions")

\Average acceptance rate: 0.97
There were 0.00% divergent transitions


In [None]:
# \Average acceptance rate: 0.87
# There were 1.20% divergent transitions