[Numpyro](https://github.com/pyro-ppl/numpyro) is a probabilistic programming language with Numpy backend powered by JAX for autograd and JIT compilation to GPU/TPU/CPU. Inspired by [Rémi Louf's blogpost](https://rlouf.github.io/post/jax-random-walk-metropolis/), here is a speed benchmark for numpyro (GPU/CPU) vs PyMC3 MCMC on medium size real life dataset.

In [None]:
!pip install --upgrade pip
!pip install numpyro==0.6.0
!pip install --upgrade jax==0.2.10 jaxlib==0.1.62+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html

In [None]:
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

In [None]:
import time
import pandas as pd
import numpy as onp

import jax.numpy as np
from jax import random

import numpyro
import numpyro.distributions as dist
from numpyro.examples.datasets import COVTYPE, load_dataset
from numpyro.infer import HMC, MCMC, NUTS

from jax.random import PRNGKey

numpyro.util.set_platform('gpu') #Use GPU for MCMC

The data here consist of all at bats of MLB game, specifically we would like to model the batting average of players in 2018 season. Baseball consist of one batter and one pitcher on each at bat, and we hope to find how they affact the batting average by grouping batter and pitcher played in the season.

In [None]:
import pandas as pd
baseball = pd.read_csv('../input/mlb-pitch-data-20152018/atbats.csv')
d = baseball

In [None]:
d = d[(d.ab_id > 2018000000) & (d.ab_id < 2019000000)]
len(d)

In [None]:
# d = d[:200]

In [None]:
hit_event = ['Single','Double','Triple','Home Run']
d['Hits'] = d.event.isin(hit_event).astype(int)
d['pitcher_code'] = d['pitcher_id'].astype('category').cat.codes
d['batter_code'] = d['batter_id'].astype('category').cat.codes

In [None]:
numpyro.set_host_device_count(2)

dat_list = {"hits": np.array(d.Hits),
            "pitcher": np.array(d.pitcher_code),
            "batter": np.array(d.batter_code)}

def model(pitcher, batter, hits=None, link=False):
    a_bar = numpyro.sample("a_bar", dist.Normal(0, 10))
    sigma_a = numpyro.sample("sigma_a", dist.HalfCauchy(5))
    b_bar = numpyro.sample("b_bar", dist.Normal(0, 10))
    sigma_b = numpyro.sample("sigma_b", dist.HalfCauchy(5))


    a = numpyro.sample("a", dist.Normal(a_bar, sigma_a), sample_shape=(len(d['pitcher_code'].unique()),))
    b = numpyro.sample("b", dist.Normal(b_bar, sigma_b), sample_shape=(len(d['batter_code'].unique()),))

    # non-centered paramaterization
#     a = numpyro.sample('a',  dist.TransformedDistribution(dist.Normal(0., 1.), dist.transforms.AffineTransform(a_bar, sigma_a)), sample_shape=(len(d['pitcher_code'].unique()),))
#     b = numpyro.sample('b',  dist.TransformedDistribution(dist.Normal(0., 1.), dist.transforms.AffineTransform(b_bar, sigma_b)), sample_shape=(len(d['batter_code'].unique()),))

    logit_p = a[pitcher] + b[batter]
    if link:
        p = expit(logit_p)
        numpyro.sample("p", dist.Delta(p), obs=p)
    numpyro.sample("hits", dist.Binomial(logits=logit_p), obs=hits)

mcmc = MCMC(NUTS(model), 1000, 1000, num_chains=1)
mcmc.run(PRNGKey(0), np.array(d.pitcher_code), np.array(d.batter_code), hits=np.array(d.Hits), extra_fields=('potential_energy','mean_accept_prob',))
mcmc.print_summary(0.89)

In [None]:
np.mean( mcmc.get_extra_fields()['mean_accept_prob'])

For computation speed of calculation, using kaggle hardware numpyro with GPU use about 80 seconds on calculation, which is about 10x speedup for numpyro with CPU and another 2x speedup for 2 chains MCMC compare to pymc3 (Codes below). However, the r_hat (Gelman Rubin diagnostic) in calculation is generally greater than 1, means that the chain has not fully converged. It can be solved by [non-centered paramaterization](https://mc-stan.org/docs/2_21/stan-users-guide/reparameterization-section.html) (codes commented out), but as noted in the stan users guide, it is not possible here due to 180k+ sample size. 

Another possible solution is use large number of chains [noted by Michael Betancourt's talk](https://youtu.be/DJ0c7Bm5Djk?t=19129). Numpyro currently do not vectorized accross chains unlike [PyMC4](https://github.com/pymc-devs/pymc4) so fitting this type of model is still hopeless, but as Rémi Louf point out, having the ability to quickly sample multiple chains could be a breakthrough on MCMC modelling.

In [None]:
import pymc3 as pm
from pymc3 import sample, Normal, HalfCauchy, Uniform
import numpy as np
import pymc3.sampling_jax

In [None]:
dat_list = {"hits": np.array(d.Hits),
            "pitcher": np.array(d.pitcher_code),
            "batter": np.array(d.batter_code)}
with pm.Model() as model:

    # Priors
    mu_a = Normal('mu_a', mu=0., tau=0.01)
    sigma_a = HalfCauchy('sigma_a', 5)
    mu_b = Normal('mu_b', mu=0., tau=0.01)
    sigma_b = HalfCauchy('sigma_b', 5)


    a = Normal('a', mu=mu_a, sigma=sigma_a, shape=len(d['pitcher_code'].unique()))

    b = Normal('b', mu=mu_b, sigma=sigma_b, shape=len(d['batter_code'].unique()))

    # Expected value
    logit_p = a[dat_list['pitcher']] + b[dat_list['batter']]

    # Data likelihood
    p = pm.Bernoulli('y', logit_p=logit_p, observed=dat_list['hits'])

In [None]:
%%time

with model:
    hierarchical_trace_jax = pm.sampling_jax.sample_numpyro_nuts(
        50_000, target_accept=0.9, chains=1
    )

In [None]:
!pip install git+https://github.com/blackjax-devs/blackjax

In [None]:
import jax
import numpy as np
import pymc3 as pm
import pymc3.sampling_jax

import blackjax.nuts as nuts
import blackjax.stan_warmup as stan_warmup

In [None]:
from theano.graph.fg import FunctionGraph
from theano.link.jax.jax_dispatch import jax_funcify

seed = jax.random.PRNGKey(1234)
chains = 1

# Get the FunctionGraph of the model.
fgraph = FunctionGraph(model.free_RVs, [model.logpt])

# Jax funcify builds Jax variant of the FunctionGraph.
fns = jax_funcify(fgraph)
logp_fn_jax = fns[0]

# Now we build a Jax variant of the initial state/inputs to the model.
rv_names = [rv.name for rv in model.free_RVs]
init_state = [model.test_point[rv_name] for rv_name in rv_names]
init_state_batched = jax.tree_map(
    lambda x: np.repeat(x[None, ...], chains, axis=0), init_state
)

In [None]:
# Then we transform the Jaxified input and FunctionGraph to a BlackJax NUTS sampler
potential = lambda x: -logp_fn_jax(*x)
initial_position = init_state
initial_state = nuts.new_state(initial_position, potential)

In [None]:
%%time

kernel_factory = lambda step_size, inverse_mass_matrix: nuts.kernel(
    potential, step_size, inverse_mass_matrix
)

last_state, (step_size, inverse_mass_matrix), _ = stan_warmup.run(
    seed, kernel_factory, initial_state, 1000
)


def inference_loop(rng_key, kernel, initial_state, num_samples):
    def one_step(state, rng_key):
        state, info = kernel(rng_key, state)
        return state, (state, info)

    keys = jax.random.split(rng_key, num_samples)
    _, (states, infos) = jax.lax.scan(one_step, initial_state, keys)

    return states, infos


# Build the kernel using the step size and inverse mass matrix returned from the window adaptation
kernel = kernel_factory(step_size, inverse_mass_matrix)

# Sample from the posterior distribution
states, infos = inference_loop(seed, kernel, last_state, 50_000)

In [None]:
# with model:
#     trace = sample(1000, tune=1000)