In [1]:
import numpy as np
import jax
from numpy.random import multivariate_normal as mvn
from approx_post import ApproximateDistribution, JointDistribution, reverse_kl, forward_kl, fit_approximation

In [2]:
def create_data(model, true_theta, noise_cov, num_samples):
    mean = model(true_theta)
    samples = mvn(mean, noise_cov, num_samples)
    return samples.reshape(num_samples, -1)

In [3]:
# First, let's define a model:
ndim = 1
model = lambda theta: theta**2
model_grad = jax.vmap(jax.vmap(jax.jacfwd(model), in_axes=0), in_axes=0)

In [4]:
# Create artificial data:
true_theta = 2*np.random.rand(ndim)
noise_cov = 0.01*np.identity(ndim)
num_samples = 1
data = create_data(model, true_theta, noise_cov, num_samples)
print(f'True theta: \n {true_theta}')
print(f'True x = model(theta): \n {model(true_theta)}')
print(f'Observations x_obs = model(theta) + noise: \n {data}')

True theta: 
 [0.51067217]
True x = model(theta): 
 [0.26078606]
Observations x_obs = model(theta) + noise: 
 [[0.05890958]]


In [5]:
# Create Gaussian approximate distribution:
approx = ApproximateDistribution.gaussian(ndim)

In [6]:
# Create Joint distribution from forward model:
prior_mean = np.zeros(ndim)
prior_cov = 0.1*np.identity(ndim)
joint = JointDistribution.from_model(model, noise_cov, prior_mean, prior_cov, model_grad)

In [7]:
# Fit sddistribution to reverse KL divergence:
loss = forward_kl(approx, joint, use_reparameterisation=True)
approx.fit(loss, data, verbose=True, num_samples=1000)



Iteration 1:
   Loss = 0.009492909535765648, Params = {'mean': DeviceArray([1.6391277e-07], dtype=float32), 'chol_diag': DeviceArray([1.0999999], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32)}
Iteration 2:
   Loss = 0.010376732796430588, Params = {'mean': DeviceArray([0.02442572], dtype=float32), 'chol_diag': DeviceArray([1.1991854], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32)}
Iteration 3:
   Loss = 0.011216770857572556, Params = {'mean': DeviceArray([0.07889289], dtype=float32), 'chol_diag': DeviceArray([1.2978662], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32)}
Iteration 4:
   Loss = 0.012006820179522038, Params = {'mean': DeviceArray([0.12628055], dtype=float32), 'chol_diag': DeviceArray([1.3962562], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32)}
Iteration 5:
   Loss = 0.012747146189212799, Params = {'mean': DeviceArray([0.13645577], dtype=float32), 'chol_diag': DeviceArray([1.4944677], dtype=float32), 'c

Iteration 41:
   Loss = 0.023422246798872948, Params = {'mean': DeviceArray([0.5332675], dtype=float32), 'chol_diag': DeviceArray([4.1635885], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32)}
Iteration 42:
   Loss = 0.02359391376376152, Params = {'mean': DeviceArray([0.5420374], dtype=float32), 'chol_diag': DeviceArray([4.237074], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32)}
Iteration 43:
   Loss = 0.023767564445734024, Params = {'mean': DeviceArray([0.5434599], dtype=float32), 'chol_diag': DeviceArray([4.3135924], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32)}
Iteration 44:
   Loss = 0.023942673578858376, Params = {'mean': DeviceArray([0.5495177], dtype=float32), 'chol_diag': DeviceArray([4.3913007], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32)}
Iteration 45:
   Loss = 0.02411872334778309, Params = {'mean': DeviceArray([0.56842154], dtype=float32), 'chol_diag': DeviceArray([4.468933], dtype=float32), 'chol_lo

Iteration 81:
   Loss = 0.0291247870773077, Params = {'mean': DeviceArray([1.1381892], dtype=float32), 'chol_diag': DeviceArray([7.3051987], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32)}
Iteration 82:
   Loss = 0.029231976717710495, Params = {'mean': DeviceArray([1.1555083], dtype=float32), 'chol_diag': DeviceArray([7.383185], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32)}
Iteration 83:
   Loss = 0.029338758438825607, Params = {'mean': DeviceArray([1.1703265], dtype=float32), 'chol_diag': DeviceArray([7.461845], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32)}
Iteration 84:
   Loss = 0.029444755986332893, Params = {'mean': DeviceArray([1.1817555], dtype=float32), 'chol_diag': DeviceArray([7.5411634], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32)}
Iteration 85:
   Loss = 0.0295497365295887, Params = {'mean': DeviceArray([1.1936375], dtype=float32), 'chol_diag': DeviceArray([7.620557], dtype=float32), 'chol_lowerd

Iteration 121:
   Loss = 0.032323747873306274, Params = {'mean': DeviceArray([1.5546367], dtype=float32), 'chol_diag': DeviceArray([10.020254], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32)}
Iteration 122:
   Loss = 0.032374437898397446, Params = {'mean': DeviceArray([1.5607214], dtype=float32), 'chol_diag': DeviceArray([10.071116], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32)}
Iteration 123:
   Loss = 0.03242434933781624, Params = {'mean': DeviceArray([1.5661066], dtype=float32), 'chol_diag': DeviceArray([10.1216135], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32)}
Iteration 124:
   Loss = 0.03247346729040146, Params = {'mean': DeviceArray([1.5712572], dtype=float32), 'chol_diag': DeviceArray([10.171789], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32)}
Iteration 125:
   Loss = 0.03252197057008743, Params = {'mean': DeviceArray([1.5767492], dtype=float32), 'chol_diag': DeviceArray([10.221685], dtype=float32), 'c

Iteration 161:
   Loss = 0.03443359211087227, Params = {'mean': DeviceArray([1.8681518], dtype=float32), 'chol_diag': DeviceArray([12.414402], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32)}
Iteration 162:
   Loss = 0.034498803317546844, Params = {'mean': DeviceArray([1.8175757], dtype=float32), 'chol_diag': DeviceArray([12.495785], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32)}
Iteration 163:
   Loss = 0.03455036133527756, Params = {'mean': DeviceArray([1.8686068], dtype=float32), 'chol_diag': DeviceArray([12.559286], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32)}
Iteration 164:
   Loss = 0.034609951078891754, Params = {'mean': DeviceArray([1.8929464], dtype=float32), 'chol_diag': DeviceArray([12.6278715], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32)}
Iteration 165:
   Loss = 0.03466767445206642, Params = {'mean': DeviceArray([1.8628178], dtype=float32), 'chol_diag': DeviceArray([12.707159], dtype=float32), 'c

KeyboardInterrupt: 

In [None]:
loss = forward_kl(approx, joint, use_reparameterisation=True)
params = approx.params
params['mean'] = np.array([0.2])
params['chol_diag'] = np.array([0.62])
loss(approx.params, data, num_samples=100000)

In [None]:
loss = forward_kl(approx, joint, use_reparameterisation=False)
params = approx.params
params['mean'] = np.array([0.2])
params['chol_diag'] = np.array([0.62])
loss(approx.params, data, num_samples=100000)