In [2]:
import arviz as az
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import numpyro
import numpyro.distributions as dist
from jax import random
from numpyro.infer import MCMC, NUTS

# Kotaro used the following lineData generation code
N = 20
_shape, _rate = 2., 0.5
concentration = _shape * jnp.ones(N)
rate = _rate * jnp.ones(N)
s = dist.Gamma(concentration=concentration, rate=rate).sample(random.PRNGKey(0))
list(np.array(s))

[2.3762975,
 2.7325578,
 1.3750422,
 0.9992071,
 1.9519926,
 6.251218,
 0.84728867,
 6.076127,
 6.0899625,
 2.6973338,
 3.7665267,
 1.572763,
 6.4568233,
 1.369886,
 6.5856237,
 1.5776805,
 7.356719,
 7.4597907,
 3.3213968,
 4.0831566]

In [3]:
# Data that I generated using an arbitrary pair of (shape, rate) parameters.
d = [2.3762975,
     2.7325578,
     1.3750422,
     0.9992071,
     1.9519926,
     6.251218,
     0.84728867,
     6.076127,
     6.0899625,
     2.6973338,
     3.7665267,
     1.572763,
     6.4568233,
     1.369886,
     6.5856237,
     1.5776805,
     7.356719,
     7.4597907,
     3.3213968,
     4.0831566]
data = jnp.array(d)


In [4]:
# Write the probabilistic model to estimate the shape and rate parameters.
def gamma_model(data):
    alpha = numpyro.sample("alpha", dist.Exponential(1.0))
    beta = numpyro.sample("beta", dist.Exponential(1.0))
    g = numpyro.sample("g", dist.Gamma(alpha, beta), obs=data)

mcmc = MCMC(NUTS(model=gamma_model), num_warmup=1000, num_samples=1000)
mcmc.run(random.PRNGKey(0), data=s)
mcmc.print_summary()
#                mean       std    median      5.0%     95.0%     n_eff     r_hat
#     alpha      2.18      0.55      2.14      1.23      3.00    283.12      1.00
#      beta      0.59      0.16      0.58      0.30      0.83    279.37      1.00

sample: 100%|██████████| 2000/2000 [00:01<00:00, 1092.79it/s, 3 steps of size 3.19e-01. acc. prob=0.94] 



                mean       std    median      5.0%     95.0%     n_eff     r_hat
     alpha      2.18      0.55      2.14      1.23      3.00    283.12      1.00
      beta      0.59      0.16      0.58      0.30      0.83    279.37      1.00

Number of divergences: 0
