In [None]:
import pymc3 as pm
import numpy as np
import jax.numpy as jnp

from blackjax import nuts
# import blackjax.stan_warmup as stan_warmup
import matplotlib.pyplot as plt

In [None]:
pip install blackjax


Collecting blackjax
  Downloading blackjax-0.4.0-py3-none-any.whl (74 kB)
[?25l[K     |████▍                           | 10 kB 34.4 MB/s eta 0:00:01[K     |████████▉                       | 20 kB 35.0 MB/s eta 0:00:01[K     |█████████████▎                  | 30 kB 27.5 MB/s eta 0:00:01[K     |█████████████████▋              | 40 kB 19.7 MB/s eta 0:00:01[K     |██████████████████████          | 51 kB 7.7 MB/s eta 0:00:01[K     |██████████████████████████▌     | 61 kB 9.1 MB/s eta 0:00:01[K     |██████████████████████████████▉ | 71 kB 8.8 MB/s eta 0:00:01[K     |████████████████████████████████| 74 kB 3.1 MB/s 
[?25hInstalling collected packages: blackjax
Successfully installed blackjax-0.4.0


In [None]:
import jax
import jax.numpy as jnp
import jax.scipy.stats as stats
import numpy as np

import blackjax

observed = np.random.normal(10, 20, size=1_000)
def logprob_fn(x):
  logpdf = stats.norm.logpdf(observed, x["loc"], x["scale"])
  return jnp.sum(logpdf)

# Build the kernel
step_size = 1e-3
inverse_mass_matrix = jnp.array([1., 1.])
nuts = blackjax.nuts(logprob_fn, step_size, inverse_mass_matrix)

# Initialize the state
initial_position = {"loc": 1., "scale": 2.}
state = nuts.init(initial_position)

# Iterate
rng_key = jax.random.PRNGKey(0)
for _ in range(100):
    _, rng_key = jax.random.split(rng_key)
    state, _ = nuts.step(rng_key, state)

Data

In [None]:
J = 8
y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])

Modeling

In [None]:
with pm.Model() as model:

    mu = pm.Normal("mu", mu=0.0, sigma=10.0)
    tau = pm.HalfCauchy("tau", 5.0)

    theta = pm.Normal("theta", mu=0, sigma=1, shape=J)
    theta_1 = mu + tau * theta
    obs = pm.Normal("obs", mu=theta_1, sigma=sigma, shape=J, observed=y)

Configuring the model for BlackJAX

In [None]:
from theano.graph.fg import FunctionGraph
from theano.link.jax.jax_dispatch import jax_funcify

seed = jax.random.PRNGKey(1234)
chains = 1

# Get the FunctionGraph of the model.
fgraph = FunctionGraph(model.free_RVs, [model.logpt])

# Jax funcify builds Jax variant of the FunctionGraph.
fns = jax_funcify(fgraph)
logp_fn_jax = fns[0]

# Now we build a Jax variant of the initial state/inputs to the model.
rv_names = [rv.name for rv in model.free_RVs]
init_state = [model.test_point[rv_name] for rv_name in rv_names]
init_state_batched = jax.tree_map(
    lambda x: np.repeat(x[None, ...], chains, axis=0), init_state
)

In [None]:
# Then we transform the Jaxified input and FunctionGraph to a BlackJax NUTS sampler
potential = lambda x: -logp_fn_jax(*x)
initial_position = init_state
initial_state = nuts.new_state(initial_position, potential)

Sampling

In [None]:
%%time

kernel_factory = lambda step_size, inverse_mass_matrix: nuts.kernel(
    potential, step_size, inverse_mass_matrix
)

last_state, (step_size, inverse_mass_matrix), _ = stan_warmup.run(
    seed, kernel_factory, initial_state, 1000
)


def inference_loop(rng_key, kernel, initial_state, num_samples):
    def one_step(state, rng_key):
        state, info = kernel(rng_key, state)
        return state, (state, info)

    keys = jax.random.split(rng_key, num_samples)
    _, (states, infos) = jax.lax.scan(one_step, initial_state, keys)

    return states, infos

# Build the kernel using the step size and inverse mass matrix returned from the window adaptation
kernel = kernel_factory(step_size, inverse_mass_matrix)

# Sample from the posterior distribution
states, infos = inference_loop(seed, kernel, last_state, 50_000)

Resources


*   https://blackjax-devs.github.io/blackjax/sampling.html#nuts
  
*   https://www.kaggle.com/code/s903124/numpyro-speed-benchmark-gpu

