In [5]:
%env JAX_PLATFORM_NAME=cuda

import warnings

import jax
import jax.numpy as jnp
import jax.random as jr
import numpyro
from numpyro import distributions as dist
from numpyro.infer import MCMC, NUTS

from mcmc import run_lmc_numpyro


warnings.simplefilter("ignore", FutureWarning)

jnp.set_printoptions(precision=3, suppress=True)
jax.config.update("jax_enable_x64", True)
print(jax.devices("cuda"))


def model():
    numpyro.sample("a", dist.Exponential(0.01))


num_chains = 2**4
num_samples_per_chain = 2**10
warmup_len = 2**10

env: JAX_PLATFORM_NAME=cuda
[cuda(id=0)]


In [6]:
out_logreg_lmc, steps_logreg_lmc = run_lmc_numpyro(
    jr.PRNGKey(3),
    model,
    (),
    num_chains,
    num_samples_per_chain,
    chain_sep=0.1,
    tol=0.01,
    warmup_mult=warmup_len,
    warmup_tol_mult=8,
    use_adaptive=False,
)

100.00%|██████████| [00:01<00:00, 67.96%/s]
100.00%|██████████| [00:24<00:00,  4.02%/s]

LMC: gradient evaluations per output: 6.512





In [7]:
arr_lmc = jnp.exp(out_logreg_lmc["a"])
print(f"mean: {jnp.mean(arr_lmc):.4}, var:  {jnp.var(arr_lmc):.4}")

mean: 93.61, var:  9.654e+03


In [8]:
nuts = MCMC(
    NUTS(model),
    num_warmup=warmup_len,
    num_samples=num_samples_per_chain,
    num_chains=num_chains,
    chain_method="vectorized",
)
nuts.warmup(
    jr.PRNGKey(2),
    extra_fields=("num_steps",),
    collect_warmup=True,
)
warmup_steps = sum(nuts.get_extra_fields()["num_steps"])
nuts.run(jr.PRNGKey(2), extra_fields=("num_steps",))
out_nuts = nuts.get_samples(group_by_chain=True)
num_steps_nuts = sum(nuts.get_extra_fields()["num_steps"]) + warmup_steps
geps_nuts = num_steps_nuts / (num_chains * num_samples_per_chain)
print(geps_nuts)

warmup: 100%|██████████| 1024/1024 [00:03<00:00, 328.18it/s]
sample: 100%|██████████| 1024/1024 [00:01<00:00, 920.16it/s]


6.16973876953125


In [9]:
arr_nuts = out_nuts["a"]
print(f"mean: {jnp.mean(arr_nuts):.4}, var:  {jnp.var(arr_nuts):.4}")

mean: 100.3, var:  9.845e+03
