Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve capabilities for non-Gaussian likelihoods #1631

Draft
wants to merge 24 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
2bfbae2
plot heteroskedastic y distribution
willcowley Jan 15, 2021
dbe09c6
Initial exploration for an alternative interface with non-Gaussian li…
avullo Jan 15, 2021
3458f1a
Initial sketch for a prototype conditional output distribution.
avullo Jan 15, 2021
38b88b3
use ConditionalLikelihood in classificationnotebook
willcowley Jan 15, 2021
2758130
update ConditionalLikelihood
willcowley Jan 18, 2021
402c8b4
update notebooks
willcowley Jan 18, 2021
cada4f6
reroute model.predict_y, predict_log_density to conditional_y_dist
st-- Jan 18, 2021
b04d630
move ConditionedLikelihood into gpflow.likelihoods
st-- Jan 18, 2021
b47d114
format on notebooks
st-- Jan 18, 2021
d7e6b21
add missing import
st-- Jan 18, 2021
d600663
fix warn() call
st-- Jan 18, 2021
b9304c3
fix rename
st-- Jan 18, 2021
d6bf0e1
add y_percentile and parameter_percentile helpers to ConditionalLikel…
willcowley Jan 18, 2021
8467a73
update notebooks to use y_dist.y_percentile and y_dist.parameter_perc…
willcowley Jan 18, 2021
9a3a41a
Adding some docstrings.
avullo Jan 18, 2021
0333234
Add also some explanations about the new interface in the notebooks a…
avullo Jan 18, 2021
5c6db88
Reformatting.
avullo Jan 18, 2021
4db4cea
Merge branch 'develop' into avullo-willcowley/working-bee-ef1
avullo Sep 14, 2021
2b103e7
Merge branch 'develop' into avullo-willcowley/working-bee-ef1
avullo Sep 29, 2021
47e61b7
Apply suggestions from code review
avullo Oct 1, 2021
1dd3223
Consistency with numpy quantile interface.
avullo Oct 7, 2021
702278c
Rearranging and updating notebooks with more user-friendly descriptions.
avullo Oct 7, 2021
5e92b2e
update classification-redesigned
willcowley Oct 15, 2021
458c6a6
update heteorskedastic notebook
willcowley Oct 15, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
74 changes: 73 additions & 1 deletion doc/source/notebooks/advanced/heteroskedastic.pct.py
Expand Up @@ -6,7 +6,7 @@
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.6.0
# jupytext_version: 1.9.1
# kernelspec:
# display_name: Python 3
# language: python
Expand Down Expand Up @@ -214,6 +214,78 @@ def optimisation_step():

model

# %% [markdown]
# ## The conditional output distribution
#
# Here we show how to get the conditional output distribution and plot samples from it. In order to plot the uncertainty associated with that, we also get the percentiles from the conditional output distribution.
# Although the likelihood is Gaussian, the marginal posterior is not hence plotting with just the mean and variance would be misrepresentative of the real situation.

# %%
plot_distribution(X, Y, Ymean, Ystd)

y_dist = model.conditional_y_dist(X)
samples = y_dist.sample(10_000)

# The folling is equivalent at doing:
avullo marked this conversation as resolved.
Show resolved Hide resolved
# y_lo_lo, y_lo, y_hi, y_hi_hi = np.quantile(samples, q=(0.025, 0.159, 0.841, 0.975), axis=0)
# Note how, contrary to the binary classification case, here we get the percentiles directly from the
# conditional output distribution
y_lo_lo, y_lo, y_hi, y_hi_hi = y_dist.y_percentile(p=(2.5, 15.9, 84.1, 97.5), num_samples=10_000)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: it's a bit annoying that numpy uses fractions of 1 and we use percentiles in our API. I guess this is because of TFP's API?


fig, ax = plt.subplots(1, 1, figsize=(15, 5))
ax.plot(X, np.mean(samples, axis=0), c="k")
ax.fill_between(X.squeeze(), y_lo[..., 0], y_hi[..., 0], color="silver", alpha=1 - 0.05 * 1 ** 3)
ax.fill_between(
X.squeeze(), y_lo_lo[..., 0], y_hi_hi[..., 0], color="silver", alpha=1 - 0.05 * 2 ** 3
)
ax.scatter(X, Y, color="gray", alpha=0.8)

# %%
p_mu_samples, p_var_samples = y_dist.parameter_samples(10_000)
# p_mu_lo, p_mu_hi = np.quantile(p_mu_samples, q=(0.159, 0.841), axis=0)
# p_var_lo, p_var_hi = np.quantile(p_var_samples, q=(0.159, 0.841), axis=0)

(p_mu_lo, p_mu_hi), (p_var_lo, p_var_hi) = y_dist.parameter_percentile(
p=(15.9, 84.1), num_samples=10_000
)

# %%
fig, ax = plt.subplots(1, 1, figsize=(15, 5))
ax.scatter(X, Y, color="gray", alpha=0.8)
ax.plot(X, Ymean, c="k")
ax.plot(X, np.mean(p_mu_samples, axis=0), ls="--")

ax.fill_between(X.squeeze(), p_mu_lo[..., 0], p_mu_hi[..., 0], color="silver", alpha=0.8)

fig, ax = plt.subplots(1, 1, figsize=(15, 5))
ax.plot(X, scale ** 2, c="k")
ax.plot(X, np.mean(p_var_samples, axis=0), ls="--")
ax.fill_between(X.squeeze(), p_var_lo[..., 0], p_var_hi[..., 0], color="silver", alpha=0.8)

# %%
from scipy.stats import norm

f_mu = 0.0
f_var = 0.1
g_mu = 0.0
g_var = 0.1
n_samples = 1_00_000

f = np.random.normal(loc=f_mu, scale=np.sqrt(f_var), size=n_samples)
g = np.random.normal(loc=g_mu, scale=np.sqrt(g_var), size=n_samples)

y = np.random.normal(loc=f, scale=np.exp(g))

fig, ax = plt.subplots(1, 1)
ax.hist(y, bins=100, density=True, alpha=0.5)
xx = np.linspace(-10, 10, 101)
noise_var = np.exp(g_mu + g_var / 2) ** 2
y_var = f_var + noise_var
ax.plot(xx, norm.pdf(xx, loc=f_mu, scale=y_var ** 0.5), c="k")
ax.set_yscale("log")
ax.set_ylim(1e-8, 1e0)


# %% [markdown]
# ## Further reading
#
Expand Down
207 changes: 207 additions & 0 deletions doc/source/notebooks/basics/classification-redesigned.pct.py
@@ -0,0 +1,207 @@
# ---
# jupyter:
# jupytext:
# formats: ipynb,.pct.py:percent
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.9.1
# kernelspec:
# display_name: Python 3
# language: python
# name: python3
# ---

# %% [markdown]
# # Basic (binary) GP classification model
#
#
# This notebook shows how to build a GP classification model using variational inference.
# Here we consider binary (two-class, 0 vs. 1) classification only (there is a separate notebook on [multiclass classification](../advanced/multiclass_classification.ipynb)).
# We first look at a one-dimensional example, and then show how you can adapt this when the input space is two-dimensional.

# %%
import numpy as np
import gpflow
import tensorflow as tf

import matplotlib.pyplot as plt

# %matplotlib inline

plt.rcParams["figure.figsize"] = (8, 4)

# %% [markdown]
# ## One-dimensional example
#
# First of all, let's have a look at the data. `X` and `Y` denote the input and output values.
# **NOTE:** `X` and `Y` must be two-dimensional NumPy arrays, $N \times 1$ or $N \times D$, where $D$ is the number of input dimensions/features, with the same number of rows as $N$ (one for each data point):

# %%
X = np.genfromtxt("data/classif_1D_X.csv").reshape(-1, 1)
Y = np.genfromtxt("data/classif_1D_Y.csv").reshape(-1, 1)

n_data = 40
X = np.random.rand(n_data) * 2 + 2

Y = np.zeros_like(X)
Y = (np.random.rand(n_data) > 0.25).astype(float)

X = X.reshape(-1, 1)
Y = Y.reshape(-1, 1)

plt.figure(figsize=(10, 6))
_ = plt.plot(X, Y, "C3x", ms=8, mew=2)

# %% [markdown]
# ### Reminders on GP classification
#
# For a binary classification model using GPs, we can simply use a `Bernoulli` likelihood. The details of the generative model are as follows:
#
# __1. Define the latent GP:__ we start from a Gaussian process $f \sim \mathcal{GP}(0, k(\cdot, \cdot'))$:

# %%
# build the kernel and covariance matrix
k = gpflow.kernels.Matern52(variance=20.0)
x_grid = np.linspace(0, 6, 200).reshape(-1, 1)
K = k(x_grid)

# sample from a multivariate normal
rng = np.random.RandomState(6)

L = np.linalg.cholesky(K)
f_grid = np.dot(L, rng.randn(200, 5))
plt.plot(x_grid, f_grid, "C0", linewidth=1)
_ = plt.plot(x_grid, f_grid[:, 1], "C0", linewidth=2)

# %% [markdown]
# __2. Squash them to $[0, 1]$:__ the samples of the GP are mapped to $[0, 1]$.
# By default, GPflow uses the standard normal cumulative distribution function (inverse probit function): $p(x) = \Phi(f(x)) = \frac{1}{2} (1 + \operatorname{erf}(x / \sqrt{2}))$.
# (This choice has the advantage that predictive mean, variance and density can be computed analytically, but any choice of invlink is possible, e.g. the logit $p(x) = \frac{\exp(f(x))}{1 + \exp(f(x))}$. Simply pass another function as the `invlink` argument to the `Bernoulli` likelihood class.)

# %%
def invlink(f):
return gpflow.likelihoods.Bernoulli().invlink(f).numpy()


p_grid = invlink(f_grid)
plt.plot(x_grid, p_grid, "C1", linewidth=1)
_ = plt.plot(x_grid, p_grid[:, 1], "C1", linewidth=2)

# %% [markdown]
# __3. Sample from a Bernoulli:__ for each observation point $X_i$, the class label $Y_i \in \{0, 1\}$ is generated by sampling from a Bernoulli distribution $Y_i \sim \mathcal{B}(g(X_i))$.

# %%
# Select some input locations
ind = rng.randint(0, 200, (30,))
X_gen = x_grid[ind]

# evaluate probability and get Bernoulli draws
p = p_grid[ind, 1:2]
Y_gen = rng.binomial(1, p)

# plot
plt.plot(x_grid, p_grid[:, 1], "C1", linewidth=2)
plt.plot(X_gen, p, "C1o", ms=6)
_ = plt.plot(X_gen, Y_gen, "C3x", ms=8, mew=2)

# %% [markdown]
# ### Implementation with GPflow
#
# For the model described above, the posterior $f(x)|Y$ (say $p$) is not Gaussian any more and does not have a closed-form expression.
# A common approach is then to look for the best approximation of this posterior by a tractable distribution (say $q$) such as a Gaussian distribution.
# In variational inference, the quality of an approximation is measured by the Kullback-Leibler divergence $\mathrm{KL}[q \| p]$.
# For more details on this model, see Nickisch and Rasmussen (2008).
#
# The inference problem is thus turned into an optimization problem: finding the best parameters for $q$.
# In our case, we introduce $U \sim \mathcal{N}(q_\mu, q_\Sigma)$, and we choose $q$ to have the same distribution as $f | f(X) = U$.
# The parameters $q_\mu$ and $q_\Sigma$ can be seen as parameters of $q$, which can be optimized in order to minimise $\mathrm{KL}[q \| p]$.
#
# This variational inference model is called `VGP` in GPflow:

# %%
m = gpflow.models.VGP(
(X, Y), likelihood=gpflow.likelihoods.Bernoulli(), kernel=gpflow.kernels.Matern52()
)
gpflow.set_trainable(m.kernel, False)
opt = gpflow.optimizers.Scipy()
opt.minimize(m.training_loss, variables=m.trainable_variables)

# %% [markdown]
# We can now inspect the result of the optimization with `gpflow.utilities.print_summary(m)`:

# %%
gpflow.utilities.print_summary(m, fmt="notebook")

# %% [markdown]
# In this table, the first two lines are associated with the kernel parameters, and the last two correspond to the variational parameters.
# **NOTE:** In practice, $q_\Sigma$ is actually parameterized by its lower-triangular square root $q_\Sigma = q_\text{sqrt} q_\text{sqrt}^T$ in order to ensure its positive-definiteness.
#
# For more details on how to handle models in GPflow (getting and setting parameters, fixing some of them during optimization, using priors, and so on), see [Manipulating GPflow models](../understanding/models.ipynb).

# %% [markdown]
# ### Predictions
#
# Finally, we will see how to use model predictions to plot the resulting model.
# We will replicate the figures of the generative model above, but using the approximate posterior distribution given by the model.

# %% [markdown]
# ## The conditional output distribution
#
# Here we show how to get the conditional output distribution and plot samples from it. In order to plot the uncertainty associated with that, we also get the percentiles from a sample of the corresponding likelihood parameter values, because the those of the output values are binary and would neither be convenient nor interesting to plot.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what is meant with the second bit of the sentence:

, because the those of the output values are binary and would neither be convenient nor interesting to plot


# %%
tf.random.set_seed(6)
y_dist = m.conditional_y_dist(x_grid.reshape(-1, 1))

y_samples = y_dist.sample(100)

plt.figure(figsize=(12, 8))
plt.plot(x_grid, np.mean(y_samples, axis=0))

((p_lo, p_50, p_hi),) = y_dist.parameter_percentile(p=(2.5, 50.0, 97.5), num_samples=10_000)
(l1,) = plt.plot(x_grid, p_50)
plt.fill_between(
x_grid.flatten(), np.ravel(p_lo), np.ravel(p_hi), alpha=0.3, color=l1.get_color(),
)
# plot data
plt.plot(X, Y, "C3x", ms=8, mew=2)
plt.ylim((-0.5, 1.5))

# %%
# this is functionally equivalent to the following, but more user-friendly and intuitive
# we need to get samples from the latent process, then obtain the quantiles at different levels to
# get the sample output mean (p=0.5) and the lowest (p=0.05) and highest (p=.95) quantiles.
samples = m.predict_f_samples(x_grid, 10).numpy().squeeze().T


def compute_y_sample_statistics(model, num_samples: int = 100):
p = invlink(model.predict_f_samples(x_grid, num_samples).numpy().squeeze().T)
mean = np.mean(p, axis=1)
p_low, p_high = np.quantile(a=p, q=[0.05, 0.95], axis=1)

return mean, p_low, p_high


y_mu, y_p_low, y_p_high = compute_y_sample_statistics(m, 10000)

plt.figure(figsize=(12, 8))
plt.plot(x_grid.flatten(), y_mu)
plt.fill_between(
x_grid.flatten(), np.ravel(y_p_low), np.ravel(y_p_high), alpha=0.3, color="C0",
)

# %% [markdown]
# ## Further reading
#
# There are dedicated notebooks giving more details on how to manipulate [models](../understanding/models.ipynb) and [kernels](../advanced/kernels.ipynb).
#
# This notebook covers only very basic classification models. You might also be interested in:
# * [Multiclass classification](../advanced/multiclass_classification.ipynb) if you have more than two classes.
# * [Sparse models](../advanced/gps_for_big_data.ipynb). The models above have one inducing variable $U_i$ per observation point $X_i$, which does not scale to large datasets. Sparse Variational GP (SVGP) is an efficient alternative where the variables $U_i$ are defined at some inducing input locations $Z_i$ that can also be optimized.
# * [Exact inference](../advanced/mcmc.ipynb). We have seen that variational inference provides an approximation to the posterior. GPflow also supports exact inference using Markov Chain Monte Carlo (MCMC) methods, and the kernel parameters can also be assigned prior distributions in order to avoid point estimates.
#
# ## References
#
# Hannes Nickisch and Carl Edward Rasmussen. 'Approximations for binary Gaussian process classification'. *Journal of Machine Learning Research* 9(Oct):2035--2078, 2008.
1 change: 1 addition & 0 deletions gpflow/likelihoods/__init__.py
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from .base import (
ConditionedLikelihood,
Likelihood,
MonteCarloLikelihood,
QuadratureLikelihood,
Expand Down