In [1]:
import numpy as np
import jax
from numpy.random import multivariate_normal as mvn
from approx_post import ApproximateDistribution, AmortisedApproximation, JointDistribution, reverse_kl, forward_kl, fit_approximation

In [2]:
def create_data(model, theta_vals, noise_cov, num_obs):
    theta_dim = noise_cov.shape[0]
    mean_vals = model(theta_vals).reshape(-1,theta_dim) # mean_vals.shape = (num_batch, theta_dim)
    # Zero-mean samples:
    epsilon = mvn(np.zeros(theta_dim), noise_cov, num_obs) # epsilon.shape = (num_samples, theta_dim)
    # Add each mean on:
    samples = mean_vals[:,None,:] + epsilon # samples.shape = (num_batch, num_samples, theta_dim)
    return samples

In [3]:
# First, let's define a model:
ndim = 1
model = lambda theta: theta**2
model_grad = jax.vmap(jax.jacfwd(model), in_axes=1)

In [4]:
# Create artificial data:
theta_vals = np.linspace(0,10,200)
noise_cov = 0.01*np.identity(ndim)
num_obs = 2
data = create_data(model, theta_vals, noise_cov, num_obs)
# print(f'True theta: \n {true_theta}')
# print(f'True x = model(theta): \n {model(true_theta)}')
# print(f'Observations x_obs = model(theta) + noise: \n {data}')

In [5]:
# Create Gaussian approximate distribution:
approx = ApproximateDistribution.gaussian(ndim)
approx = AmortisedApproximation.nn(approx, data)



In [6]:
theta_vals.shape

(200,)

In [11]:
approx.phi(theta_vals)

DeviceArray([[6.73177242e-02, 6.73177242e-02],
             [6.21010251e-02, 3.28072384e-02],
             [5.72885871e-02, 1.59885827e-02],
             [5.28490879e-02, 7.79202534e-03],
             [4.87536155e-02, 3.79743660e-03],
             [4.49755229e-02, 1.85067824e-03],
             [4.14902046e-02, 9.01926833e-04],
             [3.82749960e-02, 4.39553143e-04],
             [3.53089012e-02, 2.14215877e-04],
             [3.25727239e-02, 1.04398096e-04],
             [3.00485194e-02, 5.08782796e-05],
             [2.77199373e-02, 2.47954904e-05],
             [2.55718231e-02, 1.20840677e-05],
             [2.35901847e-02, 5.88916419e-06],
             [2.17621047e-02, 2.87007560e-06],
             [2.00756639e-02, 1.39872986e-06],
             [1.85199138e-02, 6.81669576e-07],
             [1.70847662e-02, 3.32211584e-07],
             [1.57608166e-02, 1.61903287e-07],
             [1.45394281e-02, 7.89033336e-08],
             [1.34127177e-02, 3.84534609e-08],
             

In [8]:
a

NameError: name 'a' is not defined

In [None]:
# Create Joint distribution from forward model:
prior_mean = np.zeros(ndim)
prior_cov = np.identity(ndim)
joint = JointDistribution.from_model(model, noise_cov, prior_mean, prior_cov, model_grad)

In [None]:
# Fit sddistribution to reverse KL divergence:
# loss = reverse_kl(approx, joint, use_reparameterisation=True)
loss = reverse_kl(approx, joint, use_reparameterisation=True)
approx = fit_approximation(loss, approx, data, verbose=True)