In [1]:
import numpy as np
import jax
from numpy.random import multivariate_normal as mvn
from approx_post import ApproximateDistribution, JointDistribution, reverse_kl, forward_kl, fit_approximation

In [2]:
def create_data(model, true_theta, noise_cov, num_samples):
    mean = model(true_theta)
    samples = mvn(mean, noise_cov, num_samples)
    return samples.reshape(num_samples, -1)

In [3]:
# First, let's define a model:
ndim = 1
model = lambda theta: theta**2
model_grad = jax.vmap(jax.vmap(jax.jacfwd(model), in_axes=0), in_axes=0)

In [4]:
# Create artificial data:
true_theta = 2*np.random.rand(ndim)
noise_cov = 0.01*np.identity(ndim)
num_samples = 1
data = create_data(model, true_theta, noise_cov, num_samples)
print(f'True theta: \n {true_theta}')
print(f'True x = model(theta): \n {model(true_theta)}')
print(f'Observations x_obs = model(theta) + noise: \n {data}')

True theta: 
 [1.54248085]
True x = model(theta): 
 [2.37924719]
Observations x_obs = model(theta) + noise: 
 [[2.45021876]]


In [5]:
# Create Gaussian approximate distribution:
approx = ApproximateDistribution.gaussian(ndim)

In [6]:
# Create Joint distribution from forward model:
prior_mean = np.zeros(ndim)
prior_cov = 0.1*np.identity(ndim)
joint = JointDistribution.from_model(model, noise_cov, prior_mean, prior_cov, model_grad)

In [None]:
# Fit sddistribution to reverse KL divergence:
# loss = reverse_kl(approx, joint, use_reparameterisation=True)
loss = forward_kl(approx, joint, use_reparameterisation=True)
approx = fit_approximation(loss, approx, data, verbose=True)



Iteration 1:
   Loss = 0.0020802649669349194, Params = {'mean': DeviceArray([5.2154064e-08], dtype=float32), 'chol_diag': DeviceArray([0.9], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32)}
Iteration 2:
   Loss = 0.0022976312320679426, Params = {'mean': DeviceArray([-0.09870977], dtype=float32), 'chol_diag': DeviceArray([0.7999544], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32)}
Iteration 3:
   Loss = 0.002658806974068284, Params = {'mean': DeviceArray([-0.19744402], dtype=float32), 'chol_diag': DeviceArray([0.6997573], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32)}
Iteration 4:
   Loss = 0.003048220183700323, Params = {'mean': DeviceArray([-0.29702097], dtype=float32), 'chol_diag': DeviceArray([0.6022123], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32)}
Iteration 5:
   Loss = 0.00392749160528183, Params = {'mean': DeviceArray([-0.3764578], dtype=float32), 'chol_diag': DeviceArray([0.5127401], dtype=float32), 'cho

Iteration 41:
   Loss = -9.976640285458416e-05, Params = {'mean': DeviceArray([-0.37289762], dtype=float32), 'chol_diag': DeviceArray([0.01], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32)}
Iteration 42:
   Loss = -6.254386971704662e-05, Params = {'mean': DeviceArray([-0.38911486], dtype=float32), 'chol_diag': DeviceArray([0.01], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32)}
Iteration 43:
   Loss = -3.0836716177873313e-05, Params = {'mean': DeviceArray([-0.40397888], dtype=float32), 'chol_diag': DeviceArray([0.01], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32)}
Iteration 44:
   Loss = -3.8116575069579994e-06, Params = {'mean': DeviceArray([-0.4174695], dtype=float32), 'chol_diag': DeviceArray([0.01], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32)}
Iteration 45:
   Loss = 1.9104376406176016e-05, Params = {'mean': DeviceArray([-0.42958018], dtype=float32), 'chol_diag': DeviceArray([0.01], dtype=float32), 'chol_low

Iteration 81:
   Loss = -2.7854070140165277e-05, Params = {'mean': DeviceArray([-0.38896775], dtype=float32), 'chol_diag': DeviceArray([0.01], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32)}
Iteration 82:
   Loss = -3.1095325539354235e-05, Params = {'mean': DeviceArray([-0.38755134], dtype=float32), 'chol_diag': DeviceArray([0.01], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32)}
Iteration 83:
   Loss = -3.376774475327693e-05, Params = {'mean': DeviceArray([-0.38644144], dtype=float32), 'chol_diag': DeviceArray([0.01], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32)}
Iteration 84:
   Loss = -3.588897016015835e-05, Params = {'mean': DeviceArray([-0.385625], dtype=float32), 'chol_diag': DeviceArray([0.01], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32)}
Iteration 85:
   Loss = -3.744429341168143e-05, Params = {'mean': DeviceArray([-0.3850864], dtype=float32), 'chol_diag': DeviceArray([0.01], dtype=float32), 'chol_lower