-
Notifications
You must be signed in to change notification settings - Fork 49
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add
AcquisitionMaximizer
functionality
Added an `AbstractAcquisitionMaximizer` abstract base class for maximising acquisition functions, as well as a concrete `ContinuousAcquisitionMaximizer` implementation which uses L-BFGS-B for optimising continuous acquisition functions. Unit tests have also been added.
- Loading branch information
1 parent
a3579bb
commit e2f6459
Showing
2 changed files
with
343 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
# Copyright 2023 The JaxGaussianProcesses Contributors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
from abc import ( | ||
ABC, | ||
abstractmethod, | ||
) | ||
from dataclasses import dataclass | ||
|
||
import jax.numpy as jnp | ||
from jaxopt import ScipyBoundedMinimize | ||
|
||
from gpjax.decision_making.acquisition_functions import AcquisitionFunction | ||
from gpjax.decision_making.search_space import ( | ||
AbstractSearchSpace, | ||
ContinuousSearchSpace, | ||
) | ||
from gpjax.typing import ( | ||
Array, | ||
Float, | ||
KeyArray, | ||
ScalarFloat, | ||
) | ||
|
||
|
||
def _get_discrete_maximizer( | ||
query_points: Float[Array, "N D"], acquisition_function: AcquisitionFunction | ||
) -> Float[Array, "1 D"]: | ||
"""Get the point which maximises the acquisition function evaluated at a given set of points. | ||
Args: | ||
query_points (Float[Array, "N D"]): Set of points at which to evaluate the | ||
acquisition function. | ||
acquisition_function (AcquisitionFunction): Acquisition function | ||
to evaluate at `query_points`. | ||
Returns: | ||
Float[Array, "1 D"]: Point in `query_points` which maximises the acquisition | ||
function. | ||
""" | ||
acquisition_function_values = acquisition_function(query_points) | ||
max_acquisition_function_value_idx = jnp.argmax( | ||
acquisition_function_values, axis=0, keepdims=True | ||
) | ||
best_sample_point = jnp.take_along_axis( | ||
query_points, max_acquisition_function_value_idx, axis=0 | ||
) | ||
return best_sample_point | ||
|
||
|
||
@dataclass | ||
class AbstractAcquisitionMaximizer(ABC): | ||
"""Abstract base class for acquisition function maximizers.""" | ||
|
||
@abstractmethod | ||
def maximize( | ||
self, | ||
acquisition_function: AcquisitionFunction, | ||
search_space: AbstractSearchSpace, | ||
key: KeyArray, | ||
) -> Float[Array, "1 D"]: | ||
"""Maximize the given acquisition function over the search space provided. | ||
Args: | ||
acquisition_function (AcquisitionFunction): Acquisition function to be | ||
maximized. | ||
search_space (AbstractSearchSpace): Search space over which to maximize | ||
the acquisition function. | ||
key (KeyArray): JAX PRNG key. | ||
Returns: | ||
Float[Array, "1 D"]: Point at which the acquisition function is maximized. | ||
""" | ||
raise NotImplementedError | ||
|
||
|
||
@dataclass | ||
class ContinuousAcquisitionMaximizer(AbstractAcquisitionMaximizer): | ||
"""The `ContinuousAcquisitionMaximizer` class is used to maximize acquisition | ||
functions over the continuous domain with L-BFGS-B. First we sample the acquisition | ||
function at `num_initial_samples` points from the search space, and then we run | ||
L-BFGS-B from the best of these initial points. | ||
""" | ||
|
||
num_initial_samples: int | ||
|
||
def __post_init__(self): | ||
if self.num_initial_samples < 1: | ||
raise ValueError( | ||
f"num_initial_samples must be greater than 0, got {self.num_initial_samples}." | ||
) | ||
|
||
def maximize( | ||
self, | ||
acquisition_function: AcquisitionFunction, | ||
search_space: ContinuousSearchSpace, | ||
key: KeyArray, | ||
) -> Float[Array, "1 D"]: | ||
initial_sample_points = search_space.sample(self.num_initial_samples, key=key) | ||
best_initial_sample_point = _get_discrete_maximizer( | ||
initial_sample_points, acquisition_function | ||
) | ||
|
||
# Jaxopt minimizer requires a function which returns a scalar. It calls the | ||
# acquisition function with one point at a time, so the acquisition function | ||
# returns an array of shape [1, 1], so we index to return a scalar. Note that | ||
# we also return the negative of the acquisition function - this is because | ||
# acquisition functions should be *maximimized* but the Jaxopt minimizer | ||
# minimizes functions. | ||
def scalar_acquisition_fn(x: Float[Array, "1 D"]) -> ScalarFloat: | ||
return -acquisition_function(x)[0][0] | ||
|
||
lbfgsb = ScipyBoundedMinimize(fun=scalar_acquisition_fn, method="l-bfgs-b") | ||
bounds = (search_space.lower_bounds, search_space.upper_bounds) | ||
optimised_point = lbfgsb.run(best_initial_sample_point, bounds=bounds).params | ||
return optimised_point |
216 changes: 216 additions & 0 deletions
216
tests/test_decision_making/test_acquisition_maximizer.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,216 @@ | ||
# Copyright 2023 The GPJax Contributors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
|
||
from abc import ABC | ||
|
||
from jax.config import config | ||
import jax.numpy as jnp | ||
import jax.random as jr | ||
from jaxtyping import ( | ||
Array, | ||
Float, | ||
) | ||
import pytest | ||
|
||
from gpjax.decision_making.acquisition_maximizer import ( | ||
AbstractAcquisitionMaximizer, | ||
ContinuousAcquisitionMaximizer, | ||
_get_discrete_maximizer, | ||
) | ||
from gpjax.decision_making.search_space import ContinuousSearchSpace | ||
from gpjax.typing import KeyArray | ||
|
||
config.update("jax_enable_x64", True) | ||
|
||
|
||
class TestContinuousAcquisitionFunction(ABC): | ||
search_space: ContinuousSearchSpace | ||
maximizer: Float[Array, "1 D"] | ||
|
||
def evaluate(x: Float[Array, "N D"]) -> Float[Array, "N 1"]: | ||
raise NotImplementedError | ||
|
||
|
||
class NegativeForrester(TestContinuousAcquisitionFunction): | ||
search_space = ContinuousSearchSpace( | ||
lower_bounds=jnp.array([0.0], dtype=jnp.float64), | ||
upper_bounds=jnp.array([1.0], dtype=jnp.float64), | ||
) | ||
maximizer = jnp.array([[0.75725]]) | ||
|
||
def evaluate(self, x: Float[Array, "N 1"]) -> Float[Array, "N 1"]: | ||
return -((6 * x - 2) ** 2) * jnp.sin(12 * x - 4) | ||
|
||
|
||
class NegativeGoldsteinPrice(TestContinuousAcquisitionFunction): | ||
search_space = ContinuousSearchSpace( | ||
lower_bounds=jnp.array([-2.0, -2.0], dtype=jnp.float64), | ||
upper_bounds=jnp.array([2.0, 2.0], dtype=jnp.float64), | ||
) | ||
maximizer = jnp.array([[0.0, -1.0]]) | ||
|
||
def evaluate(self, x: Float[Array, "N 2"]) -> Float[Array, "N 1"]: | ||
x1 = x[:, 0] | ||
x2 = x[:, 1] | ||
a = 1.0 + (x1 + x2 + 1.0) ** 2 * ( | ||
19.0 | ||
- 14.0 * x1 | ||
+ 3.0 * (x1**2) | ||
- 14.0 * x2 | ||
+ 6.0 * x1 * x2 | ||
+ 3.0 * (x2**2) | ||
) | ||
b = 30.0 + (2.0 * x1 - 3.0 * x2) ** 2 * ( | ||
18.0 | ||
- 32.0 * x1 | ||
+ 12.0 * (x1**2) | ||
+ 48.0 * x2 | ||
- 36.0 * x1 * x2 | ||
+ 27.0 * (x2**2) | ||
) | ||
return -(a * b).reshape(-1, 1) | ||
|
||
|
||
class Quadratic(TestContinuousAcquisitionFunction): | ||
search_space = ContinuousSearchSpace( | ||
lower_bounds=jnp.array([0.0], dtype=jnp.float64), | ||
upper_bounds=jnp.array([1.0], dtype=jnp.float64), | ||
) | ||
maximizer = jnp.array([[0.5]]) | ||
|
||
def evaluate(self, x: Float[Array, "N 1"]) -> Float[Array, "N 1"]: | ||
return -((x - 0.5) ** 2) | ||
|
||
|
||
def test_abstract_acquisition_maximizer(): | ||
with pytest.raises(TypeError): | ||
AbstractAcquisitionMaximizer() | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"test_acquisition_function, dimensionality", | ||
[(NegativeForrester(), 1), (NegativeGoldsteinPrice(), 2)], | ||
) | ||
@pytest.mark.parametrize("key", [jr.PRNGKey(42), jr.PRNGKey(10)]) | ||
@pytest.mark.filterwarnings( | ||
"ignore::UserWarning" | ||
) # Sampling with tfp causes JAX to raise a UserWarning due to some internal logic around jnp.argsort | ||
def test_discrete_maximizer_returns_correct_point( | ||
test_acquisition_function: TestContinuousAcquisitionFunction, | ||
dimensionality: int, | ||
key: KeyArray, | ||
): | ||
query_points = test_acquisition_function.search_space.sample(1000, key=key) | ||
acquisition_vals = test_acquisition_function.evaluate(query_points) | ||
true_max_acquisition_val = jnp.max(acquisition_vals) | ||
discrete_maximizer = _get_discrete_maximizer( | ||
query_points, test_acquisition_function.evaluate | ||
) | ||
assert discrete_maximizer.shape == (1, dimensionality) | ||
assert discrete_maximizer.dtype == jnp.float64 | ||
assert ( | ||
test_acquisition_function.evaluate(discrete_maximizer)[0][0] | ||
== true_max_acquisition_val | ||
) | ||
|
||
|
||
@pytest.mark.parametrize("num_initial_samples", [0, -1, -10]) | ||
def test_continuous_maximizer_raises_error_with_erroneous_num_initial_samples( | ||
num_initial_samples: int, | ||
): | ||
with pytest.raises(ValueError): | ||
ContinuousAcquisitionMaximizer(num_initial_samples=num_initial_samples) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"test_acquisition_function, dimensionality", | ||
[(NegativeForrester(), 1), (NegativeGoldsteinPrice(), 2)], | ||
) | ||
@pytest.mark.parametrize("key", [jr.PRNGKey(42), jr.PRNGKey(10)]) | ||
@pytest.mark.filterwarnings( | ||
"ignore::UserWarning" | ||
) # Sampling with tfp causes JAX to raise a UserWarning due to some internal logic around jnp.argsort | ||
def test_continous_maximizer_returns_same_point_with_same_key( | ||
test_acquisition_function: TestContinuousAcquisitionFunction, | ||
dimensionality: int, | ||
key: KeyArray, | ||
): | ||
continuous_maximizer_one = ContinuousAcquisitionMaximizer(num_initial_samples=2000) | ||
continuous_maximizer_two = ContinuousAcquisitionMaximizer(num_initial_samples=2000) | ||
maximizer_one = continuous_maximizer_one.maximize( | ||
acquisition_function=test_acquisition_function.evaluate, | ||
search_space=test_acquisition_function.search_space, | ||
key=key, | ||
) | ||
maximizer_two = continuous_maximizer_two.maximize( | ||
acquisition_function=test_acquisition_function.evaluate, | ||
search_space=test_acquisition_function.search_space, | ||
key=key, | ||
) | ||
assert maximizer_one.shape == (1, dimensionality) | ||
assert maximizer_one.dtype == jnp.float64 | ||
assert maximizer_two.shape == (1, dimensionality) | ||
assert maximizer_two.dtype == jnp.float64 | ||
assert jnp.equal(maximizer_one, maximizer_two).all() | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"test_acquisition_function, dimensionality", | ||
[ | ||
(NegativeForrester(), 1), | ||
(NegativeGoldsteinPrice(), 2), | ||
], | ||
) | ||
@pytest.mark.parametrize("key", [jr.PRNGKey(42), jr.PRNGKey(10)]) | ||
@pytest.mark.filterwarnings( | ||
"ignore::UserWarning" | ||
) # Sampling with tfp causes JAX to raise a UserWarning due to some internal logic around jnp.argsort | ||
def test_continuous_maximizer_finds_correct_point( | ||
test_acquisition_function: TestContinuousAcquisitionFunction, | ||
dimensionality: int, | ||
key: KeyArray, | ||
): | ||
continuous_acquisition_maximizer = ContinuousAcquisitionMaximizer( | ||
num_initial_samples=1000 | ||
) | ||
maximizer = continuous_acquisition_maximizer.maximize( | ||
acquisition_function=test_acquisition_function.evaluate, | ||
search_space=test_acquisition_function.search_space, | ||
key=key, | ||
) | ||
assert maximizer.shape == (1, dimensionality) | ||
assert maximizer.dtype == jnp.float64 | ||
assert jnp.allclose(maximizer, test_acquisition_function.maximizer, atol=1e-6).all() | ||
|
||
|
||
@pytest.mark.parametrize("key", [jr.PRNGKey(42), jr.PRNGKey(10), jr.PRNGKey(1)]) | ||
@pytest.mark.filterwarnings( | ||
"ignore::UserWarning" | ||
) # Sampling with tfp causes JAX to raise a UserWarning due to some internal logic around jnp.argsort | ||
def test_continuous_maximizer_jaxopt_component(key: KeyArray): | ||
quadratic_acquisition_function = Quadratic() | ||
continuous_acquisition_maximizer = ContinuousAcquisitionMaximizer( | ||
num_initial_samples=1 # Force JaxOpt L-GFBS-B to do the heavy lifting | ||
) | ||
maximizer = continuous_acquisition_maximizer.maximize( | ||
acquisition_function=quadratic_acquisition_function.evaluate, | ||
search_space=quadratic_acquisition_function.search_space, | ||
key=key, | ||
) | ||
assert maximizer.shape == (1, 1) | ||
assert maximizer.dtype == jnp.float64 | ||
assert jnp.allclose( | ||
maximizer, quadratic_acquisition_function.maximizer, atol=1e-6 | ||
).all() |