Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Heteroscedastic GPR model #1704

Draft
wants to merge 19 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 169 additions & 0 deletions doc/source/notebooks/advanced/het_gpr_script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
import gpflow as gf
from gpflow import Parameter
from gpflow.likelihoods import MultiLatentTFPConditional

from gpflow.utilities import print_summary, positive

N = 101

np.random.seed(0)
tf.random.set_seed(0)

# Build inputs X
X = np.linspace(0, 1, N)[:, None]

# Create outputs Y which includes heteroscedastic noise
rand_samples = np.random.normal(0.0, 1.0, size=N)[:, None]
noise = rand_samples * (0.05 + 0.75 * X)
signal = 2 * np.sin(2 * np.pi * X)

Y = 5 * (signal + noise)

# %% [markdown]
# ### Plot Data
# Note how the distribution density (shaded area) and the outputs $Y$ both change depending on the input $X$.

# %%
def plot_distribution(X, Y, mean=None, std=None):
plt.figure(figsize=(15, 5))
if mean is not None:
x = X.squeeze()
for k in (1, 2):
lb = (mean - k * std).squeeze()
ub = (mean + k * std).squeeze()
plt.fill_between(x, lb, ub, color="silver", alpha=1 - 0.05 * k ** 3)
plt.plot(x, lb, color="silver")
plt.plot(x, ub, color="silver")
plt.plot(X, mean, color="black")
plt.scatter(X, Y, color="gray", alpha=0.8)
plt.show()
plt.close()


# plot_distribution(X, Y)


# %% [markdown]
# ## Build Model

# %% [markdown]
# ### Likelihood
# This implements the following part of the generative model:
# $$ \text{loc}(x) = f_1(x) $$
# $$ \text{scale}(x) = \text{transform}(f_2(x)) $$
# $$ y_i|f_1, f_2, x_i \sim \mathcal{N}(\text{loc}(x_i),\;\text{scale}(x_i)^2)$$

# %% [markdown]
# ### Select a kernel
# %%
kernel = gf.kernels.Matern52()

# %% [markdown]
# ### HeteroskedasticGPR Model
# Build the **GPR** model with the data and kernel

# %%
class LinearLikelihood(MultiLatentTFPConditional):

def __init__(self, ndims: int = 1, **kwargs):
gradient_prior = tfp.distributions.Normal(loc=np.float64(0.0), scale=np.float64(1.0))
self.noise_gradient = Parameter(np.ones(ndims), transform=positive(lower=1e-6), prior=gradient_prior)
self.constant_variance = Parameter(1.0, transform=positive(lower=1e-6))
self.minimum_noise_variance = 1e-6 # ?

def conditional_distribution(Fs) -> tfp.distributions.Distribution:
tf.debugging.assert_equal(tf.shape(Fs)[-1], 2)
loc = Fs[..., :1]
scale = self.scale_transform(Fs[..., 1:])
return tfp.distributions.Normal(loc, scale)

super().__init__(latent_dim=2, conditional_distribution=conditional_distribution, ** kwargs)

def scale_transform(self, X):
""" Determine the likelihood variance at the specified input locations X. """

linear_variance = tf.reduce_sum(tf.square(X) * self.noise_gradient, axis=-1, keepdims=True)
noise_variance = linear_variance + self.constant_variance
return tf.maximum(noise_variance, self.minimum_noise_variance)


model = gf.models.het_GPR(data=(X, Y), kernel=kernel, likelihood=LinearLikelihood(), mean_function=None)

# %% [markdown]
# ## Model Optimization proceeds as in the GPR notebook
# %%
opt = gf.optimizers.Scipy()

# %% [markdown]
# %%
opt_logs = opt.minimize(model.training_loss, model.trainable_variables, options=dict(maxiter=100))
print_summary(model)
print_summary(model.posterior())

## predict mean and variance of latent GP at test points
mean, var = model.predict_y(X)
plot_distribution(X, Y, mean.numpy(), np.sqrt(var.numpy()))

# fmean, fvar = model.predict_f(X)
# plot_distribution(X, Y, fmean.numpy(), np.sqrt(fvar.numpy()))

## Repeat for standard GPR
base_model = gf.models.GPR(data=(X, Y), kernel=kernel, mean_function=None)
opt_logs = opt.minimize(base_model.training_loss, base_model.trainable_variables, options=dict(maxiter=100))
print_summary(base_model)

## predict mean and variance of latent GP at test points
mean, var = base_model.predict_y(X)
plot_distribution(X, Y, mean.numpy(), np.sqrt(var.numpy()))

base_lml = base_model.log_marginal_likelihood()
lml = model.log_marginal_likelihood()
print("Base LML", base_lml)
print("Het LML", lml)
print("Odds ratio", np.exp(lml-base_lml))


class PseudoPoissonLikelihood(MultiLatentTFPConditional):
""" While the Poisson likelihood is non-Gaussian, here we mimic the """
def __init__(self, ndims: int = 1, **kwargs):
super().__init__(**kwargs)

def _scale_transform(self, f):
""" Determine the likelihood variance based upon the function value. """
pass


# ## generate test points for prediction
# xx = np.linspace(-0.1, 1.1, 100).reshape(100, 1) # test points must be of shape (N, D)
#
# ## predict mean and variance of latent GP at test points
# mean, var = model.predict_f(xx)
#
# ## generate 10 samples from posterior
# tf.random.set_seed(1) # for reproducibility
# samples = model.predict_f_samples(xx, 10) # shape (10, 100, 1)
#
# ## plot
# plt.figure(figsize=(12, 6))
# plt.plot(X, Y, "kx", mew=2)
# plt.plot(xx, mean, "C0", lw=2)
# plt.fill_between(
# xx[:, 0],
# mean[:, 0] - 1.96 * np.sqrt(var[:, 0]),
# mean[:, 0] + 1.96 * np.sqrt(var[:, 0]),
# color="C0",
# alpha=0.2,
# )
#
# plt.plot(xx, samples[:, :, 0].numpy().T, "C0", linewidth=0.5)
# _ = plt.xlim(-0.1, 1.1)


# %% [markdown]
# ## Further reading
#
# See [Kernel Identification Through Transformers](https://arxiv.org/abs/2106.08185) by Simpson et al.
1 change: 1 addition & 0 deletions gpflow/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .gplvm import GPLVM, BayesianGPLVM
from .gpmc import GPMC
from .gpr import GPR
from .het_gpr import het_GPR
from .model import BayesianModel, GPModel
from .sgpmc import SGPMC
from .sgpr import GPRFITC, SGPR
Expand Down
108 changes: 108 additions & 0 deletions gpflow/models/het_gpr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import tensorflow as tf
from typing import Optional

from gpflow import posteriors
from ..kernels import Kernel
from ..logdensities import multivariate_normal
from ..mean_functions import MeanFunction
from ..models.gpr import GPR_with_posterior
from ..models.training_mixins import RegressionData, InputData
from ..types import MeanAndVariance
from ..utilities import add_linear_noise_cov


class het_GPR(GPR_with_posterior):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be useful if there was a way to follow the existing naming GPflow model naming convention in this case.

""" While the vanilla GPR enforces a constant noise variance across the input space, here we allow the
noise amplitude to vary linearly (and hence the noise variance to change quadratically) across the input space.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it true that the noise amplitude varies linearly with this model? Even with the PseudoPoissonLikelihood?

"""

def __init__(
self,
data: RegressionData,
kernel: Kernel,
likelihood,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this type annotation be MultiLatentTFPConditional?

mean_function: Optional[MeanFunction] = None,
noise_variance: float = 1.0,
):
super().__init__(data, kernel, mean_function, noise_variance)
self.likelihood = likelihood

def posterior(self, precompute_cache=posteriors.PrecomputeCacheType.TENSOR):
"""
Create the Posterior object which contains precomputed matrices for
faster prediction.

precompute_cache has three settings:

- `PrecomputeCacheType.TENSOR` (or `"tensor"`): Precomputes the cached
quantities and stores them as tensors (which allows differentiating
through the prediction). This is the default.
- `PrecomputeCacheType.VARIABLE` (or `"variable"`): Precomputes the cached
quantities and stores them as variables, which allows for updating
their values without changing the compute graph (relevant for AOT
compilation).
- `PrecomputeCacheType.NOCACHE` (or `"nocache"` or `None`): Avoids
immediate cache computation. This is useful for avoiding extraneous
computations when you only want to call the posterior's
`fused_predict_f` method.
"""

X, Y = self.data

return posteriors.HeteroskedasticGPRPosterior(
kernel=self.kernel,
X_data=X,
Y_data=Y,
likelihood=self.likelihood,
mean_function=self.mean_function,
precompute_cache=precompute_cache,
)

def predict_y(
self, Xnew: InputData, full_cov: bool = False, full_output_cov: bool = False
) -> MeanAndVariance:
"""
Compute the mean and variance of the held-out data at the input points.
"""
if full_cov or full_output_cov:
# See https://github.com/GPflow/GPflow/issues/1461
raise NotImplementedError(
"The predict_y method currently supports only the argument values full_cov=False and full_output_cov=False"
)

f_mean, f_var = self.predict_f(Xnew, full_cov=full_cov, full_output_cov=full_output_cov)
Fs = tf.concat([f_mean, Xnew], axis=-1)
dummy_f_var = tf.zeros_like(f_var)
F_vars = tf.concat([f_var, dummy_f_var], axis=-1)
return self.likelihood.predict_mean_and_var(Fs, F_vars)

def _add_noise_cov(self, X, K: tf.Tensor) -> tf.Tensor:
"""
Returns K + diag(σ²), where σ² is the likelihood noise variance (vector),
and I is the corresponding identity matrix.
"""
dummy_F = tf.zeros_like(X)
Fs = tf.concat([dummy_F, X], axis=-1)
variances = self.likelihood.conditional_variance(Fs)
return add_linear_noise_cov(K, tf.squeeze(variances))

def log_marginal_likelihood(self) -> tf.Tensor:
r"""
Computes the log marginal likelihood.

.. math::
\log p(Y | \theta).

"""
X, Y = self.data
K = self.kernel(X)
ks = self._add_noise_cov(X, K)
L = tf.linalg.cholesky(ks)
m = self.mean_function(X)

# [R,] log-likelihoods for each independent dimension of Y
log_prob = multivariate_normal(Y, m, L)
return tf.reduce_sum(log_prob)



36 changes: 33 additions & 3 deletions gpflow/posteriors.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@
SeparateIndependentInducingVariables,
SharedIndependentInducingVariables,
)
from .likelihoods import Likelihood
from .types import MeanAndVariance
from .utilities import Dispatcher, add_noise_cov
from .utilities import Dispatcher, add_noise_cov, add_linear_noise_cov


class _QDistribution(Module):
Expand Down Expand Up @@ -258,7 +259,7 @@ def _precompute(self) -> Tuple[tf.Tensor, tf.Tensor]:
"""

Kmm = self.kernel(self.X_data)
Kmm_plus_s = add_noise_cov(Kmm, self.likelihood_variance)
Kmm_plus_s = self.add_noise(Kmm, self.X_data)

# obtain the cholesky decomposition of Kmm_plus_s
Lm = tf.linalg.cholesky(Kmm_plus_s)
Expand Down Expand Up @@ -286,12 +287,41 @@ def _conditional_fused(
Kmm = self.kernel(self.X_data)
Knn = self.kernel(Xnew, full_cov=full_cov)
Kmn = self.kernel(self.X_data, Xnew)
Kmm_plus_s = add_noise_cov(Kmm, self.likelihood_variance)
Kmm_plus_s = self.add_noise(Kmm, self.X_data)

return base_conditional(
Kmn, Kmm_plus_s, Knn, err, full_cov=full_cov, white=False
) # [N, P], [N, P] or [P, N, N]

def add_noise(self, K: tf.Tensor, X: tf.Tensor):
return add_noise_cov(K, self.likelihood_variance)


class HeteroskedasticGPRPosterior(GPRPosterior):

def __init__(self,
kernel,
X_data: tf.Tensor,
Y_data: tf.Tensor,
likelihood: Likelihood,
mean_function: Optional[mean_functions.MeanFunction] = None,
*,
precompute_cache: Optional[PrecomputeCacheType],
):

self.likelihood = likelihood
super().__init__(kernel, X_data, Y_data, likelihood.constant_variance, mean_function=mean_function, precompute_cache=precompute_cache)

def evaluate_linear_noise_variance(self, X: tf.Tensor):
""" Noise variance contribution. """

return self.likelihood.scale_transform(X)

def add_noise(self, K: tf.Tensor, X: tf.Tensor):

noise_variance = self.evaluate_linear_noise_variance(X) + self.likelihood_variance
return add_linear_noise_cov(K, noise_variance)


class BasePosterior(AbstractPosterior):
def __init__(
Expand Down
9 changes: 9 additions & 0 deletions gpflow/utilities/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,12 @@ def add_noise_cov(K: tf.Tensor, likelihood_variance: Parameter) -> tf.Tensor:
k_diag = tf.linalg.diag_part(K)
s_diag = tf.fill(tf.shape(k_diag), likelihood_variance)
return tf.linalg.set_diag(K, k_diag + s_diag)


def add_linear_noise_cov(K: tf.Tensor, noise_variance: tf.Tensor) -> tf.Tensor:
"""
Returns K + diag(σ²), where σ² is the likelihood noise variance (vector).
"""
k_diag = tf.linalg.diag_part(K)
return tf.linalg.set_diag(K, k_diag + tf.reshape(noise_variance, [-1]))