In [5]:
import numpy as np
import matplotlib.pyplot as plt

import jax
import jax.numpy as jnp

import numpyro
from numpyro import distributions as dist, infer
from numpyro.handlers import seed

In [2]:
true_frac = 0.8

# The linear model has unit slope and zero intercept:
true_params = [1.0, 0.0]

# The outliers are drawn from a Gaussian with zero mean and unit variance:
true_outliers = [0.0, 1.0]

# For reproducibility, let's set the random number seed and generate the data:
np.random.seed(12)
x = np.sort(np.random.uniform(-2, 2, 15))
yerr = 0.2 * np.ones_like(x)
y = true_params[0] * x + true_params[1] + yerr * np.random.randn(len(x))

# Those points are all drawn from the correct model so let's replace some of
# them with outliers.
m_bkg = np.random.rand(len(x)) > true_frac
y[m_bkg] = true_outliers[0]
y[m_bkg] += np.sqrt(true_outliers[1] + yerr[m_bkg] ** 2) * np.random.randn(sum(m_bkg))

# Then save the *true* line.
x0 = np.linspace(-2.1, 2.1, 200)
y0 = np.dot(np.vander(x0, 2), true_params)

In [13]:
def linear_model(x, yerr, y=None):
    # These are the parameters that we're fitting and we're required to define explicit
    # priors using distributions from the numpyro.distributions module.
    theta = numpyro.sample("theta", dist.Uniform(-0.5 * jnp.pi, 0.5 * jnp.pi))
    b_perp = numpyro.sample("b_perp", dist.Normal(0, 1))

    # Transformed parameters (and other things!) can be tracked during sampling using
    # "deterministics" as follows:
    m = numpyro.deterministic("m", jnp.tan(theta))
    b = numpyro.deterministic("b", b_perp / jnp.cos(theta))

    # Then we specify the sampling distribution for the data, or the likelihood function.
    # Here we're using a numpyro.plate to indicate that the data are independent. This
    # isn't actually necessary here and we could have equivalently omitted the plate since
    # the Normal distribution can already handle vector-valued inputs. But, it's good to
    # get into the habit of using plates because some inference algorithms or distributions
    # can take advantage of knowing this structure.
    with numpyro.plate("data", len(x)):
        samples = numpyro.sample("y", dist.Normal(m * x + b, yerr), obs=y)
    print(samples)
    return samples

In [14]:
with seed(rng_seed=20):
    testing = linear_model(x, yerr, y)

[-1.26513002 -1.71835056 -0.53886879 -1.46833607 -0.03998567 -1.06901729
 -0.93458078 -0.00612578  0.21923331  1.36565853 -0.39197061  1.77748639
  1.85710109  1.98642967  1.58986963]


In [18]:
sampler = infer.MCMC(
    infer.NUTS(linear_model),
    num_warmup=100,
    num_samples=100,
    num_chains=1,
    progress_bar=True,
)

In [19]:
%time sampler.run(jax.random.PRNGKey(0), x, yerr, y=y)

[-1.26513002 -1.71835056 -0.53886879 -1.46833607 -0.03998567 -1.06901729
 -0.93458078 -0.00612578  0.21923331  1.36565853 -0.39197061  1.77748639
  1.85710109  1.98642967  1.58986963]
[-1.26513002 -1.71835056 -0.53886879 -1.46833607 -0.03998567 -1.06901729
 -0.93458078 -0.00612578  0.21923331  1.36565853 -0.39197061  1.77748639
  1.85710109  1.98642967  1.58986963]


  0%|                                                                                                                                      | 0/200 [00:00<?, ?it/s]

[-1.26513002 -1.71835056 -0.53886879 -1.46833607 -0.03998567 -1.06901729
 -0.93458078 -0.00612578  0.21923331  1.36565853 -0.39197061  1.77748639
  1.85710109  1.98642967  1.58986963]


sample: 100%|█████████████████████████████████████████████████████████████████████████| 200/200 [00:01<00:00, 152.68it/s, 7 steps of size 8.45e-01. acc. prob=0.92]

[-1.26513002 -1.71835056 -0.53886879 -1.46833607 -0.03998567 -1.06901729
 -0.93458078 -0.00612578  0.21923331  1.36565853 -0.39197061  1.77748639
  1.85710109  1.98642967  1.58986963]
CPU times: user 1.48 s, sys: 9.54 ms, total: 1.49 s
Wall time: 1.48 s





In [20]:
sampler.get_samples()

{'b': Array([0.19321087, 0.11051232, 0.1303351 , 0.09984763, 0.1311001 ,
        0.09537995, 0.09537995, 0.09797271, 0.13073991, 0.05498868,
        0.18083394, 0.07032708, 0.1678519 , 0.08739278, 0.07167584,
        0.13554436, 0.10206417, 0.12331628, 0.08453617, 0.08792339,
        0.08380423, 0.15920007, 0.1251351 , 0.15731004, 0.0510047 ,
        0.19470648, 0.14997862, 0.12658292, 0.05896591, 0.15566877,
        0.08477577, 0.09539074, 0.09618109, 0.14522313, 0.11880869,
        0.15404525, 0.18958052, 0.02545352, 0.12718353, 0.16434951,
        0.13340482, 0.10969486, 0.10690323, 0.1594108 , 0.17310472,
        0.17958575, 0.11810824, 0.11810824, 0.01856358, 0.01076789,
        0.15431204, 0.16992623, 0.07489697, 0.07409239, 0.01781133,
        0.14722526, 0.11527144, 0.09242926, 0.14661925, 0.2711162 ,
        0.06181287, 0.18632114, 0.16489294, 0.02995745, 0.16761327,
        0.08059081, 0.04307056, 0.20465752, 0.23279019, 0.16811176,
        0.08122335, 0.09703864, 0.133247  ,