In [1]:
import oed_toolbox
import numpy as np
import jax
import jax.numpy as jnp

In [2]:
def create_linear_model(K_func, b_func):
    def linear_model(theta, d):
        return jnp.einsum('ij,j->i', K_func(d), theta) + b_func(d)
    return linear_model

In [18]:
K_func = lambda d: jnp.atleast_2d(-1*(d-5)**2 + 20)
b_func = lambda d: 0.
model_func = create_linear_model(K_func, b_func)
model = oed_toolbox.models.Model.from_jax_function(model_func)

In [19]:
likelihood = oed_toolbox.uncertainties.Likelihood.from_model_plus_constant_gaussian_noise(model, 0.1)

In [20]:
def my_func(y, theta, d):
    y_pred = model_func(theta, d)
    return jax.scipy.stats.multivariate_normal.logpdf(y, mean=y_pred, cov=0.1*jnp.identity(y_pred.shape[0]))

In [21]:
ll_dt_dt = jax.vmap(jax.jacfwd(jax.jacfwd(my_func, argnums=1), argnums=1), in_axes=(0,None,None))

def jax_fisher_info(d, theta, num_samples, rng):
    y_pred = model_func(theta, d)
    y = y_pred + jax.random.multivariate_normal(rng, mean=jnp.array([0.]), cov=0.1*jnp.array([[1.]]), shape=(num_samples,))
    return -1*jnp.mean(ll_dt_dt(y, theta, d), axis=0)

fisher_info_dd = jax.jacfwd(jax_fisher_info, argnums=0)

In [22]:
print(jax_fisher_info(jnp.array([4.]), jnp.array([100.]), 10, jax.random.PRNGKey(20)))
print(fisher_info_dd(jnp.array([4.]), jnp.array([100.]), 10, jax.random.PRNGKey(20)))

[[3610.]]
[[[760.]]]


In [23]:
fisher_info = oed_toolbox.uncertainties.FisherInformation(likelihood)
fisher_info(4.0, 100, 10)

{'cov': array([[3610.]]), 'cov_dd': array([[[760.]]])}

In [None]:
ll_dt_dt = jax.vmap(jax.jacfwd(jax.jacfwd(my_func, argnums=1), argnums=1), in_axes=(0,None,None))

def jax_predictive_cov(d, theta, num_samples, rng):
    fisher_info = jax_fisher_info(d, theta, num_samples, rng)
    y_pred = model_func(theta, d)
    y = y_pred + jax.random.multivariate_normal(rng, mean=jnp.array([0.]), cov=0.1*jnp.array([[1.]]), shape=(num_samples,))
    return -1*jnp.mean(ll_dt_dt(y, theta, d), axis=0)

fisher_info_dd = jax.jacfwd(jax_fisher_info, argnums=0)

In [8]:
np.concatenate([np.ones((3,2)), np.ones((3,5))], axis=1)

array([[1., 1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1.]])

In [None]:
def create_posterior_cov(model, prior_cov, noise_cov):
    prior_icov = jnp.linalg.inv(prior_cov)
    noise_icov = jnp.linalg.inv(noise_cov)
    def posterior_cov_and_grad(d, theta, num_samples, rng):
        y_dt = model_dt(theta, d)
        y_dt_dd = model_dt_dd(theta, d)
        post_icov = jnp.einsum("ki,kl,lj->ij", y_dt, noise_icov, y_dt) + prior_icov
        post_cov = jnp.linalg.inv(post_icov)
        post_icov_dd = 2*jnp.einsum("lik,lm,mj->ijk", y_dt_dd, noise_icov, y_dt) 
        post_cov_dd = -1*jnp.einsum("il,lmk,mj->ijk", post_cov, post_icov_dd, post_cov)
        return {'cov': post_cov, 'cov_dd': post_cov_dd}
    return posterior_cov_and_grad
cov_and_grad = create_posterior_cov_and_grad(model, prior_cov, noise_cov)
cov = oed_toolbox.covariances.Covariance(cov_and_grad)