In [1]:
import numpy as np
import jax
import jax.numpy as jnp
from approx_post.distributions import approx, joint
from approx_post import losses, optimisers
from arraytainers import Jaxtainer

In [2]:
def create_data(model, true_theta, noise_cov, num_samples, ndim, prngkey):
    mean = model(true_theta)
    samples = jax.random.multivariate_normal(key=prngkey, mean=mean, cov=noise_cov, shape=(num_samples,))
    return samples.reshape(num_samples, -1)

In [3]:
# First, let's define a model:
ndim = 2
model = lambda theta: theta**2
model_grad = jax.vmap(jax.vmap(jax.jacfwd(model), in_axes=0), in_axes=0)

In [16]:
# Create artificial data:
prngkey = jax.random.PRNGKey(10)
true_theta = jnp.array([1, 2])
noise_cov = 0.01*np.identity(ndim)
num_samples = 1
data = create_data(model, true_theta, noise_cov, num_samples, ndim, prngkey)
print(f'True theta: \n {true_theta}')
print(f'True x = model(theta): \n {model(true_theta)}')
print(f'Observations x_obs = model(theta) + noise: \n {data}')

True theta: 
 [1 2]
True x = model(theta): 
 [1 4]
Observations x_obs = model(theta) + noise: 
 [[0.9378123 3.8724568]]


In [22]:
# Create Gaussian approximate distribution:
approx_dist = approx.Gaussian(ndim)

In [23]:
# Create Joint distribution from forward model:
prior_mean = np.zeros(ndim)
prior_cov = 1.0*np.identity(ndim)
joint_dist = joint.ModelPlusGaussian(model, noise_cov, prior_mean, prior_cov, model_grad)

In [24]:
# approx_dist._phi['mean'] = jnp.array([1., 2.])

In [None]:
prngkey = jax.random.PRNGKey(12)
loss = losses.ForwardKL(joint_dist, use_reparameterisation=False)
optimiser = optimisers.Adam()
optimiser.fit(approx_dist, loss, data, prngkey, verbose = True, max_iter=1000, num_samples=1000)

Loss = 0.003762881737202406, Params = Jaxtainer({'mean': DeviceArray([ 0.0999989 , -0.09999941], dtype=float32), 'chol_diag': DeviceArray([1.0999898, 1.0999995], dtype=float32), 'chol_lowerdiag': DeviceArray([-0.09999944], dtype=float32)})
Loss = 0.00397178390994668, Params = Jaxtainer({'mean': DeviceArray([ 0.10518708, -0.10684451], dtype=float32), 'chol_diag': DeviceArray([1.0777528, 1.1986061], dtype=float32), 'chol_lowerdiag': DeviceArray([-0.19792913], dtype=float32)})
Loss = 0.0035494219046086073, Params = Jaxtainer({'mean': DeviceArray([ 0.10301962, -0.09323387], dtype=float32), 'chol_diag': DeviceArray([1.013583 , 1.2921257], dtype=float32), 'chol_lowerdiag': DeviceArray([-0.2891729], dtype=float32)})
Loss = 0.003785498905926943, Params = Jaxtainer({'mean': DeviceArray([ 0.08753198, -0.09941825], dtype=float32), 'chol_diag': DeviceArray([0.9777185, 1.3820344], dtype=float32), 'chol_lowerdiag': DeviceArray([-0.35625684], dtype=float32)})
Loss = 0.003659829730167985, Params = Jax