In [1]:
import numpy as np
import jax
import jax.numpy as jnp
from approx_post.distributions import approx, joint
from approx_post import losses, optimisers
from arraytainers import Jaxtainer

In [2]:
def create_data(model, true_theta, noise_cov, num_samples, prngkey):
    mean = model(true_theta)
    samples = jax.random.multivariate_normal(key=prngkey, mean=mean, cov=noise_cov, shape=(num_samples,))
    return samples.reshape(num_samples, -1)

In [3]:
# First, let's define a model:
ndim = 1
model = lambda theta: theta**2
model_grad = jax.vmap(jax.vmap(jax.jacfwd(model), in_axes=0), in_axes=0)

In [4]:
# Create artificial data:
prngkey = jax.random.PRNGKey(10)
true_theta = jnp.array([2])
noise_cov = 0.01*np.identity(ndim)
num_samples = 1
data = create_data(model, true_theta, noise_cov, num_samples, prngkey)
print(f'True theta: \n {true_theta}')
print(f'True x = model(theta): \n {model(true_theta)}')
print(f'Observations x_obs = model(theta) + noise: \n {data}')



True theta: 
 [2]
True x = model(theta): 
 [4]
Observations x_obs = model(theta) + noise: 
 [[3.865546]]


In [5]:
# Create Gaussian approximate distribution:
approx_dist = approx.Gaussian(ndim)

In [6]:
# Create Joint distribution from forward model:
prior_mean = np.zeros(ndim)
prior_cov = 1.0*np.identity(ndim)
joint_dist = joint.ModelPlusGaussian(model, noise_cov, prior_mean, prior_cov, model_grad)

In [7]:
loss = losses.ReverseKL(joint_dist, use_reparameterisation=True)
optimiser = optimisers.Adam()
optimiser.fit(approx_dist, loss, data, prngkey, verbose = True, max_iter=1000, num_samples=1000)

Loss = 506.48834228515625, Params = Jaxtainer({'mean': DeviceArray([0.1], dtype=float32), 'chol_diag': DeviceArray([1.1], dtype=float32)})
Loss = 497.05157470703125, Params = Jaxtainer({'mean': DeviceArray([0.19719544], dtype=float32), 'chol_diag': DeviceArray([1.1717205], dtype=float32)})
Loss = 502.0058898925781, Params = Jaxtainer({'mean': DeviceArray([0.16163215], dtype=float32), 'chol_diag': DeviceArray([1.1703749], dtype=float32)})
Loss = 501.2159729003906, Params = Jaxtainer({'mean': DeviceArray([0.10370986], dtype=float32), 'chol_diag': DeviceArray([1.1394336], dtype=float32)})
Loss = 497.9275207519531, Params = Jaxtainer({'mean': DeviceArray([0.04778116], dtype=float32), 'chol_diag': DeviceArray([1.1017612], dtype=float32)})
Loss = 497.2038269042969, Params = Jaxtainer({'mean': DeviceArray([0.00582805], dtype=float32), 'chol_diag': DeviceArray([1.0724211], dtype=float32)})
Loss = 498.48834228515625, Params = Jaxtainer({'mean': DeviceArray([-0.02340532], dtype=float32), 'chol_d

Loss = 24.615636825561523, Params = Jaxtainer({'mean': DeviceArray([1.7923945], dtype=float32), 'chol_diag': DeviceArray([0.03162277], dtype=float32)})
Loss = 24.857635498046875, Params = Jaxtainer({'mean': DeviceArray([1.8044128], dtype=float32), 'chol_diag': DeviceArray([0.03162277], dtype=float32)})
Loss = 22.176795959472656, Params = Jaxtainer({'mean': DeviceArray([1.8273015], dtype=float32), 'chol_diag': DeviceArray([0.03162277], dtype=float32)})
Loss = 17.547359466552734, Params = Jaxtainer({'mean': DeviceArray([1.8584378], dtype=float32), 'chol_diag': DeviceArray([0.03162277], dtype=float32)})
Loss = 12.292728424072266, Params = Jaxtainer({'mean': DeviceArray([1.8949046], dtype=float32), 'chol_diag': DeviceArray([0.03162277], dtype=float32)})
Loss = 7.749441623687744, Params = Jaxtainer({'mean': DeviceArray([1.9335747], dtype=float32), 'chol_diag': DeviceArray([0.03162277], dtype=float32)})
Loss = 4.946377277374268, Params = Jaxtainer({'mean': DeviceArray([1.9712197], dtype=floa

Loss = 4.284099578857422, Params = Jaxtainer({'mean': DeviceArray([1.9590269], dtype=float32), 'chol_diag': DeviceArray([0.03162277], dtype=float32)})
Loss = 4.295891284942627, Params = Jaxtainer({'mean': DeviceArray([1.9569432], dtype=float32), 'chol_diag': DeviceArray([0.03162277], dtype=float32)})
Loss = 4.312337398529053, Params = Jaxtainer({'mean': DeviceArray([1.9558303], dtype=float32), 'chol_diag': DeviceArray([0.03162277], dtype=float32)})
Loss = 4.323840141296387, Params = Jaxtainer({'mean': DeviceArray([1.9557256], dtype=float32), 'chol_diag': DeviceArray([0.03162277], dtype=float32)})
Loss = 4.325020790100098, Params = Jaxtainer({'mean': DeviceArray([1.9565481], dtype=float32), 'chol_diag': DeviceArray([0.03162277], dtype=float32)})
Loss = 4.316204071044922, Params = Jaxtainer({'mean': DeviceArray([1.9581188], dtype=float32), 'chol_diag': DeviceArray([0.03162277], dtype=float32)})
Loss = 4.30224084854126, Params = Jaxtainer({'mean': DeviceArray([1.9601883], dtype=float32), 

KeyboardInterrupt: 