In [1]:
import jax
import jax.numpy as jnp
from approx_post.distributions import approx, joint
from approx_post import losses, optimisers

In [2]:
def create_data(model, true_theta, noise_cov, num_samples, ndim, prngkey):
    mean = model(true_theta)
    samples = jax.random.multivariate_normal(key=prngkey, mean=mean, cov=noise_cov, shape=(num_samples,))
    return samples.reshape(num_samples, -1)

In [3]:
# First, let's define a model:
ndim = 1
model = lambda theta: theta**2
model_grad = jax.vmap(jax.vmap(jax.jacfwd(model), in_axes=0), in_axes=0)

In [4]:
# Create artificial data:
prngkey = jax.random.PRNGKey(10)
true_theta = jnp.array([2])
noise_cov = 0.1*jnp.identity(ndim)
num_samples = 1
data = create_data(model, true_theta, noise_cov, num_samples, ndim, prngkey)
print(f'True theta: \n {true_theta}')
print(f'True x = model(theta): \n {model(true_theta)}')
print(f'Observations x_obs = model(theta) + noise: \n {data}')



True theta: 
 [2]
True x = model(theta): 
 [4]
Observations x_obs = model(theta) + noise: 
 [[3.5748188]]


In [5]:
# Create Gaussian approximate distribution:
approx_dist = approx.Gaussian(ndim)

In [6]:
# Create Joint distribution from forward model:
prior_mean = jnp.zeros(ndim)
prior_cov = 0.1*jnp.identity(ndim)
joint_dist = joint.ModelPlusGaussian(model, noise_cov, prior_mean, prior_cov, model_grad)

In [7]:
approx_dist = approx.Gaussian(ndim)
prngkey = jax.random.PRNGKey(12)
loss = losses.ELBO(joint_dist, use_reparameterisation=True)
optimiser = optimisers.Adam()
optimiser.fit(approx_dist, loss, data, prngkey, verbose = True, max_iter=1000, num_samples=1000)

Loss = 46.42293930053711, Params = Jaxtainer({'mean': DeviceArray([-0.1], dtype=float32), 'log_chol_diag': DeviceArray([-0.1], dtype=float32)})
Loss = 46.40510940551758, Params = Jaxtainer({'mean': DeviceArray([-0.19955799], dtype=float32), 'log_chol_diag': DeviceArray([-0.06496035], dtype=float32)})
Loss = 45.89853286743164, Params = Jaxtainer({'mean': DeviceArray([-0.29935193], dtype=float32), 'log_chol_diag': DeviceArray([-0.01416349], dtype=float32)})
Loss = 45.640193939208984, Params = Jaxtainer({'mean': DeviceArray([-0.3956759], dtype=float32), 'log_chol_diag': DeviceArray([0.00168393], dtype=float32)})
Loss = 45.63945388793945, Params = Jaxtainer({'mean': DeviceArray([-0.48024344], dtype=float32), 'log_chol_diag': DeviceArray([-0.02508833], dtype=float32)})
Loss = 45.3537712097168, Params = Jaxtainer({'mean': DeviceArray([-0.5568623], dtype=float32), 'log_chol_diag': DeviceArray([-0.072787], dtype=float32)})
Loss = 44.88300323486328, Params = Jaxtainer({'mean': DeviceArray([-0.6

Loss = 17.724828720092773, Params = Jaxtainer({'mean': DeviceArray([-1.7732109], dtype=float32), 'log_chol_diag': DeviceArray([-2.7135334], dtype=float32)})
Loss = 17.729774475097656, Params = Jaxtainer({'mean': DeviceArray([-1.7710443], dtype=float32), 'log_chol_diag': DeviceArray([-2.712759], dtype=float32)})
Loss = 17.723817825317383, Params = Jaxtainer({'mean': DeviceArray([-1.7655146], dtype=float32), 'log_chol_diag': DeviceArray([-2.7112577], dtype=float32)})
Loss = 17.71150016784668, Params = Jaxtainer({'mean': DeviceArray([-1.7577721], dtype=float32), 'log_chol_diag': DeviceArray([-2.7090852], dtype=float32)})
Loss = 17.700632095336914, Params = Jaxtainer({'mean': DeviceArray([-1.7492082], dtype=float32), 'log_chol_diag': DeviceArray([-2.706291], dtype=float32)})
Loss = 17.697072982788086, Params = Jaxtainer({'mean': DeviceArray([-1.7412179], dtype=float32), 'log_chol_diag': DeviceArray([-2.70292], dtype=float32)})
Loss = 17.70148277282715, Params = Jaxtainer({'mean': DeviceArr

Loss = 17.639751434326172, Params = Jaxtainer({'mean': DeviceArray([-1.7458072], dtype=float32), 'log_chol_diag': DeviceArray([-2.4195933], dtype=float32)})
Loss = 17.639591217041016, Params = Jaxtainer({'mean': DeviceArray([-1.7465107], dtype=float32), 'log_chol_diag': DeviceArray([-2.416719], dtype=float32)})
Loss = 17.639516830444336, Params = Jaxtainer({'mean': DeviceArray([-1.747035], dtype=float32), 'log_chol_diag': DeviceArray([-2.4140127], dtype=float32)})
Loss = 17.63955307006836, Params = Jaxtainer({'mean': DeviceArray([-1.7472796], dtype=float32), 'log_chol_diag': DeviceArray([-2.4114728], dtype=float32)})
Loss = 17.63966178894043, Params = Jaxtainer({'mean': DeviceArray([-1.7472126], dtype=float32), 'log_chol_diag': DeviceArray([-2.4090962], dtype=float32)})
Loss = 17.63982391357422, Params = Jaxtainer({'mean': DeviceArray([-1.7468717], dtype=float32), 'log_chol_diag': DeviceArray([-2.4068782], dtype=float32)})
Loss = 17.640031814575195, Params = Jaxtainer({'mean': DeviceAr

KeyboardInterrupt: 

In [None]:
approx_dist.cov()

In [None]:
approx_dist = approx.Gaussian(ndim)
prngkey = jax.random.PRNGKey(12)
loss = losses.ForwardKL(joint_dist, use_reparameterisation=True)
optimiser = optimisers.Adam()
optimiser.fit(approx_dist, loss, data, prngkey, verbose = True, max_iter=1000, num_samples=1000)

In [None]:
jnp.array(100).item()