In [1]:
import oed_toolbox
import numpy as np
import jax
import jax.numpy as jnp

In [2]:
def create_linear_model(K_func, b_func):
    def linear_model(theta, d):
        return jnp.einsum('ij,j->i', K_func(d), theta) + b_func(d)
    return linear_model

In [3]:
K_func = lambda d: jnp.atleast_2d(-1*(d-5)**2 + 20)
b_func = lambda d: 0.
model_func = create_linear_model(K_func, b_func)
model = oed_toolbox.models.Model.from_jax_function(model_func)

In [4]:
likelihood = oed_toolbox.uncertainties.Likelihood.from_model_plus_constant_gaussian_noise(model, 0.1)

In [5]:
def my_func(y, theta, d):
    y_pred = model_func(theta, d)
    return jax.scipy.stats.multivariate_normal.logpdf(y, mean=y_pred, cov=0.1*jnp.identity(y_pred.shape[0]))

In [6]:
ll_dt_dt = jax.vmap(jax.jacfwd(jax.jacfwd(my_func, argnums=1), argnums=1), in_axes=(0,None,None))

def jax_fisher_info(d, theta, num_samples, rng):
    y_pred = model_func(theta, d)
    y = y_pred + jax.random.multivariate_normal(rng, mean=jnp.array([0.]), cov=0.1*jnp.array([[1.]]), shape=(num_samples,))
    return -1*jnp.mean(ll_dt_dt(y, theta, d), axis=0)

fisher_info_dd = jax.jacfwd(jax_fisher_info, argnums=0)

In [7]:
print(jax_fisher_info(jnp.array([4.]), jnp.array([100.]), 10, jax.random.PRNGKey(20)))
print(fisher_info_dd(jnp.array([4.]), jnp.array([100.]), 10, jax.random.PRNGKey(20)))



[[3610.]]
[[[760.]]]


In [8]:
fisher_info = oed_toolbox.uncertainties.FisherInformation(likelihood)
fisher_info(4.0, 100, 10)

{'cov': array([[3610.]]), 'cov_dd': array([[[760.]]])}

In [9]:
model_func_dt = jax.jacfwd(model_func, argnums=0)

def jax_predictive_cov(d, theta, num_samples, rng):
    fisher_info = jax_fisher_info(d, theta, num_samples, rng)
    inv_fisher_info = jnp.linalg.inv(fisher_info)
    y_dt = y_pred = model_func_dt(theta, d)
    return jnp.einsum('ij,jk,kl->il', y_dt, inv_fisher_info, y_dt.T)

pred_cov_dd = jax.jacfwd(jax_predictive_cov, argnums=0)

In [10]:
print(jax_predictive_cov(jnp.array([4.]), jnp.array([100.]), 10, jax.random.PRNGKey(20)))
print(pred_cov_dd(jnp.array([4.]), jnp.array([100.]), 10, jax.random.PRNGKey(20)))

[[0.1]]
[[[8.1490725e-10]]]


In [11]:
pred_cov = oed_toolbox.uncertainties.PredictiveCovariance(model, fisher_info)
pred_cov(4.0, 100, 10)

{'cov': array([[0.1]]), 'cov_dd': array([[[0.]]])}