In [1]:
import numpy as np
import jax
from numpy.random import multivariate_normal as mvn
from approx_post import ApproximateDistribution, JointDistribution, reverse_kl, forward_kl

In [2]:
def create_data(model, true_theta, noise_cov, num_samples):
    mean = model(true_theta)
    samples = mvn(mean, noise_cov, num_samples)
    return samples.reshape(num_samples, -1)

In [3]:
# First, let's define a model:
ndim = 3
model = lambda theta: theta**2
model_grad = jax.vmap(jax.jacfwd(model), in_axes=0)

In [4]:
# Create artificial data:
true_theta = 5*np.random.rand(ndim)
noise_cov = 0.01*np.identity(ndim)
num_samples = 3
data = create_data(model, true_theta, noise_cov, num_samples)
print(f'True theta: \n {true_theta}')
print(f'True x = model(theta): \n {model(true_theta)}')
print(f'Observations x_obs = model(theta) + noise: \n {data}')

True theta: 
 [2.40320681 1.52681422 1.96291679]
True x = model(theta): 
 [5.77540298 2.33116165 3.85304233]
Observations x_obs = model(theta) + noise: 
 [[5.87051362 2.44448495 3.94581996]
 [5.7086499  2.21821999 3.93203575]
 [5.71342434 2.26671093 3.76674922]]


In [5]:
# Create Gaussian approximate distribution:
approx = ApproximateDistribution.gaussian(ndim)

In [6]:
# Create Joint distribution from forward model:
prior_mean = np.zeros(ndim)
prior_cov = np.identity(ndim)
joint = JointDistribution.from_model(data, model, noise_cov, prior_mean, prior_cov, model_grad)

In [9]:
# Fit sddistribution to reverse KL divergence:
results_dict = reverse_kl.fit(approx, joint, use_reparameterisation=True, verbose=True, num_samples=1000)

Now fitting approximate distribution by minimising reverse KL divergence.
Iteration 1:
   Loss = 13488595.0 
   Phi = {'chol_diag': DeviceArray([9.9, 9.9, 9.9], dtype=float32), 'chol_lowerdiag': DeviceArray([-0.1,  0.1,  0.1], dtype=float32), 'mean': DeviceArray([-0.1       ,  0.1       , -0.10000001], dtype=float32)}
Iteration 2:
   Loss = 13066723.0 
   Phi = {'chol_diag': DeviceArray([9.80028 , 9.799919, 9.800565], dtype=float32), 'chol_lowerdiag': DeviceArray([-0.15455326,  0.0829065 ,  0.07807915], dtype=float32), 'mean': DeviceArray([-0.05385482,  0.02923425, -0.16014206], dtype=float32)}
Iteration 3:
   Loss = 12027097.0 
   Phi = {'chol_diag': DeviceArray([9.70159 , 9.700181, 9.700792], dtype=float32), 'chol_lowerdiag': DeviceArray([-0.19713336,  0.04039015,  0.01198923], dtype=float32), 'mean': DeviceArray([-0.06028127, -0.0078432 , -0.2182903 ], dtype=float32)}
Iteration 4:
   Loss = 12747807.0 
   Phi = {'chol_diag': DeviceArray([9.603735, 9.599959, 9.600634], dtype=float32)

Iteration 30:
   Loss = 4228452.0 
   Phi = {'chol_diag': DeviceArray([7.3383713, 7.2700515, 7.2840657], dtype=float32), 'chol_lowerdiag': DeviceArray([-0.02403718,  0.11697512,  0.06182062], dtype=float32), 'mean': DeviceArray([ 0.04106906, -0.257508  ,  0.19831386], dtype=float32)}
Iteration 31:
   Loss = 3697585.0 
   Phi = {'chol_diag': DeviceArray([7.2629433, 7.195138 , 7.2107186], dtype=float32), 'chol_lowerdiag': DeviceArray([-0.04546546,  0.0982356 ,  0.07411161], dtype=float32), 'mean': DeviceArray([ 0.03557166, -0.23925522,  0.1881381 ], dtype=float32)}
Iteration 32:
   Loss = 3465987.0 
   Phi = {'chol_diag': DeviceArray([7.1880093, 7.121482 , 7.1394396], dtype=float32), 'chol_lowerdiag': DeviceArray([-0.06139163,  0.08367477,  0.08316988], dtype=float32), 'mean': DeviceArray([ 0.03242374, -0.21876283,  0.16914447], dtype=float32)}
Iteration 33:
   Loss = 3348369.75 
   Phi = {'chol_diag': DeviceArray([7.11477 , 7.049393, 7.068563], dtype=float32), 'chol_lowerdiag': DeviceAr

Iteration 59:
   Loss = 1251688.5 
   Phi = {'chol_diag': DeviceArray([5.636788 , 5.5349045, 5.6084476], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.03291577,  0.0588314 , -0.04320851], dtype=float32), 'mean': DeviceArray([-0.06637874, -0.1028661 , -0.15204325], dtype=float32)}
Iteration 60:
   Loss = 1208672.5 
   Phi = {'chol_diag': DeviceArray([5.593718 , 5.4899325, 5.5655594], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.04620408,  0.06664146, -0.03932977], dtype=float32), 'mean': DeviceArray([-0.06467816, -0.11255858, -0.14585532], dtype=float32)}
Iteration 61:
   Loss = 1313731.5 
   Phi = {'chol_diag': DeviceArray([5.550734 , 5.445563 , 5.5234346], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.0600598 ,  0.07165306, -0.03980357], dtype=float32), 'mean': DeviceArray([-0.06319935, -0.11937749, -0.13665314], dtype=float32)}
Iteration 62:
   Loss = 1069796.125 
   Phi = {'chol_diag': DeviceArray([5.5092154, 5.402146 , 5.481989 ], dtype=float32), 'chol_lowerdiag': Devi

Iteration 88:
   Loss = 507936.09375 
   Phi = {'chol_diag': DeviceArray([4.674088 , 4.5272017, 4.617197 ], dtype=float32), 'chol_lowerdiag': DeviceArray([-0.09562985, -0.06435157,  0.03461273], dtype=float32), 'mean': DeviceArray([-0.00292352, -0.045761  ,  0.12409248], dtype=float32)}
Iteration 89:
   Loss = 574026.9375 
   Phi = {'chol_diag': DeviceArray([4.6486397, 4.500638 , 4.590156 ], dtype=float32), 'chol_lowerdiag': DeviceArray([-0.09358332, -0.07027133,  0.03022184], dtype=float32), 'mean': DeviceArray([ 0.00530808, -0.04680131,  0.11736035], dtype=float32)}
Iteration 90:
   Loss = 514214.53125 
   Phi = {'chol_diag': DeviceArray([4.6233377, 4.4746637, 4.5636215], dtype=float32), 'chol_lowerdiag': DeviceArray([-0.09053762, -0.07578923,  0.02562445], dtype=float32), 'mean': DeviceArray([ 0.01105113, -0.04760285,  0.11113584], dtype=float32)}
Iteration 91:
   Loss = 448939.78125 
   Phi = {'chol_diag': DeviceArray([4.599014 , 4.4489436, 4.5376472], dtype=float32), 'chol_lowerdi

Iteration 117:
   Loss = 297146.6875 
   Phi = {'chol_diag': DeviceArray([4.047919 , 3.8633509, 3.9728856], dtype=float32), 'chol_lowerdiag': DeviceArray([-0.06559094,  0.07372686, -0.0229885 ], dtype=float32), 'mean': DeviceArray([ 0.02708839, -0.02530622, -0.05201748], dtype=float32)}
Iteration 118:
   Loss = 288847.46875 
   Phi = {'chol_diag': DeviceArray([4.029602 , 3.845096 , 3.9549851], dtype=float32), 'chol_lowerdiag': DeviceArray([-0.06469465,  0.07530625, -0.02163678], dtype=float32), 'mean': DeviceArray([ 0.03003438, -0.02459053, -0.0582595 ], dtype=float32)}
Iteration 119:
   Loss = 276557.6875 
   Phi = {'chol_diag': DeviceArray([4.011231 , 3.8271897, 3.937452 ], dtype=float32), 'chol_lowerdiag': DeviceArray([-0.06272389,  0.07304724, -0.01892442], dtype=float32), 'mean': DeviceArray([ 0.03063217, -0.023553  , -0.06644365], dtype=float32)}
Iteration 120:
   Loss = 246944.453125 
   Phi = {'chol_diag': DeviceArray([3.993297 , 3.8094535, 3.9203324], dtype=float32), 'chol_low

Iteration 146:
   Loss = 162280.765625 
   Phi = {'chol_diag': DeviceArray([3.597017 , 3.4128857, 3.5208402], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.01943142, -0.03182882, -0.00097078], dtype=float32), 'mean': DeviceArray([-0.02582662,  0.00679753, -0.06101624], dtype=float32)}
Iteration 147:
   Loss = 184582.390625 
   Phi = {'chol_diag': DeviceArray([3.5844061, 3.3995907, 3.5070064], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.01737975, -0.03439401,  0.00250498], dtype=float32), 'mean': DeviceArray([-0.02867071,  0.00254663, -0.05446161], dtype=float32)}
Iteration 148:
   Loss = 153797.046875 
   Phi = {'chol_diag': DeviceArray([3.571883 , 3.3866591, 3.4933212], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.01577711, -0.03477038,  0.0054091 ], dtype=float32), 'mean': DeviceArray([-0.03074228, -0.00257466, -0.04914777], dtype=float32)}
Iteration 149:
   Loss = 166412.25 
   Phi = {'chol_diag': DeviceArray([3.5593808, 3.3737702, 3.4798005], dtype=float32), 'chol_lo

Iteration 175:
   Loss = 108970.4375 
   Phi = {'chol_diag': DeviceArray([3.2621584, 3.0782118, 3.17488  ], dtype=float32), 'chol_lowerdiag': DeviceArray([-0.00168352, -0.02283547,  0.00166452], dtype=float32), 'mean': DeviceArray([-0.05407792,  0.00340078, -0.00429971], dtype=float32)}
Iteration 176:
   Loss = 101375.140625 
   Phi = {'chol_diag': DeviceArray([3.2522006, 3.0678968, 3.1647606], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.0003083 , -0.02550976, -0.00072013], dtype=float32), 'mean': DeviceArray([-0.05228635,  0.00605991, -0.0062558 ], dtype=float32)}
Iteration 177:
   Loss = 109218.90625 
   Phi = {'chol_diag': DeviceArray([3.24219  , 3.0576525, 3.1547325], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.00266009, -0.02685313, -0.00261167], dtype=float32), 'mean': DeviceArray([-0.04954441,  0.00894505, -0.00847747], dtype=float32)}
Iteration 178:
   Loss = 109545.828125 
   Phi = {'chol_diag': DeviceArray([3.2320867, 3.047521 , 3.1447608], dtype=float32), 'chol_l

Iteration 204:
   Loss = 77933.8203125 
   Phi = {'chol_diag': DeviceArray([2.9950058, 2.8090613, 2.911164 ], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.04410037, -0.01438534, -0.01548514], dtype=float32), 'mean': DeviceArray([-0.00142725,  0.00577863, -0.01449791], dtype=float32)}
Iteration 205:
   Loss = 74921.65625 
   Phi = {'chol_diag': DeviceArray([2.986881 , 2.8008685, 2.902837 ], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.0475912 , -0.0148197 , -0.01461605], dtype=float32), 'mean': DeviceArray([-0.0008992 ,  0.00578794, -0.01503684], dtype=float32)}
Iteration 206:
   Loss = 79746.8515625 
   Phi = {'chol_diag': DeviceArray([2.9787648, 2.7926087, 2.894515 ], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.05009599, -0.01570188, -0.01384832], dtype=float32), 'mean': DeviceArray([ 0.00070438,  0.0056525 , -0.01438276], dtype=float32)}
Iteration 207:
   Loss = 72005.40625 
   Phi = {'chol_diag': DeviceArray([2.9705818, 2.7845395, 2.886245 ], dtype=float32), 'chol_lo

Iteration 233:
   Loss = 51774.8203125 
   Phi = {'chol_diag': DeviceArray([2.7738085, 2.5870564, 2.6876006], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.01003791, -0.02277115, -0.00771447], dtype=float32), 'mean': DeviceArray([ 0.02814922,  0.02832392, -0.04288527], dtype=float32)}
Iteration 234:
   Loss = 48468.484375 
   Phi = {'chol_diag': DeviceArray([2.767154 , 2.5801623, 2.6805182], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.00782137, -0.02268052, -0.00697852], dtype=float32), 'mean': DeviceArray([ 0.02921423,  0.03038521, -0.04167541], dtype=float32)}
Iteration 235:
   Loss = 48984.375 
   Phi = {'chol_diag': DeviceArray([2.760521 , 2.5733516, 2.6735835], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.0052123 , -0.02266316, -0.00621861], dtype=float32), 'mean': DeviceArray([ 0.03056088,  0.0320055 , -0.03925348], dtype=float32)}
Iteration 236:
   Loss = 48753.55859375 
   Phi = {'chol_diag': DeviceArray([2.7539313, 2.5665565, 2.666803 ], dtype=float32), 'chol_lo

Iteration 262:
   Loss = 37568.671875 
   Phi = {'chol_diag': DeviceArray([2.5962288, 2.4057999, 2.507765 ], dtype=float32), 'chol_lowerdiag': DeviceArray([-0.02274254, -0.00388972, -0.00270462], dtype=float32), 'mean': DeviceArray([ 0.01824484,  0.01216065, -0.01402452], dtype=float32)}
Iteration 263:
   Loss = 40864.12890625 
   Phi = {'chol_diag': DeviceArray([2.5907001, 2.3998816, 2.5020907], dtype=float32), 'chol_lowerdiag': DeviceArray([-0.02431795, -0.00212853, -0.00356326], dtype=float32), 'mean': DeviceArray([ 0.01851034,  0.01040964, -0.01442235], dtype=float32)}
Iteration 264:
   Loss = 40835.2265625 
   Phi = {'chol_diag': DeviceArray([2.5852208, 2.393966 , 2.4963667], dtype=float32), 'chol_lowerdiag': DeviceArray([-0.02474317,  0.00011203, -0.00407176], dtype=float32), 'mean': DeviceArray([ 0.01845331,  0.00879901, -0.01442891], dtype=float32)}
Iteration 265:
   Loss = 39865.54296875 
   Phi = {'chol_diag': DeviceArray([2.579777 , 2.3880339, 2.4906752], dtype=float32), 'ch

Iteration 291:
   Loss = 27254.646484375 
   Phi = {'chol_diag': DeviceArray([2.4465752, 2.2475274, 2.3543332], dtype=float32), 'chol_lowerdiag': DeviceArray([-0.01721411,  0.0178589 ,  0.00182456], dtype=float32), 'mean': DeviceArray([ 0.02291739,  0.01588679, -0.03406094], dtype=float32)}
Iteration 292:
   Loss = 28280.1875 
   Phi = {'chol_diag': DeviceArray([2.4418101, 2.242449 , 2.3495913], dtype=float32), 'chol_lowerdiag': DeviceArray([-0.01684623,  0.01870254,  0.00154537], dtype=float32), 'mean': DeviceArray([ 0.02239125,  0.01643943, -0.03309759], dtype=float32)}
Iteration 293:
   Loss = 29089.484375 
   Phi = {'chol_diag': DeviceArray([2.4370432, 2.2373989, 2.3449187], dtype=float32), 'chol_lowerdiag': DeviceArray([-0.01621936,  0.01989474,  0.0010517 ], dtype=float32), 'mean': DeviceArray([ 0.02174961,  0.01702963, -0.03235475], dtype=float32)}
Iteration 294:
   Loss = 26528.328125 
   Phi = {'chol_diag': DeviceArray([2.4322653, 2.2324245, 2.3404052], dtype=float32), 'chol_l

Iteration 320:
   Loss = 27072.904296875 
   Phi = {'chol_diag': DeviceArray([2.3185577, 2.11562  , 2.2245567], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.01500342,  0.00026168, -0.00049216], dtype=float32), 'mean': DeviceArray([-0.01764445,  0.00721606, -0.01513401], dtype=float32)}
Iteration 321:
   Loss = 22115.2578125 
   Phi = {'chol_diag': DeviceArray([2.3145308, 2.1116092, 2.2202694], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.01649513, -0.00096838,  0.00027841], dtype=float32), 'mean': DeviceArray([-0.0185536 ,  0.00694973, -0.0138065 ], dtype=float32)}
Iteration 322:
   Loss = 25088.919921875 
   Phi = {'chol_diag': DeviceArray([2.310473 , 2.1075535, 2.2160118], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.01765031, -0.00102125,  0.0003251 ], dtype=float32), 'mean': DeviceArray([-0.01887341,  0.0066265 , -0.01242056], dtype=float32)}
Iteration 323:
   Loss = 21605.669921875 
   Phi = {'chol_diag': DeviceArray([2.3064718, 2.1035287, 2.2117956], dtype=float32)

Iteration 349:
   Loss = 17873.755859375 
   Phi = {'chol_diag': DeviceArray([2.2119648, 2.0022957, 2.1133964], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.03216529, -0.02003323, -0.00405495], dtype=float32), 'mean': DeviceArray([-0.01195489,  0.01428196,  0.02263922], dtype=float32)}
Iteration 350:
   Loss = 17728.669921875 
   Phi = {'chol_diag': DeviceArray([2.208575 , 1.9986806, 2.1100051], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.03113258, -0.01998761, -0.00402824], dtype=float32), 'mean': DeviceArray([-0.01295844,  0.01413643,  0.0238343 ], dtype=float32)}
Iteration 351:
   Loss = 18409.970703125 
   Phi = {'chol_diag': DeviceArray([2.205221 , 1.995062 , 2.1066241], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.02989293, -0.02046432, -0.00370024], dtype=float32), 'mean': DeviceArray([-0.01340714,  0.01321212,  0.02578289], dtype=float32)}
Iteration 352:
   Loss = 18839.650390625 
   Phi = {'chol_diag': DeviceArray([2.201897 , 1.9914134, 2.1032548], dtype=float3

Iteration 378:
   Loss = 14511.111328125 
   Phi = {'chol_diag': DeviceArray([2.1188898, 1.9028043, 2.0182228], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.00372026, -0.02099271, -0.00056735], dtype=float32), 'mean': DeviceArray([-0.00762901,  0.00494406,  0.0316621 ], dtype=float32)}
Iteration 379:
   Loss = 15627.677734375 
   Phi = {'chol_diag': DeviceArray([2.1159806, 1.8997533, 2.015067 ], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.00405971, -0.02010352, -0.00100951], dtype=float32), 'mean': DeviceArray([-0.00764947,  0.0053672 ,  0.03080264], dtype=float32)}
Iteration 380:
   Loss = 15760.064453125 
   Phi = {'chol_diag': DeviceArray([2.1130903, 1.8966701, 2.0119424], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.00402307, -0.01937999, -0.0010289 ], dtype=float32), 'mean': DeviceArray([-0.00783186,  0.00565816,  0.02996084], dtype=float32)}
Iteration 381:
   Loss = 13702.26171875 
   Phi = {'chol_diag': DeviceArray([2.1102304, 1.8936478, 2.0088778], dtype=float32

Iteration 406:
   Loss = 13116.7978515625 
   Phi = {'chol_diag': DeviceArray([2.0406542, 1.8201603, 1.9376496], dtype=float32), 'chol_lowerdiag': DeviceArray([-2.1750500e-05,  6.5302080e-04,  1.1296546e-02], dtype=float32), 'mean': DeviceArray([-0.01060206,  0.00832701,  0.02558687], dtype=float32)}
Iteration 407:
   Loss = 12483.3427734375 
   Phi = {'chol_diag': DeviceArray([2.0379915, 1.8173925, 1.9350601], dtype=float32), 'chol_lowerdiag': DeviceArray([-9.6780219e-05,  2.0298231e-03,  1.2252400e-02], dtype=float32), 'mean': DeviceArray([-0.0108015 ,  0.00813303,  0.024428  ], dtype=float32)}
Iteration 408:
   Loss = 13007.16015625 
   Phi = {'chol_diag': DeviceArray([2.0353649, 1.8146223, 1.9324799], dtype=float32), 'chol_lowerdiag': DeviceArray([0.00011858, 0.00272048, 0.01324722], dtype=float32), 'mean': DeviceArray([-0.0107587 ,  0.00755149,  0.02367715], dtype=float32)}
Iteration 409:
   Loss = 12698.20703125 
   Phi = {'chol_diag': DeviceArray([2.0327928, 1.8118603, 1.9298947

Iteration 435:
   Loss = 12193.96484375 
   Phi = {'chol_diag': DeviceArray([1.9714992, 1.7444834, 1.8645248], dtype=float32), 'chol_lowerdiag': DeviceArray([0.00386366, 0.0011933 , 0.01321824], dtype=float32), 'mean': DeviceArray([-0.0002137 , -0.00491278,  0.02522259], dtype=float32)}
Iteration 436:
   Loss = 11344.72265625 
   Phi = {'chol_diag': DeviceArray([1.9693329, 1.7420382, 1.8621145], dtype=float32), 'chol_lowerdiag': DeviceArray([0.00396199, 0.00163243, 0.01232212], dtype=float32), 'mean': DeviceArray([-2.8239025e-05, -4.7717053e-03,  2.6010580e-02], dtype=float32)}
Iteration 437:
   Loss = 11079.64453125 
   Phi = {'chol_diag': DeviceArray([1.967194 , 1.7396404, 1.8596841], dtype=float32), 'chol_lowerdiag': DeviceArray([0.00414071, 0.00206546, 0.01139434], dtype=float32), 'mean': DeviceArray([-0.00020051, -0.00436706,  0.02730455], dtype=float32)}
Iteration 438:
   Loss = 10723.095703125 
   Phi = {'chol_diag': DeviceArray([1.9650948, 1.7372955, 1.8572361], dtype=float32),

Iteration 464:
   Loss = 10492.787109375 
   Phi = {'chol_diag': DeviceArray([1.9097662, 1.6788489, 1.7983088], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.0007476 , -0.00236778,  0.00914908], dtype=float32), 'mean': DeviceArray([ 0.00230647, -0.00390544,  0.02102797], dtype=float32)}
Iteration 465:
   Loss = 10198.7880859375 
   Phi = {'chol_diag': DeviceArray([1.9076644, 1.6766834, 1.7961804], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.00116028, -0.00326075,  0.00870613], dtype=float32), 'mean': DeviceArray([ 0.00247595, -0.00411957,  0.02066516], dtype=float32)}
Iteration 466:
   Loss = 9930.197265625 
   Phi = {'chol_diag': DeviceArray([1.905572 , 1.6745338, 1.7940781], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.00169035, -0.00393989,  0.00808237], dtype=float32), 'mean': DeviceArray([ 0.0024709 , -0.00412484,  0.01998513], dtype=float32)}
Iteration 467:
   Loss = 9002.7509765625 
   Phi = {'chol_diag': DeviceArray([1.90354  , 1.6723955, 1.7920258], dtype=float3

Iteration 493:
   Loss = 9960.2001953125 
   Phi = {'chol_diag': DeviceArray([1.8522694, 1.61675  , 1.7413553], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.00376427, -0.00885582,  0.00729454], dtype=float32), 'mean': DeviceArray([ 0.00400822,  0.00844568, -0.01228261], dtype=float32)}
Iteration 494:
   Loss = 9081.626953125 
   Phi = {'chol_diag': DeviceArray([1.8503656, 1.6147237, 1.7394539], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.00419108, -0.0091231 ,  0.00772706], dtype=float32), 'mean': DeviceArray([ 0.00283029,  0.00785043, -0.0134244 ], dtype=float32)}
Iteration 495:
   Loss = 9626.0390625 
   Phi = {'chol_diag': DeviceArray([1.8484998, 1.6126946, 1.7375277], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.0049746 , -0.00933437,  0.0079258 ], dtype=float32), 'mean': DeviceArray([ 0.00136449,  0.00690575, -0.01457532], dtype=float32)}
Iteration 496:
   Loss = 9406.8486328125 
   Phi = {'chol_diag': DeviceArray([1.8466629, 1.6106603, 1.7355989], dtype=float32), 

Iteration 522:
   Loss = 7807.412109375 
   Phi = {'chol_diag': DeviceArray([1.8028834, 1.5599691, 1.687626 ], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.01068137, -0.01189979,  0.01412234], dtype=float32), 'mean': DeviceArray([-0.01523325, -0.0063742 , -0.02505573], dtype=float32)}
Iteration 523:
   Loss = 8566.2880859375 
   Phi = {'chol_diag': DeviceArray([1.8012278, 1.5581138, 1.685904 ], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.01097081, -0.0118375 ,  0.01379024], dtype=float32), 'mean': DeviceArray([-0.01482688, -0.00675562, -0.02560416], dtype=float32)}
Iteration 524:
   Loss = 8983.8857421875 
   Phi = {'chol_diag': DeviceArray([1.799537 , 1.5562512, 1.6842194], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.01106117, -0.01168295,  0.01331912], dtype=float32), 'mean': DeviceArray([-0.01455696, -0.00719964, -0.02595476], dtype=float32)}
Iteration 525:
   Loss = 8874.3623046875 
   Phi = {'chol_diag': DeviceArray([1.7978369, 1.5543988, 1.6825279], dtype=float32

Iteration 551:
   Loss = 8559.17578125 
   Phi = {'chol_diag': DeviceArray([1.7578357, 1.5095719, 1.6386955], dtype=float32), 'chol_lowerdiag': DeviceArray([0.01095118, 0.00042818, 0.00935587], dtype=float32), 'mean': DeviceArray([-0.01801814, -0.01170012, -0.02312667], dtype=float32)}
Iteration 552:
   Loss = 7854.71484375 
   Phi = {'chol_diag': DeviceArray([1.7563665, 1.5079294, 1.6370876], dtype=float32), 'chol_lowerdiag': DeviceArray([0.0101329 , 0.00074377, 0.00936323], dtype=float32), 'mean': DeviceArray([-0.01812152, -0.01156938, -0.02289253], dtype=float32)}
Iteration 553:
   Loss = 8005.978515625 
   Phi = {'chol_diag': DeviceArray([1.7549015, 1.5062715, 1.6354936], dtype=float32), 'chol_lowerdiag': DeviceArray([0.00943908, 0.00085693, 0.00922701], dtype=float32), 'mean': DeviceArray([-0.01847572, -0.01155093, -0.02251086], dtype=float32)}
Iteration 554:
   Loss = 7727.63037109375 
   Phi = {'chol_diag': DeviceArray([1.7534895, 1.5046015, 1.6339055], dtype=float32), 'chol_low

Iteration 580:
   Loss = 7514.271484375 
   Phi = {'chol_diag': DeviceArray([1.7182755, 1.4636034, 1.5955011], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.00122713,  0.00348618, -0.00160186], dtype=float32), 'mean': DeviceArray([-0.00944506, -0.00611861, -0.01351548], dtype=float32)}
Iteration 581:
   Loss = 7105.2919921875 
   Phi = {'chol_diag': DeviceArray([1.7170386, 1.4621378, 1.5940906], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.00103863,  0.00348654, -0.00229964], dtype=float32), 'mean': DeviceArray([-0.00938519, -0.00606662, -0.01381041], dtype=float32)}
Iteration 582:
   Loss = 7209.21337890625 
   Phi = {'chol_diag': DeviceArray([1.7158211, 1.4606726, 1.592683 ], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.00123858,  0.00313564, -0.00263244], dtype=float32), 'mean': DeviceArray([-0.00931658, -0.00586132, -0.01395647], dtype=float32)}
Iteration 583:
   Loss = 7208.82470703125 
   Phi = {'chol_diag': DeviceArray([1.7146003, 1.4592228, 1.5912939], dtype=float

Iteration 609:
   Loss = 7101.578125 
   Phi = {'chol_diag': DeviceArray([1.6833341, 1.4239533, 1.5564173], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.0052004 ,  0.00501565, -0.00161912], dtype=float32), 'mean': DeviceArray([-0.00682157, -0.00352356, -0.01534649], dtype=float32)}
Iteration 610:
   Loss = 7372.6337890625 
   Phi = {'chol_diag': DeviceArray([1.6821127, 1.42258  , 1.5551287], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.0051271 ,  0.00524807, -0.00151533], dtype=float32), 'mean': DeviceArray([-0.00719278, -0.00308105, -0.01553231], dtype=float32)}
Iteration 611:
   Loss = 6649.216796875 
   Phi = {'chol_diag': DeviceArray([1.6809182, 1.4212288, 1.5538621], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.00483663,  0.00542597, -0.00140907], dtype=float32), 'mean': DeviceArray([-0.00733994, -0.00260506, -0.01573931], dtype=float32)}
Iteration 612:
   Loss = 7284.37939453125 
   Phi = {'chol_diag': DeviceArray([1.6796987, 1.4198956, 1.5525901], dtype=float32), 

Iteration 638:
   Loss = 6680.82958984375 
   Phi = {'chol_diag': DeviceArray([1.6525606, 1.3860673, 1.5219703], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.00252221,  0.01072091, -0.00151672], dtype=float32), 'mean': DeviceArray([-0.00359336, -0.0004004 , -0.02343275], dtype=float32)}
Iteration 639:
   Loss = 6499.9599609375 
   Phi = {'chol_diag': DeviceArray([1.6516753, 1.3847835, 1.5208406], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.00251397,  0.01076566, -0.0013495 ], dtype=float32), 'mean': DeviceArray([-0.00368517, -0.00047436, -0.02308712], dtype=float32)}
Iteration 640:
   Loss = 6691.91943359375 
   Phi = {'chol_diag': DeviceArray([1.6507963, 1.3835008, 1.5197043], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.00222689,  0.01103178, -0.00114857], dtype=float32), 'mean': DeviceArray([-0.00367822, -0.00067744, -0.02279097], dtype=float32)}
Iteration 641:
   Loss = 6486.66748046875 
   Phi = {'chol_diag': DeviceArray([1.6498971, 1.3822416, 1.5185987], dtype=flo

Iteration 667:
   Loss = 6228.99560546875 
   Phi = {'chol_diag': DeviceArray([1.626501 , 1.3506583, 1.489793 ], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.00620794,  0.00992455, -0.00288026], dtype=float32), 'mean': DeviceArray([-0.00514776,  0.00121901, -0.01985047], dtype=float32)}
Iteration 668:
   Loss = 6382.21923828125 
   Phi = {'chol_diag': DeviceArray([1.6256468, 1.3495277, 1.4887514], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.00548531,  0.00904908, -0.00312207], dtype=float32), 'mean': DeviceArray([-0.00558336,  0.00127988, -0.02021943], dtype=float32)}
Iteration 669:
   Loss = 6329.84228515625 
   Phi = {'chol_diag': DeviceArray([1.6247958, 1.348419 , 1.4877203], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.00466311,  0.0084453 , -0.00333961], dtype=float32), 'mean': DeviceArray([-0.0059308 ,  0.00142588, -0.02049738], dtype=float32)}
Iteration 670:
   Loss = 6314.55908203125 
   Phi = {'chol_diag': DeviceArray([1.6239486, 1.3473222, 1.4866953], dtype=fl

Iteration 695:
   Loss = 6410.9921875 
   Phi = {'chol_diag': DeviceArray([1.6043758, 1.3204547, 1.4603933], dtype=float32), 'chol_lowerdiag': DeviceArray([-1.5081447e-03,  7.1410537e-03, -3.4272874e-05], dtype=float32), 'mean': DeviceArray([-0.0051005 , -0.00015538, -0.01249978], dtype=float32)}
Iteration 696:
   Loss = 5925.3193359375 
   Phi = {'chol_diag': DeviceArray([1.6036453, 1.3194268, 1.459354 ], dtype=float32), 'chol_lowerdiag': DeviceArray([-0.0014868 ,  0.00669123, -0.00023189], dtype=float32), 'mean': DeviceArray([-5.3450931e-03, -8.0346334e-05, -1.1976535e-02], dtype=float32)}
Iteration 697:
   Loss = 6057.23583984375 
   Phi = {'chol_diag': DeviceArray([1.6029315, 1.3184077, 1.4583253], dtype=float32), 'chol_lowerdiag': DeviceArray([-0.00160555,  0.00622719, -0.00043289], dtype=float32), 'mean': DeviceArray([-5.6567984e-03,  1.5120248e-05, -1.1490107e-02], dtype=float32)}
Iteration 698:
   Loss = 6669.3994140625 
   Phi = {'chol_diag': DeviceArray([1.602204 , 1.3173614,

Iteration 723:
   Loss = 5762.99072265625 
   Phi = {'chol_diag': DeviceArray([1.5833952, 1.2907807, 1.4340401], dtype=float32), 'chol_lowerdiag': DeviceArray([-0.00216698,  0.00448945,  0.00706579], dtype=float32), 'mean': DeviceArray([-0.00922257, -0.00028706,  0.00350685], dtype=float32)}
Iteration 724:
   Loss = 5708.92822265625 
   Phi = {'chol_diag': DeviceArray([1.5826327, 1.289795 , 1.4331967], dtype=float32), 'chol_lowerdiag': DeviceArray([-0.00161118,  0.00459896,  0.0073668 ], dtype=float32), 'mean': DeviceArray([-0.01001355, -0.00050901,  0.00413283], dtype=float32)}
Iteration 725:
   Loss = 6356.9677734375 
   Phi = {'chol_diag': DeviceArray([1.5818506, 1.2887855, 1.432342 ], dtype=float32), 'chol_lowerdiag': DeviceArray([-0.00122329,  0.00449342,  0.00785427], dtype=float32), 'mean': DeviceArray([-0.01070956, -0.00059345,  0.00453121], dtype=float32)}
Iteration 726:
   Loss = 6264.1806640625 
   Phi = {'chol_diag': DeviceArray([1.5810614, 1.2877761, 1.4314902], dtype=floa

Iteration 752:
   Loss = 5964.1328125 
   Phi = {'chol_diag': DeviceArray([1.5619838, 1.2631532, 1.4099663], dtype=float32), 'chol_lowerdiag': DeviceArray([0.00644424, 0.00311383, 0.01518117], dtype=float32), 'mean': DeviceArray([-0.00070276, -0.00865672,  0.00089304], dtype=float32)}
Iteration 753:
   Loss = 5901.2958984375 
   Phi = {'chol_diag': DeviceArray([1.5612673, 1.262274 , 1.4091555], dtype=float32), 'chol_lowerdiag': DeviceArray([0.00653432, 0.00326124, 0.01501584], dtype=float32), 'mean': DeviceArray([ 6.6043285e-05, -9.1434279e-03,  1.1166934e-03], dtype=float32)}
Iteration 754:
   Loss = 5711.20947265625 
   Phi = {'chol_diag': DeviceArray([1.5605828, 1.2613803, 1.408356 ], dtype=float32), 'chol_lowerdiag': DeviceArray([0.00685141, 0.00327963, 0.01503434], dtype=float32), 'mean': DeviceArray([ 0.00083133, -0.00969117,  0.00126627], dtype=float32)}
Iteration 755:
   Loss = 6434.767578125 
   Phi = {'chol_diag': DeviceArray([1.5598961, 1.2604506, 1.4075189], dtype=float32),

Iteration 781:
   Loss = 5698.2021484375 
   Phi = {'chol_diag': DeviceArray([1.5429941, 1.2367191, 1.3873614], dtype=float32), 'chol_lowerdiag': DeviceArray([0.0118479 , 0.00594515, 0.01566879], dtype=float32), 'mean': DeviceArray([ 0.01733926, -0.00786389, -0.00292071], dtype=float32)}
Iteration 782:
   Loss = 5505.45556640625 
   Phi = {'chol_diag': DeviceArray([1.5423388, 1.2358797, 1.3866909], dtype=float32), 'chol_lowerdiag': DeviceArray([0.01165258, 0.00565889, 0.01537348], dtype=float32), 'mean': DeviceArray([ 0.0173866 , -0.00705414, -0.00348414], dtype=float32)}
Iteration 783:
   Loss = 5792.9912109375 
   Phi = {'chol_diag': DeviceArray([1.5417422, 1.235016 , 1.3860214], dtype=float32), 'chol_lowerdiag': DeviceArray([0.01156755, 0.00523512, 0.01491725], dtype=float32), 'mean': DeviceArray([ 0.01761002, -0.00636033, -0.00401082], dtype=float32)}
Iteration 784:
   Loss = 5840.30615234375 
   Phi = {'chol_diag': DeviceArray([1.5411564, 1.2341518, 1.3853519], dtype=float32), 'ch

Iteration 810:
   Loss = 5650.86767578125 
   Phi = {'chol_diag': DeviceArray([1.526833 , 1.2123576, 1.3676388], dtype=float32), 'chol_lowerdiag': DeviceArray([0.00483822, 0.00160736, 0.00631132], dtype=float32), 'mean': DeviceArray([ 0.01838849,  0.00210283, -0.00618352], dtype=float32)}
Iteration 811:
   Loss = 5997.416015625 
   Phi = {'chol_diag': DeviceArray([1.5262203, 1.211547 , 1.3670012], dtype=float32), 'chol_lowerdiag': DeviceArray([0.00474162, 0.00118977, 0.00581411], dtype=float32), 'mean': DeviceArray([ 0.01869321,  0.00199888, -0.00634574], dtype=float32)}
Iteration 812:
   Loss = 5879.1796875 
   Phi = {'chol_diag': DeviceArray([1.5255916, 1.2107289, 1.3663788], dtype=float32), 'chol_lowerdiag': DeviceArray([0.00471875, 0.00083758, 0.00545199], dtype=float32), 'mean': DeviceArray([ 0.01886153,  0.00200914, -0.00657962], dtype=float32)}
Iteration 813:
   Loss = 5644.64697265625 
   Phi = {'chol_diag': DeviceArray([1.5250096, 1.2099124, 1.3657525], dtype=float32), 'chol_l

Iteration 839:
   Loss = 6193.56884765625 
   Phi = {'chol_diag': DeviceArray([1.5104502, 1.1901507, 1.3494054], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.00485991, -0.00446606,  0.0002022 ], dtype=float32), 'mean': DeviceArray([ 0.02547229,  0.00440932, -0.00598119], dtype=float32)}
Iteration 840:
   Loss = 5890.0869140625 
   Phi = {'chol_diag': DeviceArray([1.5098548, 1.1893979, 1.3487706], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.00479365, -0.00416019,  0.00048554], dtype=float32), 'mean': DeviceArray([ 0.02570242,  0.00435772, -0.00529663], dtype=float32)}
Iteration 841:
   Loss = 5495.681640625 
   Phi = {'chol_diag': DeviceArray([1.50929  , 1.1886581, 1.3481456], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.00490243, -0.00382787,  0.0007408 ], dtype=float32), 'mean': DeviceArray([ 0.02584892,  0.004395  , -0.0046384 ], dtype=float32)}
Iteration 842:
   Loss = 5915.0068359375 
   Phi = {'chol_diag': DeviceArray([1.5086964, 1.1879127, 1.3475355], dtype=float3

Iteration 868:
   Loss = 5464.4794921875 
   Phi = {'chol_diag': DeviceArray([1.4941368, 1.1695603, 1.3317206], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.00667111,  0.00575524, -0.00127242], dtype=float32), 'mean': DeviceArray([0.02992694, 0.00723294, 0.0038694 ], dtype=float32)}
Iteration 869:
   Loss = 5752.82421875 
   Phi = {'chol_diag': DeviceArray([1.4936628, 1.1688954, 1.331132 ], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.00660723,  0.00588571, -0.00171704], dtype=float32), 'mean': DeviceArray([0.03009895, 0.00706831, 0.00415016], dtype=float32)}
Iteration 870:
   Loss = 5665.3662109375 
   Phi = {'chol_diag': DeviceArray([1.4931995, 1.1682272, 1.3305405], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.00657077,  0.00593095, -0.00204566], dtype=float32), 'mean': DeviceArray([0.03040317, 0.0069011 , 0.00417105], dtype=float32)}
Iteration 871:
   Loss = 5632.90576171875 
   Phi = {'chol_diag': DeviceArray([1.4927465, 1.1675515, 1.3299638], dtype=float32), 'chol_

Iteration 897:
   Loss = 5439.17919921875 
   Phi = {'chol_diag': DeviceArray([1.4818108, 1.1510295, 1.3155849], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.00465819,  0.00500431, -0.00248211], dtype=float32), 'mean': DeviceArray([ 0.02672518, -0.00074305,  0.00489552], dtype=float32)}
Iteration 898:
   Loss = 5653.05517578125 
   Phi = {'chol_diag': DeviceArray([1.4813918, 1.1504297, 1.3150241], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.00448964,  0.00478795, -0.00231874], dtype=float32), 'mean': DeviceArray([ 0.02660285, -0.00090111,  0.00525032], dtype=float32)}
Iteration 899:
   Loss = 5921.5009765625 
   Phi = {'chol_diag': DeviceArray([1.4809124, 1.1498337, 1.3144678], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.00434176,  0.00456612, -0.00207869], dtype=float32), 'mean': DeviceArray([ 0.02624285, -0.00107812,  0.00558085], dtype=float32)}
Iteration 900:
   Loss = 5595.0625 
   Phi = {'chol_diag': DeviceArray([1.480456 , 1.1492283, 1.3139118], dtype=float32), 

Iteration 926:
   Loss = 5802.24169921875 
   Phi = {'chol_diag': DeviceArray([1.4698639, 1.133653 , 1.3009508], dtype=float32), 'chol_lowerdiag': DeviceArray([0.00139174, 0.00345867, 0.00675601], dtype=float32), 'mean': DeviceArray([ 0.0190408 , -0.00942199,  0.01481012], dtype=float32)}
Iteration 927:
   Loss = 5653.7412109375 
   Phi = {'chol_diag': DeviceArray([1.4694972, 1.1330014, 1.3004726], dtype=float32), 'chol_lowerdiag': DeviceArray([0.00110591, 0.00347164, 0.006864  ], dtype=float32), 'mean': DeviceArray([ 0.01872448, -0.0093546 ,  0.01506198], dtype=float32)}
Iteration 928:
   Loss = 5568.90771484375 
   Phi = {'chol_diag': DeviceArray([1.4691335, 1.1323394, 1.3000201], dtype=float32), 'chol_lowerdiag': DeviceArray([0.00089609, 0.0034325 , 0.00705801], dtype=float32), 'mean': DeviceArray([ 0.01803995, -0.0093399 ,  0.01518686], dtype=float32)}
Iteration 929:
   Loss = 5571.17236328125 
   Phi = {'chol_diag': DeviceArray([1.4687603, 1.1316735, 1.2995732], dtype=float32), 'c

Iteration 955:
   Loss = 5526.78662109375 
   Phi = {'chol_diag': DeviceArray([1.4601882, 1.1166257, 1.2873026], dtype=float32), 'chol_lowerdiag': DeviceArray([-0.0005104 ,  0.00079331,  0.00533189], dtype=float32), 'mean': DeviceArray([ 0.01854331, -0.00648943,  0.00089503], dtype=float32)}
Iteration 956:
   Loss = 5393.3984375 
   Phi = {'chol_diag': DeviceArray([1.4598582, 1.1160994, 1.2868564], dtype=float32), 'chol_lowerdiag': DeviceArray([-0.00029343,  0.00067245,  0.00510439], dtype=float32), 'mean': DeviceArray([ 0.01891613, -0.00613074,  0.00028946], dtype=float32)}
Iteration 957:
   Loss = 5564.4423828125 
   Phi = {'chol_diag': DeviceArray([1.4595313, 1.1155715, 1.2864001], dtype=float32), 'chol_lowerdiag': DeviceArray([-7.6078693e-05,  6.1217340e-04,  4.9153636e-03], dtype=float32), 'mean': DeviceArray([ 0.01922624, -0.00577237, -0.00042445], dtype=float32)}
Iteration 958:
   Loss = 5766.00634765625 
   Phi = {'chol_diag': DeviceArray([1.4591887, 1.1150324, 1.2859437], dtyp

Iteration 984:
   Loss = 5438.05712890625 
   Phi = {'chol_diag': DeviceArray([1.4524078, 1.1019442, 1.2751305], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.00222672, -0.00574544,  0.00208283], dtype=float32), 'mean': DeviceArray([ 0.02162471, -0.00072461, -0.01015406], dtype=float32)}
Iteration 985:
   Loss = 5344.140625 
   Phi = {'chol_diag': DeviceArray([1.4522157, 1.1014354, 1.2747293], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.00204357, -0.0062075 ,  0.00210873], dtype=float32), 'mean': DeviceArray([ 0.02164211, -0.0004869 , -0.01019295], dtype=float32)}
Iteration 986:
   Loss = 5740.56787109375 
   Phi = {'chol_diag': DeviceArray([1.4519788, 1.1008935, 1.2743475], dtype=float32), 'chol_lowerdiag': DeviceArray([ 0.00186233, -0.00639039,  0.00205695], dtype=float32), 'mean': DeviceArray([ 2.1446610e-02, -9.4111369e-05, -1.0172525e-02], dtype=float32)}
Iteration 987:
   Loss = 5426.6142578125 
   Phi = {'chol_diag': DeviceArray([1.4517368, 1.1003572, 1.2739727], dtype