Skip to content

Commit

Permalink
Documentation text updates
Browse files Browse the repository at this point in the history
  • Loading branch information
thomaspinder committed Apr 7, 2023
1 parent fdff318 commit 38b5bf4
Show file tree
Hide file tree
Showing 7 changed files with 84 additions and 116 deletions.
17 changes: 9 additions & 8 deletions examples/barycentres.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
# significantly more favourable uncertainty estimation.
#

# %%
# %% vscode={"languageId": "python"}
import typing as tp

import jax
Expand Down Expand Up @@ -99,7 +99,7 @@
# will be a sine function with a different vertical shift, periodicity, and quantity
# of noise.

# %%
# %% vscode={"languageId": "python"}
n = 100
n_test = 200
n_datasets = 5
Expand Down Expand Up @@ -135,7 +135,7 @@
# advice on selecting an appropriate kernel.


# %%
# %% vscode={"languageId": "python"}
def fit_gp(x: jax.Array, y: jax.Array) -> tfd.MultivariateNormalFullCovariance:
if y.ndim == 1:
y = y.reshape(-1, 1)
Expand All @@ -161,14 +161,15 @@ def fit_gp(x: jax.Array, y: jax.Array) -> tfd.MultivariateNormalFullCovariance:
# ## Computing the barycentre
#
# In GPJax, the predictive distribution of a GP is given by a
# [Distrax](https://github.com/deepmind/distrax) distribution, making it
# [TensorFlow Probability](https://www.tensorflow.org/probability/api_docs/python/tfp/substrates/jax)
# distribution, making it
# straightforward to extract the mean vector and covariance matrix of each GP for
# learning a barycentre. We implement the fixed point scheme given in (3) in the
# following cell by utilising Jax's `vmap` operator to speed up large matrix operations
# using broadcasting in `tensordot`.


# %%
# %% vscode={"languageId": "python"}
def sqrtm(A: jax.Array):
return jnp.real(jsl.sqrtm(A))

Expand Down Expand Up @@ -198,7 +199,7 @@ def step(covariance_candidate: jax.Array, idx: None):
# difference between the previous and current iteration that we can confirm by
# inspecting the `sequence` array in the following cell.

# %%
# %% vscode={"languageId": "python"}
weights = jnp.ones((n_datasets,)) / n_datasets

means = jnp.stack([d.mean() for d in posterior_preds])
Expand All @@ -222,7 +223,7 @@ def step(covariance_candidate: jax.Array, idx: None):
# uncertainty bands are sensible.


# %%
# %% vscode={"languageId": "python"}
def plot(
dist: tfd.MultivariateNormalTriL,
ax,
Expand Down Expand Up @@ -265,6 +266,6 @@ def plot(
# %% [markdown]
# ## System configuration

# %%
# %% vscode={"languageId": "python"}
# %reload_ext watermark
# %watermark -n -u -v -iv -w -a 'Thomas Pinder (edited by Daniel Dodd)'
12 changes: 4 additions & 8 deletions examples/classification.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,7 @@
# marginal log-likelihood.

# %% [markdown]
# To begin we obtain an initial parameter state through the `initialise` callable (see
# the [regression notebook](https://gpjax.readthedocs.io/en/latest/nbs/regression.html)).
# We can obtain a MAP estimate by optimising the marginal log-likelihood with
# We can obtain a MAP estimate by optimising the log-posterior density with
# Optax's optimisers.

# %%
Expand Down Expand Up @@ -179,7 +177,7 @@
# \log\tilde{p}(\boldsymbol{f}|\mathcal{D}) = \log\tilde{p}(\hat{\boldsymbol{f}}|\mathcal{D}) + \left[\nabla \log\tilde{p}({\boldsymbol{f}}|\mathcal{D})|_{\hat{\boldsymbol{f}}}\right]^{T} (\boldsymbol{f}-\hat{\boldsymbol{f}}) + \frac{1}{2} (\boldsymbol{f}-\hat{\boldsymbol{f}})^{T} \left[\nabla^2 \tilde{p}(\boldsymbol{y}|\boldsymbol{f})|_{\hat{\boldsymbol{f}}} \right] (\boldsymbol{f}-\hat{\boldsymbol{f}}) + \mathcal{O}(\lVert \boldsymbol{f} - \hat{\boldsymbol{f}} \rVert^3).
# \end{align}
#
# Now since $\nabla \log\tilde{p}({\boldsymbol{f}}|\mathcal{D})$ is zero at the mode,
# Since $\nabla \log\tilde{p}({\boldsymbol{f}}|\mathcal{D})$ is zero at the mode,
# this suggests the following approximation
# \begin{align}
# \tilde{p}(\boldsymbol{f}|\mathcal{D}) \approx \log\tilde{p}(\hat{\boldsymbol{f}}|\mathcal{D}) \exp\left\{ \frac{1}{2} (\boldsymbol{f}-\hat{\boldsymbol{f}})^{T} \left[-\nabla^2 \tilde{p}(\boldsymbol{y}|\boldsymbol{f})|_{\hat{\boldsymbol{f}}} \right] (\boldsymbol{f}-\hat{\boldsymbol{f}}) \right\}
Expand Down Expand Up @@ -297,7 +295,7 @@ def construct_laplace(test_inputs: Float[Array, "N D"]) -> tfd.MultivariateNorma
# %% [markdown]
# ## MCMC inference
#
# At the high level, an MCMC sampler works by starting at an initial position and
# An MCMC sampler works by starting at an initial position and
# drawing a sample from a cheap-to-simulate distribution known as the _proposal_. The
# next step is to determine whether this sample could be considered a draw from the
# posterior. We accomplish this using an _acceptance probability_ determined via the
Expand All @@ -314,9 +312,7 @@ def construct_laplace(test_inputs: Float[Array, "N D"]) -> tfd.MultivariateNorma
# Rather than implementing a suite of MCMC samplers, GPJax relies on MCMC-specific
# libraries for sampling functionality. We focus on
# [BlackJax](https://github.com/blackjax-devs/blackjax/) in this notebook, which we
# recommend adopting for general applications. However, we also support TensorFlow
# Probability as demonstrated in the
# [TensorFlow Probability Integration notebook](https://gpjax.readthedocs.io/en/latest/nbs/tfp_integration.html).
# recommend adopting for general applications.
#
# We'll use the No U-Turn Sampler (NUTS) implementation given in BlackJax for sampling.
# For the interested reader, NUTS is a Hamiltonian Monte Carlo sampling scheme where
Expand Down
12 changes: 6 additions & 6 deletions examples/collapsed_vi.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@
plt.show()

# %% [markdown]
# Next we define the posterior model for the data.
# Next we define the true posterior model for the data - note that whilst we can define
# this, it is intractable to evaluate.

# %%
meanf = gpx.Constant()
Expand All @@ -93,9 +94,10 @@
posterior = prior * likelihood

# %% [markdown]
# We now define the SGPR model through `CollapsedVariationalGaussian`. Since the form
# of the collapsed optimal posterior depends on the Gaussian likelihood's observation
# noise, we pass this to the constructer.
# We now define the SGPR model through `CollapsedVariationalGaussian`. Through a
# set of inducing points $\boldsymbol{z}$ this object builds an approximation to the
# true posterior distribution. Consequently, we pass the true posterior and initial
# inducing points into the constructor as arguments.

# %%
q = gpx.CollapsedVariationalGaussian(posterior=posterior, inducing_inputs=z)
Expand Down Expand Up @@ -130,8 +132,6 @@
# %% [markdown]
# We show predictions of our model with the learned inducing points overlayed in grey.

# %%

# %%
latent_dist = opt_posterior(xtest, train_data=D)
predictive_dist = opt_posterior.posterior.likelihood(latent_dist)
Expand Down
37 changes: 13 additions & 24 deletions examples/deep_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@

# %%
import typing as tp
from dataclasses import dataclass
from typing import Dict
from dataclasses import dataclass, field
from typing import Dict, Any

import jax
import jax.numpy as jnp
Expand All @@ -39,6 +39,7 @@
from scipy.signal import sawtooth
from flax import linen as nn
from simple_pytree import static_field
import flax

import gpjax as gpx
import gpjax.kernels as jk
Expand Down Expand Up @@ -92,19 +93,12 @@
# ### Implementation
#
# Although deep kernels are not currently supported natively in GPJax, defining one is
# straightforward as we now demonstrate. Using the base `AbstractKernel` object given
# in GPJax, we provide a mixin class named `_DeepKernelFunction` to facilitate the
# user supplying the neural network and base kernel of their choice. Kernel matrices
# straightforward as we now demonstrate. Inheriting from the base `AbstractKernel`
# in GPJax, we create the `DeepKernelFunction` object that allows the
# user to supply the neural network and base kernel of their choice. Kernel matrices
# are then computed using the regular `gram` and `cross_covariance` functions.


# %%
import flax
from dataclasses import field
from typing import Any
from simple_pytree import static_field


@dataclass
class DeepKernelFunction(AbstractKernel):
base_kernel: AbstractKernel = None
Expand Down Expand Up @@ -132,12 +126,11 @@ def __call__(self, x: Float[Array, "D"], y: Float[Array, "D"]) -> Float[Array, "
#
# With a deep kernel object created, we proceed to define a neural network. Here we
# consider a small multi-layer perceptron with two linear hidden layers and ReLU
# activation functions between the layers. The first hidden layer contains 32 units,
# while the second layer contains 64 units. Finally, we'll make the output of our
# network a single unit. However, it would be possible to project our data into a
# $d-$dimensional space for $d>1$. In these instances, making the
# [base kernel ARD](https://gpjax.readthedocs.io/en/latest/nbs/kernels.html#Active-dimensions)
# would be sensible.
# activation functions between the layers. The first hidden layer contains 64 units,
# while the second layer contains 32 units. Finally, we'll make the output of our
# network a three units wide. The corresponding kernel that we define will then be of
# [ARD form](https://gpjax.readthedocs.io/en/latest/nbs/kernels.html#Active-dimensions)
# to allow for different lengthscales in each dimension of the feature space.
# Users may wish to design more intricate network structures for more complex tasks,
# which functionality is supported well in Haiku.

Expand All @@ -164,8 +157,7 @@ def __call__(self, x):
#
# Having characterised the feature extraction network, we move to define a Gaussian
# process parameterised by this deep kernel. We consider a third-order Matérn base
# kernel and assume a Gaussian likelihood. Parameters, trainability status and
# transformations are initialised in the usual manner.
# kernel and assume a Gaussian likelihood.

# %%
base_kernel = gpx.Matern52(active_dims=list(range(feature_space_dim)))
Expand All @@ -186,14 +178,11 @@ def __call__(self, x):
# [Optax](https://optax.readthedocs.io/en/latest/) for optimisation. In particular, we
# showcase the ability to use a learning rate scheduler that decays the optimiser's
# learning rate throughout the inference. We decrease the learning rate according to a
# half-cosine curve over 1000 iterations, providing us with large step sizes early in
# half-cosine curve over 700 iterations, providing us with large step sizes early in
# the optimisation procedure before approaching more conservative values, ensuring we
# do not step too far. We also consider a linear warmup, where the learning rate is
# increased from 0 to 1 over 50 steps to get a reasonable initial learning rate value.

# %%
negative_mll = gpx.ConjugateMLL(negative=True)

# %%
schedule = ox.warmup_cosine_decay_schedule(
init_value=0.0,
Expand Down
75 changes: 33 additions & 42 deletions examples/kernels.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,13 @@
# %% [markdown]
# # Kernel Guide
#
# In this guide, we introduce the kernels available in GPJax and demonstrate how to create custom ones.
#
#
# from typing import Dict

from dataclasses import dataclass
# In this guide, we introduce the kernels available in GPJax and demonstrate how to
# create custom kernels.

# %%
import distrax as dx
from typing import Dict

from dataclasses import dataclass
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
Expand All @@ -51,6 +49,11 @@
#
# * Matérn 1/2, 3/2 and 5/2.
# * RBF (or squared exponential).
# * Rational quadratic.
# * Powered exponential.
# * Polynomial.
# * White noise
# * Linear.
# * Polynomial.
# * [Graph kernels](https://gpjax.readthedocs.io/en/latest/nbs/graph_kernels.html).
#
Expand Down Expand Up @@ -102,7 +105,8 @@
print(f"Lengthscales: {slice_kernel.lengthscale}")

# %% [markdown]
# We'll now simulate some data and evaluate the kernel on the previously selected input dimensions.
# We'll now simulate some data and evaluate the kernel on the previously selected
# input dimensions.

# %%
# Inputs
Expand All @@ -117,13 +121,13 @@
#
# The product or sum of two positive definite matrices yields a positive
# definite matrix. Consequently, summing or multiplying sets of kernels is a
# valid operation that can give rich kernel functions. In GPJax, sums of kernels
# can be created by applying the `+` operator as follows.
# valid operation that can give rich kernel functions. In GPJax, functionality for
# a sum kernel is provided by the `SumKernel` class.

# %%
k1 = gpx.kernels.RBF()
k2 = gpx.kernels.Polynomial()
sum_k = gpx.kernels.ProductKernel(kernels=[k1, k2])
sum_k = gpx.kernels.SumKernel(kernels=[k1, k2])

fig, ax = plt.subplots(ncols=3, figsize=(20, 5))
im0 = ax[0].matshow(k1.gram(x).to_dense())
Expand All @@ -135,7 +139,7 @@
fig.colorbar(im2, ax=ax[2])

# %% [markdown]
# Similarily, products of kernels can be created through the `*` operator.
# Similarily, products of kernels can be created through the `ProductKernel` class.

# %%
k3 = gpx.kernels.Matern32()
Expand All @@ -153,7 +157,6 @@
fig.colorbar(im2, ax=ax[2])
fig.colorbar(im3, ax=ax[3])


# %% [markdown]
# ## Custom kernel
#
Expand All @@ -171,7 +174,8 @@
# ### Circular kernel
#
# When the underlying space is polar, typical Euclidean kernels such as Matérn
# kernels are insufficient at the boundary as discontinuities will be present.
# kernels are insufficient at the boundary where discontinuities will present
# themselves.
# This is due to the fact that for a polar space $\lvert 0, 2\pi\rvert=0$ i.e.,
# the space wraps. Euclidean kernels have no mechanism in them to represent this
# logic and will instead treat $0$ and $2\pi$ and elements far apart. Circular
Expand All @@ -198,13 +202,10 @@ def angular_distance(x, y, c):


@dataclass
class _Polar:
class Polar(gpx.kernels.AbstractKernel):
period: float = static_field(2 * jnp.pi)
tau: float = param_field(jnp.array([4.0]), bijector=tfb.Softplus(low=4.0))


@dataclass
class Polar(gpx.kernels.AbstractKernel, _Polar):
def __post_init__(self):
self.c = self.period / 2.0

Expand All @@ -219,35 +220,25 @@ def __call__(


# %% [markdown]
# We unpack this now to make better sense of it. In the kernel's `__init__`
# function we simply specify the length of a single period. As the underlying
# domain is a circle, this is $2\pi$. Next we define the kernel's `__call__`
# function which is a direct implementation of Equation (1). Finally, we define
# the Kernel's parameter property which contains just one value $\tau$ that we
# initialise to 4 in the kernel's `__init__`.
#
#
# ### Custom Parameter Bijection
#
# The constraint on $\tau$ makes optimisation challenging with gradient descent.
# It would be much easier if we could instead parameterise $\tau$ to be on the
# real line. Fortunately, this can be taken care of with GPJax's `add parameter`
# function, only requiring us to define the parameter's name and matching
# bijection (either a Distrax of TensorFlow probability bijector). Under the
# hood, calling this function updates a configuration object to register this
# parameter and its corresponding transform.
# We unpack this now to make better sense of it. In the kernel's initialiser
# we specify the length of a single period. As the underlying
# domain is a circle, this is $2\pi$. Next, we define
# the Kernel's half-period parameter. As the kernel is a `dataclass` and `c` is
# function of `period`, we must define it in the `__post_init__` method.
# Finally, we define the kernel's `__call__`
# function which is a direct implementation of Equation (1).
#
# To define a bijector here we'll make use of the `Lambda` operator given in
# Distrax. This lets us convert any regular Jax function into a bijection. Given
# that we require $\tau$ to be strictly greater than $4.$, we'll apply a
# [softplus
# transformation](https://jax.readthedocs.io/en/latest/_autosummary/jax.nn.softplus.html)
# where the lower bound is shifted by $4$.
# To constrain $\tau$ to be greater than 4, we use a `Softplus` bijector with a
# clipped lower bound of 4.0. This is done by specifying the `bijector` argument
# when we define the parameter field.

# %% [markdown]
# ### Using our polar kernel
#
# We proceed to fit a GP with our custom circular kernel to a random sequence of points on a circle (see the [Regression notebook](https://gpjax.readthedocs.io/en/latest/nbs/regression.html) for further details on this process).
# We proceed to fit a GP with our custom circular kernel to a random sequence of
# points on a circle (see the
# [Regression notebook](https://gpjax.readthedocs.io/en/latest/nbs/regression.html)
# for further details on this process).

# %%
# Simulate data
Expand Down
Loading

0 comments on commit 38b5bf4

Please sign in to comment.