# Geometric Kernels

[GPJax](https://github.com/JaxGaussianProcesses/GPJax) is a Python package for working wit Gaussian processes in JAX. This notebook highlights how Geometric Kernels can be integrated into the functionality provided in GPJax.

In [1]:
import gpjax as gpx
import jax.numpy as jnp
import jax.random as jr
from jax.config import config
import geometric_kernels.jax 
from geometric_kernels.frontends.jax.gpjax import GeometricKernel
from geometric_kernels.kernels import MaternKarhunenLoeveKernel
from geometric_kernels.spaces import Mesh
import jax
import meshzoo         

config.update("jax_enable_x64", True)
key = jr.PRNGKey(123)

2022-12-16 09:44:56.819201: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2022-12-16 09:44:57.453849: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-12-16 09:44:57.453937: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
INFO:root:Using numpy backend


## Data

We'll now define a dataset that we'll seek to model. The data support used in this example is an icose sphere from the [MeshZoo](https://github.com/meshpro/meshzoo) library. A Matérn kernel is then defined on the sphere and a single draw is taken from the Gaussian process' prior distribution at a random set of points to give us a response variable.

In [2]:
resolution = 40
num_data = 25
vertices, faces = meshzoo.icosa_sphere(resolution)
mesh = Mesh(vertices, faces)

truncation_level = 20
base_kernel = MaternKarhunenLoeveKernel(mesh, truncation_level)
geometric_kernel = GeometricKernel(base_kernel)

init_params = geometric_kernel._initialise_params(key)


def get_data():
    _X = jr.randint(key, minval=0, maxval=mesh.num_vertices, shape=(num_data, 1))
    _K = geometric_kernel.gram(init_params, _X)
    _L = jnp.linalg.cholesky(_K.to_dense() + jnp.eye(_K.shape[0]) * 1e-6)
    _y = _L @ jr.normal(key, (num_data,))
    return _X, _y


X, y = get_data()
X_test = jnp.arange(mesh.num_vertices).reshape(mesh.num_vertices, 1)

In [3]:
resolution = 40
num_data = 25
vertices, faces = meshzoo.icosa_sphere(resolution)
mesh = Mesh(vertices, faces)

truncation_level = 20
base_kernel = MaternKarhunenLoeveKernel(mesh, truncation_level)
geometric_kernel = GeometricKernel(base_kernel)

init_params = geometric_kernel._initialise_params(key)


def get_data():
    _X = jr.randint(key, minval=0, maxval=mesh.num_vertices, shape=(num_data, 1))
    _K = geometric_kernel.gram(init_params, _X)
    _L = jnp.linalg.cholesky(_K.to_dense() + jnp.eye(_K.shape[0]) * 1e-6)
    _y = _L @ jr.normal(key, (num_data,))
    return _X, _y


X, y = get_data()
X_test = jnp.arange(mesh.num_vertices).reshape(mesh.num_vertices, 1)

## Model specification

A model can now be defined. We'll purposefully keep this section brief as the workflow is identical to that of a regular Gaussian process regression model that is detailed in [full](https://gpjax.readthedocs.io/en/latest/examples/regression.html).

In [4]:
data = gpx.Dataset(X=X, y=y.reshape(-1, 1))

prior = gpx.Prior(kernel=geometric_kernel)
gpx.config.add_parameter("nu", gpx.config.Softplus)

likelihood = gpx.likelihoods.Gaussian(num_datapoints=num_data)

posterior = likelihood * prior

As with a regular conjugate Gaussian process, the marginal log-likelihood is tractable and can be evaluated using the posterior's `marginal_log_likelihood` method.

In [5]:
params, _, _ = gpx.initialise(posterior, key).unpack()

posterior.marginal_log_likelihood(data)(params)

DeviceArray(-22.975046, dtype=float32)

Derivatives of the marginal log-likelihood can be taken.

In [6]:
grads = jax.grad(posterior.marginal_log_likelihood(data, negative=True))(params)
print(grads)

{'kernel': {'lengthscale': DeviceArray([0.00352209], dtype=float32), 'nu': DeviceArray([-0.00406692], dtype=float32)}, 'likelihood': {'obs_noise': DeviceArray([12.498416], dtype=float32)}, 'mean_function': {}}


Finally, the predictive posterior can be computed for making predictions at unseen points. Evaluating the predictive posterior distribution returns a multivariate Gaussian distribution for which we can compute the posterior mean and variance as follows.

In [7]:
predictive_posterior = posterior.predict(params, data)(X_test)

mu = predictive_posterior.mean()
sigma2 = predictive_posterior.variance()