Skip to content

Commit

Permalink
Proposed fix for the fully_correlated_conditional_repeat issue (#1652)
Browse files Browse the repository at this point in the history
Fixes #1651
  • Loading branch information
johnamcleod committed Apr 6, 2021
1 parent 5b76536 commit 21e5317
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 12 deletions.
3 changes: 2 additions & 1 deletion gpflow/conditionals/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,8 @@ def fully_correlated_conditional_repeat(
addvar = tf.reshape(tf.reduce_sum(tf.square(LTA), axis=1), (R, N, P)) # [R, N, P]
fvar = fvar[None, ...] + addvar # [R, N, P]
else:
fvar = tf.broadcast_to(fvar[None], tf.shape(fmean))
fvar_shape = tf.concat([[R], tf.shape(fvar)], axis=0)
fvar = tf.broadcast_to(fvar[None], fvar_shape)

shape_constraints.extend(
[(Knn, intended_cov_shape), (fmean, ["R", "N", "P"]), (fvar, ["R"] + intended_cov_shape),]
Expand Down
109 changes: 98 additions & 11 deletions tests/gpflow/conditionals/test_multioutput.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,29 +317,116 @@ def test_sample_conditional_mixedkernel():
)


@pytest.fixture(
name="fully_correlated_q_sqrt_factory",
params=[lambda _, __: None, lambda LM, R: tf.eye(LM, batch_shape=(R,))],
)
def _q_sqrt_factory_fixture(request):
return request.param


@pytest.mark.parametrize("R", [1, 2, 5])
@pytest.mark.parametrize(
"func, R",
"whiten",
[
(fully_correlated_conditional_repeat, 5),
(fully_correlated_conditional_repeat, 1),
(fully_correlated_conditional, 1),
True,
pytest.param(
False,
marks=pytest.mark.xfail(
reason="fully_correlated_conditional_repeat does not support whiten=False"
),
),
],
)
def test_fully_correlated_conditional_repeat_shapes(func, R):
def test_fully_correlated_conditional_repeat_shapes_fc_and_foc(
R, fully_correlated_q_sqrt_factory, full_cov, full_output_cov, whiten
):

L, M, N, P = Data.L, Data.M, Data.N, Data.P

Kmm = tf.ones((L * M, L * M)) + default_jitter() * tf.eye(L * M)
Kmn = tf.ones((L * M, N, P))
Knn = tf.ones((N, P))

if full_cov and full_output_cov:
Knn = tf.ones((N, P, N, P))
expected_v_shape = [R, N, P, N, P]
elif not full_cov and full_output_cov:
Knn = tf.ones((N, P, P))
expected_v_shape = [R, N, P, P]
elif full_cov and not full_output_cov:
Knn = tf.ones((P, N, N))
expected_v_shape = [R, P, N, N]
else:
Knn = tf.ones((N, P))
expected_v_shape = [R, N, P]

f = tf.ones((L * M, R))
q_sqrt = None
white = True
q_sqrt = fully_correlated_q_sqrt_factory(L * M, R)

m, v = fully_correlated_conditional_repeat(
Kmn,
Kmm,
Knn,
f,
full_cov=full_cov,
full_output_cov=full_output_cov,
q_sqrt=q_sqrt,
white=whiten,
)

assert m.shape.as_list() == [R, N, P]
assert v.shape.as_list() == expected_v_shape


m, v = func(
Kmn, Kmm, Knn, f, full_cov=False, full_output_cov=False, q_sqrt=q_sqrt, white=white,
@pytest.mark.parametrize(
"whiten",
[
True,
pytest.param(
False,
marks=pytest.mark.xfail(
reason="fully_correlated_conditional does not support whiten=False"
),
),
],
)
def test_fully_correlated_conditional_shapes_fc_and_foc(
fully_correlated_q_sqrt_factory, full_cov, full_output_cov, whiten
):
L, M, N, P = Data.L, Data.M, Data.N, Data.P

Kmm = tf.ones((L * M, L * M)) + default_jitter() * tf.eye(L * M)
Kmn = tf.ones((L * M, N, P))

if full_cov and full_output_cov:
Knn = tf.ones((N, P, N, P))
expected_v_shape = [N, P, N, P]
elif not full_cov and full_output_cov:
Knn = tf.ones((N, P, P))
expected_v_shape = [N, P, P]
elif full_cov and not full_output_cov:
Knn = tf.ones((P, N, N))
expected_v_shape = [P, N, N]
else:
Knn = tf.ones((N, P))
expected_v_shape = [N, P]

f = tf.ones((L * M, 1))
q_sqrt = fully_correlated_q_sqrt_factory(L * M, 1)

m, v = fully_correlated_conditional(
Kmn,
Kmm,
Knn,
f,
full_cov=full_cov,
full_output_cov=full_output_cov,
q_sqrt=q_sqrt,
white=whiten,
)

assert v.shape.as_list() == m.shape.as_list()
assert m.shape.as_list() == [N, P]
assert v.shape.as_list() == expected_v_shape


# ------------------------------------------
Expand Down
25 changes: 25 additions & 0 deletions tests/gpflow/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright 2021 The GPflow Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest


@pytest.fixture(name="full_cov", params=[True, False])
def _full_cov_fixture(request):
return request.param


@pytest.fixture(name="full_output_cov", params=[True, False])
def _full_output_cov_fixture(request):
return request.param

0 comments on commit 21e5317

Please sign in to comment.