Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support determistic randomness. #2001

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ This release contains contributions from:

* `gpflow.experimental.check_shapes` has been removed, in favour of an independent release. Use
`pip install check_shapes` and `import check_shapes` instead.
* Many model methods now take an extra parameter, `seed`, and any custom models you have will need
to be updated to accept this parameter.
* Many likelihood methods now take an extra parameter, `seed`, and any custom likelihoods you have
will need to be updated to accept this parameter.

## Known Caveats

Expand All @@ -51,6 +55,10 @@ This release contains contributions from:
## Major Features and Improvements

* Major rework of documentation landing page and "getting started" section.
* Deterministic randomness. Many model and likelihood methods now take a `seed` parameter that can
be used to get deterministic randomness. Seed the
[new notebook](https://gpflow.github.io/GPflow/2.7.0/notebooks/advanced/random_state.html) for
details.

## Bug Fixes and Other Changes

Expand Down
16 changes: 10 additions & 6 deletions benchmark/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

import numpy as np
import pandas as pd
import tensorflow_probability as tfp
from tabulate import tabulate

import benchmark.benchmarks # pylint: disable=unused-import # Make sure registry is loaded.
Expand Down Expand Up @@ -63,19 +64,22 @@ def _collect_metrics(
test_data = dataset.test

rng = np.random.default_rng(random_seed)
mk_tf_seed = tfp.util.SeedStream(seed=rng.integers(1_000_000, size=2), salt="_collect_metrics")
model_fac = MODEL_FACTORIES.get(task.model_name)
model = model_fac.create_model(train_data, rng)

metrics = {}

model.predict_y(test_data.X) # Warm-up TF.
model.predict_y(test_data.X, seed=mk_tf_seed()) # Warm-up TF.

print("Model before training:")
gpflow.utilities.print_summary(model)

if task.do_optimise:
t_before = perf_counter()
loss_fn = gpflow.models.training_loss_closure(model, train_data.XY, compile=task.do_compile)
loss_fn = gpflow.models.training_loss_closure(
model, train_data.XY, seed=mk_tf_seed(), compile=task.do_compile
)
opt_log = gpflow.optimizers.Scipy().minimize(
loss_fn,
variables=model.trainable_variables,
Expand All @@ -101,10 +105,10 @@ def _collect_metrics(
metrics[mt.prediction_time] = t_after - t_before

metrics[mt.nlpd] = -np.sum(
likelihood.predict_log_density(test_data.X, f_m, f_v, test_data.Y)
likelihood.predict_log_density(test_data.X, f_m, f_v, test_data.Y, seed=mk_tf_seed())
)

y_m, _y_v = likelihood.predict_mean_and_var(test_data.X, f_m, f_v)
y_m, _y_v = likelihood.predict_mean_and_var(test_data.X, f_m, f_v, seed=mk_tf_seed())
error = test_data.Y - y_m
metrics[mt.mae] = np.average(np.abs(error))
metrics[mt.rmse] = np.average(error ** 2) ** 0.5
Expand All @@ -122,10 +126,10 @@ def _collect_metrics(
metrics[mt.posterior_prediction_time] = t_after - t_before

metrics[mt.posterior_nlpd] = -np.sum(
likelihood.predict_log_density(test_data.X, f_m, f_v, test_data.Y)
likelihood.predict_log_density(test_data.X, f_m, f_v, test_data.Y, seed=mk_tf_seed())
)

y_m, _y_v = likelihood.predict_mean_and_var(test_data.X, f_m, f_v)
y_m, _y_v = likelihood.predict_mean_and_var(test_data.X, f_m, f_v, seed=mk_tf_seed())
error = test_data.Y - y_m
metrics[mt.posterior_mae] = np.average(np.abs(error))
metrics[mt.posterior_rmse] = np.average(error ** 2) ** 0.5
Expand Down
210 changes: 210 additions & 0 deletions doc/sphinx/notebooks/advanced/random_state.pct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
# ---
# jupyter:
# jupytext:
# formats: ipynb,py:percent
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.14.1
# kernelspec:
# display_name: Python 3
# language: python
# name: python3
# ---

# %% [markdown]
# # Managing random state
#
# Many GPflow methods take an optional `seed` parameter. This can take three types of values:
# * `None`.
# * An integer.
# * A tensor of shape `[2]` and `dtype=tf.int32`.
#
# Below we will go over how these are interpreted, and what that means.

# %%
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp

import gpflow

# %% [markdown]
# Let us quickly create a model to test on:

# %%
model = gpflow.models.GPR(
(np.zeros((0, 1)), np.zeros((0, 1))),
kernel=gpflow.kernels.SquaredExponential(),
)
Xplot = np.linspace(0.0, 10.0, 100)[:, None]

# %% [markdown]
# ## Seed `None`
#
# When the seed is set to `None`, the randomness depends on the state of the TensorFlow global seed. This is the default behaviour.

# %%
tf.random.set_seed(123)
Yplot = model.predict_f_samples(Xplot, seed=None)
plt.plot(Xplot, Yplot)

tf.random.set_seed(123)
Yplot = model.predict_f_samples(Xplot, seed=None)
plt.plot(Xplot, Yplot, ls=":")

tf.random.set_seed(456)
Yplot = model.predict_f_samples(Xplot, seed=None)
_ = plt.plot(Xplot, Yplot, ls="-.")

# %% [markdown]
# ## Integer seed
# When the seed is set to an integer, the randomness depends on both the TensorFlow global seed, and the `seed` passed to the method.

# %%
tf.random.set_seed(123)
Yplot = model.predict_f_samples(Xplot, seed=1)
plt.plot(Xplot, Yplot)

tf.random.set_seed(123)
Yplot = model.predict_f_samples(Xplot, seed=1)
plt.plot(Xplot, Yplot, ls=":")

tf.random.set_seed(123)
Yplot = model.predict_f_samples(Xplot, seed=2)
plt.plot(Xplot, Yplot, ls="-.")

tf.random.set_seed(456)
Yplot = model.predict_f_samples(Xplot, seed=2)
_ = plt.plot(Xplot, Yplot, ls="--")

# %% [markdown]
# ## Full random state as seed
# If you set the `seed` to a tensor of shape `[2]` and `dtype=tf.int32`, that will completely define the randomness used, and the TensorFlow global seed will be ignored.

# %%
tf.random.set_seed(123)
Yplot = model.predict_f_samples(
Xplot, seed=tf.constant([12, 34], dtype=tf.int32)
)
plt.plot(Xplot, Yplot)

tf.random.set_seed(456)
Yplot = model.predict_f_samples(
Xplot, seed=tf.constant([12, 34], dtype=tf.int32)
)
plt.plot(Xplot, Yplot, ls=":")

tf.random.set_seed(456)
Yplot = model.predict_f_samples(
Xplot, seed=tf.constant([56, 78], dtype=tf.int32)
)
_ = plt.plot(Xplot, Yplot, ls="-.")


# %% [markdown]
# When using the full state as random seed it is important you are careful about when you do and do not reuse the seed. If you write a function, that takes a seed, and that function makes multiple calls, that also takes seeds, you should generally be careful to pass different seeds to the called functions. You can use [tfp.random.split_seed](https://www.tensorflow.org/probability/api_docs/python/tfp/random/split_seed) to create multiple new seeds from one seed. For example:

# %%
def calls_two(seed: tf.Tensor) -> None:
s1, s2 = tfp.random.split_seed(seed)
model.predict_f_samples(Xplot, seed=s1)
model.predict_f_samples(Xplot, seed=s2)


def calls_in_loop(seed: tf.Tensor) -> None:
for _ in range(5):
seed, s = tfp.random.split_seed(seed)
calls_two(s)


calls_in_loop(tf.constant([12, 34], dtype=tf.int32))

# %% [markdown]
# ## Stable randomness in model optimisation
#
# By default the `training_loss` method has `seed = None`, which means that every call to it will have different randomness. Some of the GPflow likelihoods rely on randomness, and this means your loss function will become noisy. You can get a deterministic loss function by using `training_loss_closure` instead, which allows you to bind a fixed seed to your loss:

# %%
X = np.array([[1.0], [3.0], [9.0]])
Y = np.array([[0.0], [2.0], [2.0]])
model = gpflow.models.GPR(
(X, Y),
kernel=gpflow.kernels.SquaredExponential(),
)
fixed_seed = tf.constant([12, 34], dtype=tf.int32)
opt = gpflow.optimizers.Scipy()
opt.minimize(
model.training_loss_closure(seed=fixed_seed), model.trainable_variables
)


# %% [markdown]
# ## Stable randomness used in optimisation
#
# Generally random sampling with a fixed random state is stable enough to be used over changing model parameters. For example, observe how adding data to a model warps the random sample:

# %%
model1 = gpflow.models.GPR(
(np.zeros((0, 1)), np.zeros((0, 1))),
kernel=gpflow.kernels.SquaredExponential(),
)

X = np.array([[1.0], [3.0], [9.0]])
Y = np.array([[0.0], [2.0], [2.0]])

model2 = gpflow.models.GPR(
(X, Y),
kernel=gpflow.kernels.SquaredExponential(),
)

fixed_seed = tf.constant([12, 34], dtype=tf.int32)

plt.scatter(X, Y, c="C1")

Yplot = model1.predict_f_samples(Xplot, seed=fixed_seed)
plt.plot(Xplot, Yplot)

Yplot = model2.predict_f_samples(Xplot, seed=fixed_seed)
_ = plt.plot(Xplot, Yplot)

# %% [markdown]
# Or, let us try optimising model parameters to fit a random sample to data:

# %%
model = gpflow.models.GPR(
(np.zeros((0, 1)), np.zeros((0, 1))),
kernel=gpflow.kernels.SquaredExponential(),
)

X = np.array([[1.0], [3.0], [9.0]])
Y = np.array([[0.0], [-1.0], [2.0]])
fixed_seed = tf.constant([12, 34], dtype=tf.int32)

plt.scatter(X, Y, c="black")

Yplot = model.predict_f_samples(Xplot, seed=fixed_seed)
plt.plot(Xplot, Yplot)

data_indices = tf.searchsorted(Xplot[:, 0], X[:, 0])


def loss() -> tf.Tensor:
Yplot = model.predict_f_samples(Xplot, seed=fixed_seed)
Ypred = tf.gather(Yplot, data_indices)
delta = Ypred - Y
return tf.reduce_sum(delta ** 2)


opt = gpflow.optimizers.Scipy()
opt.minimize(loss, model.trainable_variables)

Yplot = model.predict_f_samples(Xplot, seed=fixed_seed)
_ = plt.plot(Xplot, Yplot)

# %% [markdown]
# ## Implementation details
#
# Behind the scenes any `seed` is pushed through [tfp.random.sanitize_seed](https://www.tensorflow.org/probability/api_docs/python/tfp/random/sanitize_seed), and fed to `tf.random.stateless_*`.
4 changes: 2 additions & 2 deletions doc/sphinx/notebooks/tailor/mixture_density_network.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def sinusoidal_data(N, noise):

import tensorflow as tf

from gpflow.base import Parameter
from gpflow.base import Parameter, Seed
from gpflow.models import BayesianModel, ExternalDataTrainingLossMixin

# %% [markdown]
Expand Down Expand Up @@ -140,7 +140,7 @@ def eval_network(self, X):
return pis, mus, sigmas

def maximum_log_likelihood_objective(
self, data: Tuple[tf.Tensor, tf.Tensor]
self, data: Tuple[tf.Tensor, tf.Tensor], seed: Seed = None
):
x, y = data
pis, mus, sigmas = self.eval_network(x)
Expand Down
7 changes: 7 additions & 0 deletions doc/sphinx/user_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,13 @@ Advanced needs

This section explains the more complex models and features that are available in GPflow.

.. toctree::
:maxdepth: 1

notebooks/advanced/random_state

How to manage GPflow's random state, and how to get deterministic randomness.

.. toctree::
:maxdepth: 1

Expand Down
12 changes: 12 additions & 0 deletions gpflow/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,18 @@
else:
AnyNDArray = Union[np.ndarray] # type: ignore[misc]

Seed = Union[None, int, tf.Tensor]
"""
Type of optional random seeds.
Use a tensor of shape ``[2]`` to use that as a deterministic seed.
Alternatively use an integer or ``None`` to use the TensorFlow global random seed.
See also:
* https://www.tensorflow.org/probability/api_docs/python/tfp/random/sanitize_seed
* https://www.tensorflow.org/api_docs/python/tf/random/create_rng_state
"""

VariableData = Union[List[Any], Tuple[Any], AnyNDArray, int, float] # deprecated
Transform = Union[tfp.bijectors.Bijector]
Prior = Union[tfp.distributions.Distribution]
Expand Down
9 changes: 7 additions & 2 deletions gpflow/conditionals/multioutput/sample_conditionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import tensorflow as tf
from check_shapes import check_shapes

from ...base import SamplesMeanAndVariance
from ...base import SamplesMeanAndVariance, Seed
from ...inducing_variables import (
SeparateIndependentInducingVariables,
SharedIndependentInducingVariables,
Expand Down Expand Up @@ -50,6 +50,7 @@ def _sample_conditional(
q_sqrt: Optional[tf.Tensor] = None,
white: bool = False,
num_samples: Optional[int] = None,
seed: Seed = None,
) -> SamplesMeanAndVariance:
"""
`sample_conditional` will return a sample from the conditional distribution.
Expand All @@ -58,6 +59,8 @@ def _sample_conditional(
However, for some combinations of Mok and Mof, more efficient sampling routines exist.
The dispatcher will make sure that we use the most efficent one.
:param seed: Random seed. Interpreted as by
`tfp.random.sanitize_seed <https://www.tensorflow.org/probability/api_docs/python/tfp/random/sanitize_seed>`_\.
:return: samples, mean, cov
"""
if full_cov:
Expand All @@ -71,7 +74,9 @@ def _sample_conditional(
g_mu, g_var = ind_conditional(
Xnew, inducing_variable, kernel, f, white=white, q_sqrt=q_sqrt
) # [..., N, L], [..., N, L]
g_sample = sample_mvn(g_mu, g_var, full_cov, num_samples=num_samples) # [..., (S), N, L]
g_sample = sample_mvn(
g_mu, g_var, full_cov, num_samples=num_samples, seed=seed
) # [..., (S), N, L]
f_mu, f_var = mix_latent_gp(kernel.W, g_mu, g_var, full_cov, full_output_cov)
f_sample = tf.tensordot(g_sample, kernel.W, [[-1], [-1]]) # [..., N, P]
return f_sample, f_mu, f_var
Loading