-
Notifications
You must be signed in to change notification settings - Fork 435
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Heteroscedastic GPR model #1704
base: develop
Are you sure you want to change the base?
Changes from 14 commits
487468e
bef54f4
888f886
c923b3f
7ea600b
91b640b
a0a3439
7917b42
bf3bd5f
ca89ac6
e4bf1bf
01a1eaf
d11994d
bf8e2b9
cd0b59e
7922d48
0ef6c62
0118702
c5e7574
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
import tensorflow as tf | ||
import tensorflow_probability as tfp | ||
import gpflow as gf | ||
from gpflow import Parameter | ||
from gpflow.likelihoods import MultiLatentTFPConditional | ||
|
||
from gpflow.utilities import print_summary, positive | ||
|
||
N = 101 | ||
|
||
np.random.seed(0) | ||
tf.random.set_seed(0) | ||
|
||
# Build inputs X | ||
X = np.linspace(0, 1, N)[:, None] | ||
|
||
# Create outputs Y which includes heteroscedastic noise | ||
rand_samples = np.random.normal(0.0, 1.0, size=N)[:, None] | ||
noise = rand_samples * (0.05 + 0.75 * X) | ||
signal = 2 * np.sin(2 * np.pi * X) | ||
|
||
Y = 5 * (signal + noise) | ||
|
||
# %% [markdown] | ||
# ### Plot Data | ||
# Note how the distribution density (shaded area) and the outputs $Y$ both change depending on the input $X$. | ||
|
||
# %% | ||
def plot_distribution(X, Y, mean=None, std=None): | ||
plt.figure(figsize=(15, 5)) | ||
if mean is not None: | ||
x = X.squeeze() | ||
for k in (1, 2): | ||
lb = (mean - k * std).squeeze() | ||
ub = (mean + k * std).squeeze() | ||
plt.fill_between(x, lb, ub, color="silver", alpha=1 - 0.05 * k ** 3) | ||
plt.plot(x, lb, color="silver") | ||
plt.plot(x, ub, color="silver") | ||
plt.plot(X, mean, color="black") | ||
plt.scatter(X, Y, color="gray", alpha=0.8) | ||
plt.show() | ||
plt.close() | ||
|
||
|
||
# plot_distribution(X, Y) | ||
|
||
|
||
# %% [markdown] | ||
# ## Build Model | ||
|
||
# %% [markdown] | ||
# ### Likelihood | ||
# This implements the following part of the generative model: | ||
# $$ \text{loc}(x) = f_1(x) $$ | ||
# $$ \text{scale}(x) = \text{transform}(f_2(x)) $$ | ||
# $$ y_i|f_1, f_2, x_i \sim \mathcal{N}(\text{loc}(x_i),\;\text{scale}(x_i)^2)$$ | ||
|
||
# %% [markdown] | ||
# ### Select a kernel | ||
# %% | ||
kernel = gf.kernels.Matern52() | ||
|
||
# %% [markdown] | ||
# ### HeteroskedasticGPR Model | ||
# Build the **GPR** model with the data and kernel | ||
|
||
# %% | ||
class LinearLikelihood(MultiLatentTFPConditional): | ||
|
||
def __init__(self, ndims: int = 1, **kwargs): | ||
gradient_prior = tfp.distributions.Normal(loc=np.float64(0.0), scale=np.float64(1.0)) | ||
self.noise_gradient = Parameter(np.ones(ndims), transform=positive(lower=1e-6), prior=gradient_prior) | ||
self.constant_variance = Parameter(1.0, transform=positive(lower=1e-6)) | ||
self.minimum_noise_variance = 1e-6 # ? | ||
|
||
def conditional_distribution(Fs) -> tfp.distributions.Distribution: | ||
tf.debugging.assert_equal(tf.shape(Fs)[-1], 2) | ||
loc = Fs[..., :1] | ||
scale = self.scale_transform(Fs[..., 1:]) | ||
return tfp.distributions.Normal(loc, scale) | ||
|
||
super().__init__(latent_dim=2, conditional_distribution=conditional_distribution, ** kwargs) | ||
|
||
def scale_transform(self, X): | ||
""" Determine the likelihood variance at the specified input locations X. """ | ||
|
||
linear_variance = tf.reduce_sum(tf.square(X) * self.noise_gradient, axis=-1, keepdims=True) | ||
noise_variance = linear_variance + self.constant_variance | ||
return tf.maximum(noise_variance, self.minimum_noise_variance) | ||
|
||
|
||
model = gf.models.het_GPR(data=(X, Y), kernel=kernel, likelihood=LinearLikelihood(), mean_function=None) | ||
|
||
# %% [markdown] | ||
# ## Model Optimization proceeds as in the GPR notebook | ||
# %% | ||
opt = gf.optimizers.Scipy() | ||
|
||
# %% [markdown] | ||
# %% | ||
opt_logs = opt.minimize(model.training_loss, model.trainable_variables, options=dict(maxiter=100)) | ||
print_summary(model) | ||
print_summary(model.posterior()) | ||
|
||
## predict mean and variance of latent GP at test points | ||
mean, var = model.predict_y(X) | ||
plot_distribution(X, Y, mean.numpy(), np.sqrt(var.numpy())) | ||
|
||
# fmean, fvar = model.predict_f(X) | ||
# plot_distribution(X, Y, fmean.numpy(), np.sqrt(fvar.numpy())) | ||
|
||
## Repeat for standard GPR | ||
base_model = gf.models.GPR(data=(X, Y), kernel=kernel, mean_function=None) | ||
opt_logs = opt.minimize(base_model.training_loss, base_model.trainable_variables, options=dict(maxiter=100)) | ||
print_summary(base_model) | ||
|
||
## predict mean and variance of latent GP at test points | ||
mean, var = base_model.predict_y(X) | ||
plot_distribution(X, Y, mean.numpy(), np.sqrt(var.numpy())) | ||
|
||
base_lml = base_model.log_marginal_likelihood() | ||
lml = model.log_marginal_likelihood() | ||
print("Base LML", base_lml) | ||
print("Het LML", lml) | ||
print("Odds ratio", np.exp(lml-base_lml)) | ||
|
||
|
||
class PseudoPoissonLikelihood(MultiLatentTFPConditional): | ||
""" While the Poisson likelihood is non-Gaussian, here we mimic the """ | ||
def __init__(self, ndims: int = 1, **kwargs): | ||
super().__init__(**kwargs) | ||
|
||
def _scale_transform(self, f): | ||
""" Determine the likelihood variance based upon the function value. """ | ||
pass | ||
|
||
|
||
# ## generate test points for prediction | ||
# xx = np.linspace(-0.1, 1.1, 100).reshape(100, 1) # test points must be of shape (N, D) | ||
# | ||
# ## predict mean and variance of latent GP at test points | ||
# mean, var = model.predict_f(xx) | ||
# | ||
# ## generate 10 samples from posterior | ||
# tf.random.set_seed(1) # for reproducibility | ||
# samples = model.predict_f_samples(xx, 10) # shape (10, 100, 1) | ||
# | ||
# ## plot | ||
# plt.figure(figsize=(12, 6)) | ||
# plt.plot(X, Y, "kx", mew=2) | ||
# plt.plot(xx, mean, "C0", lw=2) | ||
# plt.fill_between( | ||
# xx[:, 0], | ||
# mean[:, 0] - 1.96 * np.sqrt(var[:, 0]), | ||
# mean[:, 0] + 1.96 * np.sqrt(var[:, 0]), | ||
# color="C0", | ||
# alpha=0.2, | ||
# ) | ||
# | ||
# plt.plot(xx, samples[:, :, 0].numpy().T, "C0", linewidth=0.5) | ||
# _ = plt.xlim(-0.1, 1.1) | ||
|
||
|
||
# %% [markdown] | ||
# ## Further reading | ||
# | ||
# See [Kernel Identification Through Transformers](https://arxiv.org/abs/2106.08185) by Simpson et al. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
import tensorflow as tf | ||
from typing import Optional | ||
|
||
from gpflow import posteriors | ||
from ..kernels import Kernel | ||
from ..logdensities import multivariate_normal | ||
from ..mean_functions import MeanFunction | ||
from ..models.gpr import GPR_with_posterior | ||
from ..models.training_mixins import RegressionData, InputData | ||
from ..types import MeanAndVariance | ||
from ..utilities import add_linear_noise_cov | ||
|
||
|
||
class het_GPR(GPR_with_posterior): | ||
""" While the vanilla GPR enforces a constant noise variance across the input space, here we allow the | ||
noise amplitude to vary linearly (and hence the noise variance to change quadratically) across the input space. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it true that the noise amplitude varies linearly with this model? Even with the |
||
""" | ||
|
||
def __init__( | ||
self, | ||
data: RegressionData, | ||
kernel: Kernel, | ||
likelihood, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could this type annotation be |
||
mean_function: Optional[MeanFunction] = None, | ||
noise_variance: float = 1.0, | ||
): | ||
super().__init__(data, kernel, mean_function, noise_variance) | ||
self.likelihood = likelihood | ||
|
||
def posterior(self, precompute_cache=posteriors.PrecomputeCacheType.TENSOR): | ||
""" | ||
Create the Posterior object which contains precomputed matrices for | ||
faster prediction. | ||
|
||
precompute_cache has three settings: | ||
|
||
- `PrecomputeCacheType.TENSOR` (or `"tensor"`): Precomputes the cached | ||
quantities and stores them as tensors (which allows differentiating | ||
through the prediction). This is the default. | ||
- `PrecomputeCacheType.VARIABLE` (or `"variable"`): Precomputes the cached | ||
quantities and stores them as variables, which allows for updating | ||
their values without changing the compute graph (relevant for AOT | ||
compilation). | ||
- `PrecomputeCacheType.NOCACHE` (or `"nocache"` or `None`): Avoids | ||
immediate cache computation. This is useful for avoiding extraneous | ||
computations when you only want to call the posterior's | ||
`fused_predict_f` method. | ||
""" | ||
|
||
X, Y = self.data | ||
|
||
return posteriors.HeteroskedasticGPRPosterior( | ||
kernel=self.kernel, | ||
X_data=X, | ||
Y_data=Y, | ||
likelihood=self.likelihood, | ||
mean_function=self.mean_function, | ||
precompute_cache=precompute_cache, | ||
) | ||
|
||
def predict_y( | ||
self, Xnew: InputData, full_cov: bool = False, full_output_cov: bool = False | ||
) -> MeanAndVariance: | ||
""" | ||
Compute the mean and variance of the held-out data at the input points. | ||
""" | ||
if full_cov or full_output_cov: | ||
# See https://github.com/GPflow/GPflow/issues/1461 | ||
raise NotImplementedError( | ||
"The predict_y method currently supports only the argument values full_cov=False and full_output_cov=False" | ||
) | ||
|
||
f_mean, f_var = self.predict_f(Xnew, full_cov=full_cov, full_output_cov=full_output_cov) | ||
Fs = tf.concat([f_mean, Xnew], axis=-1) | ||
dummy_f_var = tf.zeros_like(f_var) | ||
F_vars = tf.concat([f_var, dummy_f_var], axis=-1) | ||
return self.likelihood.predict_mean_and_var(Fs, F_vars) | ||
|
||
def _add_noise_cov(self, X, K: tf.Tensor) -> tf.Tensor: | ||
""" | ||
Returns K + diag(σ²), where σ² is the likelihood noise variance (vector), | ||
and I is the corresponding identity matrix. | ||
""" | ||
dummy_F = tf.zeros_like(X) | ||
Fs = tf.concat([dummy_F, X], axis=-1) | ||
variances = self.likelihood.conditional_variance(Fs) | ||
return add_linear_noise_cov(K, tf.squeeze(variances)) | ||
|
||
def log_marginal_likelihood(self) -> tf.Tensor: | ||
r""" | ||
Computes the log marginal likelihood. | ||
|
||
.. math:: | ||
\log p(Y | \theta). | ||
|
||
""" | ||
X, Y = self.data | ||
K = self.kernel(X) | ||
ks = self._add_noise_cov(X, K) | ||
L = tf.linalg.cholesky(ks) | ||
m = self.mean_function(X) | ||
|
||
# [R,] log-likelihoods for each independent dimension of Y | ||
log_prob = multivariate_normal(Y, m, L) | ||
return tf.reduce_sum(log_prob) | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be useful if there was a way to follow the existing naming GPflow model naming convention in this case.