In [7]:
import jax
import jax.numpy as jnp
from approx_post.distributions import approx, joint
from approx_post import losses, optimisers

In [8]:
def create_data(model, true_theta, noise_cov, num_samples, ndim, prngkey):
    mean = model(true_theta)
    samples = jax.random.multivariate_normal(key=prngkey, mean=mean, cov=noise_cov, shape=(num_samples,))
    return samples.reshape(num_samples, -1)

In [9]:
# First, let's define a model:
ndim = 1
model = lambda theta: theta**2
model_grad = jax.vmap(jax.vmap(jax.jacfwd(model), in_axes=0), in_axes=0)

In [10]:
# Create artificial data:
prngkey = jax.random.PRNGKey(10)
true_theta = jnp.array([2])
noise_cov = 0.01*jnp.identity(ndim)
num_samples = 1
data = create_data(model, true_theta, noise_cov, num_samples, ndim, prngkey)
print(f'True theta: \n {true_theta}')
print(f'True x = model(theta): \n {model(true_theta)}')
print(f'Observations x_obs = model(theta) + noise: \n {data}')

True theta: 
 [2]
True x = model(theta): 
 [4]
Observations x_obs = model(theta) + noise: 
 [[3.865546]]


In [11]:
# Create Gaussian approximate distribution:
approx_dist = approx.Gaussian(ndim)

In [15]:
# Create Joint distribution from forward model:
prior_mean = jnp.zeros(ndim)
prior_cov = 1.0*jnp.identity(ndim)
joint_dist = joint.ModelPlusGaussian(model, noise_cov, prior_mean, prior_cov, model_grad)

In [16]:
approx_dist = approx.Gaussian(ndim)
prngkey = jax.random.PRNGKey(12)
loss = losses.ForwardKL(joint_dist, use_reparameterisation=False)
optimiser = optimisers.Adam()
optimiser.fit(approx_dist, loss, data, prngkey, verbose = True, max_iter=1000, num_samples=1000)

NameError: name 'epsilon' is not defined

In [9]:
approx_dist = approx.Gaussian(ndim)
prngkey = jax.random.PRNGKey(12)
loss = losses.ForwardKL(joint_dist, use_reparameterisation=True)
optimiser = optimisers.Adam()
optimiser.fit(approx_dist, loss, data, prngkey, verbose = True, max_iter=1000, num_samples=1000)

Jaxtainer({'chol_diag': DeviceArray([[-11.]], dtype=float32), 'mean': DeviceArray([[-3.]], dtype=float32)})
Jaxtainer({'chol_diag': DeviceArray([[-30666.797]], dtype=float32), 'mean': DeviceArray([[-6683.672]], dtype=float32)})
[-6.4189386]
Jaxtainer({'chol_diag': DeviceArray([[-362772.4]], dtype=float32), 'mean': DeviceArray([[-77241.17]], dtype=float32)})
Loss = 6.418938636779785, Params = Jaxtainer({'mean': DeviceArray([0.1], dtype=float32), 'chol_diag': DeviceArray([1.1], dtype=float32)})
Jaxtainer({'chol_diag': DeviceArray([[-10.]], dtype=float32), 'mean': DeviceArray([[-2.7272725]], dtype=float32)})
Jaxtainer({'chol_diag': DeviceArray([[-45888.008]], dtype=float32), 'mean': DeviceArray([[-10172.429]], dtype=float32)})
[-6.514249]
Jaxtainer({'chol_diag': DeviceArray([[-540432.44]], dtype=float32), 'mean': DeviceArray([[-115807.7]], dtype=float32)})
Loss = 6.514248847961426, Params = Jaxtainer({'mean': DeviceArray([0.19908386], dtype=float32), 'chol_diag': DeviceArray([1.1991262], 

Jaxtainer({'chol_diag': DeviceArray([[-3.7399983]], dtype=float32), 'mean': DeviceArray([[-1.0199995]], dtype=float32)})
Jaxtainer({'chol_diag': DeviceArray([[-1497275.1]], dtype=float32), 'mean': DeviceArray([[-352784.62]], dtype=float32)})
[-7.4977493]
Jaxtainer({'chol_diag': DeviceArray([[-18311606.]], dtype=float32), 'mean': DeviceArray([[-4040053.8]], dtype=float32)})
Loss = 7.497749328613281, Params = Jaxtainer({'mean': DeviceArray([2.0529933], dtype=float32), 'chol_diag': DeviceArray([3.054682], dtype=float32)})
Jaxtainer({'chol_diag': DeviceArray([[-3.6010292]], dtype=float32), 'mean': DeviceArray([[-0.9820989]], dtype=float32)})
Jaxtainer({'chol_diag': DeviceArray([[-1692892.]], dtype=float32), 'mean': DeviceArray([[-399209.72]], dtype=float32)})
[-7.535614]
Jaxtainer({'chol_diag': DeviceArray([[-20755734.]], dtype=float32), 'mean': DeviceArray([[-4581989.]], dtype=float32)})
Loss = 7.535614013671875, Params = Jaxtainer({'mean': DeviceArray([2.1680446], dtype=float32), 'chol_d

Jaxtainer({'chol_diag': DeviceArray([[-2.2]], dtype=float32), 'mean': DeviceArray([[-0.6]], dtype=float32)})
Jaxtainer({'chol_diag': DeviceArray([[-8198243.5]], dtype=float32), 'mean': DeviceArray([[-1949692.]], dtype=float32)})
[-8.028377]
Jaxtainer({'chol_diag': DeviceArray([[-1.03947896e+08]], dtype=float32), 'mean': DeviceArray([[-23101274.]], dtype=float32)})
Loss = 8.028376579284668, Params = Jaxtainer({'mean': DeviceArray([4.2326064], dtype=float32), 'chol_diag': DeviceArray([5.], dtype=float32)})
Jaxtainer({'chol_diag': DeviceArray([[-2.2]], dtype=float32), 'mean': DeviceArray([[-0.6]], dtype=float32)})
Jaxtainer({'chol_diag': DeviceArray([[-8336429.]], dtype=float32), 'mean': DeviceArray([[-1984930.4]], dtype=float32)})
[-8.028377]
Jaxtainer({'chol_diag': DeviceArray([[-1.0561039e+08]], dtype=float32), 'mean': DeviceArray([[-23484560.]], dtype=float32)})
Loss = 8.028376579284668, Params = Jaxtainer({'mean': DeviceArray([4.375177], dtype=float32), 'chol_diag': DeviceArray([5.],

KeyboardInterrupt: 