In [None]:
%load_ext autoreload
%autoreload 2

# Natural Gradients:

In [None]:
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
from jax import jit, lax
import optax as ox

import gpjax as gpx
from gpjax.natural_gradients import natural_gradients
from gpjax.abstractions import progress_bar_scan

#Set seed for reproducibility:
import tensorflow as tf
tf.random.set_seed(42)
key = jr.PRNGKey(123)

# Dataset:

Generate dataset:

In [None]:
n = 5000
noise = 0.2

x = jr.uniform(key=key, minval=-5.0, maxval=5.0, shape=(n,)).sort().reshape(-1, 1)
f = lambda x: jnp.sin(4 * x) + jnp.cos(2 * x)
signal = f(x)
y = signal + jr.normal(key, shape=signal.shape) * noise

D = gpx.Dataset(X=x, y=y)
Dbatched = D.cache().repeat().shuffle(D.n).batch(batch_size=128).prefetch(buffer_size=1)

xtest = jnp.linspace(-5.5, 5.5, 500).reshape(-1, 1)

Intialise inducing points:

In [None]:
z = jnp.linspace(-5.0, 5.0, 100).reshape(-1, 1)

fig, ax = plt.subplots(figsize=(12, 5))
ax.plot(x, y, "o", alpha=0.3)
ax.plot(xtest, f(xtest))
[ax.axvline(x=z_i, color="black", alpha=0.3, linewidth=1) for z_i in z]
plt.show()

# Model and variational inference strategy:

Define model, variational family and variational inference strategy:

In [None]:
likelihood = gpx.Gaussian(num_datapoints=n)
kernel = gpx.RBF()
prior = gpx.Prior(kernel=kernel)
p =  prior * likelihood


q = gpx.NaturalVariationalGaussian(prior=prior, inducing_inputs=z)
svgp = gpx.StochasticVI(posterior=p, variational_family=q)

Get default parameters and transform these to the uncontrained space:

In [None]:
params, trainables, constrainers, unconstrainers = gpx.initialise(svgp)

params = gpx.transform(params, unconstrainers)

# Natural gradients:

Define natural gradient and hyperparameter gradient functions:

In [None]:
nat_grads_fn, hyper_grads_fn = natural_gradients(svgp, D, constrainers)

Run optimisation loop:

In [None]:
#Optimisation example:

n_iters = 10000
log_rate = 10
train_data = Dbatched


#Define optimisers:
adam = ox.adam(1e-3) #<- hyperparameters
sgd = ox.sgd(1e-3)   #<- for natgrads
 

sgd_state = sgd.init(params)
adam_state = adam.init(params)

next_batch = train_data.get_batcher()

# Optimisation step:
@progress_bar_scan(n_iters, log_rate)
def step(params_opt_state, i):
    params, sgd_state, adam_state = params_opt_state
    batch = next_batch()
    
    # Natural gradients update:
    loss_val, loss_gradient = nat_grads_fn(params, trainables, batch)
    updates, opt_state = sgd.update(loss_gradient, sgd_state, params)
    params = ox.apply_updates(params, updates)
    
    
    # Hyperparameters update:
    loss_val, loss_gradient = hyper_grads_fn(params, trainables, batch)
    updates, adam_state = adam.update(loss_gradient, adam_state, params)
    params = ox.apply_updates(params, updates)
    
    
    params_opt_state = params, sgd_state, adam_state
    
    
    return params_opt_state, loss_val
 
    
# Optimisation loop:
(params, _, _), _ = lax.scan(step, (params, sgd_state, adam_state), jnp.arange(n_iters))

Plot results:

In [None]:
learned_params = gpx.transform(params, constrainers)

latent_dist = q(learned_params)(xtest)
predictive_dist = likelihood(latent_dist, learned_params)

meanf = predictive_dist.mean()
sigma = predictive_dist.stddev()

fig, ax = plt.subplots(figsize=(12, 5))
ax.plot(x, y, "o", alpha=0.15, label="Training Data", color="tab:gray")
ax.plot(xtest, meanf, label="Posterior mean", color="tab:blue")
ax.fill_between(xtest.flatten(), meanf - sigma, meanf + sigma, alpha=0.3)
[
    ax.axvline(x=z_i, color="black", alpha=0.3, linewidth=1)
    for z_i in learned_params["variational_family"]["inducing_inputs"]
]
plt.show()