In [1]:
import oed_toolbox
import numpy as np
import jax
import jax.numpy as jnp
from jax.scipy import optimize as joptimize

In [2]:
K_func = lambda d: jnp.atleast_2d(-1*d**2)
b_func = lambda d: 0.
def create_linear_model(K_func, b_func):
    def linear_model(theta, d):
        # Not that this model is NON-LINEAR wrt theta:
        theta = jnp.atleast_1d(theta.squeeze())
        return jnp.einsum('ij,j->i', K_func(d), -1*theta**(3/2) + theta**(-0.5)) + b_func(d)
    return linear_model

In [3]:
prior_mean = jnp.array([0.0])
prior_cov = jnp.array([[1.0]])
noise_cov = jnp.array([[0.1]])
lr=1e-6
max_iter=10
abs_tol = 1e-6
rel_tol=1e-6
model_func = create_linear_model(K_func, b_func)
model_func_dt = jax.jacfwd(model_func, argnums=0)
model = oed_toolbox.models.Model.from_jax_function(model_func)
minimizer = oed_toolbox.optim.gradient_descent_for_map(lr=lr, max_iter=max_iter, abs_tol=abs_tol, rel_tol=rel_tol)
posterior = oed_toolbox.distributions.Posterior.laplace_approximation(model, minimizer, noise_cov, prior_mean, prior_cov)



In [6]:
noise_icov = jnp.linalg.inv(noise_cov)
prior_icov = jnp.linalg.inv(prior_cov)

def loss(theta, y, d):
    y_pred = model_func(theta, d)
    return jnp.einsum("ai,ij,aj->", y-y_pred, noise_icov, y-y_pred) + \
           jnp.einsum("i,ij,j->", theta-prior_mean, prior_icov, theta-prior_mean)

def jax_posterior(theta, y, d):
    t_map = joptimize.minimize(loss, theta[0,:], args=(y, d), method='BFGS').x
    t_map = t_map[None,:]
    y_map = model_func(t_map, d)
    y_map_dt = model_func_dt(t_map, d)
    b = linearisation_constant(y_map, y_map_dt, t_map)
    mean, cov, icov = mean_cov_and_icov(y, t_map, y_map_dt, b)
    return jax.scipy.stats.multivariate_normal.logpdf(theta, mean=mean, cov=cov)

def linearisation_constant(y_map, y_map_del_theta, theta_map):
    return y_map - jnp.einsum("aij,aj->ai", y_map_del_theta, theta_map)

def mean_cov_and_icov(y, theta_map, y_map_del_theta, b):
    inv_cov = jnp.einsum("aki,kl,alj->aij", y_map_del_theta, noise_icov, y_map_del_theta) + prior_icov
    cov = jnp.linalg.inv(inv_cov)
    mean_times_inv_cov = jnp.einsum("aj,jk,aki->ai", y-b, noise_icov, y_map_del_theta) \
                         + jnp.einsum('i,ij->j', prior_mean, prior_icov)
    mean = jnp.einsum("ak,aki->ai", mean_times_inv_cov, cov)
    return mean, cov, inv_cov

true_posterior = {'logpdf': jax_posterior,
                  'logpdf_dy': jax.jacfwd(jax_posterior, argnums=1),
                  'logpdf_dd': jax.jacfwd(jax_posterior, argnums=2)}

In [7]:
key = jax.random.PRNGKey(2)
noise = jax.random.multivariate_normal(key, mean=jnp.array([0]), cov=noise_cov)
theta = jnp.array([[1.]])
d = jnp.array([[1.]])
# Nb: if y is far away from model(theta, d), then grad descent will fail:
y = jnp.array([model_func(theta, d) + noise])
post_val = posterior.logpdf(theta, y, d, return_dd=True, return_dy=True)
for key, func in true_posterior.items():
    val = func(theta, y, d)
    print(f'True {key}: {val}, Computed {key}: {post_val[key]}, Difference: {val-post_val[key]}')

True logpdf: [0.87818366], Computed logpdf: [0.87818366], Difference: [0.]
True logpdf_dy: [[[1.0789679]]], Computed logpdf_dy: [[1.07898716]], Difference: [[[-1.9192696e-05]]]
True logpdf_dd: [[[2.0532954]]], Computed logpdf_dd: [[2.053299]], Difference: [[[-3.5762787e-06]]]
