In [8]:
import numpy as np
import jax
from numpy.random import multivariate_normal as mvn
from approx_post import ApproximateDistribution, JointDistribution, reverse_kl, forward_kl

In [9]:
def create_data(model, true_theta, noise_cov, num_samples):
    mean = model(true_theta)
    samples = mvn(mean, noise_cov, num_samples)
    return samples.reshape(num_samples, -1)

In [10]:
# First, let's define a model:
ndim = 1
model = lambda theta: theta**2
model_grad = jax.vmap(jax.jacfwd(model), in_axes=0)

In [11]:
# Create artificial data:
true_theta = 5*np.random.rand(ndim)
noise_cov = 0.01*np.identity(ndim)
num_samples = 1
data = create_data(model, true_theta, noise_cov, num_samples)
print(f'True theta: \n {true_theta}')
print(f'True x = model(theta): \n {model(true_theta)}')
print(f'Observations x_obs = model(theta) + noise: \n {data}')

True theta: 
 [4.44707336]
True x = model(theta): 
 [19.77646151]
Observations x_obs = model(theta) + noise: 
 [[19.68350256]]


In [12]:
# Create Gaussian approximate distribution:
approx = ApproximateDistribution.gaussian(ndim)

In [13]:
# Create Joint distribution from forward model:
prior_mean = np.zeros(ndim)
prior_cov = np.identity(ndim)
joint = JointDistribution.from_model(data, model, noise_cov, prior_mean, prior_cov, model_grad)

In [14]:
# Fit sddistribution to reverse KL divergence:
results_dict = forward_kl.fit(approx, joint, use_reparameterisation=False, verbose=True, num_samples=1000)

Now fitting approximate distribution by minimising forward KL divergence.
Iteration 1:
   Loss = 0.005104810930788517 
   Phi = {'chol_diag': DeviceArray([1.0999999], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean': DeviceArray([0.09999961], dtype=float32)}
Iteration 2:
   Loss = 0.00742258969694376 
   Phi = {'chol_diag': DeviceArray([1.1997366], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean': DeviceArray([0.08334897], dtype=float32)}
Iteration 3:
   Loss = 0.005104662384837866 
   Phi = {'chol_diag': DeviceArray([1.2972935], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean': DeviceArray([0.04391769], dtype=float32)}
Iteration 4:
   Loss = 0.0038864798843860626 
   Phi = {'chol_diag': DeviceArray([1.3905392], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean': DeviceArray([-0.00465753], dtype=float32)}
Iteration 5:
   Loss = 0.004982221405953169 
   Phi = {'chol_diag': DeviceArray([1.482307

Iteration 41:
   Loss = 0.0030147365760058165 
   Phi = {'chol_diag': DeviceArray([3.1866941], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean': DeviceArray([-0.5650211], dtype=float32)}
Iteration 42:
   Loss = 0.002933369716629386 
   Phi = {'chol_diag': DeviceArray([3.2047014], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean': DeviceArray([-0.569169], dtype=float32)}
Iteration 43:
   Loss = 0.0030706520192325115 
   Phi = {'chol_diag': DeviceArray([3.222211], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean': DeviceArray([-0.5715407], dtype=float32)}
Iteration 44:
   Loss = 0.0029279489535838366 
   Phi = {'chol_diag': DeviceArray([3.2390223], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean': DeviceArray([-0.5739606], dtype=float32)}
Iteration 45:
   Loss = 0.0028084663208574057 
   Phi = {'chol_diag': DeviceArray([3.2547948], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32)

Iteration 81:
   Loss = 0.0030622854828834534 
   Phi = {'chol_diag': DeviceArray([3.6279287], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean': DeviceArray([-0.6159949], dtype=float32)}
Iteration 82:
   Loss = 0.0029949708841741085 
   Phi = {'chol_diag': DeviceArray([3.6356745], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean': DeviceArray([-0.6133456], dtype=float32)}
Iteration 83:
   Loss = 0.003003180492669344 
   Phi = {'chol_diag': DeviceArray([3.64344], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean': DeviceArray([-0.6098309], dtype=float32)}
Iteration 84:
   Loss = 0.0031550577841699123 
   Phi = {'chol_diag': DeviceArray([3.6515114], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean': DeviceArray([-0.6032527], dtype=float32)}
Iteration 85:
   Loss = 0.0028549274429678917 
   Phi = {'chol_diag': DeviceArray([3.6592305], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32)

Iteration 120:
   Loss = 0.002884687390178442 
   Phi = {'chol_diag': DeviceArray([3.891278], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean': DeviceArray([-0.4185523], dtype=float32)}
Iteration 121:
   Loss = 0.0031103219371289015 
   Phi = {'chol_diag': DeviceArray([3.8965187], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean': DeviceArray([-0.41627303], dtype=float32)}
Iteration 122:
   Loss = 0.002934612799435854 
   Phi = {'chol_diag': DeviceArray([3.9017253], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean': DeviceArray([-0.41371295], dtype=float32)}
Iteration 123:
   Loss = 0.002938697347417474 
   Phi = {'chol_diag': DeviceArray([3.9069033], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean': DeviceArray([-0.41077125], dtype=float32)}
Iteration 124:
   Loss = 0.002770762424916029 
   Phi = {'chol_diag': DeviceArray([3.9117851], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=fl

Iteration 160:
   Loss = 0.0029946458525955677 
   Phi = {'chol_diag': DeviceArray([4.068454], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean': DeviceArray([-0.31283915], dtype=float32)}
Iteration 161:
   Loss = 0.0030861985869705677 
   Phi = {'chol_diag': DeviceArray([4.0721807], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean': DeviceArray([-0.3093672], dtype=float32)}
Iteration 162:
   Loss = 0.0028930725529789925 
   Phi = {'chol_diag': DeviceArray([4.075805], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean': DeviceArray([-0.3071758], dtype=float32)}
Iteration 163:
   Loss = 0.0029232697561383247 
   Phi = {'chol_diag': DeviceArray([4.0793443], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean': DeviceArray([-0.30589646], dtype=float32)}
Iteration 164:
   Loss = 0.002944731153547764 
   Phi = {'chol_diag': DeviceArray([4.082962], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=fl

Iteration 200:
   Loss = 0.0028495769947767258 
   Phi = {'chol_diag': DeviceArray([4.1777916], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean': DeviceArray([-0.46096605], dtype=float32)}
Iteration 201:
   Loss = 0.002989226719364524 
   Phi = {'chol_diag': DeviceArray([4.18073], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean': DeviceArray([-0.45455042], dtype=float32)}
Iteration 202:
   Loss = 0.002919553080573678 
   Phi = {'chol_diag': DeviceArray([4.183621], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean': DeviceArray([-0.44838694], dtype=float32)}
Iteration 203:
   Loss = 0.0029186939354985952 
   Phi = {'chol_diag': DeviceArray([4.1862597], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean': DeviceArray([-0.44524357], dtype=float32)}
Iteration 204:
   Loss = 0.002931455848738551 
   Phi = {'chol_diag': DeviceArray([4.1889105], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=fl

Iteration 240:
   Loss = 0.0029513887129724026 
   Phi = {'chol_diag': DeviceArray([4.281018], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean': DeviceArray([-0.11133578], dtype=float32)}
Iteration 241:
   Loss = 0.0028470088727772236 
   Phi = {'chol_diag': DeviceArray([4.283082], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean': DeviceArray([-0.09616515], dtype=float32)}
Iteration 242:
   Loss = 0.0028930155094712973 
   Phi = {'chol_diag': DeviceArray([4.28506], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean': DeviceArray([-0.08343019], dtype=float32)}
Iteration 243:
   Loss = 0.0029034430626779795 
   Phi = {'chol_diag': DeviceArray([4.286955], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean': DeviceArray([-0.07378751], dtype=float32)}
Iteration 244:
   Loss = 0.0030785859562456608 
   Phi = {'chol_diag': DeviceArray([4.2888713], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=f

Iteration 280:
   Loss = 0.002910463372245431 
   Phi = {'chol_diag': DeviceArray([4.3387704], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean': DeviceArray([-0.0794934], dtype=float32)}
Iteration 281:
   Loss = 0.0028837863355875015 
   Phi = {'chol_diag': DeviceArray([4.339862], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean': DeviceArray([-0.08224953], dtype=float32)}
Iteration 282:
   Loss = 0.002997536910697818 
   Phi = {'chol_diag': DeviceArray([4.341008], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean': DeviceArray([-0.08042337], dtype=float32)}
Iteration 283:
   Loss = 0.002850979333743453 
   Phi = {'chol_diag': DeviceArray([4.342079], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean': DeviceArray([-0.08298766], dtype=float32)}
Iteration 284:
   Loss = 0.0029037557542324066 
   Phi = {'chol_diag': DeviceArray([4.3430924], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=flo

Iteration 320:
   Loss = 0.0028968967963010073 
   Phi = {'chol_diag': DeviceArray([4.3770623], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean': DeviceArray([0.11631064], dtype=float32)}
Iteration 321:
   Loss = 0.0029184932354837656 
   Phi = {'chol_diag': DeviceArray([4.377802], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean': DeviceArray([0.11881715], dtype=float32)}
Iteration 322:
   Loss = 0.0029086750000715256 
   Phi = {'chol_diag': DeviceArray([4.3785157], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean': DeviceArray([0.12177674], dtype=float32)}
Iteration 323:
   Loss = 0.0028224464040249586 
   Phi = {'chol_diag': DeviceArray([4.3793244], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean': DeviceArray([0.11875077], dtype=float32)}
Iteration 324:
   Loss = 0.002854488091543317 
   Phi = {'chol_diag': DeviceArray([4.3800645], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=fl

Iteration 360:
   Loss = 0.0029283633921295404 
   Phi = {'chol_diag': DeviceArray([4.403776], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean': DeviceArray([-0.04936878], dtype=float32)}
Iteration 361:
   Loss = 0.0030475398525595665 
   Phi = {'chol_diag': DeviceArray([4.4042673], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean': DeviceArray([-0.05169497], dtype=float32)}
Iteration 362:
   Loss = 0.0029172254726290703 
   Phi = {'chol_diag': DeviceArray([4.4047422], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean': DeviceArray([-0.05405873], dtype=float32)}
Iteration 363:
   Loss = 0.0029213500674813986 
   Phi = {'chol_diag': DeviceArray([4.405195], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean': DeviceArray([-0.05722437], dtype=float32)}
Iteration 364:
   Loss = 0.0029122810810804367 
   Phi = {'chol_diag': DeviceArray([4.405637], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype

KeyboardInterrupt: 