In [1]:
import numpy as np
import matplotlib.pyplot as plt

import jax
import jax.numpy as jnp

import numpyro
from numpyro import distributions as dist, infer
from numpyro.handlers import seed

  from .autonotebook import tqdm as notebook_tqdm


- https://github.com/pyro-ppl/numpyro/tree/master/examples
- https://dfm.io/posts/intro-to-numpyro/
- Check imputation method

In [2]:
# We'll choose the parameters of our synthetic data.
# The outlier probability will be 80%:
true_frac = 0.8

# The linear model has unit slope and zero intercept:
true_params = [1.0, 0.0]

# The outliers are drawn from a Gaussian with zero mean and unit variance:
true_outliers = [0.0, 1.0]

# For reproducibility, let's set the random number seed and generate the data:
np.random.seed(12)
x = np.sort(np.random.uniform(-2, 2, 15))
yerr = 0.2 * np.ones_like(x)
y = true_params[0] * x + true_params[1] + yerr * np.random.randn(len(x))

# Those points are all drawn from the correct model so let's replace some of
# them with outliers.
m_bkg = np.random.rand(len(x)) > true_frac
y[m_bkg] = true_outliers[0]
y[m_bkg] += np.sqrt(true_outliers[1] + yerr[m_bkg] ** 2) * np.random.randn(sum(m_bkg))

# Then save the *true* line.
x0 = np.linspace(-2.1, 2.1, 200)
y0 = np.dot(np.vander(x0, 2), true_params)


In [12]:
def linear_model(x, yerr, y=None):
    # These are the parameters that we're fitting and we're required to define explicit
    # priors using distributions from the numpyro.distributions module.
    theta = numpyro.sample("theta", dist.Uniform(-0.5 * jnp.pi, 0.5 * jnp.pi))
    b_perp = numpyro.sample("b_perp", dist.Normal(0, 1))

    # Transformed parameters (and other things!) can be tracked during sampling using
    # "deterministics" as follows:
    m = numpyro.deterministic("m", jnp.tan(theta))
    b = numpyro.deterministic("b", b_perp / jnp.cos(theta))

    # Then we specify the sampling distribution for the data, or the likelihood function.
    # Here we're using a numpyro.plate to indicate that the data are independent. This
    # isn't actually necessary here and we could have equivalently omitted the plate since
    # the Normal distribution can already handle vector-valued inputs. But, it's good to
    # get into the habit of using plates because some inference algorithms or distributions
    # can take advantage of knowing this structure.
    with numpyro.plate("data", len(x)):
        numpyro.sample("y", dist.Normal(m * x + b, yerr), obs=y)

In [13]:
sampler = infer.MCMC(
    infer.NUTS(linear_model),
    num_warmup=2000,
    num_samples=2000,
    num_chains=2,
    progress_bar=True,
)

  sampler = infer.MCMC(


In [14]:
%time sampler.run(jax.random.PRNGKey(0), x, yerr, y=y)

sample: 100%|████████████████████████████████████████████████████████████████████████| 4000/4000 [00:01<00:00, 2093.70it/s, 3 steps of size 9.33e-01. acc. prob=0.91]
sample: 100%|████████████████████████████████████████████████████████████████████████| 4000/4000 [00:00<00:00, 5528.99it/s, 3 steps of size 8.96e-01. acc. prob=0.92]

CPU times: user 2.76 s, sys: 29.5 ms, total: 2.79 s
Wall time: 2.77 s





In [8]:
sampler.print_summary()


                mean       std    median      5.0%     95.0%     n_eff     r_hat
    b_perp      0.10      0.04      0.10      0.03      0.17   3946.07      1.00
     theta      0.65      0.02      0.65      0.61      0.69   3224.38      1.00

Number of divergences: 0
