diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 144c51a86f..817f0a1354 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -78,5 +78,6 @@ Because GitHub's [graph of contributors](http://github.com/GPflow/GPflow/graphs/ [@insysion](https://github.com/insysion) [@sam-willis](https://github.com/sam-willis) [@vatsalaggarwal](https://github.com/vatsalaggarwal) +[@Andrew878](https://github.com/Andrew878) Add yourself when you first contribute to GPflow's code, tests, or documentation! diff --git a/RELEASE.md b/RELEASE.md index a86a810c64..5ccf8b1319 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -45,7 +45,9 @@ This release contains contributions from: ## Major Features and Improvements -* Refactor posterior base class to support other model types. +* Refactor posterior base class to support other model types. (#1695) +* Add new posterior class to enable faster predictions from the GPR model. (#1696) +* Construct Parameters from other Parameters and retain properties. (#1699) ## Bug Fixes and Other Changes @@ -57,7 +59,7 @@ This release contains contributions from: This release contains contributions from: -johnamcleod, st-- +johnamcleod, st--, Andrew878 # Release 2.2.1 diff --git a/doc/source/notebooks/advanced/fast_predictions.pct.py b/doc/source/notebooks/advanced/fast_predictions.pct.py index c68d748b8f..e0dae110e3 100644 --- a/doc/source/notebooks/advanced/fast_predictions.pct.py +++ b/doc/source/notebooks/advanced/fast_predictions.pct.py @@ -5,7 +5,7 @@ # extension: .py # format_name: light # format_version: '1.5' -# jupytext_version: 1.11.2 +# jupytext_version: 1.11.3 # kernelspec: # display_name: Python 3 # language: python @@ -64,6 +64,8 @@ # \begin{equation*} # \alpha = [K_{mm} + \sigma^2I]^{-1}\mathbf{y}\\ Q^{-1} = [K_{mm} + \sigma^2I]^{-1} # \end{equation*} +# _(note in practice, we cache the cholesky decomposition of Q to take advantage of the 'base_conditional_with_lm' utility function)_ +# # in the case of the VGP and SVGP model these are: # \begin{equation*} # \alpha = K_{uu}^{-1}\mathbf{u}\\ Q^{-1} = K_{uu}^{-1} @@ -76,23 +78,49 @@ # # Note that in the (S)VGP case, $\alpha$ is the parameter as proposed by Opper and Archambeau for the mean of the predictive distribution. +# + +import gpflow +import numpy as np + +# Create some data +X = np.linspace(-1.1, 1.1, 1000)[:, None] +Y = np.cos(X) +Xnew = np.linspace(-1.1, 1.1, 1000)[:, None] + # + [markdown] id="FzCgor4nKUcW" -# ## Example # -# We will construct an SVGP model to demonstrate the faster predictions from using the cached data in the GPFlow posterior classes (subclasses of `gpflow.posteriors.AbstractPosterior`). +# ## GPR Example +# +# We will construct a GPR model to demonstrate the faster predictions from using the cached data in the GPFlow posterior classes (subclasses of `gpflow.posteriors.AbstractPosterior`). # + id="BMnIdXNiKU6t" -import gpflow -import numpy as np +model = gpflow.models.GPR( + (X, Y), + gpflow.kernels.SquaredExponential(), +) +# - + +# The `predict_f` method on the `GPModel` class performs no caching. +# %%timeit +model.predict_f(Xnew) +# To make use of the caching, first retrieve the posterior class from the model. The posterior class has methods to predict the parameters of marginal distributions at test points, in the same way as the `predict_f` method of the `GPModel`. +posterior = model.posterior() + +# %%timeit +posterior.predict_f(Xnew) + +# ## SVGP Example +# +# Likewise, we will construct an SVGP model to demonstrate the faster predictions from using the cached data in the GPFlow posterior classes. + +# + id="BMnIdXNiKU6t" model = gpflow.models.SVGP( gpflow.kernels.SquaredExponential(), gpflow.likelihoods.Gaussian(), np.linspace(-1.1, 1.1, 1000)[:, None], ) - -Xnew = np.linspace(-1.1, 1.1, 1000)[:, None] # - # The `predict_f` method on the `GPModel` class performs no caching. @@ -100,7 +128,7 @@ # %%timeit model.predict_f(Xnew) -# To make use of the caching, first retrieve the posterior class from the model. The posterior class has methods to predict the parameters of marginal distributions at test points, in the same way as the `predict_f` method of the `GPModel` . +# And again using the posterior object and caching posterior = model.posterior() diff --git a/gpflow/base.py b/gpflow/base.py index de8a133452..c66fe57bb7 100644 --- a/gpflow/base.py +++ b/gpflow/base.py @@ -12,14 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import functools from enum import Enum from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Tuple, Union import numpy as np import tensorflow as tf import tensorflow_probability as tfp -from tensorflow.python.ops import array_ops from typing_extensions import Final from .config import default_float, default_summary_fmt @@ -101,12 +99,12 @@ class PriorOn(Enum): class Parameter(tfp.util.TransformedVariable): def __init__( self, - value: TensorData, + value: Union[TensorData, "Parameter"], *, transform: Optional[Transform] = None, prior: Optional[Prior] = None, - prior_on: Union[str, PriorOn] = PriorOn.CONSTRAINED, - trainable: bool = True, + prior_on: Optional[Union[str, PriorOn]] = None, + trainable: Optional[bool] = None, dtype: Optional[DType] = None, name: Optional[str] = None, ): @@ -117,14 +115,28 @@ def __init__( therefore we need a positive constraint and it is natural to use constrained values. A prior can be imposed either on the constrained version (default) or on the unconstrained version of the parameter. """ - if transform is None: - transform = tfp.bijectors.Identity() + if isinstance(value, Parameter): + transform = transform or value.transform + prior = prior or value.prior + prior_on = prior_on or value.prior_on + name = name or value.bijector.name + trainable = value.trainable if trainable is None else trainable + + if dtype: + value = _cast_to_dtype(value, dtype) + else: + if transform is None: + transform = tfp.bijectors.Identity() + + prior_on = prior_on if prior_on else PriorOn.CONSTRAINED + trainable = trainable if trainable is not None else True + + value = _cast_to_dtype(value, dtype) - value = _cast_to_dtype(value, dtype) _validate_unconstrained_value(value, transform, dtype) super().__init__(value, transform, dtype=value.dtype, trainable=trainable, name=name) - self.prior = prior + self.prior = prior # type: Optional[Prior] self.prior_on = prior_on # type: ignore # see https://github.com/python/mypy/issues/3004 def log_prior_density(self) -> tf.Tensor: diff --git a/gpflow/models/gpr.py b/gpflow/models/gpr.py index 45f56dee75..dabef02d44 100644 --- a/gpflow/models/gpr.py +++ b/gpflow/models/gpr.py @@ -18,15 +18,17 @@ import gpflow +from .. import posteriors from ..kernels import Kernel from ..logdensities import multivariate_normal from ..mean_functions import MeanFunction +from ..utilities.model_utils import add_noise_cov from .model import GPModel, InputData, MeanAndVariance, RegressionData from .training_mixins import InternalDataTrainingLossMixin from .util import data_input_to_tensor -class GPR(GPModel, InternalDataTrainingLossMixin): +class GPR_deprecated(GPModel, InternalDataTrainingLossMixin): r""" Gaussian Process Regression. @@ -69,9 +71,7 @@ def _add_noise_cov(self, K: tf.Tensor) -> tf.Tensor: Returns K + σ² I, where σ² is the likelihood noise variance (scalar), and I is the corresponding identity matrix. """ - k_diag = tf.linalg.diag_part(K) - s_diag = tf.fill(tf.shape(k_diag), self.likelihood.variance) - return tf.linalg.set_diag(K, k_diag + s_diag) + return add_noise_cov(K, self.likelihood.variance) def log_marginal_likelihood(self) -> tf.Tensor: r""" @@ -102,12 +102,12 @@ def predict_f( where F* are points on the GP at new data points, Y are noisy observations at training data points. """ - X_data, Y_data = self.data - err = Y_data - self.mean_function(X_data) + X, Y = self.data + err = Y - self.mean_function(X) - kmm = self.kernel(X_data) + kmm = self.kernel(X) knn = self.kernel(Xnew, full_cov=full_cov) - kmn = self.kernel(X_data, Xnew) + kmn = self.kernel(X, Xnew) kmm_plus_s = self._add_noise_cov(kmm) conditional = gpflow.conditionals.base_conditional @@ -116,3 +116,58 @@ def predict_f( ) # [N, P], [N, P] or [P, N, N] f_mean = f_mean_zero + self.mean_function(Xnew) return f_mean, f_var + + +class GPR_with_posterior(GPR_deprecated): + """ + This is an implementation of GPR that provides a posterior() method that + enables caching for faster subsequent predictions. + """ + + def posterior(self, precompute_cache=posteriors.PrecomputeCacheType.TENSOR): + """ + Create the Posterior object which contains precomputed matrices for + faster prediction. + + precompute_cache has three settings: + + - `PrecomputeCacheType.TENSOR` (or `"tensor"`): Precomputes the cached + quantities and stores them as tensors (which allows differentiating + through the prediction). This is the default. + - `PrecomputeCacheType.VARIABLE` (or `"variable"`): Precomputes the cached + quantities and stores them as variables, which allows for updating + their values without changing the compute graph (relevant for AOT + compilation). + - `PrecomputeCacheType.NOCACHE` (or `"nocache"` or `None`): Avoids + immediate cache computation. This is useful for avoiding extraneous + computations when you only want to call the posterior's + `fused_predict_f` method. + """ + + X, Y = self.data + + return posteriors.GPRPosterior( + kernel=self.kernel, + X_data=X, + Y_data=Y, + likelihood_variance=self.likelihood.variance, + mean_function=self.mean_function, + precompute_cache=precompute_cache, + ) + + def predict_f(self, Xnew: InputData, full_cov=False, full_output_cov=False) -> MeanAndVariance: + """ + For backwards compatibility, GPR's predict_f uses the fused (no-cache) + computation, which is more efficient during training. + + For faster (cached) prediction, predict directly from the posterior object, i.e.,: + model.posterior().predict_f(Xnew, ...) + """ + return self.posterior(posteriors.PrecomputeCacheType.NOCACHE).fused_predict_f( + Xnew, full_cov=full_cov, full_output_cov=full_output_cov + ) + + +class GPR(GPR_with_posterior): + # subclassed to ensure __class__ == "GPR" + pass diff --git a/gpflow/posteriors.py b/gpflow/posteriors.py index bba970afa8..0ecf38f201 100644 --- a/gpflow/posteriors.py +++ b/gpflow/posteriors.py @@ -21,9 +21,10 @@ import tensorflow_probability as tfp from . import covariances, kernels, mean_functions -from .base import Module, TensorType +from .base import Module, Parameter, TensorType from .conditionals.util import ( base_conditional, + base_conditional_with_lm, expand_independent_outputs, fully_correlated_conditional, independent_interdomain_conditional, @@ -40,7 +41,7 @@ SharedIndependentInducingVariables, ) from .types import MeanAndVariance -from .utilities import Dispatcher +from .utilities import Dispatcher, add_noise_cov class _QDistribution(Module): @@ -180,6 +181,117 @@ def _conditional_with_precompute( Relies on cached alpha and Qinv. """ + def update_cache(self, precompute_cache: Optional[PrecomputeCacheType] = None) -> None: + """ + Sets the cache depending on the value of `precompute_cache` to a + `tf.Tensor`, `tf.Variable`, or clears the cache. If `precompute_cache` + is not given, the setting defaults to the most-recently-used one. + """ + if precompute_cache is None: + try: + precompute_cache = cast( + PrecomputeCacheType, + self._precompute_cache, # type: ignore + ) + except AttributeError: + raise ValueError( + "You must pass precompute_cache explicitly (the cache had not been updated before)." + ) + else: + self._precompute_cache = precompute_cache + + if precompute_cache is PrecomputeCacheType.NOCACHE: + self.alpha = self.Qinv = None + + elif precompute_cache is PrecomputeCacheType.TENSOR: + self.alpha, self.Qinv = self._precompute() + + elif precompute_cache is PrecomputeCacheType.VARIABLE: + alpha, Qinv = self._precompute() + if isinstance(self.alpha, tf.Variable) and isinstance(self.Qinv, tf.Variable): + # re-use existing variables + self.alpha.assign(alpha) + self.Qinv.assign(Qinv) + else: # create variables + self.alpha = tf.Variable(alpha, trainable=False) + self.Qinv = tf.Variable(Qinv, trainable=False) + + +class GPRPosterior(AbstractPosterior): + def __init__( + self, + kernel, + X_data: tf.Tensor, + Y_data: tf.Tensor, + likelihood_variance: Parameter, + mean_function: Optional[mean_functions.MeanFunction] = None, + *, + precompute_cache: Optional[PrecomputeCacheType], + ): + + super().__init__(kernel, X_data, mean_function=mean_function) + self.mean_function = mean_function + self.Y_data = Y_data + self.likelihood_variance = likelihood_variance + + if precompute_cache is not None: + self.update_cache(precompute_cache) + + def _conditional_with_precompute( + self, Xnew, full_cov: bool = False, full_output_cov: bool = False + ) -> MeanAndVariance: + """ + Computes predictive mean and (co)variance at Xnew, *excluding* mean_function. + Relies on cached alpha and Qinv. + """ + Kmn = self.kernel(self.X_data, Xnew) + Knn = self.kernel(Xnew, full_cov=full_cov) + + return base_conditional_with_lm(Kmn, self.Qinv, Knn, self.alpha, full_cov=full_cov) + + def _precompute(self) -> Tuple[tf.Tensor, tf.Tensor]: + + """ + Precomputes the cholesky decomposition of Kmm_plus_s for later reuse we will call + base_conditional_with_lm implementation ('Qinv' in the Abstract Posterior class). We also + precompute the less compute intensive error term ('alpha' in the Abstract Posterior class) + """ + + Kmm = self.kernel(self.X_data) + Kmm_plus_s = add_noise_cov(Kmm, self.likelihood_variance) + + # obtain the cholesky decomposition of Kmm_plus_s + Lm = tf.linalg.cholesky(Kmm_plus_s) + + alpha = self.Y_data - self.mean_function(self.X_data) # type: ignore + tf.debugging.assert_shapes( + [ + (Lm, ["M", "M"]), + (Kmm, ["M", "M"]), + ] + ) + return alpha, Lm + + def _conditional_fused( + self, Xnew, full_cov: bool = False, full_output_cov: bool = False + ) -> MeanAndVariance: + """ + Computes predictive mean and (co)variance at Xnew, *excluding* mean_function + Does not make use of caching + """ + + # taken directly from the deprecated GPR implementation + err = self.Y_data - self.mean_function(self.X_data) # type: ignore + + Kmm = self.kernel(self.X_data) + Knn = self.kernel(Xnew, full_cov=full_cov) + Kmn = self.kernel(self.X_data, Xnew) + Kmm_plus_s = add_noise_cov(Kmm, self.likelihood_variance) + + return base_conditional( + Kmn, Kmm_plus_s, Knn, err, full_cov=full_cov, white=False + ) # [N, P], [N, P] or [P, N, N] + class BasePosterior(AbstractPosterior): def __init__( @@ -217,41 +329,6 @@ def _set_qdist(self, q_mu, q_sqrt): else: self._q_dist = _MvNormal(q_mu, q_sqrt) - def update_cache(self, precompute_cache: Optional[PrecomputeCacheType] = None): - """ - Sets the cache depending on the value of `precompute_cache` to a - `tf.Tensor`, `tf.Variable`, or clears the cache. If `precompute_cache` - is not given, the setting defaults to the most-recently-used one. - """ - if precompute_cache is None: - try: - precompute_cache = cast( - PrecomputeCacheType, - self._precompute_cache, # type: ignore - ) - except AttributeError: - raise ValueError( - "You must pass precompute_cache explicitly (the cache had not been updated before)." - ) - else: - self._precompute_cache = precompute_cache - - if precompute_cache is PrecomputeCacheType.NOCACHE: - self.alpha = self.Qinv = None - - elif precompute_cache is PrecomputeCacheType.TENSOR: - self.alpha, self.Qinv = self._precompute() - - elif precompute_cache is PrecomputeCacheType.VARIABLE: - alpha, Qinv = self._precompute() - if isinstance(self.alpha, tf.Variable) and isinstance(self.Qinv, tf.Variable): - # re-use existing variables - self.alpha.assign(alpha) - self.Qinv.assign(Qinv) - else: # create variables - self.alpha = tf.Variable(alpha, trainable=False) - self.Qinv = tf.Variable(Qinv, trainable=False) - def _precompute(self): Kuu = covariances.Kuu(self.X_data, self.kernel, jitter=default_jitter()) # [(R), M, M] q_mu = self._q_dist.q_mu diff --git a/gpflow/utilities/__init__.py b/gpflow/utilities/__init__.py index 9ae827e06a..107e278de2 100644 --- a/gpflow/utilities/__init__.py +++ b/gpflow/utilities/__init__.py @@ -1,4 +1,5 @@ from .bijectors import * from .misc import * +from .model_utils import * from .multipledispatch import Dispatcher from .traversal import * diff --git a/gpflow/utilities/model_utils.py b/gpflow/utilities/model_utils.py new file mode 100644 index 0000000000..17da8459bb --- /dev/null +++ b/gpflow/utilities/model_utils.py @@ -0,0 +1,13 @@ +import tensorflow as tf + +from ..base import Parameter + + +def add_noise_cov(K: tf.Tensor, likelihood_variance: Parameter) -> tf.Tensor: + """ + Returns K + σ² I, where σ² is the likelihood noise variance (scalar), + and I is the corresponding identity matrix. + """ + k_diag = tf.linalg.diag_part(K) + s_diag = tf.fill(tf.shape(k_diag), likelihood_variance) + return tf.linalg.set_diag(K, k_diag + s_diag) diff --git a/tests/gpflow/models/test_gpr_posterior.py b/tests/gpflow/models/test_gpr_posterior.py new file mode 100644 index 0000000000..33b364eda8 --- /dev/null +++ b/tests/gpflow/models/test_gpr_posterior.py @@ -0,0 +1,68 @@ +from itertools import product +from typing import Tuple + +import numpy as np +import pytest +import tensorflow as tf + +import gpflow +from gpflow.models.gpr import GPR_deprecated, GPR_with_posterior +from gpflow.posteriors import PrecomputeCacheType + +input_dim = 7 +output_dim = 1 + + +def make_models(regression_data): + """Helper function to create models""" + + k = gpflow.kernels.Matern52() + + mold = GPR_deprecated(data=regression_data, kernel=k) + mnew = GPR_with_posterior(data=regression_data, kernel=k) + return mold, mnew + + +def _get_data_for_tests(): + """Helper function to create testing data""" + X = np.random.randn(100, input_dim) + Y = np.random.randn(100, output_dim) + X_new = np.random.randn(100, input_dim) + return X, X_new, Y + + +@pytest.mark.parametrize("cache_type", [PrecomputeCacheType.TENSOR, PrecomputeCacheType.VARIABLE]) +@pytest.mark.parametrize("full_cov", [True, False]) +@pytest.mark.parametrize("full_output_cov", [True, False]) +def test_old_vs_new_gp_fused( + cache_type: PrecomputeCacheType, full_cov: bool, full_output_cov: bool +): + X, X_new, Y = _get_data_for_tests() + mold, mnew = make_models((X, Y)) + + mu_old, var2_old = mold.predict_f(X_new, full_cov=full_cov, full_output_cov=full_output_cov) + mu_new_fuse, var2_new_fuse = mnew.predict_f( + X_new, full_cov=full_cov, full_output_cov=full_output_cov + ) + # check new fuse is same as old version + np.testing.assert_allclose(mu_new_fuse, mu_old) + np.testing.assert_allclose(var2_new_fuse, var2_old) + + +@pytest.mark.parametrize("cache_type", [PrecomputeCacheType.TENSOR, PrecomputeCacheType.VARIABLE]) +@pytest.mark.parametrize("full_cov", [True, False]) +@pytest.mark.parametrize("full_output_cov", [True, False]) +def test_old_vs_new_with_posterior( + cache_type: PrecomputeCacheType, full_cov: bool, full_output_cov: bool +): + X, X_new, Y = _get_data_for_tests() + mold, mnew = make_models((X, Y)) + + mu_old, var2_old = mold.predict_f(X_new, full_cov=full_cov, full_output_cov=full_output_cov) + mu_new_cache, var2_new_cache = mnew.posterior(cache_type).predict_f( + X_new, full_cov=full_cov, full_output_cov=full_output_cov + ) + + # check new cache is same as old version + np.testing.assert_allclose(mu_old, mu_new_cache) + np.testing.assert_allclose(var2_old, var2_new_cache) diff --git a/tests/gpflow/models/test_svgp_posterior.py b/tests/gpflow/models/test_svgp_posterior.py index 5e0dab3fae..afb0f094d3 100644 --- a/tests/gpflow/models/test_svgp_posterior.py +++ b/tests/gpflow/models/test_svgp_posterior.py @@ -63,15 +63,14 @@ def make_models(M=64, D=input_dim, L=3, q_diag=False, whiten=True, mo=None): ), ], ) -def test_old_vs_new_svgp(q_diag, white, multioutput): +@pytest.mark.parametrize("full_cov", [True, False]) +@pytest.mark.parametrize("full_output_cov", [True, False]) +def test_old_vs_new_svgp(q_diag, white, multioutput, full_cov: bool, full_output_cov: bool): mold, mnew = make_models(q_diag=q_diag, whiten=white, mo=multioutput) X = np.random.randn(100, input_dim) - Xt = tf.convert_to_tensor(X) - for full_cov in (True, False): - for full_output_cov in (True, False): - mu, var = mnew.predict_f(X, full_cov=full_cov, full_output_cov=full_output_cov) - mu2, var2 = mold.predict_f(X, full_cov=full_cov, full_output_cov=full_output_cov) - np.testing.assert_allclose(mu, mu2) - np.testing.assert_allclose(var, var2) + mu, var = mnew.predict_f(X, full_cov=full_cov, full_output_cov=full_output_cov) + mu2, var2 = mold.predict_f(X, full_cov=full_cov, full_output_cov=full_output_cov) + np.testing.assert_allclose(mu, mu2) + np.testing.assert_allclose(var, var2) diff --git a/tests/gpflow/optimizers/test_mcmc.py b/tests/gpflow/optimizers/test_mcmc.py index cce728a1a6..7a9b96bdfd 100644 --- a/tests/gpflow/optimizers/test_mcmc.py +++ b/tests/gpflow/optimizers/test_mcmc.py @@ -34,12 +34,14 @@ def build_model(data): return model -def build_model_with_uniform_prior(data, prior_on, prior_width): +def build_model_with_uniform_prior_no_transforms(data, prior_on, prior_width): def parameter(value): low_value = -100 high_value = low_value + prior_width prior = Uniform(low=np.float64(low_value), high=np.float64(high_value)) - return gpflow.Parameter(value, prior=prior, prior_on=prior_on) + return gpflow.Parameter( + value, transform=tfp.bijectors.Identity(), prior=prior, prior_on=prior_on + ) k = gpflow.kernels.Matern52(lengthscales=0.3) k.variance = parameter(k.variance) @@ -115,7 +117,7 @@ def test_mcmc_helper_target_function_unconstrained(): prior_width = 200.0 data = build_data() - model = build_model_with_uniform_prior(data, "unconstrained", prior_width) + model = build_model_with_uniform_prior_no_transforms(data, "unconstrained", prior_width) hmc_helper = gpflow.optimizers.SamplingHelper( model.log_posterior_density, model.trainable_parameters @@ -138,7 +140,7 @@ def test_mcmc_helper_target_function_no_transforms(prior_on): prior_width = 200.0 data = build_data() - model = build_model_with_uniform_prior(data, prior_on, prior_width) + model = build_model_with_uniform_prior_no_transforms(data, prior_on, prior_width) hmc_helper = gpflow.optimizers.SamplingHelper( model.log_posterior_density, model.trainable_parameters diff --git a/tests/gpflow/posteriors/test_posteriors.py b/tests/gpflow/posteriors/test_posteriors.py index fbb33485b5..32518b279b 100644 --- a/tests/gpflow/posteriors/test_posteriors.py +++ b/tests/gpflow/posteriors/test_posteriors.py @@ -21,11 +21,13 @@ import gpflow import gpflow.ci_utils from gpflow.conditionals import conditional +from gpflow.mean_functions import Zero from gpflow.models.util import inducingpoint_wrapper from gpflow.posteriors import ( AbstractPosterior, FallbackIndependentLatentPosterior, FullyCorrelatedPosterior, + GPRPosterior, IndependentPosteriorMultiOutput, IndependentPosteriorSingleOutput, LinearCoregionalizationPosterior, @@ -625,6 +627,31 @@ def test_posterior_update_cache_with_variables_no_precompute( assert isinstance(posterior.Qinv, tf.Variable) +@pytest.mark.parametrize( + "precompute_cache_type", [PrecomputeCacheType.NOCACHE, PrecomputeCacheType.TENSOR] +) +def test_gpr_posterior_update_cache_with_variables_no_precompute( + register_posterior_test, q_sqrt_factory, whiten, precompute_cache_type +): + kernel = gpflow.kernels.SquaredExponential() + X = np.random.randn(NUM_INDUCING_POINTS, INPUT_DIMS) + Y = np.random.randn(NUM_INDUCING_POINTS, 1) + + posterior = GPRPosterior( + kernel=kernel, + X_data=X, + Y_data=Y, + likelihood_variance=gpflow.Parameter(0.1), + precompute_cache=precompute_cache_type, + mean_function=Zero(), + ) + posterior.update_cache(PrecomputeCacheType.VARIABLE) + register_posterior_test(posterior, GPRPosterior) + + assert isinstance(posterior.alpha, tf.Variable) + assert isinstance(posterior.Qinv, tf.Variable) + + def test_posterior_update_cache_with_variables_update_value(q_sqrt_factory, whiten): # setup posterior kernel = gpflow.kernels.SquaredExponential() diff --git a/tests/gpflow/test_base.py b/tests/gpflow/test_base.py index d8c45ef477..aab6938466 100644 --- a/tests/gpflow/test_base.py +++ b/tests/gpflow/test_base.py @@ -18,8 +18,10 @@ import numpy as np import pytest import tensorflow as tf +import tensorflow_probability as tfp import gpflow +from gpflow.base import PriorOn from gpflow.utilities import positive @@ -79,3 +81,135 @@ def exec(self, x: tf.Tensor) -> tf.Tensor: m1 = tf.saved_model.load(dirname) actual = m1.exec(x) np.testing.assert_equal(actual, expected) + + +@pytest.mark.parametrize("value", [0.0, [1.2, 1.1]]) +def test_construct_parameter_from_existing_parameter_check_value(value): + initial_parameter = gpflow.Parameter(value) + new_parameter = gpflow.Parameter(initial_parameter) + + np.testing.assert_equal(new_parameter.numpy(), value) + + +@pytest.mark.parametrize("value", [0.0, [1.2, 1.1]]) +def test_construct_parameter_from_existing_parameter_override_value(value): + initial_parameter = gpflow.Parameter(value) + new_parameter = gpflow.Parameter(initial_parameter + 1.0) + + np.testing.assert_equal(new_parameter.numpy(), np.array(value) + 1.0) + + +def test_construct_parameter_from_existing_parameter_check_transform(): + transform = tfp.bijectors.Sigmoid( + tf.constant(0.0, dtype=tf.float64), tf.constant(2.0, dtype=tf.float64) + ) + initial_parameter = gpflow.Parameter([1.2, 1.1], transform=transform) + new_parameter = gpflow.Parameter(initial_parameter) + + assert new_parameter.transform == transform + + +def test_construct_parameter_from_existing_parameter_override_transform(): + initial_parameter = gpflow.Parameter([1.2, 1.1]) + + transform = tfp.bijectors.Sigmoid( + tf.constant(0.0, dtype=tf.float64), tf.constant(2.0, dtype=tf.float64) + ) + new_parameter = gpflow.Parameter(initial_parameter, transform=transform) + + assert new_parameter.transform == transform + + +def test_construct_parameter_from_existing_parameter_check_prior(): + prior = tfp.distributions.Normal(0.0, 1.0) + initial_parameter = gpflow.Parameter([1.2, 1.1], prior=prior) + new_parameter = gpflow.Parameter(initial_parameter) + + assert new_parameter.prior == prior + + +def test_construct_parameter_from_existing_parameter_override_prior(): + initial_parameter = gpflow.Parameter([1.2, 1.1]) + + prior = tfp.distributions.Normal(0.0, 1.0) + new_parameter = gpflow.Parameter(initial_parameter, prior=prior) + + assert new_parameter.prior == prior + + +@pytest.mark.parametrize("prior_on", [PriorOn.CONSTRAINED, PriorOn.UNCONSTRAINED]) +def test_construct_parameter_from_existing_parameter_check_prior_on(prior_on): + initial_parameter = gpflow.Parameter([1.2, 1.1], prior_on=prior_on) + new_parameter = gpflow.Parameter(initial_parameter) + + assert new_parameter.prior_on == prior_on + + +@pytest.mark.parametrize("prior_on", [PriorOn.CONSTRAINED, PriorOn.UNCONSTRAINED]) +def test_construct_parameter_from_existing_parameter_override_prior_on(prior_on): + initial_parameter = gpflow.Parameter([1.2, 1.1]) + new_parameter = gpflow.Parameter(initial_parameter, prior_on=prior_on) + + assert new_parameter.prior_on == prior_on + + +@pytest.mark.parametrize("trainable", [True, False]) +def test_construct_parameter_from_existing_parameter_check_trainable(trainable): + initial_parameter = gpflow.Parameter([1.2, 1.1], trainable=trainable) + new_parameter = gpflow.Parameter(initial_parameter) + + assert new_parameter.trainable == trainable + + +@pytest.mark.parametrize("trainable", [True, False]) +def test_construct_parameter_from_existing_parameter_override_trainable(trainable): + initial_parameter = gpflow.Parameter([1.2, 1.1], trainable=trainable) + new_parameter = gpflow.Parameter(initial_parameter, trainable=not trainable) + + assert new_parameter.trainable is not trainable + + +@pytest.mark.parametrize("dtype", [tf.float32, tf.float64]) +def test_construct_parameter_from_existing_parameter_check_dtype(dtype): + initial_parameter = gpflow.Parameter([1.1, 2.1], dtype=dtype) + new_parameter = gpflow.Parameter(initial_parameter) + + assert new_parameter.dtype == dtype + + +@pytest.mark.parametrize("dtype", [tf.float32, tf.float64]) +def test_construct_parameter_from_existing_parameter_override_dtype(dtype): + initial_parameter = gpflow.Parameter([1.1, 2.1]) + new_parameter = gpflow.Parameter(initial_parameter, dtype=dtype) + + assert new_parameter.dtype == dtype + + +def test_construct_parameter_from_existing_parameter_check_name(): + transform = tfp.bijectors.Sigmoid( + tf.constant(0.0, dtype=tf.float64), tf.constant(2.0, dtype=tf.float64) + ) + initial_parameter = gpflow.Parameter([1.2, 1.1], transform=transform) + new_parameter = gpflow.Parameter(initial_parameter) + + assert new_parameter.name == transform.name + + +def test_construct_parameter_from_existing_parameter_override_name(): + initial_parameter = gpflow.Parameter([1.2, 1.1]) + transform = tfp.bijectors.Sigmoid( + tf.constant(0.0, dtype=tf.float64), tf.constant(2.0, dtype=tf.float64) + ) + new_parameter = gpflow.Parameter(initial_parameter, transform=transform) + + assert new_parameter.name == transform.name + + +def test_construct_parameter_from_existing_parameter_value_becomes_invalid(): + initial_parameter = gpflow.Parameter(0.0) + transform = tfp.bijectors.Reciprocal() + + with pytest.raises(tf.errors.InvalidArgumentError) as exc: + gpflow.Parameter(initial_parameter, transform=transform) + + assert "gpflow.Parameter" in exc.value.message diff --git a/tests/gpflow/utilities/test_model_utils.py b/tests/gpflow/utilities/test_model_utils.py new file mode 100644 index 0000000000..64b51db14d --- /dev/null +++ b/tests/gpflow/utilities/test_model_utils.py @@ -0,0 +1,13 @@ +import pytest +import tensorflow as tf + +import gpflow +from gpflow.utilities import add_noise_cov + + +@pytest.mark.parametrize("input_tensor", [tf.constant([[1.0, 0.5], [0.5, 1.0]])]) +@pytest.mark.parametrize("variance", [gpflow.Parameter(1.0, dtype=tf.float32)]) +@pytest.mark.parametrize("expected_tensor", [tf.constant([[2.0, 0.5], [0.5, 2.0]])]) +def test_add_noise_cov(input_tensor, variance, expected_tensor): + actual_tensor = add_noise_cov(input_tensor, variance) + tf.debugging.assert_equal(actual_tensor, expected_tensor)