In [1]:
import numpy as np
import matplotlib.pyplot as plt

import jax
import jax.numpy as jnp

import numpyro
from numpyro import distributions as dist, infer
from numpyro.handlers import seed

  from .autonotebook import tqdm as notebook_tqdm


- https://github.com/pyro-ppl/numpyro/tree/master/examples
- https://dfm.io/posts/intro-to-numpyro/
- Check imputation method

In [2]:
jax.__version__

'0.4.13'

In [3]:
true_frac = 0.8

# The linear model has unit slope and zero intercept:
true_params = [1.0, 0.0]

# The outliers are drawn from a Gaussian with zero mean and unit variance:
true_outliers = [0.0, 1.0]

# For reproducibility, let's set the random number seed and generate the data:
np.random.seed(12)
x = np.sort(np.random.uniform(-2, 2, 15))
yerr = 0.2 * np.ones_like(x)
y = true_params[0] * x + true_params[1] + yerr * np.random.randn(len(x))

# Those points are all drawn from the correct model so let's replace some of
# them with outliers.
m_bkg = np.random.rand(len(x)) > true_frac
y[m_bkg] = true_outliers[0]
y[m_bkg] += np.sqrt(true_outliers[1] + yerr[m_bkg] ** 2) * np.random.randn(sum(m_bkg))

# Then save the *true* line.
x0 = np.linspace(-2.1, 2.1, 200)
y0 = np.dot(np.vander(x0, 2), true_params)

In [41]:
def linear_model(x, yerr, y=None):
    # These are the parameters that we're fitting and we're required to define explicit
    # priors using distributions from the numpyro.distributions module.
    theta = numpyro.sample("theta", dist.Uniform(-0.5 * jnp.pi, 0.5 * jnp.pi))
    b_perp = numpyro.sample("b_perp", dist.Normal(0, 1))

    # Transformed parameters (and other things!) can be tracked during sampling using
    # "deterministics" as follows:
    m = numpyro.deterministic("m", jnp.tan(theta))
    b = numpyro.deterministic("b", b_perp / jnp.cos(theta))

    # filter nan 
        
    isnan = jnp.isnan(y)
    nanidx = jnp.nonzero(isnan)[0]
    y = jnp.asarray(y).at[isnan].set(1000)
    # Then we specify the sampling distribution for the data, or the likelihood function.
    # Here we're using a numpyro.plate to indicate that the data are independent. This
    # isn't actually necessary here and we could have equivalently omitted the plate since
    # the Normal distribution can already handle vector-valued inputs. But, it's good to
    # get into the habit of using plates because some inference algorithms or distributions
    # can take advantage of knowing this structure.
    with numpyro.plate("data", len(x)):
        samples = numpyro.sample("y", dist.Normal(m * x + b, yerr), obs=y)
    print(samples)
    return samples

In [44]:
y = jnp.asarray(y).at[-1].set(1.0)

In [45]:
with seed(rng_seed=15):
    testing = linear_model(x, yerr, y)

[]
[-1.26513    -1.7183505  -0.5388688  -1.4683361  -0.03998567 -1.0690173
 -0.9345808  -0.00612578  0.2192333   1.3656585  -0.3919706   1.7774864
  1.8571011   1.9864297   1.        ]


In [8]:
sampler = infer.MCMC(
    infer.NUTS(linear_model),
    num_warmup=100,
    num_samples=100,
    num_chains=1,
    progress_bar=True,
)

In [9]:
%time sampler.run(jax.random.PRNGKey(0), x, yerr, y=y)

[-1.26513002 -1.71835056 -0.53886879 -1.46833607 -0.03998567 -1.06901729
 -0.93458078 -0.00612578  0.21923331  1.36565853 -0.39197061  1.77748639
  1.85710109  1.98642967  1.58986963]
[-1.26513002 -1.71835056 -0.53886879 -1.46833607 -0.03998567 -1.06901729
 -0.93458078 -0.00612578  0.21923331  1.36565853 -0.39197061  1.77748639
  1.85710109  1.98642967  1.58986963]


  0%|                                                                                                                                                       | 0/200 [00:00<?, ?it/s]

[-1.26513002 -1.71835056 -0.53886879 -1.46833607 -0.03998567 -1.06901729
 -0.93458078 -0.00612578  0.21923331  1.36565853 -0.39197061  1.77748639
  1.85710109  1.98642967  1.58986963]


sample: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:01<00:00, 156.84it/s, 7 steps of size 8.45e-01. acc. prob=0.92]

[-1.26513002 -1.71835056 -0.53886879 -1.46833607 -0.03998567 -1.06901729
 -0.93458078 -0.00612578  0.21923331  1.36565853 -0.39197061  1.77748639
  1.85710109  1.98642967  1.58986963]
CPU times: user 2.44 s, sys: 34.5 ms, total: 2.48 s
Wall time: 2.46 s





In [10]:
sampler.get_samples()

{'b': Array([0.19321087, 0.11051232, 0.1303351 , 0.09984763, 0.1311001 ,
        0.09537995, 0.09537995, 0.09797271, 0.13073991, 0.05498868,
        0.18083394, 0.07032708, 0.1678519 , 0.08739278, 0.07167584,
        0.13554436, 0.10206417, 0.12331628, 0.08453617, 0.08792339,
        0.08380423, 0.15920007, 0.1251351 , 0.15731004, 0.0510047 ,
        0.19470648, 0.14997862, 0.12658292, 0.05896591, 0.15566877,
        0.08477577, 0.09539074, 0.09618109, 0.14522313, 0.11880869,
        0.15404525, 0.18958052, 0.02545352, 0.12718353, 0.16434951,
        0.13340482, 0.10969486, 0.10690323, 0.1594108 , 0.17310472,
        0.17958575, 0.11810824, 0.11810824, 0.01856358, 0.01076789,
        0.15431204, 0.16992623, 0.07489697, 0.07409239, 0.01781133,
        0.14722526, 0.11527144, 0.09242926, 0.14661925, 0.2711162 ,
        0.06181287, 0.18632114, 0.16489294, 0.02995745, 0.16761327,
        0.08059081, 0.04307056, 0.20465752, 0.23279019, 0.16811176,
        0.08122335, 0.09703864, 0.133247  ,

In [16]:
cov = np.random.randn(4, 4)

In [17]:
cov = cov @ cov.T

In [18]:
cov = jnp.asarray(cov)

In [23]:
chol = jnp.linalg.cholesky(cov)

In [27]:
mvn = dist.MultivariateNormal(loc = jnp.ones(4), scale_tril=chol)

In [28]:
mvn.sample(jax.random.PRNGKey(0), sample_shape=(10,))

Array([[ 1.052786  , -1.7269607 ,  4.177167  ,  4.322926  ],
       [ 1.2105322 ,  2.2195706 , -2.5608053 ,  0.15977508],
       [ 1.4913502 , -1.1971645 , -1.199831  , -1.7133994 ],
       [ 0.9804101 ,  3.5195224 ,  0.90743124,  4.1972857 ],
       [ 0.7316028 , -0.25510836, -2.6458333 ,  0.10257113],
       [ 1.12325   ,  1.5281603 , -0.13747454,  0.43736088],
       [ 0.22649306,  0.1536848 , -2.1526842 ,  0.61269164],
       [ 0.748422  ,  1.4179997 ,  4.3404984 ,  4.62397   ],
       [ 2.6816187 ,  4.3796167 ,  7.195057  ,  0.9375721 ],
       [-0.33204806, -0.04597795,  1.4204223 ,  1.6626201 ]],      dtype=float32)

In [60]:
@jax.jit
def func(x):
    isnan = jnp.isnan(x)
    return jnp.where(~isnan, x, 1000.0)

In [61]:
testpoint = jnp.array([0.0, 1.0, jnp.nan])

In [70]:
~jnp.isnan(testpoint)

Array([ True,  True, False], dtype=bool)

In [69]:
jnp.where(~jnp.isnan(testpoint), x=testpoint, y=100.0)

Array([  0.,   1., 100.], dtype=float32)

In [63]:
func(testpoint)

Array([   0.,    1., 1000.], dtype=float32)

In [74]:
cond = ~jnp.isnan(testpoint)

In [75]:
cond

Array([ True,  True, False], dtype=bool)

In [78]:
x = np.hstack([np.array([np.nan]*4), np.random.randn(6)*2+5])

In [79]:
x

array([       nan,        nan,        nan,        nan, 7.00188552,
       4.89717001, 5.3195754 , 3.56747283, 5.10104565, 4.71332517])

In [85]:
def model2b(x):
    x_impute = numpyro.sample("x_impute", dist.Normal(0, 1).expand([4]).mask(False))
    print(x_impute)
    x_imputed = jnp.concatenate([x_impute, x[4:]])
    x_sample = numpyro.sample("x", dist.Normal(0, 1), obs=x_imputed)
    print(x_sample)

In [87]:
with seed(rng_seed=10):
    testing = model2b(x)

[ 0.47754705 -0.52989805  2.4254878   0.7563164 ]
[ 0.47754705 -0.52989805  2.4254878   0.7563164   7.0018854   4.89717
  5.3195753   3.567473    5.1010456   4.713325  ]


In [113]:
s = jnp.array([1.5, jnp.nan, 2.0])
u = np.random.randn(3,3)
u = u @ u.T

In [114]:
jnp.linalg.eigvals(u)

Array([5.1863756 +0.j, 0.09253304+0.j, 1.321759  +0.j], dtype=complex64)

In [115]:
discard = jnp.isnan(s)

In [116]:
u = jnp.where(discard, 0, u)

In [117]:
u = jnp.where(discard[:,None], 0, u)

In [118]:
jnp.where(discard, 0, s)

Array([1.5, 0. , 2. ], dtype=float32)

In [119]:
u

Array([[ 1.0252076 ,  0.        , -0.97729826],
       [ 0.        ,  0.        ,  0.        ],
       [-0.97729826,  0.        ,  4.7572317 ]], dtype=float32)

In [120]:
jnp.linalg.eigvals(u)

Array([0.78477407+0.j, 4.9976654 +0.j, 0.        +0.j], dtype=complex64)