Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added switch_dim parameter to ChangePoints kernel #1671

Open
wants to merge 5 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion gpflow/kernels/changepoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,16 @@ def __init__(
kernels: List[Kernel],
locations: List[float],
steepness: Union[float, List[float]] = 1.0,
switch_dim: int = 0,
name: Optional[str] = None,
):
"""
:param kernels: list of kernels defining the different regimes
:param locations: list of change-point locations in the 1d input space
:param steepness: the steepness parameter(s) of the sigmoids, this can be
common between them or decoupled
:param switch_dim: the (one) dimension of the input space along which
the change-points are defined
"""
if len(kernels) != len(locations) + 1:
raise ValueError(
Expand All @@ -76,6 +79,7 @@ def __init__(

super().__init__(kernels, name=name)

self.switch_dim = switch_dim
self.locations = Parameter(locations)
self.steepness = Parameter(steepness, transform=positive())

Expand Down Expand Up @@ -119,4 +123,5 @@ def _sigmoids(self, X: tf.Tensor) -> tf.Tensor:
locations = tf.sort(self.locations) # ensure locations are ordered
locations = tf.reshape(locations, (1, 1, -1))
steepness = tf.reshape(self.steepness, (1, 1, -1))
return tf.sigmoid(steepness * (X[:, :, None] - locations))
Xslice = tf.reshape(X[:, self.switch_dim], (-1, 1, 1))
return tf.sigmoid(steepness * (Xslice - locations))
25 changes: 25 additions & 0 deletions tests/gpflow/kernels/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,28 @@ def ref_periodic_kernel(X, base_name, lengthscales, signal_variance, period):
dist = np.sqrt(5) * np.sum(np.abs(sine_base), axis=-1)
exp_dist = (1 + dist + dist ** 2 / 3) * np.exp(-dist)
return signal_variance * exp_dist


def ref_changepoints(X, kernels, locations, steepness, switch_dim=0):
"""
Calculates K(X) for each kernel in `kernels`, then multiply by sigmoid functions
in order to smoothly transition betwen them. The sigmoid transitions are defined
by a location and a steepness parameter.
"""
locations = sorted(locations)
steepness = steepness if isinstance(steepness, list) else [steepness] * len(locations)
locations = np.array(locations).reshape((1, 1, -1))
steepness = np.array(steepness).reshape((1, 1, -1))

Xslice = X[:, switch_dim].reshape(-1, 1, 1)
sig_X = 1.0 / (1.0 + np.exp(-steepness * (Xslice - locations)))

starters = sig_X * np.transpose(sig_X, axes=(1, 0, 2))
stoppers = (1 - sig_X) * np.transpose((1 - sig_X), axes=(1, 0, 2))

ones = np.ones((X.shape[0], X.shape[0], 1))
starters = np.concatenate([ones, starters], axis=2)
stoppers = np.concatenate([stoppers, ones], axis=2)

kernel_stack = np.stack([k(X) for k in kernels], axis=2)
return (kernel_stack * starters * stoppers).sum(axis=2)
156 changes: 156 additions & 0 deletions tests/gpflow/kernels/test_changepoints.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,90 @@
import numpy as np
import pytest
from numpy.testing import assert_allclose

import gpflow
from tests.gpflow.kernels.reference import ref_changepoints

rng = np.random.RandomState(1)


@pytest.mark.parametrize(
"locations, steepness, error_msg",
[
# 1. Kernels locations dimension mismatch
[
[1.0],
1.0,
r"Number of kernels \(3\) must be one more than the number of changepoint locations \(1\)",
],
# 2. Locations steepness dimension mismatch
[
[1.0, 2.0],
[1.0],
r"Dimension of steepness \(1\) does not match number of changepoint locations \(2\)",
],
],
)
def test_changepoints_init_fail(locations, steepness, error_msg):
kernels = [
gpflow.kernels.Matern12(),
gpflow.kernels.Linear(),
gpflow.kernels.Matern32(),
]
with pytest.raises(ValueError, match=error_msg):
gpflow.kernels.ChangePoints(kernels, locations, steepness)


def _assert_changepoints_kern_err(X, kernels, locations, steepness):
kernel = gpflow.kernels.ChangePoints(kernels, locations, steepness=steepness)
reference_gram_matrix = ref_changepoints(X, kernels, locations, steepness)

assert_allclose(kernel(X), reference_gram_matrix)
assert_allclose(kernel.K_diag(X), np.diag(reference_gram_matrix))


@pytest.mark.parametrize("N", [2, 10])
@pytest.mark.parametrize(
"kernels, locations, steepness",
[
# 1. Single changepoint
[[gpflow.kernels.Constant(), gpflow.kernels.Constant()], [2.0], 5.0],
# 2. Two changepoints
[
[
gpflow.kernels.Constant(),
gpflow.kernels.Constant(),
gpflow.kernels.Constant(),
],
[1.0, 2.0],
5.0,
],
# 3. Multiple steepness
[
[
gpflow.kernels.Constant(),
gpflow.kernels.Constant(),
gpflow.kernels.Constant(),
],
[1.0, 2.0],
[5.0, 10.0],
],
# 4. Variety of kernels
[
[
gpflow.kernels.Matern12(),
gpflow.kernels.Linear(),
gpflow.kernels.SquaredExponential(),
gpflow.kernels.Constant(),
],
[1.0, 2.0, 3.0],
5.0,
],
],
)
def test_changepoint_output(N, kernels, locations, steepness):
X_data = rng.randn(N, 1)
_assert_changepoints_kern_err(X_data, kernels, locations, steepness)


def test_changepoint_with_X1_X2():
Expand All @@ -16,3 +100,75 @@ def test_changepoint_with_X1_X2():
X2 = np.linspace(0, 50, N2).reshape(N2, 1)
K = k(X, X2)
assert K.shape == [N, N2]


@pytest.mark.parametrize("switch_dim", [0, 1])
def test_changepoint_xslice_sigmoid(switch_dim):
"""
Test shaping and slicing of X introduced to accommodate switch_dim parameter.
"""
X = rng.rand(10, 2)
locations = [2.0]
steepness = 5.0

X1 = X[:, [switch_dim]]
sig_X1 = 1.0 / (1.0 + np.exp(-steepness * (X1[:, :, None] - locations)))

Xslice = X[:, switch_dim].reshape(-1, 1, 1)
sig_Xslice = 1.0 / (1.0 + np.exp(-steepness * (Xslice - locations)))

assert_allclose(sig_X1, sig_Xslice)


@pytest.mark.parametrize("switch_dim", [0, 1])
def test_changepoint_xslice(switch_dim):
"""
Test switch_dim behaviour in comparison to slicing on input X.
"""
N, D = 10, 2
locations = [2.0]
steepness = 5.0
X = rng.randn(N, D)
RBF = gpflow.kernels.SquaredExponential

kernel = gpflow.kernels.ChangePoints(
[RBF(active_dims=[switch_dim]), RBF(active_dims=[switch_dim])],
locations,
steepness=steepness,
switch_dim=switch_dim,
)
reference_gram_matrix = ref_changepoints(
X[:, [switch_dim]], [RBF(), RBF()], locations, steepness
)

assert_allclose(kernel(X), reference_gram_matrix)


@pytest.mark.parametrize("D", [2, 3])
@pytest.mark.parametrize("switch_dim", [0, 1])
@pytest.mark.parametrize("active_dim", [0, 1])
def test_changepoint_ndim(D, switch_dim, active_dim):
"""
Test Changepoints with varying combinations of switch_dim and active_dim.
"""
N = 10
X = rng.randn(N, D)
RBF = gpflow.kernels.SquaredExponential
locations = [2.0]
steepness = 5.0

kernel = gpflow.kernels.ChangePoints(
[RBF(active_dims=[active_dim]), RBF(active_dims=[active_dim])],
locations,
steepness=steepness,
switch_dim=switch_dim,
)
reference_gram_matrix = ref_changepoints(
X,
[RBF(active_dims=[active_dim]), RBF(active_dims=[active_dim])],
locations,
steepness,
switch_dim=switch_dim,
)

assert_allclose(kernel(X), reference_gram_matrix)
103 changes: 0 additions & 103 deletions tests/gpflow/kernels/test_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,30 +34,6 @@
rng = np.random.RandomState(1)


def _ref_changepoints(X, kernels, locations, steepness):
"""
Calculates K(X) for each kernel in `kernels`, then multiply by sigmoid functions
in order to smoothly transition betwen them. The sigmoid transitions are defined
by a location and a steepness parameter.
"""
locations = sorted(locations)
steepness = steepness if isinstance(steepness, list) else [steepness] * len(locations)
locations = np.array(locations).reshape((1, 1, -1))
steepness = np.array(steepness).reshape((1, 1, -1))

sig_X = 1.0 / (1.0 + np.exp(-steepness * (X[:, :, None] - locations)))

starters = sig_X * np.transpose(sig_X, axes=(1, 0, 2))
stoppers = (1 - sig_X) * np.transpose((1 - sig_X), axes=(1, 0, 2))

ones = np.ones((X.shape[0], X.shape[0], 1))
starters = np.concatenate([ones, starters], axis=2)
stoppers = np.concatenate([stoppers, ones], axis=2)

kernel_stack = np.stack([k(X) for k in kernels], axis=2)
return (kernel_stack * starters * stoppers).sum(axis=2)


@pytest.mark.parametrize("variance, lengthscales", [[2.3, 1.4]])
def test_rbf_1d(variance, lengthscales):
X = rng.randn(3, 1)
Expand Down Expand Up @@ -433,85 +409,6 @@ def test_ard_property(kernel_class, param_name, param_value, ard):
assert kernel.ard is ard


@pytest.mark.parametrize(
"locations, steepness, error_msg",
[
# 1. Kernels locations dimension mismatch
[
[1.0],
1.0,
r"Number of kernels \(3\) must be one more than the number of changepoint locations \(1\)",
],
# 2. Locations steepness dimension mismatch
[
[1.0, 2.0],
[1.0],
r"Dimension of steepness \(1\) does not match number of changepoint locations \(2\)",
],
],
)
def test_changepoints_init_fail(locations, steepness, error_msg):
kernels = [
gpflow.kernels.Matern12(),
gpflow.kernels.Linear(),
gpflow.kernels.Matern32(),
]
with pytest.raises(ValueError, match=error_msg):
gpflow.kernels.ChangePoints(kernels, locations, steepness)


def _assert_changepoints_kern_err(X, kernels, locations, steepness):
kernel = gpflow.kernels.ChangePoints(kernels, locations, steepness=steepness)
reference_gram_matrix = _ref_changepoints(X, kernels, locations, steepness)

assert_allclose(kernel(X), reference_gram_matrix)
assert_allclose(kernel.K_diag(X), np.diag(reference_gram_matrix))


@pytest.mark.parametrize("N", [2, 10])
@pytest.mark.parametrize(
"kernels, locations, steepness",
[
# 1. Single changepoint
[[gpflow.kernels.Constant(), gpflow.kernels.Constant()], [2.0], 5.0],
# 2. Two changepoints
[
[
gpflow.kernels.Constant(),
gpflow.kernels.Constant(),
gpflow.kernels.Constant(),
],
[1.0, 2.0],
5.0,
],
# 3. Multiple steepness
[
[
gpflow.kernels.Constant(),
gpflow.kernels.Constant(),
gpflow.kernels.Constant(),
],
[1.0, 2.0],
[5.0, 10.0],
],
# 4. Variety of kernels
[
[
gpflow.kernels.Matern12(),
gpflow.kernels.Linear(),
gpflow.kernels.SquaredExponential(),
gpflow.kernels.Constant(),
],
[1.0, 2.0, 3.0],
5.0,
],
],
)
def test_changepoints(N, kernels, locations, steepness):
X_data = rng.randn(N, 1)
_assert_changepoints_kern_err(X_data, kernels, locations, steepness)


@pytest.mark.parametrize(
"active_dims_1, active_dims_2, is_separate",
[
Expand Down