In [1]:
import os
os.environ['JAX_CHECK_TRACER_LEAKS']='1'
import sys
sys.path.append('../imprint/research/berry/')
import berrylib.util as util
util.setup_nb()

In [2]:
import jax
import jax.numpy as jnp
import numpy as np
import time
import numpyro
import numpyro.distributions as dist
import numpyro.handlers as handlers
import scipy.stats
import matplotlib.pyplot as plt
from scipy.special import logit, expit

import inla

  from .autonotebook import tqdm as notebook_tqdm


In [62]:
narms = 4
mu_0 = -1.34
mu_sig2 = 100.0
sig2_alpha = 0.0005
sig2_beta = 0.000005
logit_p1 = logit(0.3)
def model(data):
    sig2 = numpyro.sample("sig2", dist.InverseGamma(sig2_alpha, sig2_beta))
    cov = jnp.full((narms, narms), mu_sig2) + jnp.diag(jnp.repeat(sig2, narms))
    theta = numpyro.sample(
        "theta",
        dist.MultivariateNormal(mu_0, cov),
    )
    numpyro.sample(
        "y",
        dist.BinomialLogits(theta + logit_p1, total_count=data[:,1]),
        obs=data[:, 0],
    )

In [63]:
data = np.array([[7, 35], [6.0, 35], [5, 35], [4, 35]])
params = dict(
    sig2=10.0,
    theta=np.array([0.0, 0.0, 0, 0]),
)
ll_fnc = inla.build_log_likelihood(model)
ll_fnc(params, dict(sig2=None, theta=None), data)

DeviceArray(-35.054466, dtype=float32)

In [64]:
sig2_rule = util.log_gauss_rule(15, 1e-2, 1e2)
theta_fixed = np.full((sig2_rule.pts.shape[0], 4), np.nan)
theta_fixed[:, 0] = -1.0
grad_hess_vmap = inla.build_grad_hess(
    ll_fnc, dict(sig2=jnp.array([10.0]), theta=theta_fixed[0, :])
)
for N in 2 ** np.arange(4, 20, 2):
    y = scipy.stats.binom.rvs(35, 0.3, size=(N, narms))
    n = np.full_like(y, 35)
    D = np.stack((y, n), axis=-1)
    T = np.random.rand(N, sig2_rule.pts.shape[0], narms)
    for i in range(3):
        start = time.time()
        grad, hess = grad_hess_vmap(
            dict(sig2=None, theta=T), dict(sig2=sig2_rule.pts, theta=theta_fixed), D
        )
        hess.block_until_ready()
        end = time.time()
    print(
        f"{N} iters, {(end - start) / N * 1e6:.3f} us per sample, {end - start:.2f}s total"
    )


16 iters, 2.682 us per sample, 0.00s total
64 iters, 2.425 us per sample, 0.00s total
256 iters, 1.590 us per sample, 0.00s total
1024 iters, 1.213 us per sample, 0.00s total
4096 iters, 0.596 us per sample, 0.00s total
16384 iters, 0.596 us per sample, 0.01s total
65536 iters, 0.587 us per sample, 0.04s total
262144 iters, 0.695 us per sample, 0.18s total


In [77]:


tt = np.arange(4, dtype=np.float64)
tt[1] = np.nan
tt_pin = np.full(4, np.nan)
tt_pin[1] = -1.0
ex = dict(sig2=None, theta=tt)
ravel_f, unravel_f = build_ravel_fncs(ex)
r = ravel_f(ex)
np.testing.assert_allclose(r, [0,2,3])
ur = unravel_f(r)
for k in ex:
    if ex[k] is None:
        assert(ur[k] is None)
        continue
    np.testing.assert_allclose(ur[k], ex[k])
r, ur

(DeviceArray([0., 2., 3.], dtype=float32),
 {'sig2': None,
  'theta': DeviceArray([ 0., nan,  2.,  3.], dtype=float32, weak_type=True)})

In [71]:
p_pinned= dict(sig2=10.0, theta=tt_pin)
grad = jax.grad(ll_fnc)(ex, p_pinned, data)
hess = jax.hessian(ll_fnc)(ex, p_pinned, data)
grad, ravel_f(grad)

({'sig2': None,
  'theta': DeviceArray([ -3.4057074,   0.       , -21.70585  , -27.562943 ], dtype=float32)},
 DeviceArray([ -3.4057074, -21.70585  , -27.562943 ], dtype=float32))

In [76]:
ravel_f({k: None if hess[k] is None  else ravel_f(hess[k]) for k in ex}, axis=-2)

IndexError: boolean index did not match shape of indexed array in index 0: got (4,), expected (3,)

In [None]:

n_arms = narms
sigma2_n = sig2_rule.pts.shape[0]
arms = np.arange(n_arms)
cov = np.full((sigma2_n, n_arms, n_arms), mu_sig2)
cov[:, arms, arms] += sig2_rule.pts[:, None]
neg_precQ = -np.linalg.inv(cov)

neg_precQ_b = neg_precQ[:, 0, 1]
neg_precQ_a = neg_precQ[:, 0, 0] - neg_precQ_b

logprecQdet = 0.5 * jnp.log(jnp.linalg.det(-neg_precQ))
log_prior = jnp.array(
    scipy.stats.invgamma.logpdf(sig2_rule.pts, sig2_alpha, scale=sig2_beta)
)
const = log_prior + logprecQdet
na = jnp.arange(narms)

def conditional(theta, sig2, data):
    y = data[..., 0]
    n = data[..., 1]
    theta_m0 = theta - mu_0
    theta_adj = theta + logit_p1
    exp_theta_adj = jnp.exp(theta_adj)
    quad = jnp.sum(
        theta_m0
        * ((theta_m0.sum(axis=-1) * neg_precQ_b)[..., None] + theta_m0 * neg_precQ_a),
        axis=-1,
    )
    return (
        0.5 * quad
        + jnp.sum(
            theta_adj * y[:, None] - n[:, None] * jnp.log(exp_theta_adj + 1),
            axis=-1,
        )
        + const
    )


def grad_hess(theta, _, data):
    y = data[..., 0]
    n = data[..., 1]
    theta_m0 = theta - mu_0
    exp_theta_adj = jnp.exp(theta + logit_p1)
    C = 1.0 / (exp_theta_adj + 1)
    nCeta = n[:, None] * C * exp_theta_adj
    grad = (
        jnp.matmul(neg_precQ[None], theta_m0[:, :, :, None])[..., 0]
        + y[:, None]
        - nCeta
    )
    hess = neg_precQ[None] - ((nCeta * C)[:, :, None, :] * jnp.eye(narms))
    return grad, hess


grad_hess_vmap = jax.jit(grad_hess)
conditional_vmap = jax.jit(conditional)


In [53]:
N = 10000
y = scipy.stats.binom.rvs(35, 0.3, size=(N, narms))
n = np.full_like(y, 35)
data = np.stack((y, n), axis=-1)

In [130]:
b = neg_precQ[0,0,1]
a = neg_precQ[0,0,0] - b

In [131]:
@jax.jit
def quad(theta_max, a, b):
    dotprod = ((theta_max.sum(axis=-1) * b)[..., None] + theta_max * a)
    quad = jnp.sum(theta_max * dotprod, axis=-1)
    return quad

In [132]:
%%timeit
quad(theta_max, a, b).block_until_ready()

126 µs ± 1.56 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [129]:
quad(theta_max, a, b)[0,0]

DeviceArray(-0.37294438, dtype=float32)

In [122]:
dotprod2 = jnp.einsum("...ij,...j", neg_precQ, theta_max)
quad2 = jnp.einsum("...i,...ij,...j", theta_max, neg_precQ, theta_max)
quad3 = np.sum(theta_max * dotprod2, axis=-1)

In [123]:
dotprod[0,0], dotprod2[0,0], quad[0,0], quad2[0,0], quad3[0,0]

(DeviceArray([-5.1011925,  1.39013  ,  2.319191 ,  1.390357 ], dtype=float32),
 DeviceArray([-5.101193 ,  1.3901292,  2.319191 ,  1.390357 ], dtype=float32),
 DeviceArray(-0.3729442, dtype=float32),
 DeviceArray(-0.37294447, dtype=float32),
 DeviceArray(-0.37294444, dtype=float32))

In [55]:
inla_obj = inla.INLA(conditional_vmap, grad_hess_vmap, sig2_rule, narms)

In [56]:
theta_max, hess, iters = inla_obj.optimize_loop(data, sig2_rule.pts, 1e-3)
post = inla_obj.posterior(theta_max, hess, sig2_rule.pts, sig2_rule.wts, data)

In [58]:
%%timeit -n 20 -r 5
theta_max, hess, iters = inla_obj.optimize_loop(data, sig2_rule.pts, 1e-3)
post = inla_obj.posterior(theta_max, hess, sig2_rule.pts, sig2_rule.wts, data)

31.4 ms ± 444 µs per loop (mean ± std. dev. of 5 runs, 20 loops each)
