In [3]:
import oed_toolbox
import numpy as np
import jax
import jax.numpy as jnp

In [33]:
K_func = lambda d: jnp.atleast_2d(-1*d**2)
b_func = lambda d: 0.
def create_linear_model(K_func, b_func):
    def linear_model(theta, d):
        theta = jnp.atleast_1d(theta.squeeze())
        return jnp.einsum('ij,j->i', K_func(d), theta**2) + b_func(d)
    return linear_model

In [34]:
prior_mean = jnp.array([0.0])
prior_cov = jnp.array([[1.0]])
noise_cov = jnp.array([[0.1]])
lr=1e-5
max_iter=200
abs_tol = 1e-5
rel_tol=1e-5
model_func = create_linear_model(K_func, b_func)
model_func_dt = jax.jacfwd(model_func, argnums=0)
model = oed_toolbox.models.Model.from_jax_function(model_func)
minimizer = oed_toolbox.optim.gradient_descent_for_map(lr=lr, max_iter=max_iter, abs_tol=abs_tol, rel_tol=rel_tol)
posterior = oed_toolbox.uncertainties.Posterior.laplace_approximation(model, minimizer, noise_cov, prior_mean, prior_cov)

In [35]:
def theta_map(theta_0, y, d):
    num_opt_problems = theta_0.shape[0]
    converged = jnp.zeros((num_opt_problems,), dtype=bool)
    num_iter = 0
    loss_prev_iter = None
    theta = theta_0
    while not jnp.all(converged):
        loss, grad = map_loss_and_grad(theta, y, d)
        # Zero-out converged gradients:
        theta = theta - lr*jnp.einsum('a,a...->a...', ~converged, grad)
        converged = less_than_abs_tol(loss, loss_prev_iter, num_opt_problems) | \
                    less_than_rel_tol(loss, loss_prev_iter, num_opt_problems) | \
                    exceeded_max_iter(num_iter, max_iter, num_opt_problems)
        num_iter += 1
        loss_prev_iter = loss
    return theta

def map_loss_and_grad(theta, y, d):
    y_pred = model_func(theta, d)
    y_del_theta = model_func_dt(theta, d)
    loss = jnp.einsum("ai,ij,aj->a", y-y_pred, noise_icov, y-y_pred) + \
           jnp.einsum("ai,ij,aj->a", theta-prior_mean, prior_icov, theta-prior_mean)
    loss_del_theta = -2*jnp.einsum("aik,ij,aj->ak", y_del_theta, noise_icov, y-y_pred) + \
                      2*jnp.einsum("ij,aj->ai", prior_icov, theta-prior_mean)
    return loss, loss_del_theta

def less_than_abs_tol(loss, loss_prev_iter, num_opt_problems):
    if loss_prev_iter is None:
        is_lt_abs_tol = jnp.zeros((num_opt_problems,), dtype=bool)
    else:
        is_lt_abs_tol = jnp.abs(loss - loss_prev_iter) <= abs_tol
    return is_lt_abs_tol

def less_than_rel_tol(loss, loss_prev_iter, num_opt_problems):
    if loss_prev_iter is None:
        is_lt_rel_tol = jnp.zeros((num_opt_problems,), dtype=bool)
    else:
        is_lt_rel_tol = jnp.abs(loss - loss_prev_iter) <= rel_tol*loss_prev_iter
    return is_lt_rel_tol

def exceeded_max_iter(num_iter, max_iter, num_opt_problems):
    return (num_iter >= max_iter)*jnp.ones((num_opt_problems,), dtype=bool)

In [36]:
noise_icov = jnp.linalg.inv(noise_cov)
prior_icov = jnp.linalg.inv(prior_cov)

def jax_posterior(theta, y, d):
    t_map = theta_map(theta, y, d)
    y_map = model_func(t_map, d)
    y_map_dt = model_func_dt(t_map, d)
    b = linearisation_constant(y_map, y_map_dt, t_map)
    mean, cov, icov = mean_cov_and_icov(y, t_map, y_map, y_map_dt, b)
    return jax.scipy.stats.multivariate_normal.logpdf(theta, mean=mean, cov=cov)

def linearisation_constant(y_map, y_map_del_theta, theta_map):
    return y_map - jnp.einsum("aij,aj->ai", y_map_del_theta, theta_map)

def mean_cov_and_icov(y, theta_map, y_map, y_map_del_theta, b):
    inv_cov = jnp.einsum("aki,kl,alj->aij", y_map_del_theta, noise_icov, y_map_del_theta) + prior_icov
    cov = jnp.linalg.inv(inv_cov)
    mean_times_inv_cov = jnp.einsum("aj,jk,aki->ai", y-b, noise_icov, y_map_del_theta) + jnp.einsum('i,ij->j', prior_mean, prior_icov)
    mean = jnp.einsum("ak,aki->ai", mean_times_inv_cov, cov)
    return mean, cov, inv_cov

In [37]:
jax_posterior_grad = jax.jacfwd(jax_posterior, argnums=2)

In [38]:
theta = jnp.array([[1.]])
y = jnp.array([[1.]])
d =  jnp.array([[1.]])
print(jax_posterior(theta, y, d))
print(jax_posterior_grad(theta, y, d))

[-19.296803]
[[[-35.38819]]]


In [39]:
posterior.logpdf(theta, y, d, return_dd=True)

{'logpdf': array([-19.296804], dtype=float32),
 'logpdf_dd': array([[-37.031006]], dtype=float32)}