In [1]:
import numpy as np
import jax
from numpy.random import multivariate_normal as mvn
from approx_post import ApproximateDistribution, JointDistribution, reverse_kl, forward_kl, fit_approximation

In [2]:
def create_data(model, true_theta, noise_cov, num_samples):
    mean = model(true_theta)
    samples = mvn(mean, noise_cov, num_samples)
    return samples.reshape(num_samples, -1)

In [3]:
# First, let's define a model:
ndim = 1
model = lambda theta: theta**2
model_grad = jax.vmap(jax.jacfwd(model), in_axes=0)

In [4]:
# Create artificial data:
true_theta = 5*np.random.rand(ndim)
noise_cov = 0.01*np.identity(ndim)
num_samples = 1
data = create_data(model, true_theta, noise_cov, num_samples)
print(f'True theta: \n {true_theta}')
print(f'True x = model(theta): \n {model(true_theta)}')
print(f'Observations x_obs = model(theta) + noise: \n {data}')

True theta: 
 [2.2304303]
True x = model(theta): 
 [4.97481931]
Observations x_obs = model(theta) + noise: 
 [[4.86591333]]


In [5]:
# Create Gaussian approximate distribution:
approx = ApproximateDistribution.gaussian(ndim)

In [6]:
# Create Joint distribution from forward model:
prior_mean = np.zeros(ndim)
prior_cov = np.identity(ndim)
joint = JointDistribution.from_model(data, model, noise_cov, prior_mean, prior_cov, model_grad)

In [7]:
# Fit sddistribution to reverse KL divergence:
loss = reverse_kl(approx, joint, use_reparameterisation=False)
approx = fit_approximation(loss, approx, data.reshape(1,1,1), verbose=True)



(1, 100, 1)


ValueError: einstein sum subscripts string contains too many subscripts for operand 0