In [1]:
import jax
import jax.numpy as jnp
import approx_post
import seaborn as sns

In [2]:
def create_data(model, true_theta, noise_cov, num_samples, ndim, prngkey):
    mean = model(true_theta,0)
    samples = jax.random.multivariate_normal(key=prngkey, mean=mean, cov=noise_cov, shape=(num_samples,))
    return samples.reshape(num_samples, -1)

In [3]:
# First, let's define a model:
ndim = 1
model_func = lambda theta, x: theta**2
model, model_grad = approx_post.models.from_jax(model_func)

In [4]:
# Create artificial data:
prngkey = jax.random.PRNGKey(10)
true_theta = jnp.array([2])
noise_cov = 0.1*jnp.identity(ndim)
num_samples = 1
data = create_data(model_func, true_theta, noise_cov, num_samples, ndim, prngkey)
print(f'True theta: \n {true_theta}')
print(f'True x = model(theta): \n {model_func(true_theta, 0)}')
print(f'Observations x_obs = model(theta) + noise: \n {data}')



True theta: 
 [2]
True x = model(theta): 
 [4]
Observations x_obs = model(theta) + noise: 
 [[3.5748188]]


In [5]:
# Create Gaussian approximate distribution:
approx_dist = approx_post.distributions.approx.Gaussian(ndim)

In [6]:
# Create Joint distribution from forward model:
prior_mean = jnp.zeros(ndim)
prior_cov = 0.1*jnp.identity(ndim)
joint_dist = approx_post.distributions.joint.ModelPlusGaussian(model, noise_cov, prior_mean, prior_cov, model_grad)

In [7]:
approx_dist = approx_post.distributions.approx.Gaussian(ndim)
prngkey = jax.random.PRNGKey(12)
loss = approx_post.losses.ELBO(joint_dist, use_reparameterisation=True)
optimiser = approx_post.optimisers.Adam()
loss = optimiser.fit(approx_dist, loss, data, prngkey, verbose=True, max_iter=50, num_samples=1000)

Loss = 46.42293930053711, Params = Jaxtainer({'mean': DeviceArray([-0.1], dtype=float32), 'log_chol_diag': DeviceArray([-0.1], dtype=float32)})
Loss = 46.40510940551758, Params = Jaxtainer({'mean': DeviceArray([-0.19955799], dtype=float32), 'log_chol_diag': DeviceArray([-0.06496035], dtype=float32)})
Loss = 45.89853286743164, Params = Jaxtainer({'mean': DeviceArray([-0.29935193], dtype=float32), 'log_chol_diag': DeviceArray([-0.01416349], dtype=float32)})
Loss = 45.640193939208984, Params = Jaxtainer({'mean': DeviceArray([-0.3956759], dtype=float32), 'log_chol_diag': DeviceArray([0.00168393], dtype=float32)})
Loss = 45.63945388793945, Params = Jaxtainer({'mean': DeviceArray([-0.48024344], dtype=float32), 'log_chol_diag': DeviceArray([-0.02508833], dtype=float32)})
Loss = 45.3537712097168, Params = Jaxtainer({'mean': DeviceArray([-0.5568623], dtype=float32), 'log_chol_diag': DeviceArray([-0.072787], dtype=float32)})
Loss = 44.88300323486328, Params = Jaxtainer({'mean': DeviceArray([-0.6

In [10]:
approx_dist = approx_post.distributions.approx.Gaussian(ndim)
prngkey = jax.random.PRNGKey(12)
loss = approx_post.losses.ForwardKL(joint_dist, use_reparameterisation=False)
optimiser = approx_post.optimisers.Adam()
optimiser.fit(approx_dist, loss, data, prngkey, verbose=True, max_iter=1000, num_samples=1000)

Loss = 0.0026401900686323643, Params = Jaxtainer({'mean': DeviceArray([0.09999659], dtype=float32), 'log_chol_diag': DeviceArray([0.09999955], dtype=float32)})
Loss = 0.00238200556486845, Params = Jaxtainer({'mean': DeviceArray([0.1372815], dtype=float32), 'log_chol_diag': DeviceArray([0.19769272], dtype=float32)})
Loss = 0.0022438950836658478, Params = Jaxtainer({'mean': DeviceArray([0.15107341], dtype=float32), 'log_chol_diag': DeviceArray([0.29202056], dtype=float32)})
Loss = 0.0021345692221075296, Params = Jaxtainer({'mean': DeviceArray([0.1267155], dtype=float32), 'log_chol_diag': DeviceArray([0.38159305], dtype=float32)})
Loss = 0.002060098107904196, Params = Jaxtainer({'mean': DeviceArray([0.07687004], dtype=float32), 'log_chol_diag': DeviceArray([0.4650827], dtype=float32)})
Loss = 0.0019993737805634737, Params = Jaxtainer({'mean': DeviceArray([0.01981904], dtype=float32), 'log_chol_diag': DeviceArray([0.54074323], dtype=float32)})
Loss = 0.001967070158571005, Params = Jaxtaine

Loss = 0.0019607138819992542, Params = Jaxtainer({'mean': DeviceArray([-0.36540708], dtype=float32), 'log_chol_diag': DeviceArray([0.56253785], dtype=float32)})
Loss = 0.0019587031565606594, Params = Jaxtainer({'mean': DeviceArray([-0.3550079], dtype=float32), 'log_chol_diag': DeviceArray([0.55692834], dtype=float32)})
Loss = 0.0019567525014281273, Params = Jaxtainer({'mean': DeviceArray([-0.3430634], dtype=float32), 'log_chol_diag': DeviceArray([0.5514417], dtype=float32)})
Loss = 0.0019553101155906916, Params = Jaxtainer({'mean': DeviceArray([-0.33035463], dtype=float32), 'log_chol_diag': DeviceArray([0.54623383], dtype=float32)})
Loss = 0.001954571343958378, Params = Jaxtainer({'mean': DeviceArray([-0.3174898], dtype=float32), 'log_chol_diag': DeviceArray([0.5414634], dtype=float32)})
Loss = 0.001954470993950963, Params = Jaxtainer({'mean': DeviceArray([-0.30492118], dtype=float32), 'log_chol_diag': DeviceArray([0.53728336], dtype=float32)})
Loss = 0.0019548037089407444, Params = Ja

KeyboardInterrupt: 

In [None]:
jnp.array(100).item()