# Natural Gradients:

In this notebook we demonstrate how to implement natural gradients. 

As well explained in Salimbeni et al. (2018),

"The ordinary gradient turns out to be an unnatural direction to follow for variational inference since we are optimizing a distribution, rather than a set of pa- rameters directly. One way to define the gradient is the direction that achieves maximum change subject to a perturbation within a small euclidean ball. To see why the euclidean distance is an unnatural metric for probability distributions, consider the two Gaussians $\mathcal{N}(0, 0.1)$ and $\mathcal{N}(0, 0.2)$, compared to $\mathcal{N}(0, 1000.1)$ and $\mathcal{N}N(0,1000.2)$."

In [None]:
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
from jax import jit, lax
import optax as ox

import gpjax as gpx
from gpjax.abstractions import progress_bar_scan

key = jr.PRNGKey(123)

# Dataset:

Generate dataset:

In [None]:
n = 5000
noise = 0.2

x = jr.uniform(key=key, minval=-5.0, maxval=5.0, shape=(n,)).sort().reshape(-1, 1)
f = lambda x: jnp.sin(4 * x) + jnp.cos(2 * x)
signal = f(x)
y = signal + jr.normal(key, shape=signal.shape) * noise

D = gpx.Dataset(X=x, y=y)
xtest = jnp.linspace(-5.5, 5.5, 500).reshape(-1, 1)

Intialise inducing points:

In [None]:
z = jnp.linspace(-5.0, 5.0, 20).reshape(-1, 1)

fig, ax = plt.subplots(figsize=(12, 5))
ax.plot(x, y, "o", alpha=0.3)
ax.plot(xtest, f(xtest))
[ax.axvline(x=z_i, color="black", alpha=0.3, linewidth=1) for z_i in z]
plt.show()

# Model and variational inference strategy:

Define model, variational family and variational inference strategy:

In [None]:
likelihood = gpx.Gaussian(num_datapoints=n)
kernel = gpx.RBF()
prior = gpx.Prior(kernel=kernel)
p =  prior * likelihood


q = gpx.NaturalVariationalGaussian(prior=prior, inducing_inputs=z)
svgp = gpx.StochasticVI(posterior=p, variational_family=q)

In [None]:
params, trainables, constrainers, unconstrainers = gpx.initialise(svgp).unpack()
params = gpx.transform(params, unconstrainers)

loss_fn = jit(svgp.elbo(D, constrainers, negative=True))

Get default parameters and transform these to the uncontrained space:

# Natural gradients:

Define natural gradient and hyperparameter gradient functions:

In [None]:
learned_params, training_history = gpx.fit_natgrads(svgp,
                                   params = params,
                                   trainables = trainables,   
                                   transformations = constrainers,
                                   train_data = D,
                                   n_iters = 10000,
                                   batch_size=100,
                                   key = jr.PRNGKey(42),
                                   moment_optim = ox.sgd(1.0),
                                   hyper_optim = ox.adam(1e-3),
                                   ).unpack()

learned_params = gpx.transform(learned_params, constrainers)

Plot results:

In [None]:
latent_dist = q(learned_params)(xtest)
predictive_dist = likelihood(latent_dist, learned_params)

meanf = predictive_dist.mean()
sigma = predictive_dist.stddev()

fig, ax = plt.subplots(figsize=(12, 5))
ax.plot(x, y, "o", alpha=0.15, label="Training Data", color="tab:gray")
ax.plot(xtest, meanf, label="Posterior mean", color="tab:blue")
ax.fill_between(xtest.flatten(), meanf - sigma, meanf + sigma, alpha=0.3)
[
    ax.axvline(x=z_i, color="black", alpha=0.3, linewidth=1)
    for z_i in learned_params["variational_family"]["inducing_inputs"]
]
plt.show()

# Natural gradients and sparse varational Gaussian process regression:

As mentioned in Hensman et al 2013,  ....

We demonstrate this now:

In [None]:
n = 1000
noise = 0.2

x = jr.uniform(key=key, minval=-5.0, maxval=5.0, shape=(n,)).sort().reshape(-1, 1)
f = lambda x: jnp.sin(4 * x) + jnp.cos(2 * x)
signal = f(x)
y = signal + jr.normal(key, shape=signal.shape) * noise

D = gpx.Dataset(X=x, y=y)

xtest = jnp.linspace(-5.5, 5.5, 500).reshape(-1, 1)

In [None]:
z = jnp.linspace(-5.0, 5.0, 20).reshape(-1, 1)

fig, ax = plt.subplots(figsize=(12, 5))
ax.plot(x, y, "o", alpha=0.3)
ax.plot(xtest, f(xtest))
[ax.axvline(x=z_i, color="black", alpha=0.3, linewidth=1) for z_i in z]
plt.show()

In [None]:
likelihood = gpx.Gaussian(num_datapoints=n)
kernel = gpx.RBF()
prior = gpx.Prior(kernel=kernel)
p =  prior * likelihood

We begin with natgrads:

In [None]:
from gpjax.natural_gradients import natural_gradients

q = gpx.NaturalVariationalGaussian(prior=prior, inducing_inputs=z)
svgp = gpx.StochasticVI(posterior=p, variational_family=q)
params, trainables, constrainers, unconstrainers = gpx.initialise(svgp).unpack()

params = gpx.transform(params, unconstrainers)

nat_grads_fn, hyper_grads_fn = natural_gradients(svgp, D, constrainers)

moment_optim = ox.sgd(1.0)

moment_state = moment_optim.init(params)

# Natural gradients update:
loss_val, loss_gradient = nat_grads_fn(params, trainables, D)
print(loss_val)

updates, moment_state = moment_optim.update(loss_gradient, moment_state, params)
params = ox.apply_updates(params, updates)

loss_val, _ = nat_grads_fn(params, trainables, D)

print(loss_val)

Let us now run it for SGPR:

In [None]:
from gpjax.parameters import build_identity

q = gpx.CollapsedVariationalGaussian(prior=prior, likelihood=likelihood, inducing_inputs=z)
sgpr = gpx.CollapsedVI(posterior=p, variational_family=q)

params, trainables, constrainers, unconstrainers = gpx.initialise(svgp).unpack()

params = gpx.transform(params, unconstrainers)

loss_fn = sgpr.elbo(D, constrainers, negative=True)

loss_val = loss_fn(params)

print(loss_val)

The discrepancy is due to the quadrature approximation.