Skip to content

Commit

Permalink
Multi Latent Likelihoods using new quadrature Likelihoods (#1559)
Browse files Browse the repository at this point in the history
* HeteroskedasticLikelihood base class draft

* fixup

* cleanup

* cleanup heteroskedastic

* multioutput likelihood WIP

* Notebook exemplifying HeteroskedasticTFPDistribution usage (#1462)

* fixes

* typo fix; reshaping fix

* notebook showing how to use HeteroskedasticTFPDistribution likelihood

* converting to .pct.py format

* removed .ipynb

* better descriptions

* black auto-formatting

Co-authored-by: Gustavo Carvalho <gustavo.carvalho@delfosim.com>

* note and bugfix

* add comment

* Adding heteroskedastic tests (#1508)

These tests ensure that heteroskedastic likelihood with a constant variance, will give the same results as a Gaussian likelihood with the same variance.

* testing

* added QuadratureLikelihood to base, refactored ScalarLikelihood to use it

* fix

* using the first dimension to hold the quadrature summation

* adapting ndiagquad wrapper

* merged with gustavocmv/quadrature-change-shape

* removed unecessary tf.init_scope

* removed print and tf.print

* removed print and tf.print

* Type annotations

Co-authored-by: Vincent Dutordoir <dutordoirv@gmail.com>

* Work

* Fix test

* Remove multioutput from PR

* Fix notebook

* Add student t test

* More tests

* Copyright

* Removed NDiagGHQuadratureLikelihood class in favor of non-abstract QuadratureLikelihood

* _set_latent_and_observation_dimension_eagerly

* n_gh ---> num_gauss_hermite_points

* removed NDiagGHQuadratureLikelihood from test

* black

* bugfix

* removing NDiagGHQuadratureLikelihood from test

* fixed bad commenting

* black

* refactoring scalar likelihood

* adding dtype casts to quadrature

* black

* small merging fixes

* DONE: swap n_gh for num_gauss_hermite_points

* black

Co-authored-by: ST John <st@prowler.io>
Co-authored-by: gustavocmv <47801305+gustavocmv@users.noreply.github.com>
Co-authored-by: Gustavo Carvalho <gustavo.carvalho@delfosim.com>
Co-authored-by: st-- <st--@users.noreply.github.com>
Co-authored-by: joshuacoales-pio <47976939+joshuacoales-pio@users.noreply.github.com>
  • Loading branch information
6 people committed Sep 15, 2020
1 parent 799b659 commit ad6e031
Show file tree
Hide file tree
Showing 7 changed files with 618 additions and 3 deletions.
214 changes: 214 additions & 0 deletions doc/source/notebooks/advanced/heteroskedastic.pct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
# ---
# jupyter:
# jupytext:
# formats: ipynb,.pct.py:percent
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.4.2
# kernelspec:
# display_name: Python 3
# language: python
# name: python3
# ---

# %% [markdown]
# # Heteroskedastic Likelihood and Multi-Latent GP

# %% [markdown]
# ## Standard (Homoskedastic) Regression
# In standard GP Regression, the GP latent function is used to learn the location parameter of a likelihood distribution (usually a Gaussian) as a function of the input $x$, whereas the scale parameter is considered constant. This is a homoskedastic model, which is unable to capture variations of the noise distribution with the input $x$.
#
#
# ## Heteroskedastic Regression
# This notebooks shows how to construct a model which uses multiple (2) GP latent functions to learn both the location and the scale of the Gaussian likelihood distribution. It does so by connecting a **Multi-Output Kernel**, which generates multiple GP latent functions, to a **Heteroskedastic Likelihood**, which maps the latent GPs into a single likelihood.
#
# The generative model is described as:
#
# $$ f_1(x) \sim \mathcal{GP}(0, k_1(\cdot, \cdot)) $$
# $$ f_2(x) \sim \mathcal{GP}(0, k_2(\cdot, \cdot)) $$
# $$ \text{loc}(x) = f_1(x) $$
# $$ \text{scale}(x) = \text{transform}(f_2(x)) $$
# $$ y_i|f_1, f_2, x_i \sim \mathcal{N}(\text{loc}(x_i),\;\text{scale}^2(x_i))$$
#
# Where the function $\text{transform}$ is used to map the GP $f_2$ to **positive-only values**, which is required as it represent the $\text{scale}$ of a Gaussian likelihood. In this notebook, the $\exp$ function will be used as the $\text{transform}$. An alternative would be the $\text{softplus}$ function.

# %%
import numpy as np
import gpflow as gpf
import tensorflow as tf
import tensorflow_probability as tfp
import matplotlib.pyplot as plt


# %% [markdown]
# # Data Generation
# We generate some heteroskedastic data by substituting the random latent functions $f_1$ and $f_2$ of the generative model by deterministic $\sin$ and $\cos$ functions. The input $X$ is built with $N=1001$ uniformly spaced values in the interval $[0, 4\pi]$. The outputs $Y$ are still sampled from a Gaussian likelihood.
#
# $$ x_i \in [0, 4\pi], \quad i = 1,\dots,N $$
# $$ f_1(x) = \sin(x) $$
# $$ f_2(x) = \cos(x) $$
# $$ \text{loc}(x) = f_1(x) $$
# $$ \text{scale}(x) = \exp(f_2(x)) $$
# $$ y_i|x_i \sim \mathcal{N}(\text{loc}(x_i),\;\text{scale}^2(x_i))$$

# %%
N = 1001

np.random.seed(0)
tf.random.set_seed(0)

# Build inputs X
X = np.linspace(0, 4 * np.pi, N)[:, None] # X must be of shape [N, 1]

# Deterministic functions in place of latent ones
f1 = np.sin
f2 = np.cos

# Use transform = exp to ensure positive-only scale values
transform = np.exp

# Compute loc and scale as functions of input X
loc = f1(X)
variance = transform(f2(X))
scale = variance ** 0.5

# Sample outputs Y from Gaussian Likelihood
Y = np.random.normal(loc, scale)

# %% [markdown]
# # Plot Data
# Note how the distribution density (shaded area) and the outputs $Y$ are more/less scatered depending on the input $X$.

# %%
plt.figure(figsize=(15, 5))
for k in (1, 2):
x = X.squeeze()
lb = (loc - k * scale).squeeze()
ub = (loc + k * scale).squeeze()
plt.fill_between(x, lb, ub, color="silver", alpha=1 - 0.05 * k ** 3)
plt.plot(x, lb, color="silver")
plt.plot(x, ub, color="silver")
plt.plot(X, loc, color="black")
plt.scatter(X, Y, color="gray", alpha=0.8)
plt.show()
plt.close()


# %% [markdown]
# # Build Model

# %% [markdown]
# ## Likelihood
# This implements the following part of the generative model:
# $$ \text{loc}(x) = f_1(x) $$
# $$ \text{scale}(x) = \text{transform}(f_2(x)) $$
# $$ y_i|f_1, f_2, x_i \sim \mathcal{N}(\text{loc}(x_i),\;\text{scale}^2(x_i))$$

# %%
likelihood = gpf.likelihoods.HeteroskedasticTFPConditional(
distribution_class=tfp.distributions.Normal, # Gaussian Likelihood
transform=tfp.bijectors.Exp(), # Exponential Transform
)

print(f"Likelihood's expected latent_dim: {likelihood.latent_dim}")

# %% [markdown]
# ## Kernel
# This implements the following part of the generative model:
# $$ f_1(x) \sim \mathcal{GP}(0, k_1(\cdot, \cdot)) $$
# $$ f_2(x) \sim \mathcal{GP}(0, k_2(\cdot, \cdot)) $$
# with both kernels being modeled as separated and independent $\text{RBF}$ kernels.

# %%
kernel = gpf.kernels.SeparateIndependent(
[
gpf.kernels.RBF(), # This is k1, the kernel of f1
gpf.kernels.RBF(), # this is k2, the kernel of f2
]
)
# The list contained in gpf.kernels.SeparateIndependent must be the same size of likelihood.latent_dim

# %% [markdown]
# # Inducing Points
# Since we will use the **SVGP** model to perform inference, we need to implement the inducing variables $U_1$ and $U_2$, both with size $M=20$, which are used to approximate $f_1$ and $f_2$ respectively, and initialize the inducing points positions $Z_1$ and $Z_2$. This gives a total of $2M=40$ inducing variables and inducing points.
#
# The inducing variables and their corresponding inputs will be Separate and Independent, but both $Z_1$ and $Z_2$ will be initialized as $Z$, which are placed as $M=20$ equally spaced points in $[\min(X), \max(X)]$.
#

# %%
M = 20 # Number of inducing variables for each f_i

# Initial inducing points position Z
Z = np.linspace(X.min(), X.max(), M)[:, None] # Z must be of shape [M, 1]

inducing_variable = gpf.inducing_variables.SeparateIndependentInducingVariables(
[
gpf.inducing_variables.InducingPoints(Z), # This is U1 = f1(Z1)
gpf.inducing_variables.InducingPoints(Z), # This is U2 = f2(Z2)
]
)
# The list contained in gpf.kernels.SeparateIndependent must be the same size of likelihood.latent_dim

# %% [markdown]
# ## SVGP Model
# Build the **SVGP** model by composing composing the **Kernel**, the **Likelihood** and the **Inducing Variables**.
#
# Note that the model needs to be instructed about the number of latent GPs by passing `num_latent_gps=likelihood.latent_dim`

# %%
model = gpf.models.SVGP(
kernel=kernel,
likelihood=likelihood,
inducing_variable=inducing_variable,
num_latent_gps=likelihood.latent_dim,
)

model

# %% [markdown]
# # Build Optimizers (NatGrad + Adam)

# %%
data = (X, Y)
loss_fn = model.training_loss_closure(data)

gpf.utilities.set_trainable(model.q_mu, False)
gpf.utilities.set_trainable(model.q_sqrt, False)

variational_vars = [(model.q_mu, model.q_sqrt)]
natgrad_opt = gpf.optimizers.NaturalGradient(gamma=0.01)

adam_vars = model.trainable_variables
adam_opt = tf.optimizers.Adam(0.01)

# %% [markdown]
# # Run Optimization Loop

# %%
epochs = 100
log_freq = 20

for epoch in range(epochs + 1):
natgrad_opt.minimize(loss_fn, variational_vars)
adam_opt.minimize(loss_fn, adam_vars)

# For every 'log_freq' epochs, print the epoch and plot the predictions against the data
if epoch % log_freq == 0 and epoch > 0:
print(f"Epoch {epoch} - Loss: {loss_fn().numpy() : .4f}")
Ymean, Yvar = model.predict_y(X)
Ymean = Ymean.numpy().squeeze()
Ystd = tf.sqrt(Yvar).numpy().squeeze()
plt.figure(figsize=(15, 5))
for k in (1, 2):
x = X.squeeze()
lb = (Ymean - k * Ystd).squeeze()
ub = (Ymean + k * Ystd).squeeze()
plt.fill_between(x, lb, ub, color="silver", alpha=1 - 0.05 * k ** 3)
plt.plot(x, lb, color="silver")
plt.plot(x, ub, color="silver")
plt.plot(X, Ymean, color="black")
plt.scatter(X, Y, color="gray", alpha=0.8)
plt.show()
plt.close()
28 changes: 27 additions & 1 deletion gpflow/likelihoods/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,24 @@
from .base import Likelihood, ScalarLikelihood, SwitchedLikelihood, MonteCarloLikelihood
# Copyright 2017-2020 The GPflow Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .base import (
Likelihood,
QuadratureLikelihood,
ScalarLikelihood,
SwitchedLikelihood,
MonteCarloLikelihood,
)
from .scalar_discrete import (
Bernoulli,
Ordinal,
Expand All @@ -17,3 +37,9 @@
Softmax,
RobustMax,
)

from .multilatent import (
MultiLatentLikelihood,
MultiLatentTFPConditional,
HeteroskedasticTFPConditional,
)
58 changes: 58 additions & 0 deletions gpflow/likelihoods/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@
import abc
import warnings

from typing import Optional

from ..base import Module
from ..quadrature import hermgauss, ndiag_mc, ndiagquad, NDiagGHQuadrature

Expand Down Expand Up @@ -286,6 +288,62 @@ def _variational_expectations(self, Fmu, Fvar, Y):
raise NotImplementedError


class QuadratureLikelihood(Likelihood):
def __init__(self, num_gauss_hermite_points: int = 20, **kwargs):
super().__init__(**kwargs)
self.num_gauss_hermite_points = num_gauss_hermite_points
self._quadrature = None

@property
def quadrature(self):
if self._quadrature is None:
if self.latent_dim is None:
raise Exception(
"latent_dim not specified. "
"Either set likelihood.latent_dim directly or "
"call a method which passes data to have it inferred."
)
with tf.init_scope():
self._quadrature = NDiagGHQuadrature(self.latent_dim, self.num_gauss_hermite_points)
return self._quadrature

def _predict_mean_and_var(self, Fmu, Fvar):
r"""
:param Fmu: mean function evaluation Tensor, with shape [..., latent_dim]
:param Fvar: variance of function evaluation Tensor, with shape [..., latent_dim]
:returns: mean and variance, both with shape [..., observation_dim]
"""

def conditional_y_squared(F):
return self.conditional_variance(F) + self.conditional_mean(F) ** 2

integrands = [self.conditional_mean, conditional_y_squared]
E_y, E_y2 = self.quadrature(integrands, Fmu, Fvar)
V_y = E_y2 - E_y ** 2
return E_y, V_y

def _quadrature_log_prob(self, F, Y):
return tf.expand_dims(self.log_prob(F, Y), -1)

def _predict_log_density(self, Fmu, Fvar, Y):
r"""
:param Fmu: mean function evaluation Tensor, with shape [..., latent_dim]
:param Fvar: variance of function evaluation Tensor, with shape [..., latent_dim]
:param Y: observation Tensor, with shape [..., observation_dim]:
:returns: variational expectations, with shape [...]
"""
return tf.squeeze(self.quadrature.logspace(self._quadrature_log_prob, Fmu, Fvar, Y), -1)

def _variational_expectations(self, Fmu, Fvar, Y):
r"""
:param Fmu: mean function evaluation Tensor, with shape [..., latent_dim]
:param Fvar: variance of function evaluation Tensor, with shape [..., latent_dim]
:param Y: observation Tensor, with shape [..., observation_dim]:
:returns: log predictive density, with shape [...]
"""
return tf.squeeze(self.quadrature(self._quadrature_log_prob, Fmu, Fvar, Y), -1)


class ScalarLikelihood(Likelihood):
"""
A likelihood class that helps with scalar likelihood functions: likelihoods where
Expand Down
Loading

0 comments on commit ad6e031

Please sign in to comment.