In [1]:
import jax
import numpyro
import blackjax
import numpy as np
import numpyro.distributions as dist
from numpyro.infer.util import initialize_model
from numpyro.infer.reparam import TransformReparam

  from .autonotebook import tqdm as notebook_tqdm
  warn("Couldn't import ipywidgets properly, progress bar will use console behavior")


In [2]:
## Model
def eight_schools_noncentered(J, sigma, y=None):
    mu = numpyro.sample("mu", dist.Normal(0, 5))
    tau = numpyro.sample("tau", dist.HalfCauchy(5))
    with numpyro.plate("J", J):
        with numpyro.handlers.reparam(config={"theta": TransformReparam()}):
            theta = numpyro.sample(
                "theta",
                dist.TransformedDistribution(
                    dist.Normal(0.0, 1.0), dist.transforms.AffineTransform(mu, tau)
                ),
            )
        numpyro.sample("obs", dist.Normal(theta, sigma), obs=y)

In [3]:
# Model initialization values
J = 8
y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])

In [4]:
rng_key = jax.random.PRNGKey(0)
init_params, potential_fn_gen, *_ = initialize_model(
    rng_key,
    eight_schools_noncentered,
    model_args=(J, sigma, y),
    dynamic_args=True,
)

In [5]:
logdensity_fn = lambda position: -potential_fn_gen(J, sigma, y)(position)
initial_position = init_params.z

In [6]:
num_warmup = 150

adapt = blackjax.window_adaptation(
    blackjax.nuts, logdensity_fn, target_acceptance_rate=0.8
)

In [7]:
ibuffer_size = 75
fbuffer_size = 0
wsize = 0
num_warmup = ibuffer_size + fbuffer_size + wsize  

(last_state, parameters), intermediate_states = adapt.run(rng_key, initial_position, num_warmup, initial_buffer_size = ibuffer_size,
    final_buffer_size = fbuffer_size,
    first_window_size = wsize)

In [8]:
parameters

{'step_size': Array(0.4991284, dtype=float32, weak_type=True),
 'inverse_mass_matrix': Array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], dtype=float32)}

In [9]:
ibuffer_size = 0
fbuffer_size = 0
wsize = 25
num_warmup = ibuffer_size + fbuffer_size + wsize  

(last_state, parameters), intermediate_states = adapt.run(rng_key, last_state[0], num_warmup, initial_step_size = parameters['step_size'] ,initial_buffer_size = ibuffer_size,
    final_buffer_size = fbuffer_size,
    first_window_size = wsize)

Entered the right direction!


In [10]:
parameters

{'step_size': Array(1., dtype=float32, weak_type=True),
 'inverse_mass_matrix': Array([21.636917  ,  0.82218933,  0.9235706 ,  0.9132942 ,  0.5829264 ,
         1.2119806 ,  0.8255245 ,  0.96468395,  0.7741471 ,  0.7101809 ],      dtype=float32)}

In [11]:
ibuffer_size = 0
fbuffer_size = 75
wsize = 0
num_warmup = ibuffer_size + fbuffer_size + wsize  

(last_state, parameters), intermediate_states = adapt.run(rng_key, last_state[0], num_warmup, initial_buffer_size = ibuffer_size,
    final_buffer_size = fbuffer_size,
    first_window_size = wsize)

In [12]:
parameters

{'step_size': Array(0.44459432, dtype=float32, weak_type=True),
 'inverse_mass_matrix': Array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], dtype=float32)}

In [13]:
kernel = blackjax.nuts(logdensity_fn, **parameters).step

In [14]:
def inference_loop(rng_key, kernel, initial_state, num_samples):
    @jax.jit
    def one_step(state, rng_key):
        state, info = kernel(rng_key, state)
        return state, (state, info)

    keys = jax.random.split(rng_key, num_samples)
    extra, (states, infos) = jax.lax.scan(one_step, initial_state, keys)

    return states, (
        infos.acceptance_rate,
        infos.is_divergent,
        infos.num_integration_steps,
    )

In [15]:
num_sample = 5

states, infos = inference_loop(rng_key, kernel, last_state, num_sample)
_ = states.position["mu"].block_until_ready()

In [16]:
states

HMCState(position={'mu': Array([10.23008  ,  8.307863 ,  4.398651 ,  7.6486516,  7.439686 ],      dtype=float32), 'tau': Array([2.3501744 , 0.7994094 , 1.4532665 , 0.51785344, 0.28244084],      dtype=float32), 'theta_base': Array([[ 0.6804302 ,  0.38823453, -0.20930839, -0.6533625 ,  0.6064742 ,
        -1.1144998 ,  0.87010014, -0.7791636 ],
       [ 0.67929256, -1.7555128 ,  0.472714  ,  0.03503663, -2.7181075 ,
         1.5632589 , -0.22113849,  1.4079455 ],
       [ 0.15394486,  1.7087718 , -0.5606208 , -0.07571013,  1.8809739 ,
        -1.874018  ,  1.0483693 , -0.730608  ],
       [ 0.66666377, -1.1821587 , -0.5147217 ,  0.03661108, -1.9523745 ,
         1.5535854 , -0.8931773 , -0.33664265],
       [ 0.9475352 , -1.6104362 ,  0.3292531 ,  0.6192756 , -2.1482735 ,
         1.375846  , -1.2776589 , -0.28885585]], dtype=float32)}, logdensity=Array([-45.55864 , -50.009613, -47.83721 , -46.952873, -48.623524],      dtype=float32), logdensity_grad={'mu': Array([-0.6178483 , -0.3011075

HMCState(position={'mu': Array([9.922787 , 2.6043172, 3.7035708, 8.089965 , 7.3719726], dtype=float32), 'tau': Array([-2.1415424, -0.4907572,  1.4969707,  0.7235381,  0.4811546],      dtype=float32), 'theta_base': Array([[ 0.011798  , -0.6990618 ,  0.5800221 ,  0.6235273 , -0.0350373 ,
        -0.06943975, -2.394491  , -0.13232218],
       [ 0.23016587, -1.844681  ,  0.89817107, -0.6893593 , -1.3917323 ,
         1.8075554 ,  2.0827086 ,  0.20634526],
       [ 0.95044273,  1.6913677 , -0.5630975 ,  0.8642062 ,  1.5816722 ,
        -1.4738693 ,  1.5239587 , -0.8587075 ],
       [-1.0939564 , -0.8527685 , -1.012984  , -0.47866136, -2.1197004 ,
         1.2409846 ,  0.27299595,  2.669759  ],
       [-0.3839416 , -1.2815466 , -0.20384768,  0.45797092, -2.280699  ,
         1.0681396 , -0.48478994,  1.9311254 ]], dtype=float32)}, logdensity=Array([-49.38864 , -50.07078 , -47.313896, -50.438282, -48.44399 ],      dtype=float32), logdensity_grad={'mu': Array([-0.5289164 ,  0.2025292 , -0.11358353, -0.29387444, -0.22782859],      dtype=float32), 'tau': Array([ 0.972792  ,  1.0703826 , -0.6049909 ,  0.6082729 ,  0.72445697],      dtype=float32), 'theta_base': Array([[-2.3605300e-03,  6.9689953e-01, -5.8598340e-01, -6.2643600e-01,
         1.9202059e-02,  6.0784940e-02,  2.4043100e+00,  1.3308096e-01],
       [-1.6145459e-01,  1.8846242e+00, -9.1288722e-01,  7.1373290e-01,
         1.3709313e+00, -1.8212701e+00, -1.9962667e+00, -1.8883181e-01],
       [-5.5228788e-01, -1.8370658e+00,  4.9000901e-01, -8.8506830e-01,
        -2.2309687e+00,  1.6172142e+00, -1.1894215e+00,  1.0260315e+00],
       [ 1.2970624e+00,  8.8716203e-01,  9.4048995e-01,  4.7690460e-01,
         1.9995674e+00, -1.4053854e+00, -8.0283418e-02, -2.6799037e+00],
       [ 5.3674161e-01,  1.3252552e+00,  1.4038040e-01, -4.7285253e-01,
         2.1871793e+00, -1.1764501e+00,  6.6943574e-01, -1.9236170e+00]],      dtype=float32)})

In [17]:
acceptance_rate = np.mean(infos[0])
num_divergent = np.mean(infos[1])

print(f"\Average acceptance rate: {acceptance_rate:.2f}")
print(f"There were {100*num_divergent:.2f}% divergent transitions")

\Average acceptance rate: 0.91
There were 0.00% divergent transitions
