In [1]:
import time
import os

import numpy as np

import jax
import jax.numpy as jnp
from jax import random

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, init_to_median

import matplotlib.pyplot as plt
import arviz as az

numpyro.set_platform("cpu")
numpyro.set_host_device_count(4)

rng_key = random.PRNGKey(67)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def run_mcmc(
    rng_key,  # random key
    model,  # Numpyro model
    args,  # Dictionary of arguments
    verbose=True,  # boolean for verbose MCMC
):
    init_strategy = init_to_median(num_samples=10)
    kernel = NUTS(model, init_strategy=init_strategy)
    mcmc = MCMC(
        kernel,
        num_warmup=args["num_warmup"],
        num_samples=args["num_mcmc_samples"],
        num_chains=args["num_chains"],
        thinning=args["thinning"],
        progress_bar=True,
    )
    start = time.time()
    mcmc.run(rng_key, args)
    t_elapsed = time.time() - start
    if verbose:
        mcmc.print_summary(exclude_deterministic=False)
    else:
        mcmc.print_summary()

    print("\nMCMC elapsed time:", round(t_elapsed), "s")

    # plot posterior distribution and traceplots
    data = az.from_numpyro(mcmc)
    az.plot_trace(data, compact=True)
    plt.tight_layout()

    return mcmc, mcmc.get_samples(), t_elapsed