Skip to content

Commit

Permalink
Merge pull request optuna#4675 from Alnusjaponica/terminator-docs
Browse files Browse the repository at this point in the history
Add docstrings to `optuna.termintor`
  • Loading branch information
HideakiImamura committed May 25, 2023
2 parents 6ee468b + a0a22b4 commit 5cef676
Show file tree
Hide file tree
Showing 6 changed files with 182 additions and 30 deletions.
1 change: 1 addition & 0 deletions docs/source/reference/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,6 @@ API Reference
search_space
storages
study
terminator
trial
visualization/index
20 changes: 20 additions & 0 deletions docs/source/reference/terminator.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
.. module:: optuna.terminator

optuna.terminator
=================

The :mod:`~optuna.terminator` module implements a mechanism for automatically terminating the optimization process, accompanied by a callback class for the termination and evaluators for the estimated room for improvement in the optimization and statistical error of the objective function. The terminator stops the optimization process when the estimated potential improvement is smaller than the statistical error.

.. autosummary::
:toctree: generated/
:nosignatures:

optuna.terminator.BaseTerminator
optuna.terminator.Terminator
optuna.terminator.BaseImprovementEvaluator
optuna.terminator.RegretBoundEvaluator
optuna.terminator.BaseErrorEvaluator
optuna.terminator.CrossValidationErrorEvaluator
optuna.terminator.StaticErrorEvaluator
optuna.terminator.TerminatorCallback
optuna.terminator.report_cross_validation_scores
49 changes: 49 additions & 0 deletions optuna/terminator/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,55 @@

@experimental_class("3.2.0")
class TerminatorCallback:
"""A callback that terminates the optimization using Terminator.
This class implements a callback which wraps :class:`~optuna.terminator.Terminator`
so that it can be used with the :func:`~optuna.study.Study.optimize` method.
Args:
terminator:
A terminator object which determines whether to terminate the optimization by
assessing the room for optimization and statistical error. Defaults to a
:class:`~optuna.terminator.Terminator` object with default
improvement_evaluator and error_evaluator.
Example:
.. testcode::
from sklearn.datasets import load_wine
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import KFold
import optuna
from optuna.terminator import TerminatorCallback
from optuna.terminator import report_cross_validation_scores
def objective(trial):
X, y = load_wine(return_X_y=True)
clf = RandomForestClassifier(
max_depth=trial.suggest_int("max_depth", 2, 32),
min_samples_split=trial.suggest_float("min_samples_split", 0, 1),
criterion=trial.suggest_categorical("criterion", ("gini", "entropy")),
)
scores = cross_val_score(clf, X, y, cv=KFold(n_splits=5, shuffle=True))
report_cross_validation_scores(trial, scores)
return scores.mean()
study = optuna.create_study(direction="maximize")
terminator = TerminatorCallback()
study.optimize(objective, n_trials=50, callbacks=[terminator])
.. seealso::
Please refer to :class:`~optuna.terminator.Terminator` for the details of
the terminator mechanism.
"""

def __init__(
self,
terminator: Optional[BaseTerminator] = None,
Expand Down
49 changes: 49 additions & 0 deletions optuna/terminator/erroreval.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@


class BaseErrorEvaluator(metaclass=abc.ABCMeta):
"""Base class for error evaluators."""

@abc.abstractmethod
def evaluate(
self,
Expand All @@ -27,11 +29,34 @@ def evaluate(

@experimental_class("3.2.0")
class CrossValidationErrorEvaluator(BaseErrorEvaluator):
"""An error evaluator for objective functions based on cross-validation.
This evaluator evaluates the objective function's statistical error, which comes from the
randomness of dataset. This evaluator assumes that the objective function is the average of
the cross-validation and uses the scaled variance of the cross-validation scores in the best
trial at the moment as the statistical error.
"""

def evaluate(
self,
trials: list[FrozenTrial],
study_direction: StudyDirection,
) -> float:
"""Evaluate the statistical error of the objective function based on cross-validation.
Args:
trials:
A list of trials to consider. The best trial in `trials` is used to compute the
statistical error.
study_direction:
The direction of the study.
Returns:
A float representing the statistical error of the objective function.
"""
trials = [trial for trial in trials if trial.state == TrialState.COMPLETE]
assert len(trials) > 0

Expand Down Expand Up @@ -62,13 +87,37 @@ def evaluate(

@experimental_class("3.2.0")
def report_cross_validation_scores(trial: Trial, scores: list[float]) -> None:
"""A function to report cross-validation scores of a trial.
This function should be called within the objective function to report the cross-validation
scores. The reported scores are used to evaluate the statistical error for termination
judgement.
Args:
trial:
A :class:`~optuna.trial.Trial` object to report the cross-validation scores.
scores:
The cross-validation scores of the trial.
"""
if len(scores) <= 1:
raise ValueError("The length of `scores` is expected to be greater than one.")
trial.storage.set_trial_system_attr(trial._trial_id, _CROSS_VALIDATION_SCORES_KEY, scores)


@experimental_class("3.2.0")
class StaticErrorEvaluator(BaseErrorEvaluator):
"""An error evaluator that always returns a constant value.
This evaluator can be used to terminate the optimization when the evaluated improvement
potential is below the fixed threshold.
Args:
constant:
A user-specified constant value to always return as an error estimate.
"""

def __init__(self, constant: float) -> None:
self._constant = constant

Expand Down
23 changes: 23 additions & 0 deletions optuna/terminator/improvement/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@

@experimental_class("3.2.0")
class BaseImprovementEvaluator(metaclass=abc.ABCMeta):
"""Base class for improvement evaluators."""

@abc.abstractmethod
def evaluate(
self,
Expand All @@ -39,6 +41,27 @@ def evaluate(

@experimental_class("3.2.0")
class RegretBoundEvaluator(BaseImprovementEvaluator):
"""An error evaluator for upper bound on the regret with high-probability confidence.
This evaluator evaluates the regret of current best solution, which defined as the difference
between the objective value of the best solution and of the global optimum. To be specific,
this evaluator calculates the upper bound on the regret based on the fact that empirical
estimator of the objective function is bounded by lower and upper confidence bounds with
high probability under the Gaussian process model assumption.
Args:
gp:
A Gaussian process model on which evaluation base. If not specified, the default
Gaussian process model is used.
top_trials_ratio:
A ratio of top trials to be considered when estimating the regret. Default to 0.5.
min_n_trials:
A minimum number of complete trials to estimate the regret. Default to 20.
min_lcb_n_additional_samples:
A minimum number of additional samples to estimate the lower confidence bound.
Default to 2000.
"""

def __init__(
self,
gp: Optional[BaseGaussianProcess] = None,
Expand Down
70 changes: 40 additions & 30 deletions optuna/terminator/terminator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,16 @@


class BaseTerminator(metaclass=abc.ABCMeta):
"""Base class for terminators."""

@abc.abstractmethod
def should_terminate(self, study: Study) -> bool:
pass


@experimental_class("3.2.0")
class Terminator(BaseTerminator):
"""Automatic stopping mechanism for Optuna studies
"""Automatic stopping mechanism for Optuna studies.
This class implements an automatic stopping mechanism for Optuna studies, aiming to prevent
unnecessary computation. The study is terminated when the statistical error, e.g.
Expand All @@ -36,7 +38,7 @@ class Terminator(BaseTerminator):
:class:`~optuna.terminator.improvement.evaluator.RegretBoundEvaluator` object.
error_evaluator:
An evaluator for calculating the statistical error, e.g. cross-validation error.
Defaults to a :class:`~optuna.terminator.erroreval.CrossValidationErrorEvaluator`
Defaults to a :class:`~optuna.terminator.CrossValidationErrorEvaluator`
object.
min_n_trials:
The minimum number of trials before termination is considered. Defaults to ``20``.
Expand All @@ -45,45 +47,52 @@ class Terminator(BaseTerminator):
ValueError: If ``min_n_trials`` is not a positive integer.
Example:
from sklearn.datasets import load_wine
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import KFold
import optuna
from optuna.terminator.terminator import Terminator
from optuna.terminator.serror import report_cross_validation_scores
.. testcode::
import logging
import sys
from sklearn.datasets import load_wine
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import KFold
import optuna
from optuna.terminator import Terminator
from optuna.terminator import report_cross_validation_scores
study = optuna.create_study(direction="maximize")
terminator = Terminator()
min_n_trials = 20
while True:
trial = study.ask()
study = optuna.create_study(direction="maximize")
terminator = Terminator()
min_n_trials = 20
X, y = load_wine(return_X_y=True)
while True:
trial = study.ask()
clf = RandomForestClassifier(
max_depth=trial.suggest_int("max_depth", 2, 32),
min_samples_split=trial.suggest_float("min_samples_split", 0, 1),
criterion=trial.suggest_categorical("criterion", ("gini", "entropy")),
)
X, y = load_wine(return_X_y=True)
scores = cross_val_score(clf, X, y, cv=KFold(n_splits=5, shuffle=True))
report_cross_validation_scores(trial, scores)
clf = RandomForestClassifier(
max_depth=trial.suggest_int("max_depth", 2, 32),
min_samples_split=trial.suggest_float("min_samples_split", 0, 1),
criterion=trial.suggest_categorical("criterion", ("gini", "entropy")),
)
value = scores.mean()
print(f"Trial #{trial.number} finished with value {value}.")
study.tell(trial, value)
scores = cross_val_score(clf, X, y, cv=KFold(n_splits=5, shuffle=True))
report_cross_validation_scores(trial, scores)
if trial.number > min_n_trials and terminator.should_terminate(study):
print("Terminated by Optuna Terminator!")
break
value = scores.mean()
logging.info(f"Trial #{trial.number} finished with value {value}.")
study.tell(trial, value)
if trial.number > min_n_trials and terminator.should_terminate(study):
logging.info("Terminated by Optuna Terminator!")
break
.. seealso::
Please refer to :class:`~optuna.terminator.callbacks.TerminationCallback` for to use the
terminator mechanism with the :func:`~optuna.study.Study.optimize` method.
Please refer to :class:`~optuna.terminator.TerminationCallback` for how to use
the terminator mechanism with the :func:`~optuna.study.Study.optimize` method.
"""

def __init__(
Expand All @@ -100,6 +109,7 @@ def __init__(
self._min_n_trials = min_n_trials

def should_terminate(self, study: Study) -> bool:
"""Judge whether the study should be terminated based on the reported values."""
trials = study.get_trials(states=[TrialState.COMPLETE])

if len(trials) < self._min_n_trials:
Expand Down

0 comments on commit 5cef676

Please sign in to comment.