In [1]:
import os
os.environ['JAX_CHECK_TRACER_LEAKS']='1'
import sys
sys.path.append('../imprint/research/berry/')
import berrylib.util as util
util.setup_nb()

In [2]:
import jax
import jax.numpy as jnp
import numpy as np
import time
import numpyro
import numpyro.distributions as dist
import numpyro.handlers as handlers
import scipy.stats
import matplotlib.pyplot as plt
from scipy.special import logit, expit

import inla
from berry_model import berry_model, fast_berry


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
n_arms = 4
fl = inla.FullLaplace(berry_model(n_arms), "sig2", np.zeros((4, 2)))
sig2_rule = util.log_gauss_rule(15, 1e-2, 1e2)
fl = fast_berry(sig2_rule.pts, n_arms)
dtype = np.float64
# for N in 2 ** np.array([4, 9, 14, 16]):
for N in 2 ** np.array([14]):
    y = scipy.stats.binom.rvs(35, 0.3, size=(N, n_arms))
    n = np.full_like(y, 35)
    D = np.stack((y, n), axis=-1).astype(dtype)
    x0 = np.zeros((D.shape[0], sig2_rule.pts.shape[0], 4), dtype=dtype)
    f = lambda: fl(dict(sig2=sig2_rule.pts.astype(dtype), theta=None), D, x0, should_batch=False)
    f()
    for i in range(20):
        start = time.time()
        post, x_max, hess, iters = f()
        end = time.time()
        print(
            f"{N} datasets, {(end - start) / N * 1e6:.3f} us per dataset, {end - start:.2f}s total"
        )

16384 datasets, 3.040 us per dataset, 0.05s total
16384 datasets, 3.030 us per dataset, 0.05s total
16384 datasets, 2.992 us per dataset, 0.05s total
16384 datasets, 2.939 us per dataset, 0.05s total
16384 datasets, 2.940 us per dataset, 0.05s total
16384 datasets, 2.919 us per dataset, 0.05s total
16384 datasets, 3.003 us per dataset, 0.05s total
16384 datasets, 2.888 us per dataset, 0.05s total
16384 datasets, 2.870 us per dataset, 0.05s total
16384 datasets, 2.697 us per dataset, 0.04s total
16384 datasets, 2.880 us per dataset, 0.05s total
16384 datasets, 2.844 us per dataset, 0.05s total
16384 datasets, 2.865 us per dataset, 0.05s total
16384 datasets, 2.776 us per dataset, 0.05s total
16384 datasets, 2.684 us per dataset, 0.04s total
16384 datasets, 2.884 us per dataset, 0.05s total
16384 datasets, 2.902 us per dataset, 0.05s total
16384 datasets, 2.792 us per dataset, 0.05s total
16384 datasets, 2.884 us per dataset, 0.05s total
16384 datasets, 2.823 us per dataset, 0.05s total


In [4]:
hess = f()[2]

In [19]:
H = hess[0,7] 
grad = np.random.rand(4)
b = H[0, 1]
a = np.diag(H) - b
a, b

(DeviceArray([-8.18910926, -8.53883568, -7.79770498, -9.32746209], dtype=float64),
 DeviceArray(0.24937656, dtype=float64))

In [25]:
# Hinv = np.diag(1.0 / a) - np.outer(bs / a, bs / a) / (1 + np.sum(bs / a) * bs)

bs = np.sqrt(b)
grad_over_a = grad / a
bs_over_a = bs / a
grad_over_a - bs_over_a * (bs * grad_over_a.sum() / (1 + np.sum(bs_over_a) * bs))

DeviceArray([-0.11571474, -0.06712648, -0.11595943, -0.028569  ], dtype=float64)

In [26]:
np.linalg.inv(H).dot(grad)

array([-0.11571474, -0.06712648, -0.11595943, -0.028569  ])

In [5]:
import numpy as np
b = 1.0
a = np.random.rand(4)
M = np.full((4,4), b) + np.diag(a)

In [6]:
Minv = np.linalg.inv(M)
Minv

array([[ 1.31148807, -0.06612969, -0.92047326, -0.26805707],
       [-0.06612969,  1.10833818, -0.77032047, -0.22433009],
       [-0.92047326, -0.77032047,  5.47526137, -3.12249854],
       [-0.26805707, -0.22433009, -3.12249854,  3.80766208]])

In [130]:
b = neg_precQ[0,0,1]
a = neg_precQ[0,0,0] - b

In [131]:
@jax.jit
def quad(theta_max, a, b):
    dotprod = ((theta_max.sum(axis=-1) * b)[..., None] + theta_max * a)
    quad = jnp.sum(theta_max * dotprod, axis=-1)
    return quad

In [132]:
%%timeit
quad(theta_max, a, b).block_until_ready()

126 µs ± 1.56 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [129]:
quad(theta_max, a, b)[0,0]

DeviceArray(-0.37294438, dtype=float32)

In [122]:
dotprod2 = jnp.einsum("...ij,...j", neg_precQ, theta_max)
quad2 = jnp.einsum("...i,...ij,...j", theta_max, neg_precQ, theta_max)
quad3 = np.sum(theta_max * dotprod2, axis=-1)

In [123]:
dotprod[0,0], dotprod2[0,0], quad[0,0], quad2[0,0], quad3[0,0]

(DeviceArray([-5.1011925,  1.39013  ,  2.319191 ,  1.390357 ], dtype=float32),
 DeviceArray([-5.101193 ,  1.3901292,  2.319191 ,  1.390357 ], dtype=float32),
 DeviceArray(-0.3729442, dtype=float32),
 DeviceArray(-0.37294447, dtype=float32),
 DeviceArray(-0.37294444, dtype=float32))

In [55]:
inla_obj = inla.INLA(conditional_vmap, grad_hess_vmap, sig2_rule, narms)

In [56]:
theta_max, hess, iters = inla_obj.optimize_loop(data, sig2_rule.pts, 1e-3)
post = inla_obj.posterior(theta_max, hess, sig2_rule.pts, sig2_rule.wts, data)

In [58]:
%%timeit -n 20 -r 5
theta_max, hess, iters = inla_obj.optimize_loop(data, sig2_rule.pts, 1e-3)
post = inla_obj.posterior(theta_max, hess, sig2_rule.pts, sig2_rule.wts, data)

31.4 ms ± 444 µs per loop (mean ± std. dev. of 5 runs, 20 loops each)
