Skip to content

Commit

Permalink
Merge branch 'GPflow:develop' into cp_switchdim
Browse files Browse the repository at this point in the history
  • Loading branch information
clwgg committed Jul 27, 2021
2 parents 62eeadd + 3e4b9f1 commit a461461
Show file tree
Hide file tree
Showing 14 changed files with 508 additions and 76 deletions.
1 change: 1 addition & 0 deletions CONTRIBUTORS.md
Expand Up @@ -78,5 +78,6 @@ Because GitHub's [graph of contributors](http://github.com/GPflow/GPflow/graphs/
[@insysion](https://github.com/insysion)
[@sam-willis](https://github.com/sam-willis)
[@vatsalaggarwal](https://github.com/vatsalaggarwal)
[@Andrew878](https://github.com/Andrew878)

Add yourself when you first contribute to GPflow's code, tests, or documentation!
6 changes: 4 additions & 2 deletions RELEASE.md
Expand Up @@ -45,7 +45,9 @@ This release contains contributions from:

## Major Features and Improvements

* Refactor posterior base class to support other model types.
* Refactor posterior base class to support other model types. (#1695)
* Add new posterior class to enable faster predictions from the GPR model. (#1696)
* Construct Parameters from other Parameters and retain properties. (#1699)

## Bug Fixes and Other Changes

Expand All @@ -57,7 +59,7 @@ This release contains contributions from:

This release contains contributions from:

johnamcleod, st--
johnamcleod, st--, Andrew878


# Release 2.2.1
Expand Down
44 changes: 36 additions & 8 deletions doc/source/notebooks/advanced/fast_predictions.pct.py
Expand Up @@ -5,7 +5,7 @@
# extension: .py
# format_name: light
# format_version: '1.5'
# jupytext_version: 1.11.2
# jupytext_version: 1.11.3
# kernelspec:
# display_name: Python 3
# language: python
Expand Down Expand Up @@ -64,6 +64,8 @@
# \begin{equation*}
# \alpha = [K_{mm} + \sigma^2I]^{-1}\mathbf{y}\\ Q^{-1} = [K_{mm} + \sigma^2I]^{-1}
# \end{equation*}
# _(note in practice, we cache the cholesky decomposition of Q to take advantage of the 'base_conditional_with_lm' utility function)_
#
# in the case of the VGP and SVGP model these are:
# \begin{equation*}
# \alpha = K_{uu}^{-1}\mathbf{u}\\ Q^{-1} = K_{uu}^{-1}
Expand All @@ -76,31 +78,57 @@
#
# Note that in the (S)VGP case, $\alpha$ is the parameter as proposed by Opper and Archambeau for the mean of the predictive distribution.

# +
import gpflow
import numpy as np

# Create some data
X = np.linspace(-1.1, 1.1, 1000)[:, None]
Y = np.cos(X)
Xnew = np.linspace(-1.1, 1.1, 1000)[:, None]

# + [markdown] id="FzCgor4nKUcW"
# ## Example
#
# We will construct an SVGP model to demonstrate the faster predictions from using the cached data in the GPFlow posterior classes (subclasses of `gpflow.posteriors.AbstractPosterior`).
# ## GPR Example
#
# We will construct a GPR model to demonstrate the faster predictions from using the cached data in the GPFlow posterior classes (subclasses of `gpflow.posteriors.AbstractPosterior`).

# + id="BMnIdXNiKU6t"
import gpflow
import numpy as np
model = gpflow.models.GPR(
(X, Y),
gpflow.kernels.SquaredExponential(),
)
# -

# The `predict_f` method on the `GPModel` class performs no caching.

# %%timeit
model.predict_f(Xnew)

# To make use of the caching, first retrieve the posterior class from the model. The posterior class has methods to predict the parameters of marginal distributions at test points, in the same way as the `predict_f` method of the `GPModel`.
posterior = model.posterior()

# %%timeit
posterior.predict_f(Xnew)

# ## SVGP Example
#
# Likewise, we will construct an SVGP model to demonstrate the faster predictions from using the cached data in the GPFlow posterior classes.

# + id="BMnIdXNiKU6t"
model = gpflow.models.SVGP(
gpflow.kernels.SquaredExponential(),
gpflow.likelihoods.Gaussian(),
np.linspace(-1.1, 1.1, 1000)[:, None],
)

Xnew = np.linspace(-1.1, 1.1, 1000)[:, None]
# -

# The `predict_f` method on the `GPModel` class performs no caching.

# %%timeit
model.predict_f(Xnew)

# To make use of the caching, first retrieve the posterior class from the model. The posterior class has methods to predict the parameters of marginal distributions at test points, in the same way as the `predict_f` method of the `GPModel` .
# And again using the posterior object and caching

posterior = model.posterior()

Expand Down
30 changes: 21 additions & 9 deletions gpflow/base.py
Expand Up @@ -12,14 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
from enum import Enum
from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Tuple, Union

import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow.python.ops import array_ops
from typing_extensions import Final

from .config import default_float, default_summary_fmt
Expand Down Expand Up @@ -101,12 +99,12 @@ class PriorOn(Enum):
class Parameter(tfp.util.TransformedVariable):
def __init__(
self,
value: TensorData,
value: Union[TensorData, "Parameter"],
*,
transform: Optional[Transform] = None,
prior: Optional[Prior] = None,
prior_on: Union[str, PriorOn] = PriorOn.CONSTRAINED,
trainable: bool = True,
prior_on: Optional[Union[str, PriorOn]] = None,
trainable: Optional[bool] = None,
dtype: Optional[DType] = None,
name: Optional[str] = None,
):
Expand All @@ -117,14 +115,28 @@ def __init__(
therefore we need a positive constraint and it is natural to use constrained values.
A prior can be imposed either on the constrained version (default) or on the unconstrained version of the parameter.
"""
if transform is None:
transform = tfp.bijectors.Identity()
if isinstance(value, Parameter):
transform = transform or value.transform
prior = prior or value.prior
prior_on = prior_on or value.prior_on
name = name or value.bijector.name
trainable = value.trainable if trainable is None else trainable

if dtype:
value = _cast_to_dtype(value, dtype)
else:
if transform is None:
transform = tfp.bijectors.Identity()

prior_on = prior_on if prior_on else PriorOn.CONSTRAINED
trainable = trainable if trainable is not None else True

value = _cast_to_dtype(value, dtype)

value = _cast_to_dtype(value, dtype)
_validate_unconstrained_value(value, transform, dtype)
super().__init__(value, transform, dtype=value.dtype, trainable=trainable, name=name)

self.prior = prior
self.prior = prior # type: Optional[Prior]
self.prior_on = prior_on # type: ignore # see https://github.com/python/mypy/issues/3004

def log_prior_density(self) -> tf.Tensor:
Expand Down
71 changes: 63 additions & 8 deletions gpflow/models/gpr.py
Expand Up @@ -18,15 +18,17 @@

import gpflow

from .. import posteriors
from ..kernels import Kernel
from ..logdensities import multivariate_normal
from ..mean_functions import MeanFunction
from ..utilities.model_utils import add_noise_cov
from .model import GPModel, InputData, MeanAndVariance, RegressionData
from .training_mixins import InternalDataTrainingLossMixin
from .util import data_input_to_tensor


class GPR(GPModel, InternalDataTrainingLossMixin):
class GPR_deprecated(GPModel, InternalDataTrainingLossMixin):
r"""
Gaussian Process Regression.
Expand Down Expand Up @@ -69,9 +71,7 @@ def _add_noise_cov(self, K: tf.Tensor) -> tf.Tensor:
Returns K + σ² I, where σ² is the likelihood noise variance (scalar),
and I is the corresponding identity matrix.
"""
k_diag = tf.linalg.diag_part(K)
s_diag = tf.fill(tf.shape(k_diag), self.likelihood.variance)
return tf.linalg.set_diag(K, k_diag + s_diag)
return add_noise_cov(K, self.likelihood.variance)

def log_marginal_likelihood(self) -> tf.Tensor:
r"""
Expand Down Expand Up @@ -102,12 +102,12 @@ def predict_f(
where F* are points on the GP at new data points, Y are noisy observations at training data points.
"""
X_data, Y_data = self.data
err = Y_data - self.mean_function(X_data)
X, Y = self.data
err = Y - self.mean_function(X)

kmm = self.kernel(X_data)
kmm = self.kernel(X)
knn = self.kernel(Xnew, full_cov=full_cov)
kmn = self.kernel(X_data, Xnew)
kmn = self.kernel(X, Xnew)
kmm_plus_s = self._add_noise_cov(kmm)

conditional = gpflow.conditionals.base_conditional
Expand All @@ -116,3 +116,58 @@ def predict_f(
) # [N, P], [N, P] or [P, N, N]
f_mean = f_mean_zero + self.mean_function(Xnew)
return f_mean, f_var


class GPR_with_posterior(GPR_deprecated):
"""
This is an implementation of GPR that provides a posterior() method that
enables caching for faster subsequent predictions.
"""

def posterior(self, precompute_cache=posteriors.PrecomputeCacheType.TENSOR):
"""
Create the Posterior object which contains precomputed matrices for
faster prediction.
precompute_cache has three settings:
- `PrecomputeCacheType.TENSOR` (or `"tensor"`): Precomputes the cached
quantities and stores them as tensors (which allows differentiating
through the prediction). This is the default.
- `PrecomputeCacheType.VARIABLE` (or `"variable"`): Precomputes the cached
quantities and stores them as variables, which allows for updating
their values without changing the compute graph (relevant for AOT
compilation).
- `PrecomputeCacheType.NOCACHE` (or `"nocache"` or `None`): Avoids
immediate cache computation. This is useful for avoiding extraneous
computations when you only want to call the posterior's
`fused_predict_f` method.
"""

X, Y = self.data

return posteriors.GPRPosterior(
kernel=self.kernel,
X_data=X,
Y_data=Y,
likelihood_variance=self.likelihood.variance,
mean_function=self.mean_function,
precompute_cache=precompute_cache,
)

def predict_f(self, Xnew: InputData, full_cov=False, full_output_cov=False) -> MeanAndVariance:
"""
For backwards compatibility, GPR's predict_f uses the fused (no-cache)
computation, which is more efficient during training.
For faster (cached) prediction, predict directly from the posterior object, i.e.,:
model.posterior().predict_f(Xnew, ...)
"""
return self.posterior(posteriors.PrecomputeCacheType.NOCACHE).fused_predict_f(
Xnew, full_cov=full_cov, full_output_cov=full_output_cov
)


class GPR(GPR_with_posterior):
# subclassed to ensure __class__ == "GPR"
pass

0 comments on commit a461461

Please sign in to comment.