# Computational Inference via NumPyro

We want try out what we learned about computational inference and conduct the inferences we did in the previous notebook now with the use of sampling methods. We will be using the NumPyro library, to which we will do a short introduction before digging into the inference part. 

## 1. Model specification in NumPyro

As before we want to estimate the probability of a coin landing heads based on a number of observed flips. Assuming a Beta-distributed prior with parameters $\alpha=\beta=10$, the model has the form 

$$
    \theta \sim \mathrm{Beta}(10, 10) \\
    y_i \sim \mathrm{Bernoulli}(\theta), \quad i = 1, \dots, N
$$

In [None]:
ALPHA = 10
BETA = 10

### Numpy

Let's start by building a probabilistic model as a Python function just using NumPy.

In [None]:
import numpy as np

We can use `np.random` to sample from distributions, e.g., from the prior.

In [None]:
np.random.beta(ALPHA, BETA)

In [None]:
def model(n=1):
    theta = np.random.beta(ALPHA, BETA)  # prior
    y = np.random.binomial(1, theta, size=n)  # likelihood
    return theta, y

Or indeed from the model. We simply sample theta, then n samples of y given that theta and return them.

In [None]:
model(20)

### Numpyro

The corresponding model in NumPyro will look very similar, but instead of using distributions from `numpy.random` we use `numpyro.sample` statements and pass the corresponding distribution.

In [None]:
import numpyro
import numpyro.distributions as dist

# tell numpyro to use multiple cores
numpyro.set_host_device_count(4)

In [None]:
def model(n=1):
    theta = numpyro.sample("theta", dist.Beta(ALPHA, BETA))  # prior
    y = numpyro.sample("y", dist.Bernoulli(theta).expand((n,)))  # likelihood
    return theta, y

One final change is that NumPyro does not maintain global random state, so we need to manually seed the model. Don't worry about why this is here.

In [None]:
from jax import random

rng_key = random.PRNGKey(42)

with numpyro.handlers.seed(rng_seed=rng_key):
    theta, y = model(n=20)

theta, y

### Tracing models

The reason for annotating variables we want to sample with numpyro.sample statements is so that NumPyro can trace the execution of the model and understand its structure for inference purposes. NumPyro actually exposes the trace handler so we can do this ourselves to see what information is captured when we run the model.

In [None]:
 with numpyro.handlers.seed(rng_seed=rng_key), numpyro.handlers.trace() as t:
    theta, y = model(n=20)

t

As you can see, the trace picks out the annotated variables, their values and a bunch of other information that might be relevant to inference algorithms.

### Improving the model

We can make a few improvements to the model.

- We will allow the user to pass in y values for the model to condition on. This ability to condition on observed data is crucial in order to make inferences about the latent variables ($\theta$ in this case).
- We will replace `.expand` with a `numpyro.plate` context manager. This NumPyro primitive is another way of specifying batch dimensions inspired by [plate notation](https://en.wikipedia.org/wiki/Plate_notation) and makes for more readble code + additional useful output in the model trace.
- Finally we won't bother with return values, because as we can see above all of the values are captured in the model trace. Generally we only use return values in NumPyro models for debugging purposes.

In [None]:
def model(y=None):

    n = y.shape[0] if y is not None else 20

    # prior
    theta = numpyro.sample("theta", dist.Beta(ALPHA, BETA))

    # likelihood
    with numpyro.plate("n", n):
        numpyro.sample("y", dist.Bernoulli(theta), obs=y)

Numpyro lets us create neat visualizations of our probabilistic  model.

In [None]:
numpyro.render_model(model)

### Sampling from the prior
NumPyro provides a convenient interface for drawing multiple samples from a model.

In [None]:
from numpyro.infer import Predictive

In [None]:
N_SAMPLES = 4_000

In [None]:
prior = Predictive(model, num_samples=N_SAMPLES)

rng_key, subkey = random.split(rng_key)
prior_samples = prior(subkey)

prior_samples["theta"].shape, prior_samples["y"].shape

In [None]:
prior_samples

## 2. Bayesian inference via NumPyro

Once we've built the model inference is pretty straight-forward. Let's fit the model to the data of our lazy 3-coin toss trial. 

In [None]:
# Our data described as 3 outcomes from Bernoulli trials
y = np.array([1, 1, 1])

In [None]:
from numpyro.infer import MCMC, NUTS

In [None]:
mcmc = MCMC(NUTS(model), num_warmup=500, num_samples=int(N_SAMPLES / 4), num_chains=4)

rng_key, subkey = random.split(rng_key)
mcmc.run(subkey, y=y)

NumPyro provides a summary of your MCMC sampling, including statistics on the produced samples and an MCMC diagnostic: `r_hat`. We won't go into the details of `r_hat` here but note that `r_hat=1` indicates that our chains have converged.

In [None]:
mcmc.print_summary()

In [None]:
posterior_samples = mcmc.get_samples()

In [None]:
posterior_samples["theta"].shape

In [None]:
posterior_samples["theta"]

### Plot prior and posterior histograms

In [None]:
import matplotlib.pyplot as plt

f, ax = plt.subplots(figsize=(8, 5))
ax.hist(np.array(posterior_samples["theta"]), alpha=0.5, bins=50, label="posterior")
ax.hist(np.array(prior_samples["theta"]), alpha=0.5, bins=50, label="prior")
plt.legend(loc="best")
plt.show()

If we compare these histograms to the true pdf-plots from the previous notebook we notice they look very similar.

### Posterior Statistics

#### A. Expecation value

**Exercise:** Compute the Monte Carlo estimate of the expectation of $\theta$ under the posterior. For posterior samples $\theta_i$, $i=1, ..., N$, this estimate is defined by

$$\mathbb{E}\left[\theta \vert y\right] \approx \frac{1}{N}\sum_{i=1}^N \theta_i.$$

In [None]:
mean = ...
print(mean)

#### B. Event Probabilities

Compute the Monte Carlo estimate for the probability of $\theta \in [\theta_1, \theta_2]$ under the posterior. For posterior samples $\theta_i$, $i=1, ..., N$, this estimate is defined by,

$$ \int_{\theta_1}^{\theta_2} p(\theta \vert y) \mathrm{d} \theta \approx \frac{ \#\{\theta_i \in [\theta_1, \theta_2]\} }{N}$$

**Exercice:** What's the estimated probability of the coin being fair within a tolerance of 0.025? 


In [None]:
proba_fair = ...
print(proba_fair)

#### C. Quantiles

For a given probability $P \in [0,1]$ the associated posterior quantile can be approximated via posterior samples $\theta_i$, $i=1, ..., N$ by

$$ \arg\max_x \left\{ \int_0^{x} \ p(\theta \vert y) d\theta \le P \right\} \approx \arg\max_x \left\{ \frac{1}{N}\sum_{i=1}^N \chi(\theta_i \in [0, x]) \le P \right\} $$

**Exercise:** What's the 5-th and 95-th percentile estimate?

In [None]:
perc_5 = ...
perc_95 = ...

In [None]:
print(perc_5)
print(perc_95)