In [1]:
import os
os.environ['JAX_CHECK_TRACER_LEAKS']='1'
import sys
sys.path.append('../imprint/research/berry/')
import berrylib.util as util
util.setup_nb()

In [27]:
import jax
import jax.numpy as jnp
import numpy as np
import time
import numpyro
import numpyro.distributions as dist
import numpyro.handlers as handlers
import scipy.stats
import matplotlib.pyplot as plt
from scipy.special import logit, expit

import inla

In [3]:
narms = 4
mu_0 = -1.34
mu_sig2 = 100.0
sig2_alpha = 0.0005
sig2_beta = 0.000005
logit_p1 = -1.0
def model(data):
    sig2 = numpyro.sample("sig2", dist.InverseGamma(sig2_alpha, sig2_beta))
    cov = jnp.full((narms, narms), mu_sig2) + jnp.diag(jnp.repeat(sig2, narms))
    theta = numpyro.sample(
        "theta",
        dist.MultivariateNormal(mu_0, cov),
    )
    numpyro.sample(
        "y",
        dist.BinomialLogits(theta + logit_p1, total_count=data[:,1]),
        obs=data[:, 0],
    )

In [5]:
data = np.array([[7, 35], [6.0, 35], [5, 35], [4, 35]])
params = dict(
    sig2=10.0,
    theta=np.array([0.0, 0.0, 0, 0]),
)
ll_fnc, ravel_f, unravel_f = inla.build_raw_log_likelihood(
    model, dict(sig2=np.nan, theta=np.full(4, np.nan))
)
ll_fnc(ravel_f(params), np.full(5, np.nan), data)


DeviceArray(-32.336067, dtype=float32)

In [6]:
sig2_rule = util.log_gauss_rule(15, 1e-2, 1e2)

In [7]:
grad, hess = inla.build_grad_hess(ll_fnc)(
    ravel_f(params), np.full(5, np.nan), data
)
hess

DeviceArray([[ 2.5007863e-02,  7.9713209e-06,  7.9716556e-06,  7.9714810e-06,  7.9714519e-06],
             [ 7.9711317e-06, -6.9570279e+00,  2.4390262e-02,  2.4390247e-02,  2.4390250e-02],
             [ 7.9716556e-06,  2.4390258e-02, -6.9570279e+00,  2.4390256e-02,  2.4390252e-02],
             [ 7.9714955e-06,  2.4390241e-02,  2.4390258e-02, -6.9570279e+00,  2.4390236e-02],
             [ 7.9714664e-06,  2.4390245e-02,  2.4390254e-02,  2.4390237e-02, -6.9570279e+00]],            dtype=float32)

In [24]:
param_fixed = dict(
    sig2 = 10.0,
    theta = np.array([np.nan,np.nan,np.nan,np.nan]),
)
ll_fnc, ravel_f, unravel_f = inla.build_raw_log_likelihood(model, param_fixed)
array_fixed = ravel_f(param_fixed)
grad_hess_f = inla.build_grad_hess(ll_fnc)
grad_hess_vmap = jax.jit(
    jax.vmap(jax.vmap(grad_hess_f, in_axes=(0, 0, None)), in_axes=(0, None, 0))
)

In [21]:

jax.tree_util.tree_map(lambda x: jnp.any(jnp.isnan(x)).astype(int), param_fixed)

{'sig2': DeviceArray(0, dtype=int32), 'theta': DeviceArray(1, dtype=int32)}

In [25]:

grad, hess = grad_hess_vmap(ravel_f(params)[None,None,:], array_fixed[None], data[None])
grad, hess

(DeviceArray([[[ 0.       , -2.4162188, -3.4162188, -4.4162188, -5.4162188]]], dtype=float32),
 DeviceArray([[[[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ],
                [ 0.        , -6.9570274 ,  0.02439025,  0.02439025,  0.02439025],
                [ 0.        ,  0.02439025, -6.9570274 ,  0.02439025,  0.02439025],
                [ 0.        ,  0.02439025,  0.02439025, -6.9570274 ,  0.02439024],
                [ 0.        ,  0.02439025,  0.02439025,  0.02439024, -6.9570274 ]]]], dtype=float32))

In [None]:
seeded_model = handlers.seed(model, jax.random.PRNGKey(10))
subs_model = handlers.substitute(seeded_model, dict(sig2=sig2, theta=theta))
trace = handlers.trace(subs_model).get_trace(data)

In [26]:
for N in 2 ** np.arange(4, 20, 2):
    y = scipy.stats.binom.rvs(35, 0.3, size=(N, narms))
    n = np.full_like(y, 35)
    D = np.stack((y, n), axis=-1)
    T = np.random.rand(N, sig2_rule.pts.shape[0], narms+1)
    S = np.tile(sig2_rule.pts[:, None], (1, 5))
    S[:,1:] = np.nan
    for i in range(3):
        start = time.time()
        grad, hess = grad_hess_vmap(T, S, D)
        # grad.block_until_ready().shape
        # hess.block_until_ready().shape
        # hess['sig2']['sig2'].block_until_ready()
        hess.block_until_ready()
        end = time.time()
    print(
        f"{N} iters, {(end - start) / N * 1e6:.3f} us per sample, {end - start:.2f}s total"
    )


16 iters, 51.633 us per sample, 0.00s total
64 iters, 47.501 us per sample, 0.00s total


KeyboardInterrupt: 

In [139]:
for N in 2 ** np.arange(4, 20, 2):
    y = scipy.stats.binom.rvs(35, 0.3, size=(N, narms))
    n = np.full_like(y, 35)
    data = np.stack((y, n), axis=-1)
    theta = np.random.rand(N, sig2_rule.pts.shape[0], narms)
    for i in range(3):
        start = time.time()
        grad, hess = grad_hess_vmap(
            dict(sig2=None, theta=theta),
            dict(
                sig2=sig2_rule.pts,
                theta=np.tile(np.array([np.nan, np.nan, np.nan, 0])[None, :], (15, 1)),
            ),
            data,
        )
        # grad.block_until_ready().shape
        # hess.block_until_ready().shape
        # hess['sig2']['sig2'].block_until_ready()
        hess["theta"]["theta"].block_until_ready()
        end = time.time()
    print(
        f"{N} iters, {(end - start) / N * 1e6:.3f} us per sample, {end - start:.2f}s total"
    )


16 iters, 7.123 us per sample, 0.00s total
64 iters, 14.998 us per sample, 0.00s total
256 iters, 1.535 us per sample, 0.00s total
1024 iters, 1.391 us per sample, 0.00s total
4096 iters, 0.923 us per sample, 0.00s total
16384 iters, 0.747 us per sample, 0.01s total
65536 iters, 0.880 us per sample, 0.06s total
262144 iters, 0.845 us per sample, 0.22s total


In [152]:
jax.vmap(jax.vmap(ll_fnc, in_axes=(0, 0, None)), in_axes=(0, None, 0))(
    dict(sig2=None, theta=theta),
    dict(
        sig2=sig2_rule.pts,
        theta=np.tile(np.array([np.nan, np.nan, np.nan, 0])[None, :], (15, 1)),
    ),
    data
)[0,0]

DeviceArray(-18.26399, dtype=float32)

In [165]:
ll_fnc(
    dict(sig2=None, theta=theta[0]),
    dict(sig2=sig2_rule.pts[0], theta=np.array([np.nan, np.nan, np.nan, 0])),
    data[0],
)


IndexError: too many indices for array: array is 1-dimensional, but 2 were indexed

In [146]:
hess['theta']['theta'][0,0]

DeviceArray([[-78.92571 ,  23.660025,  23.65211 ,   0.      ],
             [ 23.660028, -79.05701 ,  23.653358,   0.      ],
             [ 23.652115,  23.65336 , -78.954544,   0.      ],
             [  0.      ,   0.      ,   0.      ,   0.      ]], dtype=float32)

In [127]:
a, b = jax.tree_util.tree_flatten(hess)


(262144, 15, 4, 4)

In [51]:
def grad_hess(theta, sig2, data):
    p = dict(sig2 = sig2, theta=None)
    fixed = dict(sig2=None, theta=theta)
    grad = jax.grad(ll_fnc)(p, fixed, data)
    hess = jax.hessian(ll_fnc)(p, fixed, data)
    return grad, hess


grad_hess_vmap = jax.jit(
    jax.vmap(jax.vmap(grad_hess, in_axes=(0, 0, None)), in_axes=(0, None, 0))
)
conditional_vmap = jax.jit(
    jax.vmap(jax.vmap(conditional, in_axes=(0, 0, None)), in_axes=(0, None, 0))
)


In [56]:
%%time
grad, hess = grad_hess_vmap(theta, sig2_rule.pts, data)

CPU times: user 2.31 s, sys: 490 ms, total: 2.8 s
Wall time: 1.08 s


In [46]:
grad['sig2'].shape, grad['theta'].shape

((262144, 15), (262144, 15, 4))

In [None]:

n_arms = narms
sigma2_n = sig2_rule.pts.shape[0]
arms = np.arange(n_arms)
cov = np.full((sigma2_n, n_arms, n_arms), mu_sig2)
cov[:, arms, arms] += sig2_rule.pts[:, None]
neg_precQ = -np.linalg.inv(cov)

neg_precQ_b = neg_precQ[:, 0, 1]
neg_precQ_a = neg_precQ[:, 0, 0] - neg_precQ_b

logprecQdet = 0.5 * jnp.log(jnp.linalg.det(-neg_precQ))
log_prior = jnp.array(
    scipy.stats.invgamma.logpdf(sig2_rule.pts, sig2_alpha, scale=sig2_beta)
)
const = log_prior + logprecQdet
na = jnp.arange(narms)

def conditional(theta, sig2, data):
    y = data[..., 0]
    n = data[..., 1]
    theta_m0 = theta - mu_0
    theta_adj = theta + logit_p1
    exp_theta_adj = jnp.exp(theta_adj)
    quad = jnp.sum(
        theta_m0
        * ((theta_m0.sum(axis=-1) * neg_precQ_b)[..., None] + theta_m0 * neg_precQ_a),
        axis=-1,
    )
    return (
        0.5 * quad
        + jnp.sum(
            theta_adj * y[:, None] - n[:, None] * jnp.log(exp_theta_adj + 1),
            axis=-1,
        )
        + const
    )


def grad_hess(theta, _, data):
    y = data[..., 0]
    n = data[..., 1]
    theta_m0 = theta - mu_0
    exp_theta_adj = jnp.exp(theta + logit_p1)
    C = 1.0 / (exp_theta_adj + 1)
    nCeta = n[:, None] * C * exp_theta_adj
    grad = (
        jnp.matmul(neg_precQ[None], theta_m0[:, :, :, None])[..., 0]
        + y[:, None]
        - nCeta
    )
    hess = neg_precQ[None] - ((nCeta * C)[:, :, None, :] * jnp.eye(narms))
    return grad, hess


grad_hess_vmap = jax.jit(grad_hess)
conditional_vmap = jax.jit(conditional)


In [53]:
N = 10000
y = scipy.stats.binom.rvs(35, 0.3, size=(N, narms))
n = np.full_like(y, 35)
data = np.stack((y, n), axis=-1)

In [130]:
b = neg_precQ[0,0,1]
a = neg_precQ[0,0,0] - b

In [131]:
@jax.jit
def quad(theta_max, a, b):
    dotprod = ((theta_max.sum(axis=-1) * b)[..., None] + theta_max * a)
    quad = jnp.sum(theta_max * dotprod, axis=-1)
    return quad

In [132]:
%%timeit
quad(theta_max, a, b).block_until_ready()

126 µs ± 1.56 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [129]:
quad(theta_max, a, b)[0,0]

DeviceArray(-0.37294438, dtype=float32)

In [122]:
dotprod2 = jnp.einsum("...ij,...j", neg_precQ, theta_max)
quad2 = jnp.einsum("...i,...ij,...j", theta_max, neg_precQ, theta_max)
quad3 = np.sum(theta_max * dotprod2, axis=-1)

In [123]:
dotprod[0,0], dotprod2[0,0], quad[0,0], quad2[0,0], quad3[0,0]

(DeviceArray([-5.1011925,  1.39013  ,  2.319191 ,  1.390357 ], dtype=float32),
 DeviceArray([-5.101193 ,  1.3901292,  2.319191 ,  1.390357 ], dtype=float32),
 DeviceArray(-0.3729442, dtype=float32),
 DeviceArray(-0.37294447, dtype=float32),
 DeviceArray(-0.37294444, dtype=float32))

In [55]:
inla_obj = inla.INLA(conditional_vmap, grad_hess_vmap, sig2_rule, narms)

In [56]:
theta_max, hess, iters = inla_obj.optimize_loop(data, sig2_rule.pts, 1e-3)
post = inla_obj.posterior(theta_max, hess, sig2_rule.pts, sig2_rule.wts, data)

In [58]:
%%timeit -n 20 -r 5
theta_max, hess, iters = inla_obj.optimize_loop(data, sig2_rule.pts, 1e-3)
post = inla_obj.posterior(theta_max, hess, sig2_rule.pts, sig2_rule.wts, data)

31.4 ms ± 444 µs per loop (mean ± std. dev. of 5 runs, 20 loops each)
