In [None]:
@jax.jit
def log_joint(data, rho, neg_precQ, logprecQdet, log_prior):
    """
    theta is expected to have shape (N, n_sigma2, n_arms)
    """
    n_events = data[..., 0]
    total_obs_time = data[..., 1]
    rho_m0 = rho - mu_mean
    hazard = 1.0 / (jnp.exp(rho) * lambdaj[None])
    return (
        0.5 * jnp.einsum("...i,...ij,...j", rho_m0, neg_precQ, rho_m0)
        + logprecQdet
        + jnp.sum(
            jnp.log(hazard) * n_events[:, None] - hazard * total_obs_time[:, None],
            axis=-1,
        )
        + log_prior
    )


def scalar_log_joint_opt(rho, neg_precQ, n_events, total_obs_time):
    """
    theta is expected to have shape (N, n_sigma2, n_arms)
    """
    rho_m0 = rho - mu_mean
    hazard = 1.0 / (jnp.exp(rho) * lambdaj)
    return 0.5 * neg_precQ.dot(rho_m0).T.dot(rho_m0) + jnp.sum(
        jnp.log(hazard) * n_events - hazard * total_obs_time
    )


grad_opt = jax.jit(
    jax.vmap(
        jax.vmap(jax.grad(scalar_log_joint_opt), in_axes=(0, 0, None, None)),
        in_axes=(0, None, 0, 0),
    )
)
hessian_opt = jax.jit(
    jax.vmap(
        jax.vmap(jax.hessian(scalar_log_joint_opt), in_axes=(0, 0, None, None)),
        in_axes=(0, None, 0, 0),
    )
)

def grad_hess(fi, data, rho, arms_opt):
    grad = grad_opt(rho, fi.neg_precQ, data[..., 0], data[..., 1])
    hess = hessian_opt(rho, fi.neg_precQ, data[..., 0], data[..., 1])
    return grad, hess


import berrylib.fast_inla as fast_inla

def log_joint_wrapper(fi, data, rho):
    return log_joint(data, rho, fi.neg_precQ, fi.logprecQdet, fi.log_prior)

model = fast_inla.FastINLAModel(log_joint_wrapper, grad_hess)
fi = fast_inla.FastINLA(
    model=model,
    n_arms=3,
    mu_0=mu_mean,
    mu_sig2=mu_sig2,
    sigma2_n=20,
    sigma2_bounds=[0.006401632420120484, 2.8421994410275007],
    sigma2_alpha=sig2_alpha,
    sigma2_beta=sig2_beta,
)