In [1]:
import numpy as np
import jax
from numpy.random import multivariate_normal as mvn
from approx_post.distributions import approx, joint
from approx_post import losses, optimisers

In [2]:
def create_data(model, true_theta, noise_cov, num_samples):
    mean = model(true_theta)
    samples = mvn(mean, noise_cov, num_samples)
    return samples.reshape(num_samples, -1)

In [3]:
# First, let's define a model:
ndim = 1
model = lambda theta: theta**2
model_grad = jax.vmap(jax.vmap(jax.jacfwd(model), in_axes=0), in_axes=0)

In [4]:
# Create artificial data:
true_theta = 2*np.random.rand(ndim)
noise_cov = 0.01*np.identity(ndim)
num_samples = 1
data = create_data(model, true_theta, noise_cov, num_samples)
print(f'True theta: \n {true_theta}')
print(f'True x = model(theta): \n {model(true_theta)}')
print(f'Observations x_obs = model(theta) + noise: \n {data}')

True theta: 
 [0.70060952]
True x = model(theta): 
 [0.4908537]
Observations x_obs = model(theta) + noise: 
 [[0.47258579]]


In [5]:
# Create Gaussian approximate distribution:
approx = approx.Gaussian(ndim)



In [6]:
# Create Joint distribution from forward model:
prior_mean = np.zeros(ndim)
prior_cov = 0.1*np.identity(ndim)
joint = joint.ModelPlusGaussian(model, noise_cov, prior_mean, prior_cov, model_grad)

In [7]:
# Fit sddistribution to reverse KL divergence:
loss = losses.ReverseKL(joint, use_reparameterisation=False)
optimiser = optimisers.Adam()

In [8]:
prngkey = jax.random.PRNGKey(20)
optimiser.fit(approx, loss, data, prngkey)

AttributeError: 'Jaxtainer' object has no attribute 'list_elements'