In [4]:
import oed_toolbox
import numpy as np
import jax
import jax.numpy as jnp

In [32]:
K_func = lambda d: jnp.einsum('...i,...j->...ij', -1*(d-1)**2, (d+1)**0.5)  # jnp.atleast_2d(-1*(d-5)**2 + 20)
b_func = lambda d: 0.2*d**(1/2) + 2
def create_linear_model(K_func, b_func):
    def linear_model(theta, d):
        theta = jnp.atleast_1d(theta.squeeze())
        return jnp.einsum('ij,j->i', K_func(d), theta) + b_func(d)
    return linear_model

In [38]:
prior_mean = jnp.array([1.0, -1.2])
prior_cov = jnp.identity(2)
noise_cov = 0.1*jnp.identity(2)
model_func = create_linear_model(K_func, b_func)
model_func_dt = jax.jacfwd(model_func, argnums=0)
model = oed_toolbox.models.Model.from_jax_function(model_func)

# Likelihood check:

In [39]:
def true_likelihood(y, theta, d):
    y_mean = model_func(theta, d)
    return jax.scipy.stats.multivariate_normal.logpdf(y, mean=y_mean, cov=noise_cov)

true_likelihood_funcs = \
{'logpdf': jax.vmap(true_likelihood, in_axes=(0,0,0)),
 'logpdf_dd': jax.vmap(jax.jacfwd(true_likelihood, argnums=2), in_axes=(0,0,0)), 
 'logpdf_dy': jax.vmap(jax.jacfwd(true_likelihood, argnums=0), in_axes=(0,0,0)),
 'logpdf_dt': jax.vmap(jax.jacfwd(true_likelihood, argnums=1), in_axes=(0,0,0)),
 'logpdf_dt_dy': jax.vmap(jax.jacfwd(jax.jacfwd(true_likelihood, argnums=1), argnums=0), in_axes=(0,0,0)),
 'logpdf_dt_dd': jax.vmap(jax.jacfwd(jax.jacfwd(true_likelihood, argnums=1), argnums=2), in_axes=(0,0,0)),
 'logpdf_dt_dt': jax.vmap(jax.jacfwd(jax.jacfwd(true_likelihood, argnums=1), argnums=1), in_axes=(0,0,0))}

In [40]:
likelihood = oed_toolbox.distributions.Likelihood.from_model_plus_constant_gaussian_noise(model, noise_cov)

In [41]:
y = jnp.array([[1., 0.5], [1., -1.], [-1., -0.6]])
theta = jnp.array([[1., 0.1], [0.33, 0.2], [1., 2.]])
d = jnp.array([[0.5, 1.], [0.33, 0.2],[1.2, 0.33]])
like_vals = likelihood.logpdf(y, theta, d, return_dd=True, return_dt=True, return_dy=True, return_dt_dt=True,
                             return_dt_dy=True, return_dt_dd=True)
for key, func in true_likelihood_funcs.items():
    print(f'Difference for {key}: \n {func(y,theta,d) - like_vals[key]}')

Difference for logpdf: 
 [0. 0. 0.]
Difference for logpdf_dd: 
 [[0. 0.]
 [0. 0.]
 [0. 0.]]
Difference for logpdf_dy: 
 [[0. 0.]
 [0. 0.]
 [0. 0.]]
Difference for logpdf_dt: 
 [[ 0.000000e+00  0.000000e+00]
 [ 0.000000e+00  0.000000e+00]
 [ 0.000000e+00 -9.536743e-07]]
Difference for logpdf_dt_dy: 
 [[[ 0. -0.]
  [ 0.  0.]]

 [[ 0.  0.]
  [ 0.  0.]]

 [[ 0.  0.]
  [ 0.  0.]]]
Difference for logpdf_dt_dd: 
 [[[0.000000e+00 0.000000e+00]
  [0.000000e+00 0.000000e+00]]

 [[0.000000e+00 0.000000e+00]
  [0.000000e+00 0.000000e+00]]

 [[0.000000e+00 0.000000e+00]
  [0.000000e+00 9.536743e-07]]]
Difference for logpdf_dt_dt: 
 [[[ 0.0000000e+00  0.0000000e+00]
  [ 0.0000000e+00  0.0000000e+00]]

 [[ 0.0000000e+00  0.0000000e+00]
  [ 0.0000000e+00  0.0000000e+00]]

 [[ 0.0000000e+00 -1.1920929e-07]
  [-1.1920929e-07  0.0000000e+00]]]


# Posterior check:

In [42]:
minimizer = oed_toolbox.optim.gradient_descent_for_map()
posterior = oed_toolbox.distributions.Posterior.laplace_approximation(model, minimizer, noise_cov, prior_mean, prior_cov)

In [43]:
noise_icov = jnp.linalg.inv(noise_cov)
prior_icov = jnp.linalg.inv(prior_cov)
def true_posterior(theta, y, d):
    K = K_func(d)
    b = b_func(d)
    icov = K.T @ noise_icov @ K + prior_icov
    cov = jnp.linalg.inv(icov)
    mean = ((y-b).T @ noise_icov @ K + prior_mean.T @ prior_icov) @ cov
    return jax.scipy.stats.multivariate_normal.logpdf(theta, mean=mean, cov=cov)

true_posterior_funcs = \
{'logpdf': jax.vmap(true_posterior, in_axes=(0,0,0)),
 'logpdf_dd': jax.vmap(jax.jacfwd(true_posterior, argnums=2), in_axes=(0,0,0)), 
 'logpdf_dy': jax.vmap(jax.jacfwd(true_posterior, argnums=1), in_axes=(0,0,0))}

In [None]:
# Nb: can get overflow-related errors if y - f(theta, d) is large
theta = jnp.array([[-0.33, 0.2], [1., 2.], [1., 0.1]]) 
y = jnp.array([ [1., -1.], [-1., -0.6], [1., 0.5]]) 
d = jnp.array([ [0.33, 0.2], [1.2, 0.33], [0.5, 1.]]) 
post_vals = posterior.logpdf(theta, y, d, return_dd=True, return_dy=True)
for key, func in true_posterior_funcs.items():
    print(f'Difference for {key}: \n {func(theta,y,d) - post_vals[key]}')

# Ape check:

In [10]:
ape = \
oed_toolbox.losses.APE.from_model_plus_constant_gaussian_noise(model, \
minimizer, prior_mean,prior_cov, noise_cov, apply_control_variates=False, use_reparameterisation=True)

In [11]:
vmap_true_posterior = jax.vmap(true_posterior, in_axes=(0,0,None))
vmap_model_func = jax.vmap(model_func, in_axes=(0,None))
noise_chol = jnp.linalg.cholesky(noise_cov)

def compute_y(d, epsilon_samples, theta_samples):
    return vmap_model_func(theta_samples, d) + jnp.einsum('ij,ai->aj', noise_chol, epsilon_samples)

def true_ape(d, theta_samples, epsilon_samples):
    y = compute_y(d, epsilon_samples, theta_samples)
    post_vals = vmap_true_posterior(theta_samples, y, d)
    return jnp.mean(post_vals)

true_ape_grad = jax.jacfwd(true_ape, argnums=0)

In [12]:
rng = jax.random.PRNGKey(22)
d = jnp.array([1.5, 0.2])
num_samples = 1000
theta_samples = jax.random.multivariate_normal(rng, mean=prior_mean, cov=prior_cov, shape=(num_samples,))
epsilon_samples = jax.random.multivariate_normal(rng, mean=jnp.zeros(noise_cov.shape[0]), \
                                                 cov=jnp.identity(noise_cov.shape[0]), shape=(num_samples,))
print(f'True values: APE = {true_ape(d, theta_samples, epsilon_samples)}, APE_dd = {true_ape_grad(d, theta_samples, epsilon_samples)}')
ape_loss = ape(d, samples={'epsilon': epsilon_samples, 'theta': theta_samples})
print(f"Computed Values: APE = {-1*ape_loss[0]}, APE_dd = {-1*ape_loss[1]}")
print(f'Difference: APE = {-1*ape_loss[0] - true_ape(d, theta_samples, epsilon_samples)}, APE_dd = {-1*ape_loss[1] - true_ape_grad(d, theta_samples, epsilon_samples)}')

True values: APE = -1.535464882850647, APE_dd = [ 0.52723044 -2.4393964 ]
Computed Values: APE = -1.5354645252227783, APE_dd = [ 0.52722939 -2.43936651]
Difference: APE = 3.5762786865234375e-07, APE_dd = [-1.0728836e-06  2.9802322e-05]


In [37]:
y_samples = compute_y(d, epsilon_samples, theta_samples)
ape_cv = \
oed_toolbox.losses.APE.from_model_plus_constant_gaussian_noise(model, \
minimizer, prior_mean, prior_cov, noise_cov, apply_control_variates=True)
ape_loss = ape_cv(d, num_samples)
print(f"Computed Values: APE = {-1*ape_loss[0]}, APE_dd = {-1*ape_loss[1]}")

Computed Values: APE = [-1.40510219], APE_dd = [ 0.54795729 -2.1830221 ]


In [32]:
ape_cv(d, num_samples)

(1.3739977791860656, array([-0.99735295,  2.1177567 ]))

In [33]:
ape(d, num_samples)

(1.3682229626601548, array([-0.55734692,  1.97240645]))