In [1]:
import jax
from jax import random
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer.hmc import hmc
from numpyro.infer.util import initialize_model
from numpyro.util import fori_collect

true_coefs = jnp.array([1., 2., 3.])
data = random.normal(random.PRNGKey(2), (2000, 3))
labels = dist.Bernoulli(logits=(true_coefs * data).sum(-1)).sample(random.PRNGKey(3))

def model(data, labels):
     coefs = numpyro.sample('coefs', dist.Normal(jnp.zeros(3), jnp.ones(3)))
     intercept = numpyro.sample('intercept', dist.Normal(0., 10.))
     return numpyro.sample('y', dist.Bernoulli(logits=(coefs * data + intercept).sum(-1)), obs=labels)

model_info = initialize_model(random.PRNGKey(0), model, model_args=(data, labels,))
init_kernel, sample_kernel = hmc(model_info.potential_fn, algo='NUTS')
hmc_state = init_kernel(model_info.param_info,
                         trajectory_length=10,
                         num_warmup=300)
samples = fori_collect(0, 500, sample_kernel, hmc_state,
                        transform=lambda state: model_info.postprocess_fn(state.z))
print(jnp.mean(samples['coefs'], axis=0))  # doctest: +SKIP

100%|██████████| 500/500 [00:02<00:00, 220.95it/s]


[0.92838067 2.0078518  2.8804803 ]
