In [30]:
import numpy as np
import jax
import jax.numpy as jnp
from approx_post.distributions import approx, joint
from approx_post import losses, optimisers
from arraytainers import Jaxtainer

In [31]:
def create_data(model, true_theta, noise_cov, num_samples, prngkey):
    mean = model(true_theta)
    samples = jax.random.multivariate_normal(key=prngkey, mean=mean, cov=noise_cov, shape=(num_samples,))
    return samples.reshape(num_samples, -1)

In [32]:
# First, let's define a model:
ndim = 1
model = lambda theta: theta**2
model_grad = jax.vmap(jax.vmap(jax.jacfwd(model), in_axes=0), in_axes=0)

In [35]:
# Create artificial data:
prngkey = jax.random.PRNGKey(10)
true_theta = jnp.array([2])
noise_cov = 0.01*np.identity(ndim)
num_samples = 1
data = create_data(model, true_theta, noise_cov, num_samples, prngkey)
print(f'True theta: \n {true_theta}')
print(f'True x = model(theta): \n {model(true_theta)}')
print(f'Observations x_obs = model(theta) + noise: \n {data}')

True theta: 
 [2]
True x = model(theta): 
 [4]
Observations x_obs = model(theta) + noise: 
 [[3.865546]]


In [34]:
# Create Gaussian approximate distribution:
approx_dist = approx.Gaussian(ndim)

In [28]:
# Create Joint distribution from forward model:
prior_mean = np.zeros(ndim)
prior_cov = 1.0*np.identity(ndim)
joint_dist = joint.ModelPlusGaussian(model, noise_cov, prior_mean, prior_cov, model_grad)

In [29]:
loss = losses.ReverseKL(joint_dist, use_reparameterisation=True)
optimiser = optimisers.Adam()
# data = jnp.array([[10.40]])
optimiser.fit(approx_dist, loss, data, prngkey, verbose = True, max_iter=1000, num_samples=1000)

Loss = 225.7564239501953, Params = Jaxtainer({'mean': DeviceArray([0.1], dtype=float32), 'chol_diag': DeviceArray([0.9], dtype=float32)})
Loss = 110.0890121459961, Params = Jaxtainer({'mean': DeviceArray([0.19544262], dtype=float32), 'chol_diag': DeviceArray([0.8048882], dtype=float32)})
Loss = 49.37482833862305, Params = Jaxtainer({'mean': DeviceArray([0.28343594], dtype=float32), 'chol_diag': DeviceArray([0.71804345], dtype=float32)})
Loss = 22.33612060546875, Params = Jaxtainer({'mean': DeviceArray([0.36243066], dtype=float32), 'chol_diag': DeviceArray([0.64148885], dtype=float32)})
Loss = 12.819815635681152, Params = Jaxtainer({'mean': DeviceArray([0.43232512], dtype=float32), 'chol_diag': DeviceArray([0.57558906], dtype=float32)})
Loss = 10.932327270507812, Params = Jaxtainer({'mean': DeviceArray([0.49394715], dtype=float32), 'chol_diag': DeviceArray([0.51953775], dtype=float32)})
Loss = 11.645439147949219, Params = Jaxtainer({'mean': DeviceArray([0.54849553], dtype=float32), 'cho

Loss = 1.880064606666565, Params = Jaxtainer({'mean': DeviceArray([0.9531635], dtype=float32), 'chol_diag': DeviceArray([0.05484301], dtype=float32)})
Loss = 1.8770745992660522, Params = Jaxtainer({'mean': DeviceArray([0.9526437], dtype=float32), 'chol_diag': DeviceArray([0.05414409], dtype=float32)})
Loss = 1.8721367120742798, Params = Jaxtainer({'mean': DeviceArray([0.95076984], dtype=float32), 'chol_diag': DeviceArray([0.05383635], dtype=float32)})
Loss = 1.8656593561172485, Params = Jaxtainer({'mean': DeviceArray([0.9477363], dtype=float32), 'chol_diag': DeviceArray([0.05389784], dtype=float32)})
Loss = 1.8602275848388672, Params = Jaxtainer({'mean': DeviceArray([0.94378525], dtype=float32), 'chol_diag': DeviceArray([0.05428232], dtype=float32)})
Loss = 1.859460473060608, Params = Jaxtainer({'mean': DeviceArray([0.9391893], dtype=float32), 'chol_diag': DeviceArray([0.05492517], dtype=float32)})
Loss = 1.8667176961898804, Params = Jaxtainer({'mean': DeviceArray([0.93423325], dtype=f

Loss = 1.9474972486495972, Params = Jaxtainer({'mean': DeviceArray([0.9248385], dtype=float32), 'chol_diag': DeviceArray([0.05824671], dtype=float32)})
Loss = 1.9491970539093018, Params = Jaxtainer({'mean': DeviceArray([0.9247452], dtype=float32), 'chol_diag': DeviceArray([0.0584247], dtype=float32)})
Loss = 1.951374888420105, Params = Jaxtainer({'mean': DeviceArray([0.9245987], dtype=float32), 'chol_diag': DeviceArray([0.05858965], dtype=float32)})
Loss = 1.9538074731826782, Params = Jaxtainer({'mean': DeviceArray([0.92441756], dtype=float32), 'chol_diag': DeviceArray([0.05873049], dtype=float32)})
Loss = 1.956275224685669, Params = Jaxtainer({'mean': DeviceArray([0.92422], dtype=float32), 'chol_diag': DeviceArray([0.05883872], dtype=float32)})
Loss = 1.958577036857605, Params = Jaxtainer({'mean': DeviceArray([0.9240228], dtype=float32), 'chol_diag': DeviceArray([0.05890888], dtype=float32)})
Loss = 1.9605463743209839, Params = Jaxtainer({'mean': DeviceArray([0.92383987], dtype=float3

Loss = 1.9581212997436523, Params = Jaxtainer({'mean': DeviceArray([0.9237765], dtype=float32), 'chol_diag': DeviceArray([0.05843685], dtype=float32)})
Loss = 1.9579147100448608, Params = Jaxtainer({'mean': DeviceArray([0.9237868], dtype=float32), 'chol_diag': DeviceArray([0.05842683], dtype=float32)})
Loss = 1.9577550888061523, Params = Jaxtainer({'mean': DeviceArray([0.9237903], dtype=float32), 'chol_diag': DeviceArray([0.05841812], dtype=float32)})
Loss = 1.9576536417007446, Params = Jaxtainer({'mean': DeviceArray([0.92378694], dtype=float32), 'chol_diag': DeviceArray([0.05841138], dtype=float32)})


KeyboardInterrupt: 