In [2]:
import inlaw
import inlaw.berry as berry
import inlaw.quad as quad
import numpy as np
import jax.numpy as jnp
import jax
import time
import inlaw.inla as inla

  from .autonotebook import tqdm as notebook_tqdm


In [23]:
def my_timeit(N, f, iter=5, inner_iter=10, should_print=True):
    _ = f()
    runtimes = []
    for i in range(iter):
        start = time.time()
        f()
        runtimes.append(time.time() - start)
    if should_print:
        print("median runtime", np.median(runtimes))
        print("min us per sample ", np.min(runtimes) * 1e6 / N)
        print("median us per sample", np.median(runtimes) * 1e6 / N)
    return runtimes

def benchmark(N=10000, iter=5):
    dtype = np.float32
    data = berry.figure2_data(N).astype(dtype)
    sig2_rule = quad.log_gauss_rule(15, 1e-6, 1e3)
    sig2 = sig2_rule.pts.astype(dtype)
    x0 = jnp.zeros((sig2.shape[0], 4), dtype=dtype)

    print("\ncustom dirty bayes")
    db = jax.jit(jax.vmap(berry.build_dirty_bayes(sig2, n_arms=4, dtype=dtype)))
    my_timeit(N, lambda: db(data)[0].block_until_ready(), iter=iter)

    print("\ncustom dirty bayes")
    db = jax.jit(jax.vmap(berry.build_dirty_bayes(sig2, n_arms=4, dtype=dtype)))
    my_timeit(N, lambda: db(data)[0].block_until_ready(), iter=iter)

    def bench_ops(name, ops):
        print(f"\n{name} gaussian")
        hyperpost = jax.jit(jax.vmap(ops.laplace_logpost, in_axes=(None, None, 0)))
        p_pinned = dict(sig2=sig2, theta=None)
        my_timeit(
            N, lambda: hyperpost(x0, p_pinned, data)[0].block_until_ready(), iter=iter
        )

        print(f"\n{name} laplace")
        _, x_max, hess_info, _ = hyperpost(x0, p_pinned, data)
        arm_logpost_f = jax.jit(
            jax.vmap(
                jax.vmap(
                    ops.cond_laplace_logpost, in_axes=(0, 0, None, 0, 0, None, None)
                ),
                in_axes=(None, None, None, None, 0, None, None),
            ),
            static_argnums=(5, 6),
        )
        invv = jax.jit(jax.vmap(jax.vmap(ops.invert)))

        def f():
            inv_hess = invv(hess_info)
            arm_post = []
            for arm_idx in range(4):
                cx, wts = inla.gauss_hermite_grid(
                    x_max, inv_hess[..., arm_idx, :], arm_idx, n=25
                )
                arm_logpost = arm_logpost_f(
                    x_max, inv_hess[:, :, arm_idx], p_pinned, data, cx, arm_idx, True
                )
                arm_post.append(inla.exp_and_normalize(arm_logpost, wts, axis=0))
            return jnp.array(arm_post)

        my_timeit(N, jax.jit(f), iter=iter)

    custom_ops = berry.optimized(sig2, dtype=dtype).config(max_iter=10)
    bench_ops("custom berry", custom_ops)

    ad_ops = inla.from_log_joint(
        berry.log_joint(4), dict(sig2=np.array([np.nan]), theta=np.full(4, 0.0))
    ).config(max_iter=10)
    bench_ops("numpyro berry", ad_ops)

In [11]:
N = 1
dtype = np.float64
data = berry.figure2_data(N).astype(dtype)
sig2_rule = quad.log_gauss_rule(15, 1e-6, 1e3)
sig2 = sig2_rule.pts.astype(dtype)
x0 = jnp.zeros((sig2.shape[0], 4), dtype=dtype)

ad_ops = inla.from_log_joint(
    berry.log_joint(4), dict(sig2=np.array([np.nan]), theta=np.full(4, 0.0))
).config(max_iter=10)

hyperpost = jax.jit(jax.vmap(ad_ops.laplace_logpost, in_axes=(None, None, 0)))
p_pinned = dict(sig2=sig2, theta=None)
out = hyperpost(x0, p_pinned, data)

In [12]:
out

(DeviceArray([[-17.76098252, -16.48206911, -15.83229637, -16.36576571,
               -17.72656449, -19.50403024, -21.44891798, -23.192983  ,
               -23.9123101 , -24.68772225, -27.07693559, -30.02418804,
               -32.72084883, -34.81752249, -36.06468754]], dtype=float64),
 DeviceArray([[[-0.65720195, -0.65720081, -0.65719484, -0.65719371],
               [-0.65720546, -0.65720355, -0.65719345, -0.65719153],
               [-0.65721851, -0.65721369, -0.65718827, -0.65718345],
               [-0.65727494, -0.65725755, -0.65716589, -0.65714851],
               [-0.65757963, -0.65749441, -0.65704505, -0.65695985],
               [-0.65958489, -0.65905323, -0.65625057, -0.65571955],
               [-0.67465157, -0.67076447, -0.65032674, -0.64647392],
               [-0.78705011, -0.75796462, -0.60851132, -0.58150985],
               [-1.34091076, -1.17062088, -0.44985596, -0.34985839],
               [-2.47957009, -1.79535068, -0.28588712, -0.14714995],
               [-3.788