In [1]:
import jax
import jax.numpy as jnp
from approx_post.distributions import approx, joint
from approx_post import losses, optimisers

In [2]:
def create_data(model, true_theta, noise_cov, num_samples, ndim, prngkey):
    mean = model(true_theta)
    samples = jax.random.multivariate_normal(key=prngkey, mean=mean, cov=noise_cov, shape=(num_samples,))
    return samples.reshape(num_samples, -1)

In [3]:
# First, let's define a model:
ndim = 1
model = lambda theta: theta**2
model_grad = jax.vmap(jax.vmap(jax.jacfwd(model), in_axes=0), in_axes=0)

In [22]:
# Create artificial data:
prngkey = jax.random.PRNGKey(10)
true_theta = jnp.array([2])
noise_cov = 1*jnp.identity(ndim)
num_samples = 1
data = create_data(model, true_theta, noise_cov, num_samples, ndim, prngkey)
print(f'True theta: \n {true_theta}')
print(f'True x = model(theta): \n {model(true_theta)}')
print(f'Observations x_obs = model(theta) + noise: \n {data}')

True theta: 
 [2]
True x = model(theta): 
 [4]
Observations x_obs = model(theta) + noise: 
 [[2.6554594]]


In [23]:
# Create Gaussian approximate distribution:
approx_dist = approx.Gaussian(ndim)

In [24]:
# Create Joint distribution from forward model:
prior_mean = jnp.zeros(ndim)
prior_cov = 1.0*jnp.identity(ndim)
joint_dist = joint.ModelPlusGaussian(model, noise_cov, prior_mean, prior_cov, model_grad)

In [25]:
approx_dist = approx.Gaussian(ndim)
prngkey = jax.random.PRNGKey(12)
loss = losses.ReverseKL(joint_dist, use_reparameterisation=True)
optimiser = optimisers.Adam()
optimiser.fit(approx_dist, loss, data, prngkey, verbose = True, max_iter=1000, num_samples=1000)

Loss = 3.350170373916626, Params = Jaxtainer({'mean': DeviceArray([-0.10000001], dtype=float32), 'log_chol_diag': DeviceArray([-0.10000001], dtype=float32)})
Loss = 3.260319709777832, Params = Jaxtainer({'mean': DeviceArray([-0.18682866], dtype=float32), 'log_chol_diag': DeviceArray([-0.15968904], dtype=float32)})
Loss = 3.270439863204956, Params = Jaxtainer({'mean': DeviceArray([-0.26481748], dtype=float32), 'log_chol_diag': DeviceArray([-0.17455044], dtype=float32)})
Loss = 3.272676706314087, Params = Jaxtainer({'mean': DeviceArray([-0.33431935], dtype=float32), 'log_chol_diag': DeviceArray([-0.16271082], dtype=float32)})
Loss = 3.2648744583129883, Params = Jaxtainer({'mean': DeviceArray([-0.38568926], dtype=float32), 'log_chol_diag': DeviceArray([-0.13933037], dtype=float32)})
Loss = 3.264331340789795, Params = Jaxtainer({'mean': DeviceArray([-0.40334502], dtype=float32), 'log_chol_diag': DeviceArray([-0.11636119], dtype=float32)})
Loss = 3.2709591388702393, Params = Jaxtainer({'mea

Loss = 3.2534544467926025, Params = Jaxtainer({'mean': DeviceArray([-0.17478144], dtype=float32), 'log_chol_diag': DeviceArray([-0.08885797], dtype=float32)})
Loss = 3.253798007965088, Params = Jaxtainer({'mean': DeviceArray([-0.17195573], dtype=float32), 'log_chol_diag': DeviceArray([-0.08902973], dtype=float32)})
Loss = 3.2539026737213135, Params = Jaxtainer({'mean': DeviceArray([-0.17105225], dtype=float32), 'log_chol_diag': DeviceArray([-0.09035807], dtype=float32)})
Loss = 3.253789186477661, Params = Jaxtainer({'mean': DeviceArray([-0.17214358], dtype=float32), 'log_chol_diag': DeviceArray([-0.09207663], dtype=float32)})
Loss = 3.2535650730133057, Params = Jaxtainer({'mean': DeviceArray([-0.17506915], dtype=float32), 'log_chol_diag': DeviceArray([-0.09336572], dtype=float32)})
Loss = 3.2533180713653564, Params = Jaxtainer({'mean': DeviceArray([-0.17940846], dtype=float32), 'log_chol_diag': DeviceArray([-0.09374897], dtype=float32)})
Loss = 3.2530980110168457, Params = Jaxtainer({'

Loss = 3.252932548522949, Params = Jaxtainer({'mean': DeviceArray([-0.18654393], dtype=float32), 'log_chol_diag': DeviceArray([-0.09268206], dtype=float32)})
Loss = 3.2529304027557373, Params = Jaxtainer({'mean': DeviceArray([-0.18703528], dtype=float32), 'log_chol_diag': DeviceArray([-0.092583], dtype=float32)})
Loss = 3.252925157546997, Params = Jaxtainer({'mean': DeviceArray([-0.18750891], dtype=float32), 'log_chol_diag': DeviceArray([-0.09258707], dtype=float32)})
Loss = 3.2529103755950928, Params = Jaxtainer({'mean': DeviceArray([-0.18787935], dtype=float32), 'log_chol_diag': DeviceArray([-0.09271015], dtype=float32)})
Loss = 3.25288724899292, Params = Jaxtainer({'mean': DeviceArray([-0.18810527], dtype=float32), 'log_chol_diag': DeviceArray([-0.09288097], dtype=float32)})
Loss = 3.252864122390747, Params = Jaxtainer({'mean': DeviceArray([-0.1881831], dtype=float32), 'log_chol_diag': DeviceArray([-0.09299372], dtype=float32)})
Loss = 3.252851724624634, Params = Jaxtainer({'mean': 

KeyboardInterrupt: 

In [27]:
approx_dist.cov()

DeviceArray([[0.83064824]], dtype=float32)

In [None]:
approx_dist = approx.Gaussian(ndim)
prngkey = jax.random.PRNGKey(12)
loss = losses.ForwardKL(joint_dist, use_reparameterisation=True)
optimiser = optimisers.Adam()
optimiser.fit(approx_dist, loss, data, prngkey, verbose = True, max_iter=1000, num_samples=1000)