In [1]:
import numpy as np
import jax
import jax.numpy as jnp
from numpy.random import multivariate_normal as mvn_sample
from approx_post import ApproximateDistribution, JointDistribution, reverse_kl, forward_kl
from arraytainers import Jaxtainer

In [2]:
def create_data(model, true_theta, noise_cov, num_samples):
    mean = model(true_theta)
    samples = mvn_sample(mean, noise_cov, num_samples)
    return samples.reshape(num_samples, -1)

# First, let's define a model:
ndim = 1
model = lambda theta: theta**2
model_grad = jax.vmap(jax.jacfwd(model), in_axes=0)

In [3]:
# Create Gaussian approximate distribution:
approx = ApproximateDistribution.gaussian(ndim)
# Create Joint distribution from forward model:
prior_mean = np.zeros(ndim)
noise_cov = 0.5*np.identity(ndim)
prior_cov = 0.1*np.identity(ndim)

In [4]:
data = jnp.array([[10.01392484],[10.53159464],[10.60713809]])
joint = JointDistribution.from_model(data, model, noise_cov, prior_mean, prior_cov, model_grad)
reverse_kl.reversekl_reparameterisation(Jaxtainer(approx.phi), approx, joint, 5)



epsilon
[[ 0.88389313]
 [ 0.19586502]
 [ 0.35753652]
 [-2.343262  ]
 [-1.0848325 ]]
theta
[[ 0.88389313]
 [ 0.19586502]
 [ 0.35753652]
 [-2.343262  ]
 [-1.0848325 ]]
approx_lp
[-1.309572   -0.93812007 -0.9828547  -3.6643767  -1.5073693 ]
joint_lp
[-282.2496  -322.99524 -317.9128  -100.98205 -261.90384]
transform_del_phi
Jaxtainer({'chol_diag': DeviceArray([[[ 0.88389313]],

             [[ 0.19586502]],

             [[ 0.35753652]],

             [[-2.343262  ]],

             [[-1.0848325 ]]], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean': DeviceArray([[[1.]],

             [[1.]],

             [[1.]],

             [[1.]],

             [[1.]]], dtype=float32)})
joint_del_phi
Jaxtainer({'chol_diag': DeviceArray([[ 82.216965 ],
             [  4.3791585],
             [ 14.454853 ],
             [267.51605  ],
             [118.2608   ]], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean': DeviceArray([[  93.01686 ],
             [  2

(DeviceArray(255.52827, dtype=float32),
 Jaxtainer({'chol_diag': DeviceArray([-98.88861], dtype=float32), 'chol_lowerdiag': DeviceArray([], dtype=float32), 'mean': DeviceArray([13.872756], dtype=float32)}))