diff --git a/examples/barycentres.pct.py b/examples/barycentres.pct.py index 9762cbfca..ade690825 100644 --- a/examples/barycentres.pct.py +++ b/examples/barycentres.pct.py @@ -28,7 +28,7 @@ # significantly more favourable uncertainty estimation. # -# %% +# %% vscode={"languageId": "python"} import typing as tp import jax @@ -99,7 +99,7 @@ # will be a sine function with a different vertical shift, periodicity, and quantity # of noise. -# %% +# %% vscode={"languageId": "python"} n = 100 n_test = 200 n_datasets = 5 @@ -135,7 +135,7 @@ # advice on selecting an appropriate kernel. -# %% +# %% vscode={"languageId": "python"} def fit_gp(x: jax.Array, y: jax.Array) -> tfd.MultivariateNormalFullCovariance: if y.ndim == 1: y = y.reshape(-1, 1) @@ -161,14 +161,15 @@ def fit_gp(x: jax.Array, y: jax.Array) -> tfd.MultivariateNormalFullCovariance: # ## Computing the barycentre # # In GPJax, the predictive distribution of a GP is given by a -# [Distrax](https://github.com/deepmind/distrax) distribution, making it +# [TensorFlow Probability](https://www.tensorflow.org/probability/api_docs/python/tfp/substrates/jax) +# distribution, making it # straightforward to extract the mean vector and covariance matrix of each GP for # learning a barycentre. We implement the fixed point scheme given in (3) in the # following cell by utilising Jax's `vmap` operator to speed up large matrix operations # using broadcasting in `tensordot`. -# %% +# %% vscode={"languageId": "python"} def sqrtm(A: jax.Array): return jnp.real(jsl.sqrtm(A)) @@ -198,7 +199,7 @@ def step(covariance_candidate: jax.Array, idx: None): # difference between the previous and current iteration that we can confirm by # inspecting the `sequence` array in the following cell. -# %% +# %% vscode={"languageId": "python"} weights = jnp.ones((n_datasets,)) / n_datasets means = jnp.stack([d.mean() for d in posterior_preds]) @@ -222,7 +223,7 @@ def step(covariance_candidate: jax.Array, idx: None): # uncertainty bands are sensible. -# %% +# %% vscode={"languageId": "python"} def plot( dist: tfd.MultivariateNormalTriL, ax, @@ -265,6 +266,6 @@ def plot( # %% [markdown] # ## System configuration -# %% +# %% vscode={"languageId": "python"} # %reload_ext watermark # %watermark -n -u -v -iv -w -a 'Thomas Pinder (edited by Daniel Dodd)' diff --git a/examples/classification.pct.py b/examples/classification.pct.py index 7c1bc9c78..d1b2cb1fa 100644 --- a/examples/classification.pct.py +++ b/examples/classification.pct.py @@ -96,9 +96,7 @@ # marginal log-likelihood. # %% [markdown] -# To begin we obtain an initial parameter state through the `initialise` callable (see -# the [regression notebook](https://gpjax.readthedocs.io/en/latest/nbs/regression.html)). -# We can obtain a MAP estimate by optimising the marginal log-likelihood with +# We can obtain a MAP estimate by optimising the log-posterior density with # Optax's optimisers. # %% @@ -179,7 +177,7 @@ # \log\tilde{p}(\boldsymbol{f}|\mathcal{D}) = \log\tilde{p}(\hat{\boldsymbol{f}}|\mathcal{D}) + \left[\nabla \log\tilde{p}({\boldsymbol{f}}|\mathcal{D})|_{\hat{\boldsymbol{f}}}\right]^{T} (\boldsymbol{f}-\hat{\boldsymbol{f}}) + \frac{1}{2} (\boldsymbol{f}-\hat{\boldsymbol{f}})^{T} \left[\nabla^2 \tilde{p}(\boldsymbol{y}|\boldsymbol{f})|_{\hat{\boldsymbol{f}}} \right] (\boldsymbol{f}-\hat{\boldsymbol{f}}) + \mathcal{O}(\lVert \boldsymbol{f} - \hat{\boldsymbol{f}} \rVert^3). # \end{align} # -# Now since $\nabla \log\tilde{p}({\boldsymbol{f}}|\mathcal{D})$ is zero at the mode, +# Since $\nabla \log\tilde{p}({\boldsymbol{f}}|\mathcal{D})$ is zero at the mode, # this suggests the following approximation # \begin{align} # \tilde{p}(\boldsymbol{f}|\mathcal{D}) \approx \log\tilde{p}(\hat{\boldsymbol{f}}|\mathcal{D}) \exp\left\{ \frac{1}{2} (\boldsymbol{f}-\hat{\boldsymbol{f}})^{T} \left[-\nabla^2 \tilde{p}(\boldsymbol{y}|\boldsymbol{f})|_{\hat{\boldsymbol{f}}} \right] (\boldsymbol{f}-\hat{\boldsymbol{f}}) \right\} @@ -297,7 +295,7 @@ def construct_laplace(test_inputs: Float[Array, "N D"]) -> tfd.MultivariateNorma # %% [markdown] # ## MCMC inference # -# At the high level, an MCMC sampler works by starting at an initial position and +# An MCMC sampler works by starting at an initial position and # drawing a sample from a cheap-to-simulate distribution known as the _proposal_. The # next step is to determine whether this sample could be considered a draw from the # posterior. We accomplish this using an _acceptance probability_ determined via the @@ -314,9 +312,7 @@ def construct_laplace(test_inputs: Float[Array, "N D"]) -> tfd.MultivariateNorma # Rather than implementing a suite of MCMC samplers, GPJax relies on MCMC-specific # libraries for sampling functionality. We focus on # [BlackJax](https://github.com/blackjax-devs/blackjax/) in this notebook, which we -# recommend adopting for general applications. However, we also support TensorFlow -# Probability as demonstrated in the -# [TensorFlow Probability Integration notebook](https://gpjax.readthedocs.io/en/latest/nbs/tfp_integration.html). +# recommend adopting for general applications. # # We'll use the No U-Turn Sampler (NUTS) implementation given in BlackJax for sampling. # For the interested reader, NUTS is a Hamiltonian Monte Carlo sampling scheme where diff --git a/examples/collapsed_vi.pct.py b/examples/collapsed_vi.pct.py index dea8e9a22..a7e0525ff 100644 --- a/examples/collapsed_vi.pct.py +++ b/examples/collapsed_vi.pct.py @@ -83,7 +83,8 @@ plt.show() # %% [markdown] -# Next we define the posterior model for the data. +# Next we define the true posterior model for the data - note that whilst we can define +# this, it is intractable to evaluate. # %% meanf = gpx.Constant() @@ -93,9 +94,10 @@ posterior = prior * likelihood # %% [markdown] -# We now define the SGPR model through `CollapsedVariationalGaussian`. Since the form -# of the collapsed optimal posterior depends on the Gaussian likelihood's observation -# noise, we pass this to the constructer. +# We now define the SGPR model through `CollapsedVariationalGaussian`. Through a +# set of inducing points $\boldsymbol{z}$ this object builds an approximation to the +# true posterior distribution. Consequently, we pass the true posterior and initial +# inducing points into the constructor as arguments. # %% q = gpx.CollapsedVariationalGaussian(posterior=posterior, inducing_inputs=z) @@ -130,8 +132,6 @@ # %% [markdown] # We show predictions of our model with the learned inducing points overlayed in grey. -# %% - # %% latent_dist = opt_posterior(xtest, train_data=D) predictive_dist = opt_posterior.posterior.likelihood(latent_dist) diff --git a/examples/deep_kernels.py b/examples/deep_kernels.py index dac62226f..7beb9fc38 100644 --- a/examples/deep_kernels.py +++ b/examples/deep_kernels.py @@ -26,8 +26,8 @@ # %% import typing as tp -from dataclasses import dataclass -from typing import Dict +from dataclasses import dataclass, field +from typing import Dict, Any import jax import jax.numpy as jnp @@ -39,6 +39,7 @@ from scipy.signal import sawtooth from flax import linen as nn from simple_pytree import static_field +import flax import gpjax as gpx import gpjax.kernels as jk @@ -92,19 +93,12 @@ # ### Implementation # # Although deep kernels are not currently supported natively in GPJax, defining one is -# straightforward as we now demonstrate. Using the base `AbstractKernel` object given -# in GPJax, we provide a mixin class named `_DeepKernelFunction` to facilitate the -# user supplying the neural network and base kernel of their choice. Kernel matrices +# straightforward as we now demonstrate. Inheriting from the base `AbstractKernel` +# in GPJax, we create the `DeepKernelFunction` object that allows the +# user to supply the neural network and base kernel of their choice. Kernel matrices # are then computed using the regular `gram` and `cross_covariance` functions. - # %% -import flax -from dataclasses import field -from typing import Any -from simple_pytree import static_field - - @dataclass class DeepKernelFunction(AbstractKernel): base_kernel: AbstractKernel = None @@ -132,12 +126,11 @@ def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, " # # With a deep kernel object created, we proceed to define a neural network. Here we # consider a small multi-layer perceptron with two linear hidden layers and ReLU -# activation functions between the layers. The first hidden layer contains 32 units, -# while the second layer contains 64 units. Finally, we'll make the output of our -# network a single unit. However, it would be possible to project our data into a -# $d-$dimensional space for $d>1$. In these instances, making the -# [base kernel ARD](https://gpjax.readthedocs.io/en/latest/nbs/kernels.html#Active-dimensions) -# would be sensible. +# activation functions between the layers. The first hidden layer contains 64 units, +# while the second layer contains 32 units. Finally, we'll make the output of our +# network a three units wide. The corresponding kernel that we define will then be of +# [ARD form](https://gpjax.readthedocs.io/en/latest/nbs/kernels.html#Active-dimensions) +# to allow for different lengthscales in each dimension of the feature space. # Users may wish to design more intricate network structures for more complex tasks, # which functionality is supported well in Haiku. @@ -164,8 +157,7 @@ def __call__(self, x): # # Having characterised the feature extraction network, we move to define a Gaussian # process parameterised by this deep kernel. We consider a third-order Matérn base -# kernel and assume a Gaussian likelihood. Parameters, trainability status and -# transformations are initialised in the usual manner. +# kernel and assume a Gaussian likelihood. # %% base_kernel = gpx.Matern52(active_dims=list(range(feature_space_dim))) @@ -186,14 +178,11 @@ def __call__(self, x): # [Optax](https://optax.readthedocs.io/en/latest/) for optimisation. In particular, we # showcase the ability to use a learning rate scheduler that decays the optimiser's # learning rate throughout the inference. We decrease the learning rate according to a -# half-cosine curve over 1000 iterations, providing us with large step sizes early in +# half-cosine curve over 700 iterations, providing us with large step sizes early in # the optimisation procedure before approaching more conservative values, ensuring we # do not step too far. We also consider a linear warmup, where the learning rate is # increased from 0 to 1 over 50 steps to get a reasonable initial learning rate value. -# %% -negative_mll = gpx.ConjugateMLL(negative=True) - # %% schedule = ox.warmup_cosine_decay_schedule( init_value=0.0, diff --git a/examples/kernels.pct.py b/examples/kernels.pct.py index 4d9130fb3..7f883a36a 100644 --- a/examples/kernels.pct.py +++ b/examples/kernels.pct.py @@ -17,15 +17,13 @@ # %% [markdown] # # Kernel Guide # -# In this guide, we introduce the kernels available in GPJax and demonstrate how to create custom ones. -# -# -# from typing import Dict - -from dataclasses import dataclass +# In this guide, we introduce the kernels available in GPJax and demonstrate how to +# create custom kernels. # %% -import distrax as dx +from typing import Dict + +from dataclasses import dataclass import jax.numpy as jnp import jax.random as jr import matplotlib.pyplot as plt @@ -51,6 +49,11 @@ # # * Matérn 1/2, 3/2 and 5/2. # * RBF (or squared exponential). +# * Rational quadratic. +# * Powered exponential. +# * Polynomial. +# * White noise +# * Linear. # * Polynomial. # * [Graph kernels](https://gpjax.readthedocs.io/en/latest/nbs/graph_kernels.html). # @@ -102,7 +105,8 @@ print(f"Lengthscales: {slice_kernel.lengthscale}") # %% [markdown] -# We'll now simulate some data and evaluate the kernel on the previously selected input dimensions. +# We'll now simulate some data and evaluate the kernel on the previously selected +# input dimensions. # %% # Inputs @@ -117,13 +121,13 @@ # # The product or sum of two positive definite matrices yields a positive # definite matrix. Consequently, summing or multiplying sets of kernels is a -# valid operation that can give rich kernel functions. In GPJax, sums of kernels -# can be created by applying the `+` operator as follows. +# valid operation that can give rich kernel functions. In GPJax, functionality for +# a sum kernel is provided by the `SumKernel` class. # %% k1 = gpx.kernels.RBF() k2 = gpx.kernels.Polynomial() -sum_k = gpx.kernels.ProductKernel(kernels=[k1, k2]) +sum_k = gpx.kernels.SumKernel(kernels=[k1, k2]) fig, ax = plt.subplots(ncols=3, figsize=(20, 5)) im0 = ax[0].matshow(k1.gram(x).to_dense()) @@ -135,7 +139,7 @@ fig.colorbar(im2, ax=ax[2]) # %% [markdown] -# Similarily, products of kernels can be created through the `*` operator. +# Similarily, products of kernels can be created through the `ProductKernel` class. # %% k3 = gpx.kernels.Matern32() @@ -153,7 +157,6 @@ fig.colorbar(im2, ax=ax[2]) fig.colorbar(im3, ax=ax[3]) - # %% [markdown] # ## Custom kernel # @@ -171,7 +174,8 @@ # ### Circular kernel # # When the underlying space is polar, typical Euclidean kernels such as Matérn -# kernels are insufficient at the boundary as discontinuities will be present. +# kernels are insufficient at the boundary where discontinuities will present +# themselves. # This is due to the fact that for a polar space $\lvert 0, 2\pi\rvert=0$ i.e., # the space wraps. Euclidean kernels have no mechanism in them to represent this # logic and will instead treat $0$ and $2\pi$ and elements far apart. Circular @@ -198,13 +202,10 @@ def angular_distance(x, y, c): @dataclass -class _Polar: +class Polar(gpx.kernels.AbstractKernel): period: float = static_field(2 * jnp.pi) tau: float = param_field(jnp.array([4.0]), bijector=tfb.Softplus(low=4.0)) - -@dataclass -class Polar(gpx.kernels.AbstractKernel, _Polar): def __post_init__(self): self.c = self.period / 2.0 @@ -219,35 +220,25 @@ def __call__( # %% [markdown] -# We unpack this now to make better sense of it. In the kernel's `__init__` -# function we simply specify the length of a single period. As the underlying -# domain is a circle, this is $2\pi$. Next we define the kernel's `__call__` -# function which is a direct implementation of Equation (1). Finally, we define -# the Kernel's parameter property which contains just one value $\tau$ that we -# initialise to 4 in the kernel's `__init__`. -# -# -# ### Custom Parameter Bijection -# -# The constraint on $\tau$ makes optimisation challenging with gradient descent. -# It would be much easier if we could instead parameterise $\tau$ to be on the -# real line. Fortunately, this can be taken care of with GPJax's `add parameter` -# function, only requiring us to define the parameter's name and matching -# bijection (either a Distrax of TensorFlow probability bijector). Under the -# hood, calling this function updates a configuration object to register this -# parameter and its corresponding transform. +# We unpack this now to make better sense of it. In the kernel's initialiser +# we specify the length of a single period. As the underlying +# domain is a circle, this is $2\pi$. Next, we define +# the Kernel's half-period parameter. As the kernel is a `dataclass` and `c` is +# function of `period`, we must define it in the `__post_init__` method. +# Finally, we define the kernel's `__call__` +# function which is a direct implementation of Equation (1). # -# To define a bijector here we'll make use of the `Lambda` operator given in -# Distrax. This lets us convert any regular Jax function into a bijection. Given -# that we require $\tau$ to be strictly greater than $4.$, we'll apply a -# [softplus -# transformation](https://jax.readthedocs.io/en/latest/_autosummary/jax.nn.softplus.html) -# where the lower bound is shifted by $4$. +# To constrain $\tau$ to be greater than 4, we use a `Softplus` bijector with a +# clipped lower bound of 4.0. This is done by specifying the `bijector` argument +# when we define the parameter field. # %% [markdown] # ### Using our polar kernel # -# We proceed to fit a GP with our custom circular kernel to a random sequence of points on a circle (see the [Regression notebook](https://gpjax.readthedocs.io/en/latest/nbs/regression.html) for further details on this process). +# We proceed to fit a GP with our custom circular kernel to a random sequence of +# points on a circle (see the +# [Regression notebook](https://gpjax.readthedocs.io/en/latest/nbs/regression.html) +# for further details on this process). # %% # Simulate data diff --git a/examples/regression.pct.py b/examples/regression.pct.py index e5204c7e7..3682173c9 100644 --- a/examples/regression.pct.py +++ b/examples/regression.pct.py @@ -43,7 +43,8 @@ # # $$\boldsymbol{y} \sim \mathcal{N} \left(\sin(4\boldsymbol{x}) + \cos(2 \boldsymbol{x}), \textbf{I} * 0.3^2 \right).$$ # -# We store our data $\mathcal{D}$ as a GPJax `Dataset` and create test inputs and labels for later. +# We store our data $\mathcal{D}$ as a GPJax `Dataset` and create test inputs and labels +# for later. # %% vscode={"languageId": "python"} n = 100 @@ -75,7 +76,6 @@ # observations $\mathcal{D}$ via Gaussian process regression. We begin by defining a # Gaussian process prior in the next section. -# %% [markdown] # ## Defining the prior # # A zero-mean Gaussian process (GP) places a prior distribution over real-valued @@ -98,16 +98,16 @@ # %% vscode={"languageId": "python"} kernel = gpx.kernels.RBF() -meanf = gpx.mean_functions.Constant(constant=0.0) -meanf = meanf.replace_trainable(constant=False) +meanf = gpx.mean_functions.Zero() prior = gpx.Prior(mean_function=meanf, kernel=kernel) # %% [markdown] # # The above construction forms the foundation for GPJax's models. Moreover, the GP prior -# we have just defined can be represented by a [Distrax](https://github.com/deepmind/distrax) +# we have just defined can be represented by a +# [TensorFlow Probability](https://www.tensorflow.org/probability/api_docs/python/tfp/substrates/jax) # multivariate Gaussian distribution. Such functionality enables trivial sampling, and -# mean and covariance evaluation of the GP. +# the evaluation of the GP's mean and covariance . # %% vscode={"languageId": "python"} prior_dist = prior.predict(xtest) @@ -160,23 +160,15 @@ # # ## Parameter state # -# So far, all of the objects that we've defined have been stateless. To give our model -# state, we can use the `initialise` function provided in GPJax. Upon calling this, a -# `ParameterState` class is returned that contains four dictionaries: -# -# | Dictionary | Description | -# |---|---| -# | `params` | Initial parameter values. | -# | `trainable` | Boolean dictionary that determines the training status of parameters (`True` for being trained and `False` otherwise). | -# | `bijectors` | Bijectors that can map parameters between the _unconstrained space_ and their original _constrained space_. | -# -# Further, upon calling `initialise`, we can state specific initial values for some, or -# all, of the parameters within our model. By default, the kernel lengthscale and -# variance and the likelihood's variance parameter are all initialised to 1. However, -# in the following cell, we'll demonstrate how the kernel lengthscale can be -# initialised to 0.5. +# As outlined in the [PyTrees](https://jax.readthedocs.io/en/latest/pytrees.html) +# documentation, parameters are contained within the model and for the leaves of the +# PyTree. Consequently, in this particular model, we have three parameters: the +# kernel lengthscale, kernel variance and the observation noise variance. Whilst +# we have initialised each of these to 1, we can learn Type 2 MLEs for each of +# these parameters by optimising the marginal log-likelihood (MLL). # %% vscode={"languageId": "python"} +# TODO: drop this once `step` is implemented into `Objectives` negative_mll = jit(gpx.objectives.ConjugateMLL(negative=True)) negative_mll(posterior, train_data=D) @@ -199,10 +191,12 @@ ) # %% [markdown] -# Similar to the `ParameterState` object above, the returned variable from the `fit` -# function is a class, namely an `InferenceState` object that contains the parameters' -# final values and a tracked array of the evaluation of our objective function -# throughout optimisation. +# The calling of `fit` returns two objects: the optimised posterior and a history of +# training losses. We can plot the training loss to see how the optimisation has +# progressed. + +# %% vscode={"languageId": "python"} +plt.plot(history) # %% [markdown] # ## Prediction diff --git a/examples/uncollapsed_vi.pct.py b/examples/uncollapsed_vi.pct.py index ed20e8d9e..32f48e91e 100644 --- a/examples/uncollapsed_vi.pct.py +++ b/examples/uncollapsed_vi.pct.py @@ -110,14 +110,11 @@ # We show a cost comparison between the approaches below, where $b$ is the mini-batch # size. # -# -# # | | GPs | sparse GPs | SVGP | # | -- | -- | -- | -- | # | Inference cost | $\mathcal{O}(n^3)$ | $\mathcal{O}(n m^2)$ | $\mathcal{O}(b m^2 + m^3)$ | # | Memory cost | $\mathcal{O}(n^2)$ | $\mathcal{O}(n m)$ | $\mathcal{O}(b m + m^2)$ | # -# # To apply SVGP inference to our dataset, we begin by initialising $m = 50$ equally # spaced inducing inputs $\boldsymbol{z}$ across our observed data's support. These # are depicted below via horizontal black lines.