In [1]:
import oed_toolbox
import numpy as np
import jax
import jax.numpy as jnp

In [2]:
K_func = lambda d: jnp.atleast_2d(-1*(d-5)**2 + 20)
b_func = lambda d: 0.
def create_linear_model(K_func, b_func):
    def linear_model(theta, d):
        theta = jnp.atleast_1d(theta.squeeze())
        return jnp.einsum('ij,j->i', K_func(d), theta) + b_func(d)
    return linear_model

In [3]:
prior_mean = jnp.array([0.0])
prior_cov = jnp.array([[1.0]])
noise_cov = jnp.array([[0.1]])
model_func = create_linear_model(K_func, b_func)
model_func_dt = jax.jacfwd(model_func, argnums=0)
model = oed_toolbox.models.Model.from_jax_function(model_func)



# Likelihood check:

In [30]:
def true_likelihood(y, theta, d):
    y_mean = model_func(theta, d)
    return jax.scipy.stats.multivariate_normal.logpdf(y, mean=y_mean, cov=noise_cov)
true_likelihood_vmap = jax.vmap(true_likelihood, in_axes=(0,0,0))
true_likelihood_dd = jax.vmap(jax.jacfwd(true_likelihood, argnums=2), in_axes=(0,0,0))

In [31]:
likelihood = oed_toolbox.uncertainties.Likelihood.from_model_plus_constant_gaussian_noise(model, noise_cov)

In [34]:
theta = jnp.array([[1.], [2.], [3.]])
y = jnp.array([[1.], [2.], [3.]])
d = jnp.array([[1.], [2.],[ 3.]])
print(f'True values: logpdf = {true_likelihood_vmap(y, theta, d)}, logpdf_dd = {true_likelihood_dd(y, theta, d)}')
like_vals = likelihood.logpdf(y, theta, d, return_dd=True)
print(f"Computed Values: logpdf = {like_vals['logpdf']}, logpdf_dd = {like_vals['logpdf_dd']}")

True values: logpdf = [   -44.76765  -1999.7676  -10124.767  ], logpdf_dd = [[ -240.00002]
 [-2400.     ]
 [-5400.     ]]
Computed Values: logpdf = [   -44.767647  -1999.7677   -10124.768   ], logpdf_dd = [[ -240.]
 [-2400.]
 [-5400.]]


# Posterior check:

In [7]:
minimizer = oed_toolbox.optim.gradient_descent_for_map()
posterior = oed_toolbox.uncertainties.Posterior.laplace_approximation(model, minimizer, noise_cov, prior_mean, prior_cov)

In [8]:
noise_icov = jnp.linalg.inv(noise_cov)
prior_icov = jnp.linalg.inv(prior_cov)
def true_posterior(theta, y, d):
    K = K_func(d)
    b = b_func(d)
    icov = K.T @ noise_icov @ K + prior_icov
    cov = jnp.linalg.inv(icov)
    mean = ((y-b).T @ noise_icov @ K + prior_mean.T @ prior_icov) @ cov
    print(mean)
    print(cov)
    return jax.scipy.stats.multivariate_normal.logpdf(theta, mean=mean, cov=cov)
true_posterior_vmap = jax.vmap(true_posterior, in_axes=(0,0,0))
true_posterior_grad = jax.jacfwd(true_posterior, argnums=2)

In [20]:
theta = jnp.array([[1.3315865]])
y = jnp.array([[10.74088]])
d = jnp.array([1.5])
print(f'True values: logpdf = {true_posterior(theta, y, d)}, logpdf_dd = {true_posterior_grad(theta, y, d)}')
post_vals = posterior.logpdf(theta, y, d, return_dd=True)
print(f"Computed Values: logpdf = {post_vals['logpdf']}, logpdf_dd = {post_vals['logpdf_dd']}")

[[1.3836163]]
[[0.00166216]]
Traced<ConcreteArray([[1.3836163]], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = DeviceArray([[1.3836163]], dtype=float32)
  tangent = Traced<ShapedArray(float32[1,1])>with<BatchTrace(level=1/0)> with
    val = DeviceArray([[[-1.2455635]]], dtype=float32)
    batch_dim = 1
Traced<ConcreteArray([[0.00166216]], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = DeviceArray([[0.00166216]], dtype=float32)
  tangent = Traced<ShapedArray(float32[1,1])>with<BatchTrace(level=1/0)> with
    val = DeviceArray([[[-0.00299763]]], dtype=float32)
    batch_dim = 0
True values: logpdf = [1.4665475], logpdf_dd = [[38.42232]]
[[1.3836163]]
[[[0.00166216]]]
Computed Values: logpdf = [1.4665475], logpdf_dd = [[38.422318]]


# Ape check:

In [10]:
ape = \
oed_toolbox.losses.APE.from_model_plus_constant_gaussian_noise(model, minimizer, prior_mean, \
                                                               prior_cov, noise_cov, apply_control_variates=True)

In [22]:
vmap_true_posterior = jax.vmap(true_posterior, in_axes=(0,0,None))

def sample_joint(d, num_samples, rng):
    theta = jax.random.multivariate_normal(rng, mean=prior_mean, cov=prior_cov, shape=(num_samples,))
    y = model_func(theta, d) + \
    jax.random.multivariate_normal(rng, mean=jnp.zeros(noise_cov.shape[0]), cov=noise_cov, shape=(num_samples,))
    return theta, y

def true_ape(d, num_samples, rng):
#     theta, y = sample_joint(d, num_samples, rng)
    theta = jnp.array([[1.3315865 ], [0.71527897]])
    y = jnp.array([[10.74088  ], [ 5.9644966]])
    post_vals = vmap_true_posterior(theta, y, d)
    return jnp.mean(post_vals)

true_ape_grad = jax.jacfwd(true_ape, argnums=0)

In [23]:
rng = jax.random.PRNGKey(20)
d = jnp.array([1.5])
num_samples = 2
print(f'True values: APE = {true_ape(d, num_samples, rng)}, APE_dd = {true_ape_grad(d, num_samples, rng)}')
ape_loss = ape(d, num_samples)
print(f"Computed Values: APE = {-1*ape_loss[0]}, APE_dd = {-1*ape_loss[1]}")

Traced<ShapedArray(float32[1])>with<BatchTrace(level=1/0)> with
  val = DeviceArray([[1.3836163],
             [0.7683332]], dtype=float32)
  batch_dim = 0
[[0.00166216]]
Traced<ShapedArray(float32[1])>with<BatchTrace(level=3/0)> with
  val = Traced<ConcreteArray([[1.3836163]
 [0.7683332]], dtype=float32)>with<JVPTrace(level=2/0)> with
    primal = DeviceArray([[1.3836163],
             [0.7683332]], dtype=float32)
    tangent = Traced<ShapedArray(float32[2,1])>with<BatchTrace(level=1/0)> with
      val = DeviceArray([[[-1.2455635 ]],

             [[-0.69167125]]], dtype=float32)
      batch_dim = 1
  batch_dim = 0
Traced<ConcreteArray([[0.00166216]], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = DeviceArray([[0.00166216]], dtype=float32)
  tangent = Traced<ShapedArray(float32[1,1])>with<BatchTrace(level=1/0)> with
    val = DeviceArray([[[-0.00299763]]], dtype=float32)
    batch_dim = 0
True values: APE = 1.45035719871521, APE_dd = [29.93716]
[[ 0.7913362 ]
 [-0.07422391]]

In [25]:
theta = jnp.array([[1.3315865 ], [0.71527897]])
y = jnp.array([[10.74088  ], [ 5.9644966]])
true_posterior(theta[0,:], y[0,:], d)

[1.3836163]
[[0.00166216]]


DeviceArray(1.4665475, dtype=float32)

In [26]:
posterior.logpdf(theta, y, d)

[[1.3836163]
 [0.7683332]]
[[[0.00166216]]

 [[0.00166216]]]


{'logpdf': array([0.5476091, 0.5152283], dtype=float32)}