In [1]:
from inlaw.nb_util import setup_nb
setup_nb()

In [2]:
from inlaw.numpyro_interface import *
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
import scipy.special
import numpyro.distributions as dist
mu_0 = -1.34
mu_sig2 = 100.0
sig2_alpha = 0.0005
sig2_beta = 0.000005
logit_p1 = scipy.special.logit(0.3)
def berry_model_fast(d):
    def model(params, data):
        sig2 = params["sig2"]

        cov = jnp.full((d, d), mu_sig2) + jnp.diag(jnp.repeat(sig2, d))
        return (
            dist.InverseGamma(sig2_alpha, sig2_beta).log_prob(sig2)
            + dist.MultivariateNormal(mu_0, cov).log_prob(params["theta"])
            + jnp.sum(dist.BinomialLogits(
                params["theta"] + logit_p1, total_count=data[..., 1]
            ).log_prob(data[..., 0]))
        )

    return model

In [4]:
import inlaw.inla
fl = inlaw.inla.FullLaplace({"sig2": jnp.array([np.nan]), "theta": jnp.zeros(4)}, berry_model_fast(4))

In [9]:
jax.make_jaxpr(fl.optimizer)(
    np.zeros((1, 1, 4)), {"sig2": np.array([0.1]), "theta": None}, np.zeros((1, 4, 2))
)


{ [34m[22m[1mlambda [39m[22m[22ma[35m:i64[4][39m b[35m:f64[1][39m c[35m:i64[4][39m d[35m:f64[1][39m e[35m:i64[4][39m f[35m:i64[4][39m g[35m:i64[4][39m h[35m:i64[4][39m; i[35m:f64[1,1,4][39m
    j[35m:f64[1][39m k[35m:f64[1,4,2][39m. [34m[22m[1mlet
    [39m[22m[22m_[35m:f64[1][39m = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] nan
    l[35m:f64[4][39m = broadcast_in_dim[broadcast_dimensions=() shape=(4,)] nan
    m[35m:bool[4][39m = lt a 0
    n[35m:i64[4][39m = add a 4
    o[35m:i64[4][39m = select_n m a n
    p[35m:i32[4][39m = convert_element_type[new_dtype=int32 weak_type=False] o
    q[35m:i32[4,1][39m = broadcast_in_dim[broadcast_dimensions=(0,) shape=(4, 1)] p
    r[35m:f64[4][39m = convert_element_type[new_dtype=float64 weak_type=False] l
    s[35m:f64[1,4][39m = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 4)] r
    t[35m:f64[1,1,4][39m = broadcast_in_dim[broadcast_dimensions=(1, 2) shape=(1, 1, 4)] s
    u[

In [5]:
h = jax.jit(jax.hessian(fl.log_joint_single))

In [6]:
%%time
h(
    {"sig2": None, "theta": np.array([0.1, 0.1, 0.1, 0.1])},
    {"sig2": 0.1, "theta": None},
    np.zeros((4, 2)),
)

CPU times: user 123 ms, sys: 5.39 ms, total: 128 ms
Wall time: 124 ms


{'sig2': None,
 'theta': {'sig2': None,
  'theta': DeviceArray([[-7.50062484,  2.49937516,  2.49937516,  2.49937516],
               [ 2.49937516, -7.50062484,  2.49937516,  2.49937516],
               [ 2.49937516,  2.49937516, -7.50062484,  2.49937516],
               [ 2.49937516,  2.49937516,  2.49937516, -7.50062484]], dtype=float64)}}

In [5]:

%%time
h2 = jax.jit(inlaw.inla.build_grad_hess(fl.log_joint_single, fl.spec))
h2(
    np.array([[[0.1, 0.1, 0.1, 0.1]]]),
    {"sig2": np.array([0.1]), "theta": None},
    np.zeros((1, 4, 2)),
)


CPU times: user 376 ms, sys: 9.8 ms, total: 386 ms
Wall time: 382 ms


(DeviceArray([[[-0.0035991, -0.0035991, -0.0035991, -0.0035991]]], dtype=float64, weak_type=True),
 DeviceArray([[[[-7.50062484,  2.49937516,  2.49937516,  2.49937516],
                [ 2.49937516, -7.50062484,  2.49937516,  2.49937516],
                [ 2.49937516,  2.49937516, -7.50062484,  2.49937516],
                [ 2.49937516,  2.49937516,  2.49937516, -7.50062484]]]],            dtype=float64, weak_type=True))

In [5]:
def build_grad_hess(log_joint_single, param_spec):
    def grad_hess(x, p_pinned, data):
        # The inputs to grad_hess are pytrees but the output grad/hess are
        # flattened.
        p = param_spec.unravel_f(x)
        grad = jax.grad(log_joint_single)(p, p_pinned, data)
        hess = jax.hessian(log_joint_single)(p, p_pinned, data)

        full_grad = param_spec.ravel_f(grad)
        full_hess = jnp.concatenate(
            [
                jnp.concatenate(
                    [
                        hess[k1][k2][param_spec.not_nan[k1]][:, param_spec.not_nan[k2]]
                        for k2 in param_spec.key_order
                        if hess[k1][k2] is not None
                    ],
                    axis=-1,
                )
                for k1 in param_spec.key_order
                if hess[k1] is not None
            ],
            axis=-2,
        )
        return full_grad, full_hess
    return grad_hess

In [6]:
%%time
h3 = jax.jit(build_grad_hess(fl.log_joint_single, fl.spec))
h3(
    np.array([0.1, 0.1, 0.1, 0.1]),
    {"sig2": 0.1, "theta": None},
    np.zeros((4, 2)),
)

CPU times: user 210 ms, sys: 7.04 ms, total: 217 ms
Wall time: 213 ms


(DeviceArray([-0.0035991, -0.0035991, -0.0035991, -0.0035991], dtype=float64, weak_type=True),
 DeviceArray([[-7.50062484,  2.49937516,  2.49937516,  2.49937516],
              [ 2.49937516, -7.50062484,  2.49937516,  2.49937516],
              [ 2.49937516,  2.49937516, -7.50062484,  2.49937516],
              [ 2.49937516,  2.49937516,  2.49937516, -7.50062484]], dtype=float64, weak_type=True))

In [5]:
def build_grad_hess(log_joint_single, param_spec):
    def grad_hess(p, p_pinned, data):
        # The inputs to grad_hess are pytrees but the output grad/hess are
        # flattened.
        return jax.hessian(log_joint_single)(p, p_pinned, data)
    return grad_hess

In [6]:
h4 = jax.jit(build_grad_hess(fl.log_joint_single, fl.spec))

In [7]:
%%time
h4(
    {"sig2": None, "theta": np.array([0.1, 0.1, 0.1, 0.1])},
    {"sig2": 0.1, "theta": None},
    np.zeros((4, 2)),
)

CPU times: user 963 ms, sys: 20.2 ms, total: 984 ms
Wall time: 984 ms


{'sig2': None,
 'theta': {'sig2': None,
  'theta': DeviceArray([[-7.50062484,  2.49937516,  2.49937516,  2.49937516],
               [ 2.49937516, -7.50062484,  2.49937516,  2.49937516],
               [ 2.49937516,  2.49937516, -7.50062484,  2.49937516],
               [ 2.49937516,  2.49937516,  2.49937516, -7.50062484]], dtype=float64)}}

In [34]:
m = berry_model_fast(4)
h = jax.jit(jax.hessian(m))

In [35]:
%%time
h({"sig2": jnp.array([0.1]), "theta": jnp.zeros(4)}, jnp.array([[0, 0, 0, 0]]))

CPU times: user 434 ms, sys: 9.58 ms, total: 443 ms
Wall time: 436 ms


{'sig2': {'sig2': DeviceArray([[[250.04000301]]], dtype=float64),
  'theta': DeviceArray([[[8.37081406e-06, 8.37081407e-06, 8.37081407e-06, 8.37081407e-06]]], dtype=float64)},
 'theta': {'sig2': DeviceArray([[[8.37081406e-06],
                [8.37081407e-06],
                [8.37081407e-06],
                [8.37081407e-06]]], dtype=float64),
  'theta': DeviceArray([[[-7.50062484,  2.49937516,  2.49937516,  2.49937516],
                [ 2.49937516, -7.50062484,  2.49937516,  2.49937516],
                [ 2.49937516,  2.49937516, -7.50062484,  2.49937516],
                [ 2.49937516,  2.49937516,  2.49937516, -7.50062484]]], dtype=float64)}}

In [25]:
fns, ex, ll_fnc = from_numpyro(berry_model(4), "sig2", (4,2))

In [26]:
ex

{'sig2': array([nan]), 'theta': array([0., 0., 0., 0.])}

In [21]:
dir(fns['y'])

['__call__',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__signature__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_batch_shape',
 '_event_shape',
 '_validate_args',
 '_validate_sample',
 'arg_constraints',
 'batch_shape',
 'cdf',
 'enumerate_support',
 'event_dim',
 'event_shape',
 'expand',
 'expand_by',
 'has_enumerate_support',
 'has_rsample',
 'icdf',
 'infer_shapes',
 'is_discrete',
 'log_prob',
 'logits',
 'mask',
 'mean',
 'probs',
 'reparametrized_params',
 'rsample',
 'sample',
 'sample_with_intermediates',
 'set_default_validate_args',
 'shape',
 'support',
 'to_event',
 'total_count',
 'tree_flatten',
 'tree_unflatten',
 'variance']

In [23]:
fns['y']

<numpyro.distributions.discrete.BinomialLogits at 0x29e747d90>

In [18]:
ll_fnc(
    dict(sig2=np.random.rand(1), theta=None),
    dict(sig2=None, theta=np.random.rand(4)),
    np.random.rand(4, 2),
)


KeyError: 'y'