In [1]:
import numpy as np
import jax
from numpy.random import multivariate_normal as mvn
from approx_post.distributions import approx, joint
from approx_post import losses, optimisers

In [2]:
def create_data(model, true_theta, noise_cov, num_samples):
    mean = model(true_theta)
    samples = mvn(mean, noise_cov, num_samples)
    return samples.reshape(num_samples, -1)

In [3]:
# First, let's define a model:
ndim = 1
model = lambda theta: theta**2
model_grad = jax.vmap(jax.vmap(jax.jacfwd(model), in_axes=0), in_axes=0)

In [4]:
# Create artificial data:
true_theta = 2*np.random.rand(ndim)
noise_cov = 0.5*np.identity(ndim)
num_samples = 1
data = create_data(model, true_theta, noise_cov, num_samples)
print(f'True theta: \n {true_theta}')
print(f'True x = model(theta): \n {model(true_theta)}')
print(f'Observations x_obs = model(theta) + noise: \n {data}')

True theta: 
 [1.51787927]
True x = model(theta): 
 [2.30395748]
Observations x_obs = model(theta) + noise: 
 [[1.87528783]]


In [5]:
# Create Gaussian approximate distribution:
approx_dist = approx.Gaussian(ndim)



In [6]:
# Create Joint distribution from forward model:
prior_mean = np.zeros(ndim)
prior_cov = 0.1*np.identity(ndim)
noise_cov = 0.5*np.identity(ndim)
joint_dist = joint.ModelPlusGaussian(model, noise_cov, prior_mean, prior_cov, model_grad)

In [11]:
key = jax.random.PRNGKey(20)
approx_dist.sample_base(1, key)

DeviceArray([[-0.74462473]], dtype=float32)

In [7]:
key = jax.random.PRNGKey(20)
num_batch = 1
num_samples = 1
x_dim = ndim
theta = jax.random.normal(key, shape=(num_batch, num_samples, x_dim))
x = jax.random.normal(key, shape=(num_batch, x_dim))
joint_dist.logpdf_del_1(theta, x)

DeviceArray([[[11.315588]]], dtype=float32)

In [8]:
key = jax.random.PRNGKey(20)
num_batch = 1
num_samples = 5
x_dim = ndim
theta = jax.random.normal(key, shape=(num_batch, num_samples, x_dim))
approx_dist.logpdf_del_2(theta)

Jaxtainer({'chol_diag': DeviceArray([[[-0.99993795],
              [-0.6391506 ],
              [-0.62039256],
              [-0.92603296],
              [-0.97708076]]], dtype=float32), 'mean': DeviceArray([[[-0.0078783 ],
              [-0.6007074 ],
              [-0.6161229 ],
              [-0.27196884],
              [-0.151391  ]]], dtype=float32)})

In [9]:
# Fit sddistribution to reverse KL divergence:
loss = losses.ReverseKL(joint_dist, use_reparameterisation=True)
optimiser = optimisers.Adam()

In [10]:
prngkey = jax.random.PRNGKey(20)
optimiser.fit(approx_dist, loss, data, prngkey, verbose=True)

Loss = 3.559851884841919, Params = Jaxtainer({'mean': DeviceArray([-0.1], dtype=float32), 'chol_diag': DeviceArray([0.9], dtype=float32)})
Loss = 2.941087484359741, Params = Jaxtainer({'mean': DeviceArray([0.07441369], dtype=float32), 'chol_diag': DeviceArray([0.92558634], dtype=float32)})
Loss = 3.1198480129241943, Params = Jaxtainer({'mean': DeviceArray([-0.06388136], dtype=float32), 'chol_diag': DeviceArray([0.93611866], dtype=float32)})
Loss = 3.1320605278015137, Params = Jaxtainer({'mean': DeviceArray([0.05811283], dtype=float32), 'chol_diag': DeviceArray([0.94188714], dtype=float32)})
Loss = 3.202230930328369, Params = Jaxtainer({'mean': DeviceArray([-0.05454892], dtype=float32), 'chol_diag': DeviceArray([0.9454511], dtype=float32)})
Loss = 3.187124013900757, Params = Jaxtainer({'mean': DeviceArray([0.05221178], dtype=float32), 'chol_diag': DeviceArray([0.94778824], dtype=float32)})
Loss = 3.233907699584961, Params = Jaxtainer({'mean': DeviceArray([-0.05063773], dtype=float32), '

Loss = 3.1278321743011475, Params = Jaxtainer({'mean': DeviceArray([-0.07339499], dtype=float32), 'chol_diag': DeviceArray([0.926605], dtype=float32)})
Loss = 3.0783963203430176, Params = Jaxtainer({'mean': DeviceArray([0.07401828], dtype=float32), 'chol_diag': DeviceArray([0.9259817], dtype=float32)})
Loss = 3.1217591762542725, Params = Jaxtainer({'mean': DeviceArray([-0.07463723], dtype=float32), 'chol_diag': DeviceArray([0.92536277], dtype=float32)})
Loss = 3.071570634841919, Params = Jaxtainer({'mean': DeviceArray([0.07525185], dtype=float32), 'chol_diag': DeviceArray([0.9247481], dtype=float32)})
Loss = 3.1158111095428467, Params = Jaxtainer({'mean': DeviceArray([-0.07586213], dtype=float32), 'chol_diag': DeviceArray([0.92413783], dtype=float32)})
Loss = 3.064879894256592, Params = Jaxtainer({'mean': DeviceArray([0.0764681], dtype=float32), 'chol_diag': DeviceArray([0.9235319], dtype=float32)})
Loss = 3.10998797416687, Params = Jaxtainer({'mean': DeviceArray([-0.07706974], dtype=f

In [11]:
(1, 1)
(1, 10, 1)
[[0.5]]

[[0.5]]

In [14]:
key = jax.random.PRNGKey(20)
loss.eval(approx_dist, data, key)

(DeviceArray(3.119848, dtype=float32),
 Jaxtainer({'chol_diag': DeviceArray([6.0927553], dtype=float32), 'mean': DeviceArray([1.2652347], dtype=float32)}))