In [85]:
import numpy as np
import arviz as az
import matplotlib.pyplot as plt
import multiprocessing as mp
import pandas as pd
import seaborn as sns
import numpyro as npr
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO
from numpyro.optim import Adam
from numpyro.infer.reparam import TransformReparam
from numpyro.distributions import constraints
from jax import random
import jax.numpy as jnp
import jax.scipy.special as jss

%matplotlib inline
%config InlineBackend.figure_format = 'retina'

sns.set()
npr.set_host_device_count(mp.cpu_count())

# From the doc

In [93]:
def model(data):
    # define the hyperparameters that control the beta prior
    alpha0 = 10.0
    beta0 = 10.0
    # sample f from the beta prior
    f = npr.sample("latent_fairness", dist.Beta(alpha0, beta0))
    # loop over the observed data
    for i in range(len(data)):
        # observe datapoint i using the bernoulli
        # likelihood Bernoulli(f)
        npr.sample("obs_{}".format(i), dist.Bernoulli(f), obs=data[i])

In [94]:
def guide(data):
    # register the two variational parameters with Pyro.
    alpha_q = npr.param("alpha_q", 10.0,
                        constraint=constraints.positive)
    beta_q = npr.param("beta_q", 10.0,
                       constraint=constraints.positive)
    # sample latent_fairness from the distribution Beta(alpha_q, beta_q)
    npr.sample("latent_fairness", dist.Beta(alpha_q, beta_q))

In [95]:
# create some data with 6 observed heads and 4 observed tails
data = jnp.concatenate([jnp.ones(6), jnp.zeros(4)])

In [96]:
# set up the optimizer
optimizer = Adam(step_size=0.0005)

# setup the inference algorithm
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

svi_result = svi.run(random.PRNGKey(0), 5000, data)

100%|██████████| 5000/5000 [00:05<00:00, 920.39it/s, init loss: 6.7353, avg. loss [4751-5000]: 7.0861] 


In [97]:
params = svi_result.params
params

{'alpha_q': DeviceArray(11.5468235, dtype=float32),
 'beta_q': DeviceArray(10.174291, dtype=float32)}

In [98]:
params["alpha_q"] / (params["alpha_q"] + params["beta_q"])

DeviceArray(0.5315944, dtype=float32)