In [9]:
import numpy as np
import jax
from numpy.random import multivariate_normal as mvn
from approx_post import ApproximateDistribution, JointDistribution, reverse_kl, forward_kl, fit_approximation

In [2]:
def create_data(model, true_theta, noise_cov, num_samples):
    mean = model(true_theta)
    samples = mvn(mean, noise_cov, num_samples)
    return samples.reshape(num_samples, -1)

In [3]:
# First, let's define a model:
ndim = 1
model = lambda theta: theta**2
model_grad = jax.vmap(jax.jacfwd(model), in_axes=0)

In [4]:
# Create artificial data:
true_theta = 5*np.random.rand(ndim)
noise_cov = 0.01*np.identity(ndim)
num_samples = 1
data = create_data(model, true_theta, noise_cov, num_samples)
print(f'True theta: \n {true_theta}')
print(f'True x = model(theta): \n {model(true_theta)}')
print(f'Observations x_obs = model(theta) + noise: \n {data}')

True theta: 
 [1.53475002]
True x = model(theta): 
 [2.35545763]
Observations x_obs = model(theta) + noise: 
 [[2.24852804]]


In [5]:
# Create Gaussian approximate distribution:
approx = ApproximateDistribution.gaussian(ndim)

In [6]:
# Create Joint distribution from forward model:
prior_mean = np.zeros(ndim)
prior_cov = np.identity(ndim)
joint = JointDistribution.from_model(data, model, noise_cov, prior_mean, prior_cov, model_grad)

In [10]:
# Fit sddistribution to reverse KL divergence:
loss = reverse_kl(approx, joint, use_reparameterisation=False)
approx = fit_approximation(loss, approx, verbose=True)



Iteration 1:
   Loss = 152.84918212890625 
   Phi = {'chol_diag': DeviceArray([1.1], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean': DeviceArray([0.1], dtype=float32)}
Iteration 2:
   Loss = 181.67759704589844 
   Phi = {'chol_diag': DeviceArray([1.0443203], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean': DeviceArray([0.06800561], dtype=float32)}
Iteration 3:
   Loss = 171.41094970703125 
   Phi = {'chol_diag': DeviceArray([0.9876883], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean': DeviceArray([0.02481545], dtype=float32)}
Iteration 4:
   Loss = 169.91943359375 
   Phi = {'chol_diag': DeviceArray([0.938082], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean': DeviceArray([0.02957264], dtype=float32)}
Iteration 5:
   Loss = 153.89671325683594 
   Phi = {'chol_diag': DeviceArray([0.9116411], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean': DeviceArray([0.03209034]

Iteration 41:
   Loss = 160.78546142578125 
   Phi = {'chol_diag': DeviceArray([0.7313172], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean': DeviceArray([0.75880194], dtype=float32)}
Iteration 42:
   Loss = 182.70245361328125 
   Phi = {'chol_diag': DeviceArray([0.70688576], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean': DeviceArray([0.81650996], dtype=float32)}
Iteration 43:
   Loss = 135.3577880859375 
   Phi = {'chol_diag': DeviceArray([0.6804698], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean': DeviceArray([0.8872728], dtype=float32)}
Iteration 44:
   Loss = 156.64639282226562 
   Phi = {'chol_diag': DeviceArray([0.6349402], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean': DeviceArray([0.9524356], dtype=float32)}
Iteration 45:
   Loss = 133.66165161132812 
   Phi = {'chol_diag': DeviceArray([0.57472694], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean': Dev

Iteration 83:
   Loss = 3.258204698562622 
   Phi = {'chol_diag': DeviceArray([0.02842671], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean': DeviceArray([1.5224512], dtype=float32)}
Iteration 84:
   Loss = 3.4578745365142822 
   Phi = {'chol_diag': DeviceArray([0.02563158], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean': DeviceArray([1.5274135], dtype=float32)}
Iteration 85:
   Loss = 3.6139333248138428 
   Phi = {'chol_diag': DeviceArray([0.02516962], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean': DeviceArray([1.526689], dtype=float32)}
Iteration 86:
   Loss = 3.6110434532165527 
   Phi = {'chol_diag': DeviceArray([0.02501637], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean': DeviceArray([1.5218621], dtype=float32)}
Iteration 87:
   Loss = 3.4844882488250732 
   Phi = {'chol_diag': DeviceArray([0.02720412], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean': Dev

Iteration 125:
   Loss = 3.1421620845794678 
   Phi = {'chol_diag': DeviceArray([0.03279109], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean': DeviceArray([1.4971341], dtype=float32)}
Iteration 126:
   Loss = 3.13657808303833 
   Phi = {'chol_diag': DeviceArray([0.03272359], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean': DeviceArray([1.4991125], dtype=float32)}
Iteration 127:
   Loss = 3.140172004699707 
   Phi = {'chol_diag': DeviceArray([0.03281008], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean': DeviceArray([1.5004072], dtype=float32)}
Iteration 128:
   Loss = 3.1432647705078125 
   Phi = {'chol_diag': DeviceArray([0.03311782], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean': DeviceArray([1.5008727], dtype=float32)}
Iteration 129:
   Loss = 3.1516613960266113 
   Phi = {'chol_diag': DeviceArray([0.03313994], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean':

Iteration 167:
   Loss = 3.1535580158233643 
   Phi = {'chol_diag': DeviceArray([0.0328534], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean': DeviceArray([1.4949404], dtype=float32)}
Iteration 168:
   Loss = 3.1425163745880127 
   Phi = {'chol_diag': DeviceArray([0.03205229], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean': DeviceArray([1.4948267], dtype=float32)}
Iteration 169:
   Loss = 3.1398885250091553 
   Phi = {'chol_diag': DeviceArray([0.03209363], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean': DeviceArray([1.4951417], dtype=float32)}
Iteration 170:
   Loss = 3.1433093547821045 
   Phi = {'chol_diag': DeviceArray([0.03256581], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean': DeviceArray([1.4958428], dtype=float32)}
Iteration 171:
   Loss = 3.141684055328369 
   Phi = {'chol_diag': DeviceArray([0.0331964], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean':

KeyboardInterrupt: 