In [1]:
import numpy as np
import jax
from numpy.random import multivariate_normal as mvn
from approx_post import ApproximateDistribution, JointDistribution, reverse_kl, forward_kl, fit_approximation

In [2]:
def create_data(model, true_theta, noise_cov, num_samples):
    mean = model(true_theta)
    samples = mvn(mean, noise_cov, num_samples)
    return samples.reshape(num_samples, -1)

In [3]:
# First, let's define a model:
ndim = 3
model = lambda theta: theta**2
model_grad = jax.vmap(jax.jacfwd(model), in_axes=1)

In [4]:
# Create artificial data:
true_theta = 2*np.random.rand(ndim)
noise_cov = 0.01*np.identity(ndim)
num_samples = 1
data = create_data(model, true_theta, noise_cov, num_samples)
print(f'True theta: \n {true_theta}')
print(f'True x = model(theta): \n {model(true_theta)}')
print(f'Observations x_obs = model(theta) + noise: \n {data}')

True theta: 
 [1.06421586 0.2386097  1.79916763]
True x = model(theta): 
 [1.13255539 0.05693459 3.23700418]
Observations x_obs = model(theta) + noise: 
 [[1.06916098 0.043666   3.12491514]]


In [5]:
# Create Gaussian approximate distribution:
approx = ApproximateDistribution.gaussian(ndim)

In [6]:
# Create Joint distribution from forward model:
prior_mean = np.zeros(ndim)
prior_cov = np.identity(ndim)
joint = JointDistribution.from_model(model, noise_cov, prior_mean, prior_cov, model_grad)

In [8]:
# Fit sddistribution to reverse KL divergence:
# loss = reverse_kl(approx, joint, use_reparameterisation=True)
loss = reverse_kl(approx, joint, use_reparameterisation=True)
approx = fit_approximation(loss, approx, data, verbose=True)

Iteration 1:
   Loss = 562.9141845703125, Params = {'mean': DeviceArray([ 0.1, -0.1,  0.1], dtype=float32), 'chol_diag': DeviceArray([0.9, 0.9, 1.1], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.1       ,  0.1       , -0.10000001], dtype=float32)}
Iteration 2:
   Loss = 483.5230712890625, Params = {'mean': DeviceArray([ 0.06742443, -0.08673996,  0.18843424], dtype=float32), 'chol_diag': DeviceArray([0.81255907, 0.80599606, 1.0266991 ], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.03485377,  0.05550434, -0.03187067], dtype=float32)}
Iteration 3:
   Loss = 387.9850769042969, Params = {'mean': DeviceArray([ 0.05548366, -0.05962436,  0.24967538], dtype=float32), 'chol_diag': DeviceArray([0.74240565, 0.7210684 , 0.9590349 ], dtype=float32), 'chol_lowerdiag': DeviceArray([-0.02766427,  0.01312061,  0.01706082], dtype=float32)}
Iteration 4:
   Loss = 372.2558898925781, Params = {'mean': DeviceArray([ 0.04202702, -0.02861032,  0.32055786], dtype=float32), 'chol_diag': DeviceArray([0.

Iteration 30:
   Loss = 9.688433647155762, Params = {'mean': DeviceArray([0.9643949 , 0.10734767, 1.7333187 ], dtype=float32), 'chol_diag': DeviceArray([0.15997909, 0.03481197, 0.01      ], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.14943275, -0.10555964, -0.02209058], dtype=float32)}
Iteration 31:
   Loss = 10.01645278930664, Params = {'mean': DeviceArray([0.9791778 , 0.11138786, 1.6939538 ], dtype=float32), 'chol_diag': DeviceArray([0.1373505 , 0.02934256, 0.01      ], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.16003641, -0.10033514, -0.01227864], dtype=float32)}
Iteration 32:
   Loss = 12.249408721923828, Params = {'mean': DeviceArray([0.99972904, 0.11581596, 1.6638018 ], dtype=float32), 'chol_diag': DeviceArray([0.11644606, 0.02435225, 0.01      ], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.16849193, -0.09396593, -0.00317117], dtype=float32)}
Iteration 33:
   Loss = 15.293722152709961, Params = {'mean': DeviceArray([1.0230503 , 0.12064835, 1.6442116 ], dtype=fl

Iteration 60:
   Loss = 12.15556812286377, Params = {'mean': DeviceArray([1.0203534, 0.2245911, 1.7543628], dtype=float32), 'chol_diag': DeviceArray([0.01, 0.01, 0.01], dtype=float32), 'chol_lowerdiag': DeviceArray([0.20712517, 0.01687631, 0.0632486 ], dtype=float32)}
Iteration 61:
   Loss = 11.927322387695312, Params = {'mean': DeviceArray([1.0221874 , 0.22249134, 1.7655401 ], dtype=float32), 'chol_diag': DeviceArray([0.01, 0.01, 0.01], dtype=float32), 'chol_lowerdiag': DeviceArray([0.20680787, 0.01429151, 0.06174934], dtype=float32)}
Iteration 62:
   Loss = 11.823270797729492, Params = {'mean': DeviceArray([1.0264691 , 0.22017613, 1.7760644 ], dtype=float32), 'chol_diag': DeviceArray([0.01, 0.01, 0.01], dtype=float32), 'chol_lowerdiag': DeviceArray([0.20648853, 0.01160072, 0.06020344], dtype=float32)}
Iteration 63:
   Loss = 11.862732887268066, Params = {'mean': DeviceArray([1.0321473 , 0.21777406, 1.7849267 ], dtype=float32), 'chol_diag': DeviceArray([0.01, 0.01, 0.01], dtype=float3

Iteration 90:
   Loss = 11.790973663330078, Params = {'mean': DeviceArray([1.0351316 , 0.20718177, 1.7678586 ], dtype=float32), 'chol_diag': DeviceArray([0.01      , 0.01000916, 0.01      ], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.21707904, -0.00267727,  0.03365628], dtype=float32)}
Iteration 91:
   Loss = 11.782376289367676, Params = {'mean': DeviceArray([1.0367763 , 0.20752257, 1.7654729 ], dtype=float32), 'chol_diag': DeviceArray([0.01      , 0.01001403, 0.01      ], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.21771617, -0.0019845 ,  0.03331807], dtype=float32)}
Iteration 92:
   Loss = 11.782732963562012, Params = {'mean': DeviceArray([1.0382055 , 0.20774084, 1.7637465 ], dtype=float32), 'chol_diag': DeviceArray([0.01      , 0.01001843, 0.01      ], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.21835072, -0.00136587,  0.03300126], dtype=float32)}
Iteration 93:
   Loss = 11.790544509887695, Params = {'mean': DeviceArray([1.03911   , 0.20785101, 1.7628256 ], dtype=

Iteration 119:
   Loss = 11.795251846313477, Params = {'mean': DeviceArray([1.0370226 , 0.19930142, 1.7708553 ], dtype=float32), 'chol_diag': DeviceArray([0.01      , 0.01000287, 0.01      ], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.23342788, -0.00216455,  0.02850325], dtype=float32)}
Iteration 120:
   Loss = 11.7980318069458, Params = {'mean': DeviceArray([1.0369521 , 0.19929916, 1.7708241 ], dtype=float32), 'chol_diag': DeviceArray([0.01     , 0.0100055, 0.01     ], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.2340333 , -0.00211943,  0.02843894], dtype=float32)}
Iteration 121:
   Loss = 11.797663688659668, Params = {'mean': DeviceArray([1.0367185 , 0.19929603, 1.7705693 ], dtype=float32), 'chol_diag': DeviceArray([0.01     , 0.0100085, 0.01     ], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.23463097, -0.00206234,  0.02837991], dtype=float32)}
Iteration 122:
   Loss = 11.794632911682129, Params = {'mean': DeviceArray([1.036493  , 0.19932824, 1.7701457 ], dtype=floa

Iteration 147:
   Loss = 11.782050132751465, Params = {'mean': DeviceArray([1.0364482 , 0.19479021, 1.7689997 ], dtype=float32), 'chol_diag': DeviceArray([0.01      , 0.01007277, 0.01      ], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.24576241, -0.00118831,  0.02741449], dtype=float32)}
Iteration 148:
   Loss = 11.78105640411377, Params = {'mean': DeviceArray([1.036428  , 0.19456525, 1.7691622 ], dtype=float32), 'chol_diag': DeviceArray([0.01      , 0.01007377, 0.01      ], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.24614199, -0.00122075,  0.02738905], dtype=float32)}
Iteration 149:
   Loss = 11.783158302307129, Params = {'mean': DeviceArray([1.0363544 , 0.19438215, 1.7693081 ], dtype=float32), 'chol_diag': DeviceArray([0.01      , 0.01007487, 0.01      ], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.24651593, -0.00124822,  0.02736459], dtype=float32)}
Iteration 150:
   Loss = 11.784263610839844, Params = {'mean': DeviceArray([1.0363338 , 0.19422995, 1.7694169 ], dty

Iteration 175:
   Loss = 11.77888298034668, Params = {'mean': DeviceArray([1.0363673 , 0.18896547, 1.7690581 ], dtype=float32), 'chol_diag': DeviceArray([0.01      , 0.01011848, 0.01      ], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.2544132 , -0.00103522,  0.02693442], dtype=float32)}
Iteration 176:
   Loss = 11.78007698059082, Params = {'mean': DeviceArray([1.0363362 , 0.18891951, 1.769074  ], dtype=float32), 'chol_diag': DeviceArray([0.01      , 0.01011936, 0.01      ], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.2545822 , -0.00104237,  0.02692212], dtype=float32)}
Iteration 177:
   Loss = 11.777966499328613, Params = {'mean': DeviceArray([1.0363015, 0.1889067, 1.769102 ], dtype=float32), 'chol_diag': DeviceArray([0.01      , 0.01012029, 0.01      ], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.2547743 , -0.00105113,  0.02690997], dtype=float32)}
Iteration 178:
   Loss = 11.77889347076416, Params = {'mean': DeviceArray([1.0363034 , 0.18887377, 1.7691363 ], dtype=fl

Iteration 203:
   Loss = 11.776687622070312, Params = {'mean': DeviceArray([1.036329  , 0.18603922, 1.7691362 ], dtype=float32), 'chol_diag': DeviceArray([0.01      , 0.01015853, 0.01      ], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.26213667, -0.00104704,  0.02667686], dtype=float32)}
Iteration 204:
   Loss = 11.776001930236816, Params = {'mean': DeviceArray([1.0363693 , 0.18573368, 1.7691351 ], dtype=float32), 'chol_diag': DeviceArray([0.01      , 0.01015978, 0.01      ], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.26232237, -0.00104121,  0.02667057], dtype=float32)}
Iteration 205:
   Loss = 11.776928901672363, Params = {'mean': DeviceArray([1.0364071 , 0.18546785, 1.769138  ], dtype=float32), 'chol_diag': DeviceArray([0.01      , 0.01016096, 0.01      ], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.26247206, -0.00103325,  0.02666447], dtype=float32)}
Iteration 206:
   Loss = 11.775200843811035, Params = {'mean': DeviceArray([1.0363759 , 0.18524161, 1.7691437 ], dt

Iteration 231:
   Loss = 11.776293754577637, Params = {'mean': DeviceArray([1.0357894 , 0.18149121, 1.7691511 ], dtype=float32), 'chol_diag': DeviceArray([0.01     , 0.0101945, 0.01     ], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.26852173, -0.00105276,  0.02653992], dtype=float32)}
Iteration 232:
   Loss = 11.771641731262207, Params = {'mean': DeviceArray([1.0359693 , 0.18135016, 1.7691507 ], dtype=float32), 'chol_diag': DeviceArray([0.01      , 0.01019521, 0.01      ], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.26859587, -0.00105635,  0.02653634], dtype=float32)}
Iteration 233:
   Loss = 11.76970100402832, Params = {'mean': DeviceArray([1.036207  , 0.18124758, 1.7691512 ], dtype=float32), 'chol_diag': DeviceArray([0.01      , 0.01019629, 0.01      ], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.26875263, -0.00105592,  0.02653302], dtype=float32)}
Iteration 234:
   Loss = 11.773599624633789, Params = {'mean': DeviceArray([1.0364412 , 0.18110089, 1.7691525 ], dtype=

Iteration 259:
   Loss = 11.76838493347168, Params = {'mean': DeviceArray([1.0364693 , 0.17945981, 1.7691541 ], dtype=float32), 'chol_diag': DeviceArray([0.01      , 0.01022517, 0.01      ], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.27266812, -0.00101925,  0.02647014], dtype=float32)}
Iteration 260:
   Loss = 11.767197608947754, Params = {'mean': DeviceArray([1.0365076 , 0.17975679, 1.769154  ], dtype=float32), 'chol_diag': DeviceArray([0.01      , 0.01022733, 0.01      ], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.27307728, -0.00100952,  0.02646878], dtype=float32)}
Iteration 261:
   Loss = 11.769912719726562, Params = {'mean': DeviceArray([1.0365342 , 0.17994875, 1.7691544 ], dtype=float32), 'chol_diag': DeviceArray([0.01      , 0.01022936, 0.01      ], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.27347487, -0.00100461,  0.0264673 ], dtype=float32)}
Iteration 262:
   Loss = 11.771353721618652, Params = {'mean': DeviceArray([1.0364776 , 0.18004471, 1.7691551 ], dty

Iteration 287:
   Loss = 11.770646095275879, Params = {'mean': DeviceArray([1.0366575 , 0.17525925, 1.769156  ], dtype=float32), 'chol_diag': DeviceArray([0.01      , 0.01025747, 0.01      ], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.2773003 , -0.00100115,  0.0264382 ], dtype=float32)}
Iteration 288:
   Loss = 11.768966674804688, Params = {'mean': DeviceArray([1.0365177 , 0.17520355, 1.7691557 ], dtype=float32), 'chol_diag': DeviceArray([0.01      , 0.01025884, 0.01      ], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.27740654, -0.00097924,  0.0264383 ], dtype=float32)}
Iteration 289:
   Loss = 11.768460273742676, Params = {'mean': DeviceArray([1.0363708, 0.1752233, 1.7691557], dtype=float32), 'chol_diag': DeviceArray([0.01      , 0.01026013, 0.01      ], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.27750602, -0.00096062,  0.02643831], dtype=float32)}
Iteration 290:
   Loss = 11.76981258392334, Params = {'mean': DeviceArray([1.0363373 , 0.17512256, 1.769156  ], dtype=

Iteration 315:
   Loss = 11.76547622680664, Params = {'mean': DeviceArray([1.0359983, 0.1722931, 1.7691555], dtype=float32), 'chol_diag': DeviceArray([0.01      , 0.01029014, 0.01      ], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.2820447 , -0.00101827,  0.02642261], dtype=float32)}
Iteration 316:
   Loss = 11.76479721069336, Params = {'mean': DeviceArray([1.0360656 , 0.17205012, 1.7691555 ], dtype=float32), 'chol_diag': DeviceArray([0.01      , 0.01029097, 0.01      ], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.28224364, -0.00104284,  0.0264214 ], dtype=float32)}
Iteration 317:
   Loss = 11.765684127807617, Params = {'mean': DeviceArray([1.036183  , 0.17176309, 1.7691555 ], dtype=float32), 'chol_diag': DeviceArray([0.01      , 0.01029179, 0.01      ], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.28240126, -0.00105939,  0.0264205 ], dtype=float32)}
Iteration 318:
   Loss = 11.763421058654785, Params = {'mean': DeviceArray([1.0365181 , 0.17142403, 1.7691559 ], dtype=f

Iteration 343:
   Loss = 11.760858535766602, Params = {'mean': DeviceArray([1.036333  , 0.17056313, 1.7691565 ], dtype=float32), 'chol_diag': DeviceArray([0.01      , 0.01032566, 0.01000002], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.28690046, -0.00092694,  0.02642066], dtype=float32)}
Iteration 344:
   Loss = 11.764240264892578, Params = {'mean': DeviceArray([1.0361011 , 0.17038195, 1.7691565 ], dtype=float32), 'chol_diag': DeviceArray([0.01      , 0.01032695, 0.01000002], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.28716317, -0.00093926,  0.02642016], dtype=float32)}
Iteration 345:
   Loss = 11.76146125793457, Params = {'mean': DeviceArray([1.0360363 , 0.17014845, 1.7691565 ], dtype=float32), 'chol_diag': DeviceArray([0.01      , 0.01032826, 0.01000002], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.28742433, -0.00095005,  0.02641971], dtype=float32)}
Iteration 346:
   Loss = 11.759718894958496, Params = {'mean': DeviceArray([1.0361418 , 0.17008592, 1.7691563 ], dty

Iteration 371:
   Loss = 11.757736206054688, Params = {'mean': DeviceArray([1.0364377 , 0.15923592, 1.7691534 ], dtype=float32), 'chol_diag': DeviceArray([0.01      , 0.01037696, 0.01000007], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.29770777, -0.00103659,  0.02641515], dtype=float32)}
Iteration 372:
   Loss = 11.755997657775879, Params = {'mean': DeviceArray([1.0364312, 0.1589793, 1.7691537], dtype=float32), 'chol_diag': DeviceArray([0.01      , 0.01038096, 0.01000008], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.29874197, -0.00103496,  0.02641519], dtype=float32)}
Iteration 373:
   Loss = 11.758469581604004, Params = {'mean': DeviceArray([1.0363948 , 0.15861921, 1.7691553 ], dtype=float32), 'chol_diag': DeviceArray([0.01      , 0.01038473, 0.01000008], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.29965994, -0.00102522,  0.02641556], dtype=float32)}
Iteration 374:
   Loss = 11.75888442993164, Params = {'mean': DeviceArray([1.0362295 , 0.15839295, 1.7691574 ], dtype=

Iteration 399:
   Loss = 11.746676445007324, Params = {'mean': DeviceArray([1.0360035 , 0.13468532, 1.7691562 ], dtype=float32), 'chol_diag': DeviceArray([0.01      , 0.01048369, 0.01000003], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.3247292 , -0.00086096,  0.02642071], dtype=float32)}
Iteration 400:
   Loss = 11.743330001831055, Params = {'mean': DeviceArray([1.0361301 , 0.13333777, 1.7691561 ], dtype=float32), 'chol_diag': DeviceArray([0.01      , 0.01048993, 0.01000002], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.3266201 , -0.00086961,  0.02642037], dtype=float32)}
Iteration 401:
   Loss = 11.749995231628418, Params = {'mean': DeviceArray([1.0363934 , 0.13150865, 1.7691561 ], dtype=float32), 'chol_diag': DeviceArray([0.01      , 0.01049634, 0.01000002], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.32840773, -0.00084753,  0.02642102], dtype=float32)}
Iteration 402:
   Loss = 11.743818283081055, Params = {'mean': DeviceArray([1.0367576 , 0.12971696, 1.7691563 ], dt

Iteration 427:
   Loss = 11.718459129333496, Params = {'mean': DeviceArray([1.0361221 , 0.01407048, 1.7691551 ], dtype=float32), 'chol_diag': DeviceArray([0.01     , 0.0107239, 0.01     ], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.4028481, -0.0006545,  0.0264251], dtype=float32)}
Iteration 428:
   Loss = 11.71627426147461, Params = {'mean': DeviceArray([1.0358698 , 0.00748092, 1.7691547 ], dtype=float32), 'chol_diag': DeviceArray([0.01      , 0.01073198, 0.01      ], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.40566355, -0.00065654,  0.02642492], dtype=float32)}
Iteration 429:
   Loss = 11.715444564819336, Params = {'mean': DeviceArray([1.0360003e+00, 1.0647201e-03, 1.7691548e+00], dtype=float32), 'chol_diag': DeviceArray([0.01      , 0.01073908, 0.01      ], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.40815675, -0.00066923,  0.02642453], dtype=float32)}
Iteration 430:
   Loss = 11.715876579284668, Params = {'mean': DeviceArray([ 1.0361444 , -0.00498496,  1.769155  

Iteration 455:
   Loss = 11.719815254211426, Params = {'mean': DeviceArray([1.035621  , 0.01069485, 1.7691563 ], dtype=float32), 'chol_diag': DeviceArray([0.01      , 0.01074843, 0.01000022], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.4014098 , -0.00088566,  0.02642098], dtype=float32)}
Iteration 456:
   Loss = 11.710875511169434, Params = {'mean': DeviceArray([1.0361671 , 0.01190132, 1.7691578 ], dtype=float32), 'chol_diag': DeviceArray([0.01      , 0.01075028, 0.01000022], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.4015989 , -0.00085048,  0.0264218 ], dtype=float32)}
Iteration 457:
   Loss = 11.719439506530762, Params = {'mean': DeviceArray([1.0369561 , 0.01261779, 1.7691584 ], dtype=float32), 'chol_diag': DeviceArray([0.01      , 0.01075219, 0.01000022], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.40164012, -0.00078267,  0.02642351], dtype=float32)}
Iteration 458:
   Loss = 11.71174430847168, Params = {'mean': DeviceArray([1.0371782 , 0.01317338, 1.7691579 ], dty

KeyboardInterrupt: 