In [None]:
%load_ext autoreload
%autoreload 2
import gpjax as gpx

In [None]:
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
from jax import jit, lax
import optax as ox

import gpjax as gpx
from gpjax.natural_gradients import natural_gradients
from gpjax.abstractions import progress_bar_scan

#Set seed for reproducibility:
import tensorflow as tf
tf.random.set_seed(4)
key = jr.PRNGKey(123)

import typing as tp
from copy import deepcopy

import distrax as dx
import jax.numpy as jnp
import jax.scipy as jsp
from jax import lax, value_and_grad, jacobian
from jaxtyping import f64

from gpjax.config import get_defaults
from gpjax.gps import AbstractPosterior
from gpjax.parameters import (
    build_identity,
    build_trainables_false,
    build_trainables_true,
    trainable_params,
    transform,
)
from gpjax.types import Dataset
from gpjax.utils import I
from gpjax.variational_families import (
    AbstractVariationalFamily,
    ExpectationVariationalGaussian,
    NaturalVariationalGaussian,
)
from gpjax.variational_inference import StochasticVI
DEFAULT_JITTER = get_defaults()["jitter"]


from gpjax.natural_gradients import natural_gradients, natural_to_expectation, _expectation_elbo, _stop_gradients_nonmoments, _stop_gradients_moments, fit_natgrads

# Dataset and inducing points:

In [None]:
n = 5000
noise = 0.2

x = jr.uniform(key=key, minval=-5.0, maxval=5.0, shape=(n,)).sort().reshape(-1, 1)
f = lambda x: jnp.sin(4 * x) + jnp.cos(2 * x)
signal = f(x)
y = signal + jr.normal(key, shape=signal.shape) * noise

D = gpx.Dataset(X=x, y=y)
Dbatched = D.cache().repeat().shuffle(D.n).batch(batch_size=256).prefetch(buffer_size=1)

xtest = jnp.linspace(-5.5, 5.5, 500).reshape(-1, 1)

In [None]:
z = jnp.linspace(-5.0, 5.0, 2).reshape(-1, 1)

fig, ax = plt.subplots(figsize=(12, 5))
ax.plot(x, y, "o", alpha=0.3)
ax.plot(xtest, f(xtest))
[ax.axvline(x=z_i, color="black", alpha=0.3, linewidth=1) for z_i in z]
plt.show()

# Natgrads code

In [None]:

def natural_gradients(
    stochastic_vi: StochasticVI,
    train_data: Dataset,
    transformations: dict,
    xi_to_nat: tp.Callable[[tp.Dict], tp.Dict],
    nat_to_xi: tp.Callable[[tp.Dict], tp.Dict],
) -> tp.Tuple[tp.Callable[[dict, Dataset], dict]]:
    """
    Computes natural gradients for variational Gaussian.
    Args:
        posterior: An instance of AbstractPosterior.
        variational_family: An instance of AbstractVariationalFamily.
        train_data: A Dataset.
        transformations: A dictionary of transformations.
    Returns:
        Tuple[tp.Callable[[dict, Dataset], dict]]: Functions that compute natural gradients and hyperparameter gradients respectively.
    """
    posterior = stochastic_vi.posterior
    variational_family = stochastic_vi.variational_family

    # The ELBO under the user chosen parameterisation xi.
    xi_elbo = stochastic_vi.elbo(train_data, transformations, negative=True)

    # The ELBO under the expectation parameterisation, L(η).
    expectation_elbo = _expectation_elbo(posterior, variational_family, train_data)

    if isinstance(variational_family, NaturalVariationalGaussian):

        def nat_grads_fn(params: dict, trainables: dict, batch: Dataset) -> dict:
            """
            Computes the natural gradients of the ELBO.
            Args:
                params: A dictionary of parameters.
                trainables: A dictionary of trainables.
                batch: A Dataset.
            Returns:
                dict: A dictionary of natural gradients.
            """
            # Transform parameters to constrained space.
            params = transform(params, transformations)

            # Get natural moments θ.
            natural_moments = params["variational_family"]["moments"]

            # Get expectation moments η.
            expectation_moments = natural_to_expectation(natural_moments)

            # Full params with expectation moments.
            expectation_params = deepcopy(params)
            expectation_params["variational_family"]["moments"] = expectation_moments

            # Compute gradient ∂L/∂η:
            def loss_fn(params: dict, batch: Dataset) -> f64["1"]:
                # Determine hyperparameters that should be trained.
                trains = deepcopy(trainables)
                trains["variational_family"]["moments"] = build_trainables_true(
                    params["variational_family"]["moments"]
                )
                params = trainable_params(params, trains)

                # Stop gradients for non-moment parameters.
                params = _stop_gradients_nonmoments(params)

                return expectation_elbo(params, batch)

            value, dL_dnat = value_and_grad(loss_fn)(expectation_params, batch)

            # This is a renaming of the gradient components to match the natural parameterisation pytree.
            natural_gradient = dL_dnat
            natural_gradient["variational_family"]["moments"] = {
                "natural_vector": dL_dnat["variational_family"]["moments"][
                    "expectation_vector"
                ],
                "natural_matrix": dL_dnat["variational_family"]["moments"][
                    "expectation_matrix"
                ],
            }

            return value, natural_gradient

    else:

        def nat_grads_fn(params: dict, trainables: dict, batch: Dataset) -> dict:
            # Transform parameters to constrained space.
            params = transform(params, transformations)

            # Get natural moments θ.
            natural_moments = xi_to_nat(params["variational_family"]["moments"])

            # Get expectation moments η.
            expectation_moments = natural_to_expectation(natural_moments)

            # Gradient function ∂ξ/∂θ:
            dxi_dnat = jacobian(nat_to_xi)(natural_moments)

           # Full params with expectation moments.
            expectation_params = deepcopy(params)
            expectation_params["variational_family"]["moments"] = expectation_moments

            # Compute gradient ∂L/∂η:
            def loss_fn(params: dict, batch: Dataset) -> f64["1"]:
                # Determine hyperparameters that should be trained.
                trains = deepcopy(trainables)
                trains["variational_family"]["moments"] = build_trainables_true(
                    params["variational_family"]["moments"]
                )
                params = trainable_params(params, trains)

                # Stop gradients for non-moment parameters.
                params = _stop_gradients_nonmoments(params)

                return expectation_elbo(params, batch)

            value, dL_dexp = value_and_grad(loss_fn)(expectation_params, batch)

            
            # The issue is combining: ∂ξ/∂θ ∂L/∂η
            natural_gradient = None
            
            return value, natural_gradient

    def hyper_grads_fn(params: dict, trainables: dict, batch: Dataset) -> dict:
        """
        Computes the hyperparameter gradients of the ELBO.
        Args:
            params: A dictionary of parameters.
            trainables: A dictionary of trainables.
            batch: A Dataset.
        Returns:
            dict: A dictionary of hyperparameter gradients.
        """

        def loss_fn(params: dict, batch: Dataset) -> f64["1"]:
            # Determine hyperparameters that should be trained.
            params = trainable_params(params, trainables)

            # Stop gradients for the moment parameters.
            params = _stop_gradients_moments(params)

            return xi_elbo(params, batch)

        value, dL_dhyper = value_and_grad(loss_fn)(params, batch)

        return value, dL_dhyper

    return nat_grads_fn, hyper_grads_fn

# Example

We will consider using the expectation family as a test for computing natural gradients $\xi = \eta$ (though of course this simplifies in reality: $\frac{d\xi}{d\theta} \frac{d\mathcal{L}}{d\eta} = \frac{d\mathcal{L}}{d\theta}$).

We begin by defining the bijection between $\xi$ and $\theta$:

In [None]:
def xi_to_nat(moments: dict) -> dict:
    
    expectation_vector = moments["expectation_vector"]
    expectation_matrix = moments["expectation_matrix"]
    
    m = expectation_vector.shape[0]
    
    mu = expectation_vector
    
    S = expectation_matrix - jnp.matmul(mu, mu.T)
    S += I(m) * 1e-6
    
    L = jnp.linalg.cholesky(S)
    
    L_inv = jsp.linalg.solve_triangular(L, S, lower=True)
    
    S_inv = jnp.matmul(L_inv.T, L_inv)
    
    natural_matrix = - 0.5 * S_inv
    natural_vector = jnp.matmul(S_inv, mu)
    
    return {"natural_matrix": natural_matrix, "natural_vector": natural_vector}

def nat_to_xi(moments: dict) -> dict:
    
    natural_vector = moments["natural_vector"]
    natural_matrix = moments["natural_matrix"]
    
    m = natural_vector.shape[0]
    
    S_inv = -2 * natural_matrix
    S_inv += I(m) * 1e-6
    L = jnp.linalg.cholesky(S_inv)
    
    C = jsp.linalg.solve_triangular(L, I(m), lower=True)
    S = jnp.matmul(C.T, C)
    
    mu = jnp.matmul(S, natural_vector)
    
    expectation_vector = mu
    expectation_matrix = S + jnp.matmul(mu, mu.T)
    
    
    return {"expectation_vector": expectation_vector, "expectation_matrix": expectation_matrix}

We then would do:

In [None]:
likelihood = gpx.Gaussian(num_datapoints=n)
kernel = gpx.RBF()
prior = gpx.Prior(kernel=kernel)
p =  prior * likelihood


q = gpx.NaturalVariationalGaussian(prior=prior, inducing_inputs=z)

q = gpx.ExpectationVariationalGaussian(prior=prior, inducing_inputs=z)

svgp = gpx.StochasticVI(posterior=p, variational_family=q)


params, trainables, constrainers, unconstrainers = gpx.initialise(svgp)

params = gpx.transform(params, unconstrainers)

We then obtain our gradient functions as follows: 

In [None]:
nat_grads_fn, hyper_grads_fn = natural_gradients(svgp, D, constrainers, xi_to_nat = xi_to_nat, nat_to_xi = nat_to_xi)

And evaluate them e.g., like

In [None]:
nat_grads_fn(params=params, trainables=trainables, batch=D)

This gives a tuple of the loss function value and the gradient that is None for now, as we have not implemented it.

In reality, we won't see these as we have a training loop abstraction that could look something like this:

In [None]:
learned_params = fit_natgrads(svgp,
                                   params = params,
                                   trainables = trainables,   
                                   transformations = constrainers,
                                   train_data = Dbatched,
                                   n_iters = 5000,
                                   xi_to_nat= xi_to_nat,
                                   nat_to_xi = nat_to_xi
)

learned_params = gpx.transform(learned_params, constrainers)