In [1]:
import sys
sys.path.append('../imprint/research/berry/')
import berrylib.util as util
util.setup_nb()

In [2]:
import jax
import jax.numpy as jnp
import numpy as np
import time
import numpyro
import numpyro.distributions as dist
import scipy.stats
import matplotlib.pyplot as plt
from scipy.special import logit, expit

import inla

  from .autonotebook import tqdm as notebook_tqdm


In [3]:

mu_0 = -1.34
mu_sig2 = 100.0
sig2_alpha = 0.0005
sig2_beta = 0.000005
logit_p1 = -1.0
def model(data):
    sig2 = numpyro.sample("sig2", dist.InverseGamma(sig2_alpha, sig2_beta))
    cov = jnp.full((3, 3), mu_sig2) + jnp.diag(jnp.repeat(sig2, 3))
    theta = numpyro.sample(
        "theta",
        dist.MultivariateNormal(mu_0, cov),
    )
    numpyro.sample(
        "y",
        dist.BinomialLogits(theta + logit_p1, total_count=data[:,1]),
        obs=data[:, 0],
    )

In [4]:
data = np.array([[6.0,35], [5,35], [4,35]])
params = dict(
    sig2 = 10.0,
    theta = np.array([0.0,0,0]),
)
ll_fnc = inla.build_log_likelihood(model)

In [5]:
sig2_rule = util.log_gauss_rule(15, 1e-2, 1e2)

In [82]:
def conditional(theta, sig2, data):
    params = dict(theta=theta, sig2=sig2)
    return ll_fnc(params, data)

def grad_hess(theta, sig2, data):
    grad = jax.grad(conditional)(theta, sig2, data)
    hess = jax.hessian(conditional)(theta, sig2, data)
    return grad, hess
grad_hess_vmap = jax.jit(
    jax.vmap(
        jax.vmap(grad_hess, in_axes=(0, 0, None)), in_axes=(0, None, 0)
    )
)
conditional_vmap = jax.jit(
    jax.vmap(
        jax.vmap(conditional, in_axes=(0, 0, None)), in_axes=(0, None, 0)
    )
)
    
n_arms = 3
sigma2_n = sig2_rule.pts.shape[0]
arms = np.arange(n_arms)
cov = np.full((sigma2_n, n_arms, n_arms), mu_sig2)
cov[:, arms, arms] += sig2_rule.pts[:, None]
neg_precQ = -np.linalg.inv(cov)
na = jnp.arange(3)
def grad_hess(theta, _, data):
    y = data[..., 0]
    n = data[..., 1]
    theta_m0 = theta - mu_0
    exp_theta_adj = jnp.exp(theta + logit_p1)
    C = 1.0 / (exp_theta_adj + 1)
    nCeta = n[:, None] * C * exp_theta_adj
    grad = (
        jnp.matmul(neg_precQ[None], theta_m0[:, :, :, None])[..., 0]
        + y[:, None] - nCeta
    )
    hess = neg_precQ[None] - ((nCeta * C)[:, :, None, :] * jnp.eye(3))
    return grad, hess
grad_hess_vmap = jax.jit(grad_hess)

In [84]:
for N in 2 ** np.arange(4, 20, 2):
    y = scipy.stats.binom.rvs(35, 0.3, size=(N, 3))
    n = np.full_like(y, 35)
    data = np.stack((y, n), axis=-1)
    theta = np.random.rand(N, sig2_rule.pts.shape[0], 3)
    for i in range(3):
        start = time.time()
        grad, hess = grad_hess_vmap(theta, sig2_rule.pts, data)
        Gt = grad.block_until_ready()
        Ht = hess.block_until_ready()
        end = time.time()
    print(f'{N} iters, {(end - start) / N * 1e6:.3f} us per sample, {end - start:.2f}s total')

16 iters, 0.685 us per sample, 0.00s total
64 iters, 0.376 us per sample, 0.00s total
256 iters, 0.278 us per sample, 0.00s total
1024 iters, 0.271 us per sample, 0.00s total
4096 iters, 0.282 us per sample, 0.00s total
16384 iters, 0.162 us per sample, 0.00s total
65536 iters, 0.176 us per sample, 0.01s total
262144 iters, 0.245 us per sample, 0.06s total


In [85]:
N = 10000
y = scipy.stats.binom.rvs(35, 0.3, size=(N, 3))
n = np.full_like(y, 35)
data = np.stack((y, n), axis=-1)

In [86]:
data.shape

(10000, 3, 2)

In [87]:
inla_obj = inla.INLA(conditional_vmap, grad_hess_vmap, sig2_rule, 3)

In [88]:
theta_max, hess, iters = inla_obj.optimize_loop(data, sig2_rule.pts, 1e-3)

In [92]:
%%timeit -n 30
theta_max, hess, iters = inla_obj.optimize_loop(data, sig2_rule.pts, 1e-3)
post = inla_obj.posterior(theta_max, hess, sig2_rule.pts, data)
post2 = inla_obj.scale_posterior(post, sig2_rule.wts)

18.7 ms ± 537 µs per loop (mean ± std. dev. of 7 runs, 30 loops each)
