Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Docs pytree #211

Merged
merged 48 commits into from
Apr 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
75b196c
Add module.
daniel-dodd Mar 27, 2023
fb2d6d1
Add stuff.
daniel-dodd Mar 27, 2023
1e3591f
Fix mytree links
thomaspinder Mar 27, 2023
54f7183
Merge branch 'Module' into refactor_kernels
thomaspinder Mar 27, 2023
f861244
Tests fixed
thomaspinder Mar 27, 2023
cf6443e
Refactor tests
thomaspinder Mar 27, 2023
623be42
Reformat
thomaspinder Mar 27, 2023
309d0c4
Module methods' return type is Self (the subclass)
frazane Mar 28, 2023
6859304
refactor matern12
frazane Mar 28, 2023
ace9e5f
stationary kernels refactoring
frazane Mar 28, 2023
b609c0b
tests draft
frazane Mar 29, 2023
f39d07b
spectral density as property (RBF)
frazane Mar 29, 2023
9823543
add jitter in gram test
frazane Mar 29, 2023
cc0c220
fix default engine for white kernel
frazane Mar 29, 2023
2d06b50
fix jaxtyping hints
frazane Mar 29, 2023
48d7fdb
Merge pull request #206 from JaxGaussianProcesses/refactor_kernels_zaf
daniel-dodd Mar 29, 2023
aee2a12
Fix bugs on the base.
daniel-dodd Mar 30, 2023
a669eb9
Refactored variational families.
daniel-dodd Mar 28, 2023
8051384
Update likelihoods and refactor collapsed variational family
daniel-dodd Mar 29, 2023
4c1c965
Update likelihoods.
daniel-dodd Mar 29, 2023
cacac4e
Add fit.py and test.
daniel-dodd Mar 30, 2023
e07ffe5
Remove types and add dataset.
daniel-dodd Mar 30, 2023
57daf09
Commit.
daniel-dodd Apr 2, 2023
3c6aaa0
Use tfb bijectors, update base.
daniel-dodd Apr 2, 2023
2ec8727
Minimal passing tests except for eigen and basis work
daniel-dodd Apr 2, 2023
cd1cce7
Improve dataset tests.
daniel-dodd Apr 3, 2023
aa356e5
Update fit testing.
daniel-dodd Apr 3, 2023
5877698
Refactor docs
thomaspinder Apr 3, 2023
bacfa64
Classification nb
thomaspinder Apr 3, 2023
18ab3a9
Collapsed VI
thomaspinder Apr 4, 2023
2557fe8
Sampling fixed
thomaspinder Apr 4, 2023
395f295
Sampling fixed
thomaspinder Apr 4, 2023
0700628
Graph kernel
thomaspinder Apr 4, 2023
5d29d7f
RFF refactored
thomaspinder Apr 3, 2023
d955c98
Graph kernel refactored
thomaspinder Apr 3, 2023
d6a045b
Fix imports and switch to FillTriangular for now, to avoid dtype error.
daniel-dodd Apr 4, 2023
10fd335
Merge pull request #208 from JaxGaussianProcesses/new_refactor_kernels
daniel-dodd Apr 4, 2023
6f4b684
Merge branch 'refactor_gpjax_to_pytrees' into docs_pytree
thomaspinder Apr 5, 2023
80abfd2
Docs complete
thomaspinder Apr 5, 2023
e7472fa
Docs outline
thomaspinder Apr 6, 2023
1616e1c
Push fix.
daniel-dodd Apr 6, 2023
2496ec6
DKL fixed
thomaspinder Apr 7, 2023
8e33bc0
Docs up-to-date
thomaspinder Apr 7, 2023
aa57156
Add flax to reqs
thomaspinder Apr 7, 2023
3f21cb9
Drop beartype refs
thomaspinder Apr 7, 2023
070dda3
Fix link fn. tests
thomaspinder Apr 7, 2023
fdff318
Add flax deps
thomaspinder Apr 7, 2023
38b5bf4
Documentation text updates
thomaspinder Apr 7, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ jobs:
pip install -e .
pip install -e .[dev]
pytest --cov=./ --cov-report=xml

- name: Upload coverage to Codecov
uses: codecov/codecov-action@v1
with:
Expand Down
13 changes: 2 additions & 11 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ installation
design
contributing
examples/intro_to_gps
examples/pytree
```

```{toctree}
Expand All @@ -74,8 +75,7 @@ examples/uncollapsed_vi
examples/collapsed_vi
examples/graph_kernels
examples/barycentres
examples/haiku
examples/tfp_integration
examples/deep_kernels
```

```{toctree}
Expand All @@ -88,15 +88,6 @@ examples/kernels
examples/yacht
```

```{toctree}
---
maxdepth: 1
caption: Experimental
hidden:
---
examples/natgrads
```

```{toctree}
---
maxdepth: 1
Expand Down
2 changes: 1 addition & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ watermark
sphinxext-opengraph
blackjax>=0.8.2
jaxopt
dm-haiku
ipywidgets
pandas
scikit-learn
flax
# Install GPJax istself
.
17 changes: 17 additions & 0 deletions docs/sharp_bits.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# 🔪 The sharp bits

## Pseudo-randomness

Can briefly acknowledge and then point to the Jax docs for more information.

## Float64

The need for Float64 when inverting the Gram matrix

## Positive-definiteness

The need for jitter in the kernel Gram matrix

## Slow-to-evaluate

More than several thousand data points will require the use of inducing points - don't try and use the ConjugateMLL objective on a million data points.
135 changes: 94 additions & 41 deletions examples/barycentres.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,28 @@
# %% [markdown]
# # Gaussian Processes Barycentres
#
# In this notebook we'll give an implementation of <strong data-cite="mallasto2017learning"></strong>. In this work, the existence of a Wasserstein barycentre between a collection of Gaussian processes is proven. When faced with trying to _average_ a set of probability distributions, the Wasserstein barycentre is an attractive choice as it enables uncertainty amongst the individual distributions to be incorporated into the averaged distribution. When compared to a naive _mean of means_ and _mean of variances_ approach to computing the average probability distributions, it can be seen that Wasserstein barycentres offer significantly more favourable uncertainty estimation.
# In this notebook we'll give an implementation of
# <strong data-cite="mallasto2017learning"></strong>. In this work, the existence of a
# Wasserstein barycentre between a collection of Gaussian processes is proven. When
# faced with trying to _average_ a set of probability distributions, the Wasserstein
# barycentre is an attractive choice as it enables uncertainty amongst the individual
# distributions to be incorporated into the averaged distribution. When compared to a
# naive _mean of means_ and _mean of variances_ approach to computing the average
# probability distributions, it can be seen that Wasserstein barycentres offer
# significantly more favourable uncertainty estimation.
#

# %%
# %% vscode={"languageId": "python"}
import typing as tp

import distrax as dx
import jax
import jax.numpy as jnp
import jax.random as jr
import jax.scipy.linalg as jsl
import matplotlib.pyplot as plt
import optax as ox
import tensorflow_probability.substrates.jax.distributions as tfd
from jax.config import config
from jaxutils import Dataset
import gpjax.kernels as jk

import gpjax as gpx

Expand All @@ -45,31 +51,55 @@
#
# ### Wasserstein distance
#
# The 2-Wasserstein distance metric between two probability measures $\mu$ and $\nu$ quantifies the minimal cost required to transport the unit mass from $\mu$ to $\nu$, or vice-versa. Typically, computing this metric requires solving a linear program. However, when $\mu$ and $\nu$ both belong to the family of multivariate Gaussian distributions, the solution is analytically given by
# The 2-Wasserstein distance metric between two probability measures $\mu$ and $\nu$
# quantifies the minimal cost required to transport the unit mass from $\mu$ to $\nu$,
# or vice-versa. Typically, computing this metric requires solving a linear program.
# However, when $\mu$ and $\nu$ both belong to the family of multivariate Gaussian
# distributions, the solution is analytically given by
# $$W_2^2(\mu, \nu) = \lVert m_1- m_2 \rVert^2_2 + \operatorname{Tr}(S_1 + S_2 - 2(S_1^{1/2}S_2S_1^{1/2})^{1/2}),$$
# where $\mu \sim \mathcal{N}(m_1, S_1)$ and $\nu\sim\mathcal{N}(m_2, S_2)$.
#
# ### Wasserstein barycentre
#
# For a collection of $T$ measures $\lbrace\mu_i\rbrace_{t=1}^T \in \mathcal{P}_2(\theta)$, the Wasserstein barycentre $\bar{\mu}$ is the measure that minimises the average Wasserstein distance to all other measures in the set. More formally, the Wasserstein barycentre is the Fréchet mean on a Wasserstein space that we can write as
# For a collection of $T$ measures
# $\lbrace\mu_i\rbrace_{t=1}^T \in \mathcal{P}_2(\theta)$, the Wasserstein barycentre
# $\bar{\mu}$ is the measure that minimises the average Wasserstein distance to all
# other measures in the set. More formally, the Wasserstein barycentre is the Fréchet
# mean on a Wasserstein space that we can write as
# $$\bar{\mu} = \operatorname{argmin}_{\mu\in\mathcal{P}_2(\theta)}\sum_{t=1}^T \alpha_t W_2^2(\mu, \mu_t),$$
# where $\alpha\in\bbR^T$ is a weight vector that sums to 1.
#
# As with the Wasserstein distance, identifying the Wasserstein barycentre $\bar{\mu}$ is often an computationally demanding optimisation problem. However, when all the measures admit a multivariate Gaussian density, the barycentre $\bar{\mu} = \mathcal{N}(\bar{m}, \bar{S})$ has analytical solutions
# As with the Wasserstein distance, identifying the Wasserstein barycentre $\bar{\mu}$
# is often an computationally demanding optimisation problem. However, when all the
# measures admit a multivariate Gaussian density, the barycentre
# $\bar{\mu} = \mathcal{N}(\bar{m}, \bar{S})$ has analytical solutions
# $$\bar{m} = \sum_{t=1}^T \alpha_t m_t\,, \quad \bar{S}=\sum_{t=1}^T\alpha_t (\bar{S}^{1/2}S_t\bar{S}^{1/2})^{1/2}\,. \qquad (\star)$$
# Identifying $\bar{S}$ is achieved through a fixed-point iterative update.
#
# ## Barycentre of Gaussian processes
#
# It was shown in <strong data-cite="mallasto2017learning"></strong> that the barycentre $\bar{f}$ of a collection of Gaussian processes $\lbrace f_i\rbrace_{i=1}^T$ such that $f_i \sim \mathcal{GP}(m_i, K_i)$ can be found using the same solutions as in $(\star)$. For a full theoretical understanding, we recommend reading the original paper. However, the central argument to this result is that one can first show that the barycentre GP $\bar{f}\sim\mathcal{GP}(\bar{m}, \bar{S})$ is non-degenerate for any finite set of GPs $\lbrace f_t\rbrace_{t=1}^T$ i.e., $T<\infty$. With this established, one can show that for a $n$-dimensional finite Gaussian distribution $f_{i,n}$, the Wasserstein metric between any two Gaussian distributions $f_{i, n}, f_{j, n}$ converges to the Wasserstein metric between GPs as $n\to\infty$.
# It was shown in <strong data-cite="mallasto2017learning"></strong> that the
# barycentre $\bar{f}$ of a collection of Gaussian processes
# $\lbrace f_i\rbrace_{i=1}^T$ such that $f_i \sim \mathcal{GP}(m_i, K_i)$ can be
# found using the same solutions as in $(\star)$. For a full theoretical understanding,
# we recommend reading the original paper. However, the central argument to this result
# is that one can first show that the barycentre GP
# $\bar{f}\sim\mathcal{GP}(\bar{m}, \bar{S})$ is non-degenerate for any finite set of
# GPs $\lbrace f_t\rbrace_{t=1}^T$ i.e., $T<\infty$. With this established, one can
# show that for a $n$-dimensional finite Gaussian distribution $f_{i,n}$, the
# Wasserstein metric between any two Gaussian distributions $f_{i, n}, f_{j, n}$
# converges to the Wasserstein metric between GPs as $n\to\infty$.
#
# In this notebook, we will demonstrate how this can be achieved in GPJax.
#
# ## Dataset
#
# We'll simulate five datasets and develop a Gaussian process posterior before identifying the Gaussian process barycentre at a set of test points. Each dataset will be a sine function with a different vertical shift, periodicity, and quantity of noise.
# We'll simulate five datasets and develop a Gaussian process posterior before
# identifying the Gaussian process barycentre at a set of test points. Each dataset
# will be a sine function with a different vertical shift, periodicity, and quantity
# of noise.

# %%
# %% vscode={"languageId": "python"}
n = 100
n_test = 200
n_datasets = 5
Expand All @@ -96,46 +126,56 @@
# %% [markdown]
# ## Learning a posterior distribution
#
# We'll now independently learn Gaussian process posterior distributions for each dataset. We won't spend any time here discussing how GP hyperparameters are optimised. For advice on achieving this, see the [Regression notebook](https://gpjax.readthedocs.io/en/latest/nbs/regression.html) for advice on optimisation and the [Kernels notebook](https://gpjax.readthedocs.io/en/latest/nbs/kernels.html) for advice on selecting an appropriate kernel.
# We'll now independently learn Gaussian process posterior distributions for each
# dataset. We won't spend any time here discussing how GP hyperparameters are
# optimised. For advice on achieving this, see the
# [Regression notebook](https://gpjax.readthedocs.io/en/latest/nbs/regression.html)
# for advice on optimisation and the
# [Kernels notebook](https://gpjax.readthedocs.io/en/latest/nbs/kernels.html) for
# advice on selecting an appropriate kernel.

# %%
def fit_gp(x: jax.Array, y: jax.Array) -> dx.MultivariateNormalTri:

# %% vscode={"languageId": "python"}
def fit_gp(x: jax.Array, y: jax.Array) -> tfd.MultivariateNormalFullCovariance:
if y.ndim == 1:
y = y.reshape(-1, 1)
D = Dataset(X=x, y=y)
D = gpx.Dataset(X=x, y=y)

likelihood = gpx.Gaussian(num_datapoints=n)
posterior = gpx.Prior(kernel=jk.RBF()) * likelihood

parameter_state = gpx.initialise(posterior, key)
negative_mll = jax.jit(posterior.marginal_log_likelihood(D, negative=True))
optimiser = ox.adam(learning_rate=0.01)

inference_state = gpx.fit(
objective=negative_mll,
parameter_state=parameter_state,
optax_optim=optimiser,
num_iters=1000,
posterior = gpx.Prior(mean_function=gpx.Constant(), kernel=gpx.RBF()) * likelihood

opt_posterior, _ = gpx.fit(
model=posterior,
objective=jax.jit(gpx.ConjugateMLL(negative=True)),
train_data=D,
optim=ox.adamw(learning_rate=0.01),
num_iters=500,
)

learned_params, training_history = inference_state.unpack()
return likelihood(learned_params, posterior(learned_params, D)(xtest))
latent_dist = opt_posterior.predict(xtest, train_data=D)
return opt_posterior.likelihood(latent_dist)


posterior_preds = [fit_gp(x, i) for i in ys]

# %% [markdown]
# ## Computing the barycentre
#
# In GPJax, the predictive distribution of a GP is given by a [Distrax](https://github.com/deepmind/distrax) distribution, making it straightforward to extract the mean vector and covariance matrix of each GP for learning a barycentre. We implement the fixed point scheme given in (3) in the following cell by utilising Jax's `vmap` operator to speed up large matrix operations using broadcasting in `tensordot`.
# In GPJax, the predictive distribution of a GP is given by a
# [TensorFlow Probability](https://www.tensorflow.org/probability/api_docs/python/tfp/substrates/jax)
# distribution, making it
# straightforward to extract the mean vector and covariance matrix of each GP for
# learning a barycentre. We implement the fixed point scheme given in (3) in the
# following cell by utilising Jax's `vmap` operator to speed up large matrix operations
# using broadcasting in `tensordot`.

# %%

# %% vscode={"languageId": "python"}
def sqrtm(A: jax.Array):
return jnp.real(jsl.sqrtm(A))


def wasserstein_barycentres(
distributions: tp.List[dx.MultivariateNormalTri], weights: jax.Array
distributions: tp.List[tfd.MultivariateNormalFullCovariance], weights: jax.Array
):
covariances = [d.covariance() for d in distributions]
cov_stack = jnp.stack(covariances)
Expand All @@ -152,9 +192,14 @@ def step(covariance_candidate: jax.Array, idx: None):


# %% [markdown]
# With a function defined for learning a barycentre, we'll now compute it using the `lax.scan` operator that drastically speeds up for loops in Jax (see the [Jax documentation](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html)). The iterative update will be executed 100 times, with convergence measured by the difference between the previous and current iteration that we can confirm by inspecting the `sequence` array in the following cell.

# %%
# With a function defined for learning a barycentre, we'll now compute it using the
# `lax.scan` operator that drastically speeds up for loops in Jax (see the
# [Jax documentation](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html)).
# The iterative update will be executed 100 times, with convergence measured by the
# difference between the previous and current iteration that we can confirm by
# inspecting the `sequence` array in the following cell.

# %% vscode={"languageId": "python"}
weights = jnp.ones((n_datasets,)) / n_datasets

means = jnp.stack([d.mean() for d in posterior_preds])
Expand All @@ -168,16 +213,19 @@ def step(covariance_candidate: jax.Array, idx: None):
)
L = jnp.linalg.cholesky(barycentre_covariance)

barycentre_process = dx.MultivariateNormalTri(barycentre_mean, L)
barycentre_process = tfd.MultivariateNormalTriL(barycentre_mean, L)

# %% [markdown]
# ## Plotting the result
#
# With a barycentre learned, we can visualise the result. We can see that the result looks reasonable as it follows the sinusoidal curve of all the inferred GPs, and the uncertainty bands are sensible.
# With a barycentre learned, we can visualise the result. We can see that the result
# looks reasonable as it follows the sinusoidal curve of all the inferred GPs, and the
# uncertainty bands are sensible.


# %%
# %% vscode={"languageId": "python"}
def plot(
dist: dx.MultivariateNormalTri,
dist: tfd.MultivariateNormalTriL,
ax,
color: str = "tab:blue",
label: str = None,
Expand Down Expand Up @@ -206,13 +254,18 @@ def plot(
# %% [markdown]
# ## Displacement interpolation
#
# In the above example, we assigned uniform weights to each of the posteriors within the barycentre. In practice, we may have prior knowledge of which posterior is most likely to be the correct one. Regardless of the weights chosen, the barycentre remains a Gaussian process. We can interpolate between a pair of posterior distributions $\mu_1$ and $\mu_2$ to visualise the corresponding barycentre $\bar{\mu}$.
# In the above example, we assigned uniform weights to each of the posteriors within
# the barycentre. In practice, we may have prior knowledge of which posterior is most
# likely to be the correct one. Regardless of the weights chosen, the barycentre
# remains a Gaussian process. We can interpolate between a pair of posterior
# distributions $\mu_1$ and $\mu_2$ to visualise the corresponding barycentre
# $\bar{\mu}$.
#
# ![](figs/barycentre_gp.gif)

# %% [markdown]
# ## System configuration

# %%
# %% vscode={"languageId": "python"}
# %reload_ext watermark
# %watermark -n -u -v -iv -w -a 'Thomas Pinder (edited by Daniel Dodd)'
Loading