Skip to content
This repository has been archived by the owner on Dec 20, 2023. It is now read-only.
/ GPJax Public archive

Minimal Gaussian process library in JAX with a simple (custom) approach to state management.

Notifications You must be signed in to change notification settings

aidanscannell/GPJax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

GPJax: Gaussian processes in JAX

https://github.com/aidanscannell/GPJax/actions/workflows/tests.yml/badge.svg?style=svg

UPDATE: This repo is no longer maintained. Please see JaxGaussianProcess/GPJax if you’re interested in a Gaussian process library in JAX.

GPJax is a minimal package for implementing Gaussian process models in Python using JAX. I have spent a lot of time using GPflow and I like how they implement their GP library, in particular, their focus on variational inference and how they implement GP conditionals. As such, this package takes a similar approach but offers the benefits (and ease) of having JAX under the hood.

GPJax keeps in the spirit of JAX (to provide an extensible system for composable function transformations) by implementing its features as pure functions. However, managing the parameters associated with the different components in Gaussian process methods (compositions of mean functions, kernels, variational parameters etc) makes a purely functional approach less appealing. I have experimented using both haiku and objax for state management and neither of them provided a level of abstraction I was happy with. As a result, GPJax now implements a simple approach to state management.

This package is a work in progress and functionality will be implemented when I need it for my research. If you like what I’m doing or have any recommendations then please reach out, or even better, get involved!

State Management

  1. gpjax relies upon its gpjax.Module class which all constructs with trainable state must inherit e.g. MeanFunction, Kernel, GPModel.
  2. Each gpjax.Module subclass must implement an get_params() method that returns a dictionary of the parameters associated with the object, e.g. for a stationary kernel kernel_param_dict = kernel.get_params() returns {'lengthscales': DeviceArray([1.], dtype=float64), 'variance': DeviceArray([1.], dtype=float64)}.
  3. JAX’s functionality depends upon pure functions and GPJax implements all functionality as standalone functions; classes then call the relevant functions.
    • For example, the squared exponential kernel class is really just a convenience class for storing its associated parameters and providing a wrapper around its associated function. This is handy because it means that we can either create an instance of the kernel class and call it, or we can call the standalone function,
      kernel = SquaredExponential()
      params = kernel.get_params()  # get the dictionary of parameters associated with kernel
      
      K_object = kernel(
          params, X1, X2
      )  # call the kernel object with the dictionary of parameters
      
      K_func = squared_exponential_cov_fn(
          params, X1, X2
      )  # call the standalone function with the dictionary of parameters
      assert K_object == K_func
              
  4. In general, classes provide wrappers around functions that accept a dictionary of parameters as the first argument,
    • gpjax.models.SVGP.predict_f(params, Xnew,...)
    • gpjax.kernels.SquaredExponential.K(params, X1, X2)
    • gpjax.models.SVGP.predict_f(params, Xnew,...)

Example

Let’s take a look at how state is managed in an example using a kernel.

import jax
from gpjax.kernels import SquaredExponential

key = jax.random.PRNGKey(0)

# create some dummy data
input_dim = 5
num_x1 = 100
num_x2 = 30
X1 = jax.random.uniform(key, shape=(num_x1, input_dim))
X2 = jax.random.uniform(key, shape=(num_x2, input_dim))

# create an instance of the SE kernel
kernel = SquaredExponential()
params = kernel.get_params()  # get the dictionary of parameters associated with kernel
K = kernel(params, X1, X2)  # call the kernel with the dictionary of parameters

Importantly, this is a pure function because the kernel’s hyperparameters (lengthscales and variances in this case) are passed as an argument. This means that we can easily compose it with any JAX transformation, for example, jax.jit and jax.jacfwd. We can calculate the derivative of our kernel w.r.t its hyperparameters using jax.jacfwd,

# create a function that returns the derivative of kernel w.r.t params (its first argument)
jac_kernel_wrt_params_fn = jax.jacfwd(kernel, argnums=0)
# evaluate the derivative of the kernel wrt its hyperparameters
jac_kernel_wrt_params = jac_kernel_wrt_params_fn(params, X1, X2)
print(jac_kernel_wrt_params["lengthscales"].shape)
print(jac_kernel_wrt_params["variance"].shape)

(100, 30, 1) (100, 30, 1)

Install GPJax

This is a Python package that should be installed into a virtual environment. Start by cloning this repo from Github:

git clone https://github.com/aidanscannell/GPJax.git

The package can then be installed into a virtual environment by adding it as a local dependency.

Install with Poetry

GPJax’s dependencies and packaging are being managed with Poetry, instead of other tools such as Pipenv. To install GPJax into an existing poetry environment add it as a dependency under [tool.poetry.dependencies] (in the ./pyproject.toml configuration file) with the following line:

gpjax = {path = "/path/to/gpjax"}

If you want to develop the gpjax codebase then set develop=true:

gpjax = {path = "/path/to/gpjax", develop=true}

The dependencies in a ./pyproject.toml file are resolved and installed with:

poetry install

If you do not require the development packages then you can opt to install without them,

poetry install --no-dev

Install with pip

Create a new virtualenv and activate it, for example,

mkvirtualenv --python=python3 gpjax-env
workon gpjax-env

cd into the root of this package and install it and its dependencies with,

pip install .

If you want to develop the gpjax codebase then install it in “editable” or “develop” mode with:

pip install -e .

TODOs

  • [ ] Implement mean functions
    • [X] Implement zero
    • [X] Implement constant
  • [ ] Implement kernels
    • [X] Implement base
    • [X] Implement squared exponential
    • [X] Implement multi output
      • [X] Implement separate independent
      • [ ] Implement shared independent
      • [ ] Implement LinearCoregionalization
  • [ ] Implement conditionals
    • [X] Implement single-output conditionals
    • [X] Implement multi-output conditionals
    • [X] Implement dispatch for single/multioutput
    • [ ] Implement dispatch for different inducing variables
  • [ ] Implement likelihoods
    • [X] Implement base likelihood
    • [X] Implement Gaussian likelihood
    • [ ] Implement Bernoulli likelihood
    • [ ] Implement Softmax likelihood
  • [ ] Implement gpjax.models
    • [X] Implement gpjax.models.GPModel
      • [X] predict_f
      • [X] predict_y
    • [ ] Implement gpjax.models.GPR
    • [ ] Implement gpjax.models.SVGP
      • [X] predict_f
      • [X] init_variational_parameters
      • [X] KL
      • [X] lower bound
  • [ ] Notebook examples
    • [ ] GPR regression
    • [X] SVGP regression
    • [ ] SVGP classification
  • [X] Tests for mean functions
    • [X] Tests for zero
    • [X] Tests for constant
  • [X] Tests for kernels
    • [X] Tests for squared exponential
    • [X] Tests for separate independent
  • [ ] Tests for conditionals
    • [ ] Tests for single output conditionals
    • [ ] Tests for multi output conditionals
  • [ ] Tests for likelihoods
    • [ ] Tests for gaussian likelihood
    • [ ] Tests for bernoulli likelihood
    • [ ] Tests for softmax likelihood
  • [ ] Tests for gpjax.models.SVGP
    • [X] Tests for gpjax.models.SVGP.predict_f
    • [X] Tests for gpjax.models.SVGP.prior_kl

About

Minimal Gaussian process library in JAX with a simple (custom) approach to state management.

Topics

Resources

Stars

Watchers

Forks

Packages

No packages published

Languages