In [1]:
import oed_toolbox
import numpy as np
import jax
import jax.numpy as jnp

In [3]:
def create_linear_model(K_func, b_func):
    def linear_model(theta, d):
        return jnp.einsum('ij,j->i', K_func(d), 2*jnp.log(theta) - theta**-0.5) + b_func(d)*theta**2
    return linear_model

In [25]:
K_func = lambda d: jnp.atleast_2d(-1*(d-5)**2 + 20)
b_func = lambda d: 0.5*d**(5/2)
model_func = create_linear_model(K_func, b_func)
model = oed_toolbox.models.Model.from_jax_function(model_func)

In [30]:
ndim = 1
noise = jnp.identity(1)
likelihood = oed_toolbox.distributions.Likelihood.from_model_plus_constant_gaussian_noise(model, noise)

In [31]:
def true_likelihood(y, theta, d):
    y_pred = model_func(theta, d)
    return jax.scipy.stats.multivariate_normal.logpdf(y, mean=y_pred, cov=noise*jnp.identity(y_pred.shape[0]))
true_likelihood_dt = jax.vmap(jax.jacfwd(true_likelihood, argnums=1), in_axes=(0,None,None))

In [32]:
noise_chol = jnp.linalg.cholesky(noise)
def transform(epsilon, theta, d):
    return model_func(theta,d) + jnp.einsum('ij,aj->ai', noise_chol, epsilon)

def jax_fisher_info(d, theta, epsilon):
    y = transform(epsilon, theta, d)
    ll_dt = true_likelihood_dt(y, theta, d)
    return jnp.mean(jnp.einsum('ai,aj->aij', ll_dt, ll_dt), axis=0)

true_fisher_info = {'cov': jax_fisher_info,
                    'cov_dd': jax.jacfwd(jax_fisher_info, argnums=0)}

In [33]:
rng = jax.random.PRNGKey(6)
num_samples = 1000
d = jnp.array([1.6])
theta = jnp.array([1.])
epsilon = jax.random.multivariate_normal(rng, mean=jnp.array([0.]), cov=jnp.array([[1.]]), shape=(num_samples,))

In [34]:
fisher_info = oed_toolbox.covariances.FisherInformation(likelihood, use_reparameterisation=True)
fi_vals = fisher_info(d, theta, num_samples, samples=epsilon)
for key, func in true_fisher_info.items():
    true_val = func(d, theta, epsilon)
    print(f'True {key}: {true_val}, \n Computed {key}: {fi_vals[key]}, \n Difference: {true_val-fi_vals[key]}')

True cov: [[601.1879]], 
 Computed cov: [[601.1879]], 
 Difference: [[0.]]
True cov_dd: [[[1089.8102]]], 
 Computed cov_dd: [[[1089.8102]]], 
 Difference: [[[0.]]]


In [35]:
fisher_info = oed_toolbox.covariances.FisherInformation(likelihood)
y_samples = transform(epsilon, theta, d)
fi_vals = fisher_info(d, theta, num_samples, samples=y_samples) 
for key, func in true_fisher_info.items():
    true_val = func(d, theta, epsilon)
    print(f'True {key}: {true_val}, \n Computed {key}: {fi_vals[key]}, \n Difference: {true_val-fi_vals[key]}')

True cov: [[601.1879]], 
 Computed cov: [[601.25308413]], 
 Difference: [[-0.06518555]]
True cov_dd: [[[1089.8102]]], 
 Computed cov_dd: [[[1167.67463642]]], 
 Difference: [[[-77.8645]]]


In [38]:
model_func_dt = jax.jacfwd(model_func, argnums=0)
def jax_predictive_cov(d, theta, epsilon):
    fisher_info = jax_fisher_info(d, theta, epsilon)
    inv_fisher_info = jnp.linalg.inv(fisher_info)
    y_dt = model_func_dt(theta, d)
    return jnp.einsum('ij,jk,kl->il', y_dt, inv_fisher_info, y_dt.T)
true_pred_cov = {'cov': jax_predictive_cov,
                 'cov_dd': jax.jacfwd(jax_predictive_cov, argnums=0)}

In [39]:
fisher_info = oed_toolbox.covariances.FisherInformation(likelihood, use_reparameterisation=True)
pred_cov = oed_toolbox.covariances.PredictiveCovariance(model, fisher_info)
pc_vals = pred_cov(d, theta, samples=epsilon)
for key, func in true_pred_cov.items():
    true_val = func(d, theta, epsilon)
    print(f'True {key}: {true_val}, \n Computed {key}: {pc_vals[key]}, \n Difference: {true_val-pc_vals[key]}')

True cov: [[0.9852935]], 
 Computed cov: [[0.98529357]], 
 Difference: [[-5.9604645e-08]]
True cov_dd: [[[-2.117116e-07]]], 
 Computed cov_dd: [[[-2.3841858e-07]]], 
 Difference: [[[2.6706985e-08]]]


In [41]:
fisher_info = oed_toolbox.covariances.FisherInformation(likelihood)
pred_cov = oed_toolbox.covariances.PredictiveCovariance(model, fisher_info)
y_samples = transform(epsilon, theta, d)
pc_vals = pred_cov(d, theta, samples=y_samples)
for key, func in true_pred_cov.items():
    true_val = func(d, theta, epsilon)
    print(f'True {key}: {true_val}, \n Computed {key}: {pc_vals[key]}, \n Difference: {true_val-pc_vals[key]}')

True cov: [[0.9852935]], 
 Computed cov: [[0.98518674]], 
 Difference: [[0.00010675]]
True cov_dd: [[[-2.117116e-07]]], 
 Computed cov_dd: [[[-0.12739188]]], 
 Difference: [[[0.12739168]]]
