<a href="https://colab.research.google.com/github/probml/pyprobml/blob/master/book1/bayes_stats/numpyro_intro.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

[NumPyro](https://github.com/pyro-ppl/numpyro) is probabilistic programming language built on top of JAX. It is very similar to [Pyro](https://pyro.ai/), which is built on top of PyTorch, but [tends to be faster](https://stackoverflow.com/questions/61846620/numpyro-vs-pyro-why-is-former-100x-faster-and-when-should-i-use-the-latter). (Both Pyro flavors are usually also [faster than PyMc3](https://www.kaggle.com/s903124/numpyro-speed-benchmark).)

This colab gives a brief introduction (WIP).

# Installation

In [None]:
# Standard Python libraries
from __future__ import absolute_import, division, print_function, unicode_literals

import os
import time
#import numpy as np
#np.set_printoptions(precision=3)
import glob
import matplotlib.pyplot as plt
import PIL
import imageio

from IPython import display
%matplotlib inline

import sklearn

import seaborn as sns;
sns.set(style="ticks", color_codes=True)

import pandas as pd
pd.set_option('precision', 2) # 2 decimal places
pd.set_option('display.max_rows', 20)
pd.set_option('display.max_columns', 30)
pd.set_option('display.width', 100) # wide windows

In [None]:
import jax
import jax.numpy as np
import numpy as onp # original numpy

print("jax version {}".format(jax.__version__))
print("jax backend {}".format(jax.lib.xla_bridge.get_backend().platform))

jax version 0.2.7
jax backend gpu


In [None]:
# https://github.com/pyro-ppl/numpyro
!pip install numpyro

# It seems that numpyro installs jaxlib for CPU
#https://github.com/pyro-ppl/numpyro/issues/531

Collecting numpyro
[?25l  Downloading https://files.pythonhosted.org/packages/35/f1/7bada66676245f9e085b870b1051ba183b377af287002e10a2e1bea1b498/numpyro-0.4.1-py3-none-any.whl (176kB)
[K     |████████████████████████████████| 184kB 8.6MB/s 
Collecting jax==0.2.3
[?25l  Downloading https://files.pythonhosted.org/packages/d7/b2/738298445cb0d9445e84f58f1fdaf73aa7b1d4199e6360620461d6fe3a8b/jax-0.2.3.tar.gz (473kB)
[K     |████████████████████████████████| 481kB 12.4MB/s 
[?25hCollecting jaxlib==0.1.56
[?25l  Downloading https://files.pythonhosted.org/packages/aa/44/16d06ee6418ae1b020b0722f7b7465baa08031a85728392e5413dd4e3e04/jaxlib-0.1.56-cp36-none-manylinux2010_x86_64.whl (32.1MB)
[K     |████████████████████████████████| 32.1MB 111kB/s 
Building wheels for collected packages: jax
  Building wheel for jax (setup.py) ... [?25l[?25hdone
  Created wheel for jax: filename=jax-0.2.3-cp36-none-any.whl size=542178 sha256=f258c0d1f96711cc0b308e64517b4d916ae57c44003c7a217fc8b6cf71fdccd8
 

In [None]:
import jax
import jax.numpy as np
import numpy as onp # original numpy
from jax import random

print("jax version {}".format(jax.__version__))
print("jax backend {}".format(jax.lib.xla_bridge.get_backend().platform))

jax version 0.2.7
jax backend gpu


# Distributions

In [None]:
import numpyro
import numpyro.distributions as dist
from numpyro.diagnostics import hpdi
from numpyro.distributions.transforms import AffineTransform
from numpyro.infer import MCMC, NUTS, Predictive

rng_key = random.PRNGKey(0)
rng_key, rng_key_ = random.split(rng_key)

## 1d Gaussian

In [None]:
# 2 independent 1d gaussians (ie 1 diagonal Gaussian)
mu = 1.5
sigma = 2
d = dist.Normal(mu, sigma)
dir(d)

['__call__',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_batch_shape',
 '_event_shape',
 '_validate_args',
 '_validate_sample',
 'arg_constraints',
 'batch_shape',
 'enumerate_support',
 'event_dim',
 'event_shape',
 'expand',
 'expand_by',
 'has_enumerate_support',
 'icdf',
 'is_discrete',
 'loc',
 'log_prob',
 'mask',
 'mean',
 'reparametrized_params',
 'sample',
 'sample_with_intermediates',
 'scale',
 'set_default_validate_args',
 'shape',
 'support',
 'to_event',
 'tree_flatten',
 'tree_unflatten',
 'variance']

In [None]:
rng_key, rng_key_ = random.split(rng_key)
nsamples = 1000
ys = d.sample(rng_key_, (nsamples,))
print(ys.shape)
mu_hat = np.mean(ys,0)
print(mu_hat)
sigma_hat = np.std(ys, 0)
print(sigma_hat)

(1000,)
1.5070927
2.0493808


## Multivariate Gaussian



In [None]:
mu = np.array([-1, 1])
sigma = np.array([1, 2])
Sigma = np.diag(sigma)
d2 = dist.MultivariateNormal(mu, Sigma)

In [None]:
#rng_key, rng_key_ = random.split(rng_key)
nsamples = 1000
ys = d2.sample(rng_key_, (nsamples,))
print(ys.shape)
mu_hat = np.mean(ys,0)
print(mu_hat)
Sigma_hat = np.cov(ys, rowvar=False) #jax.np.cov not implemented
print(Sigma_hat)

(1000, 2)
[-1.0127413  1.0091063]
[[ 0.9770031  -0.00533966]
 [-0.00533966  1.9718108 ]]


## Shape semantics

Numpyro, [Pyro](https://pyro.ai/examples/tensor_shapes.html) and [TFP](https://www.tensorflow.org/probability/examples/Understanding_TensorFlow_Distributions_Shapes) all distinguish between 'event shape' and 'batch shape'.
For a D-dimensional Gaussian, the event shape is (D,), and the batch shape
will be (), meaning we have a single instance of this distribution.
If the covariance is diagonal, we can view this as D independent
1d Gaussians, stored along the batch dimension; this will have event shape () but batch shape (2,). 

When we sample from a distribution, we also specify the sample_shape.
Suppose we draw N samples  from a single D-dim diagonal Gaussian,
and N samples from D 1d Gaussians. These samples will have the same shape.
However, the semantics of logprob differs.
We illustrate this below.


In [None]:
d2 = dist.MultivariateNormal(mu, Sigma)
print(f'event shape {d2.event_shape}, batch shape {d2.batch_shape}') 
nsamples = 3
ys2 = d2.sample(rng_key_, (nsamples,))
print('samples, shape {}'.format(ys2.shape))
print(ys2)

# 2 independent 1d gaussians (same as one 2d diagonal Gaussian)
d3 = dist.Normal(mu, np.diag(Sigma))
print(f'event shape {d3.event_shape}, batch shape {d3.batch_shape}') 
ys3 = d3.sample(rng_key_, (nsamples,))
print('samples, shape {}'.format(ys3.shape))
print(ys3)

print(np.allclose(ys2, ys3))

event shape (2,), batch shape ()
samples, shape (3, 2)
[[-0.06819373  0.9942934 ]
 [-1.740325   -1.0183868 ]
 [ 0.05969942  2.314332  ]]
event shape (), batch shape (2,)
samples, shape (3, 2)
[[-0.06819373  0.99192965]
 [-1.740325   -1.85443   ]
 [ 0.05969942  2.8587465 ]]
False


In [None]:
y = ys2[0,:] # 2 numbers
print(d2.log_prob(y)) # log prob of a single 2d distribution on 2d input 
print(d3.log_prob(y)) # log prob of two 1d distributions on 2d input


-2.6185904
[-1.35307   -1.6120898]


We can turn a set of independent distributions into a single product
distribution using the [Independent class](http://num.pyro.ai/en/stable/distributions.html#independent)


In [None]:
d4 = dist.Independent(d3, 1) # treat the first batch dimension as an event dimensions
print(d4.event_shape)
print(d4.batch_shape)
print(d4.log_prob(y))

(2,)
()
-2.96516


# Posterior inference with MCMC


## Example: 1d Gaussian with unknown mean.

We use the simple example from the [Pyro intro](https://pyro.ai/examples/intro_part_ii.html#A-Simple-Example). The goal is to infer the weight $\theta$ of an object, given noisy measurements $y$. We assume the following model:
$$
\begin{align}
\theta &\sim N(\mu=8.5, \tau^2=1.0)\\ 
y \sim &N(\theta, \sigma^2=0.75^2)
\end{align}
$$

Where $\mu=8.5$ is the initial guess. 

By Bayes rule for Gaussians, we know that the exact posterior,
given a single observation $y=9.5$, is given by


$$
\begin{align}
\theta|y &\sim N(m, s^s) \\
m &=\frac{\sigma^2 \mu + \tau^2 y}{\sigma^2 + \tau^2} 
  = \frac{0.75^2 \times 8.5 + 1 \times 9.5}{0.75^2 + 1^2}
  = 9.14 \\
s^2 &= \frac{\sigma^2 \tau^2}{\sigma^2  + \tau^2} 
= \frac{0.75^2 \times 1^2}{0.75^2 + 1^2}= 0.6^2
\end{align}
$$

In [None]:
mu = 8.5; tau = 1.0; sigma = 0.75; y = 9.5
m = (sigma**2 * mu + tau**2 * y)/(sigma**2 + tau**2)
s2 = (sigma**2 * tau**2)/(sigma**2 + tau**2)
s = np.sqrt(s2)
print(m)
print(s)

9.14
0.6


In [None]:
def model(prior_mean, prior_sd, obs_sd, measurement=None):
    theta = numpyro.sample("theta", dist.Normal(prior_mean, prior_sd))
    return numpyro.sample("y", dist.Normal(theta, obs_sd), obs=measurement)


In [None]:
nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_warmup=100, num_samples=1000)
mcmc.run(rng_key_, mu, tau, sigma, y)

mcmc.print_summary()
samples  = mcmc.get_samples()
 


sample: 100%|██████████| 1100/1100 [00:03<00:00, 286.64it/s, 3 steps of size 9.41e-01. acc. prob=0.91]



                mean       std    median      5.0%     95.0%     n_eff     r_hat
     theta      9.17      0.60      9.13      8.27     10.15    365.16      1.00

Number of divergences: 0
