Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BQ stopping criteria that depend on model statistics #463

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions emukit/core/loop/outer_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def run_loop(
_log.info("Iteration {}".format(self.loop_state.iteration))

self._update_models()
self._update_loop_state()
new_x = self.candidate_point_calculator.compute_next_points(self.loop_state, context)
_log.debug("Next suggested point(s): {}".format(new_x))
results = user_function.evaluate(new_x)
Expand All @@ -109,8 +110,14 @@ def run_loop(
self.iteration_end_event(self, self.loop_state)

self._update_models()
self._update_loop_state()
_log.info("Finished outer loop")

def _update_loop_state(self) -> None:
"""This method is called after the models are updated. Override this function to store additional statistics
other than the collected points and function values in the loop state."""
pass

def _update_models(self):
for model_updater in self.model_updaters:
model_updater.update(self.loop_state)
Expand Down
6 changes: 6 additions & 0 deletions emukit/quadrature/loop/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,18 @@


from .bayesian_monte_carlo_loop import BayesianMonteCarlo # noqa: F401
from .bq_loop_state import QuadratureLoopState
from .bq_outer_loop import QuadratureOuterLoop
from .bq_stopping_conditions import CoefficientOfVariationStoppingCondition
from .vanilla_bq_loop import VanillaBayesianQuadratureLoop # noqa: F401
from .wsabil_loop import WSABILLoop # noqa: F401

__all__ = [
"QuadratureOuterLoop",
"BayesianMonteCarlo",
"VanillaBayesianQuadratureLoop",
"WSABILLoop",
"QuadratureLoopState",
"point_calculators",
"CoefficientOfVariationStoppingCondition",
]
9 changes: 5 additions & 4 deletions emukit/quadrature/loop/bayesian_monte_carlo_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

from ...core.loop import FixedIntervalUpdater, ModelUpdater, OuterLoop
from ...core.loop.loop_state import create_loop_state
from ...core.loop import FixedIntervalUpdater, ModelUpdater
from ...core.parameter_space import ParameterSpace
from ..loop.point_calculators import BayesianMonteCarloPointCalculator
from ..methods import WarpedBayesianQuadratureModel
from .bq_loop_state import create_bq_loop_state
from .bq_outer_loop import QuadratureOuterLoop


class BayesianMonteCarlo(OuterLoop):
class BayesianMonteCarlo(QuadratureOuterLoop):
"""The loop for Bayesian Monte Carlo (BMC).


Expand Down Expand Up @@ -61,7 +62,7 @@ def __init__(self, model: WarpedBayesianQuadratureModel, model_updater: ModelUpd

space = ParameterSpace(model.reasonable_box_bounds.convert_to_list_of_continuous_parameters())
candidate_point_calculator = BayesianMonteCarloPointCalculator(model, space)
loop_state = create_loop_state(model.X, model.Y)
loop_state = create_bq_loop_state(model.X, model.Y)

super().__init__(candidate_point_calculator, model_updater, loop_state)

Expand Down
48 changes: 48 additions & 0 deletions emukit/quadrature/loop/bq_loop_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright 2020-2024 The Emukit Authors. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0


from typing import List, Optional

import numpy as np

from ...core.loop.loop_state import LoopState, create_loop_state
from ...core.loop.user_function_result import UserFunctionResult


class QuadratureLoopState(LoopState):
"""Contains the state of the BQ loop, which includes a history of all integrand evaluations and integral mean and
variance estimates.

:param initial_results: The results from previous integrand evaluations.

"""

def __init__(self, initial_results: List[UserFunctionResult]) -> None:

super().__init__(initial_results)

self.integral_means = []
self.integral_vars = []

def update_integral_stats(self, integral_mean: float, integral_var: float) -> None:
"""Adds the latest integral mean and variance to the loop state.

:param integral_mean: The latest integral mean estimate.
:param integral_var: The latest integral variance.
"""
self.integral_means.append(integral_mean)
self.integral_vars.append(integral_var)


def create_bq_loop_state(x_init: np.ndarray, y_init: np.ndarray, **kwargs) -> QuadratureLoopState:
"""Creates a BQ loop state object using the provided data.

:param x_init: x values for initial function evaluations. Shape: (n_initial_points x n_input_dims)
:param y_init: y values for initial function evaluations. Shape: (n_initial_points x n_output_dims)
:param kwargs: extra outputs observed from a function evaluation. Shape: (n_initial_points x n_dims)
:return: The BQ loop state.
"""

loop_state = create_loop_state(x_init, y_init, **kwargs)
return QuadratureLoopState(loop_state.results)
39 changes: 39 additions & 0 deletions emukit/quadrature/loop/bq_outer_loop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright 2020-2024 The Emukit Authors. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0


from typing import List, Union

from ...core.loop import OuterLoop
from ...core.loop.candidate_point_calculators import CandidatePointCalculator
from ...core.loop.model_updaters import ModelUpdater
from .bq_loop_state import QuadratureLoopState


class QuadratureOuterLoop(OuterLoop):
"""Base class for a Bayesian quadrature loop.

:param candidate_point_calculator: Finds next point(s) to evaluate.
:param model_updaters: Updates the model with the new data and fits the model hyper-parameters.
:param loop_state: Object that keeps track of the history of the BQ loop. Default is None, resulting in empty
initial state.

:raises ValueError: If more than one model updater is provided.

"""

def __init__(
self,
candidate_point_calculator: CandidatePointCalculator,
model_updaters: Union[ModelUpdater, List[ModelUpdater]],
loop_state: QuadratureLoopState = None,
):
if isinstance(model_updaters, list):
raise ValueError("The BQ loop only supports a single model.")

super().__init__(candidate_point_calculator, model_updaters, loop_state)

def _update_loop_state(self) -> None:
model = self.model_updaters[0].model # only works if there is one model, but for BQ nothing else makes sense
integral_mean, integral_var = model.integrate()
self.loop_state.update_integral_stats(integral_mean, integral_var)
63 changes: 63 additions & 0 deletions emukit/quadrature/loop/bq_stopping_conditions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright 2020-2024 The Emukit Authors. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0


import logging

import numpy as np

from ...core.loop.stopping_conditions import StoppingCondition
from .bq_loop_state import QuadratureLoopState

_log = logging.getLogger(__name__)


class CoefficientOfVariationStoppingCondition(StoppingCondition):
r"""Stops once the coefficient of variation (COV) falls below a threshold.

The COV is given by

.. math::
COV = \frac{\sigma}{\mu}

where :math:`\mu` and :math:`\sigma^2` are the current mean and variance respectively of the integral according to
the BQ posterior model.

:param eps: Threshold under which the COV must fall.
:param delay: Number of times the stopping condition needs to be true in a row in order to stop. Defaults to 1.

:raises ValueError: If `delay` is smaller than 1.
:raises ValueError: If `eps` is non-negative.

"""

def __init__(self, eps: float, delay: int = 1) -> None:

if delay < 1:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if delay < 1:
if delay < 1 or not isinstance(delay, int):

because if you pass float('inf'), then you will have an infinite loop right?

raise ValueError(f"delay ({delay}) must be and integer greater than zero.")

if eps <= 0.0:
raise ValueError(f"eps ({eps}) must be positive.")

self.eps = eps
self.delay = delay
self.times_true = 0 # counts how many times stopping has been triggered in a row

def should_stop(self, loop_state: QuadratureLoopState) -> bool:
if len(loop_state.integral_means) < 1:
return False

m = loop_state.integral_means[-1]
v = loop_state.integral_vars[-1]
should_stop = (np.sqrt(v) / m) < self.eps

if should_stop:
self.times_true += 1
else:
self.times_true = 0

should_stop = should_stop and (self.times_true >= self.delay)

if should_stop:
_log.info(f"Stopped as coefficient of variation is below threshold of {self.eps}.")
return should_stop
9 changes: 5 additions & 4 deletions emukit/quadrature/loop/vanilla_bq_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,16 @@


from ...core.acquisition import Acquisition
from ...core.loop import FixedIntervalUpdater, ModelUpdater, OuterLoop, SequentialPointCalculator
from ...core.loop.loop_state import create_loop_state
from ...core.loop import FixedIntervalUpdater, ModelUpdater, SequentialPointCalculator
from ...core.optimization import AcquisitionOptimizerBase, GradientAcquisitionOptimizer
from ...core.parameter_space import ParameterSpace
from ..acquisitions import IntegralVarianceReduction
from ..methods import VanillaBayesianQuadrature
from .bq_loop_state import create_bq_loop_state
from .bq_outer_loop import QuadratureOuterLoop


class VanillaBayesianQuadratureLoop(OuterLoop):
class VanillaBayesianQuadratureLoop(QuadratureOuterLoop):
"""The loop for standard ('vanilla') Bayesian Quadrature.

.. seealso::
Expand Down Expand Up @@ -46,7 +47,7 @@ def __init__(
if acquisition_optimizer is None:
acquisition_optimizer = GradientAcquisitionOptimizer(space)
candidate_point_calculator = SequentialPointCalculator(acquisition, acquisition_optimizer)
loop_state = create_loop_state(model.X, model.Y)
loop_state = create_bq_loop_state(model.X, model.Y)

super().__init__(candidate_point_calculator, model_updater, loop_state)

Expand Down
9 changes: 5 additions & 4 deletions emukit/quadrature/loop/wsabil_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@
"""The WSABI-L loop"""


from ...core.loop import FixedIntervalUpdater, ModelUpdater, OuterLoop, SequentialPointCalculator
from ...core.loop.loop_state import create_loop_state
from ...core.loop import FixedIntervalUpdater, ModelUpdater, SequentialPointCalculator
from ...core.optimization import AcquisitionOptimizerBase, GradientAcquisitionOptimizer
from ...core.parameter_space import ParameterSpace
from ..acquisitions import UncertaintySampling
from ..methods import WSABIL
from .bq_loop_state import create_bq_loop_state
from .bq_outer_loop import QuadratureOuterLoop


class WSABILLoop(OuterLoop):
class WSABILLoop(QuadratureOuterLoop):
"""The loop for WSABI-L.

.. rubric:: References
Expand Down Expand Up @@ -44,7 +45,7 @@ def __init__(
if acquisition_optimizer is None:
acquisition_optimizer = GradientAcquisitionOptimizer(space)
candidate_point_calculator = SequentialPointCalculator(acquisition, acquisition_optimizer)
loop_state = create_loop_state(model.X, model.Y)
loop_state = create_bq_loop_state(model.X, model.Y)

super().__init__(candidate_point_calculator, model_updater, loop_state)

Expand Down
Loading