In [1]:
import numpy as np
import matplotlib.pyplot as plt

import jax
import jax.numpy as jnp

import numpyro
from numpyro import distributions as dist, infer
from numpyro.handlers import seed

- https://github.com/pyro-ppl/numpyro/tree/master/examples
- https://dfm.io/posts/intro-to-numpyro/
- Check imputation method
- https://jrnold.github.io/bayesian_notes/mcmc-diagnostics.html
- https://mc-stan.org/docs/2_19/reference-manual/effective-sample-size-section.html

In [2]:
# We'll choose the parameters of our synthetic data.
# The outlier probability will be 80%:
true_frac = 0.8

# The linear model has unit slope and zero intercept:
true_params = [1.0, 0.0]

# The outliers are drawn from a Gaussian with zero mean and unit variance:
true_outliers = [0.0, 1.0]

# For reproducibility, let's set the random number seed and generate the data:
np.random.seed(12)
x = np.sort(np.random.uniform(-2, 2, 15))
yerr = 0.2 * np.ones_like(x)
y = true_params[0] * x + true_params[1] + yerr * np.random.randn(len(x))

# Those points are all drawn from the correct model so let's replace some of
# them with outliers.
m_bkg = np.random.rand(len(x)) > true_frac
y[m_bkg] = true_outliers[0]
y[m_bkg] += np.sqrt(true_outliers[1] + yerr[m_bkg] ** 2) * np.random.randn(sum(m_bkg))

# Then save the *true* line.
x0 = np.linspace(-2.1, 2.1, 200)
y0 = np.dot(np.vander(x0, 2), true_params)


In [3]:
def linear_model(x, yerr, y=None):
    # These are the parameters that we're fitting and we're required to define explicit
    # priors using distributions from the numpyro.distributions module.
    theta = numpyro.sample("theta", dist.Uniform(-0.5 * jnp.pi, 0.5 * jnp.pi))
    b_perp = numpyro.sample("b_perp", dist.Normal(0, 1))

    # Transformed parameters (and other things!) can be tracked during sampling using
    # "deterministics" as follows:
    m = numpyro.deterministic("m", jnp.tan(theta))
    b = numpyro.deterministic("b", b_perp / jnp.cos(theta))

    # Then we specify the sampling distribution for the data, or the likelihood function.
    # Here we're using a numpyro.plate to indicate that the data are independent. This
    # isn't actually necessary here and we could have equivalently omitted the plate since
    # the Normal distribution can already handle vector-valued inputs. But, it's good to
    # get into the habit of using plates because some inference algorithms or distributions
    # can take advantage of knowing this structure.
    with numpyro.plate("data", len(x)):
        numpyro.sample("y", dist.Normal(m * x + b, yerr), obs=y)

In [4]:
sampler = infer.MCMC(
    infer.NUTS(linear_model),
    num_warmup=2000,
    num_samples=2000,
    num_chains=2,
    progress_bar=True,
)

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
  sampler = infer.MCMC(


In [5]:
%time sampler.run(jax.random.PRNGKey(0), x, yerr, y=y, extra_fields=('potential_energy', 'num_steps'))

sample: 100%|███████████████████████████████████████████████████████████████████████████| 4000/4000 [00:01<00:00, 2001.05it/s, 3 steps of size 9.33e-01. acc. prob=0.91]
sample: 100%|███████████████████████████████████████████████████████████████████████████| 4000/4000 [00:00<00:00, 5516.80it/s, 3 steps of size 8.96e-01. acc. prob=0.92]


CPU times: user 4.2 s, sys: 49.9 ms, total: 4.25 s
Wall time: 4.23 s


In [6]:
sampler.print_summary()


                mean       std    median      5.0%     95.0%     n_eff     r_hat
    b_perp      0.10      0.04      0.10      0.03      0.17   3946.07      1.00
     theta      0.65      0.02      0.65      0.61      0.69   3224.38      1.00

Number of divergences: 0


In [7]:
sampler.get_samples()

{'b': Array([0.06753219, 0.15629883, 0.14952852, ..., 0.09567709, 0.11458827,
        0.15024549], dtype=float32),
 'b_perp': Array([0.05319173, 0.12351318, 0.11815203, ..., 0.07576108, 0.08877319,
        0.11912037], dtype=float32),
 'm': Array([0.78222936, 0.7754641 , 0.77565634, ..., 0.771274  , 0.81618613,
        0.76867133], dtype=float32),
 'theta': Array([0.6638109 , 0.65959996, 0.65972   , ..., 0.656978  , 0.68453294,
        0.65534407], dtype=float32)}

In [14]:
sum(sampler.get_extra_fields(group_by_chain=True)['num_steps'][1]) + sum(sampler.get_extra_fields(group_by_chain=True)['num_steps'][0])

Array(14512, dtype=int32)

In [16]:
summary = sampler.print_summary()


                mean       std    median      5.0%     95.0%     n_eff     r_hat
    b_perp      0.10      0.04      0.10      0.03      0.17   3946.07      1.00
     theta      0.65      0.02      0.65      0.61      0.69   3224.38      1.00

Number of divergences: 0


In [23]:
 3224.3837150814024 / 14512

0.22218741145820028

In [24]:
3946.07 / 14512

0.27191772326350605