Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add method to get the best configuration directly from Tuner, add com… #767

Merged
merged 7 commits into from
Oct 12, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/source/faq.rst
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,14 @@ You can take a look at this example
`examples/launch_checkpoint_example.py <examples.html#retrieving-the-best-checkpoint>`__
which shows how to retrieve the best checkpoint obtained after tuning an XGBoost model.

How can I retrain the best model found after tuning?
====================================================

You can call ``tuner.trial_backend.start_trial(config=tuner.best_config())`` after tuning to retrain the best config,
you can take a look at this example
`examples/launch_plot_example.py <examples.html#plot-results-of-tuning-experiment>`__
which shows how to retrain the best model found while tuning.

Which schedulers make use of checkpointing?
===========================================

Expand Down
9 changes: 7 additions & 2 deletions examples/launch_plot_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
)

if __name__ == "__main__":
logging.getLogger().setLevel(logging.DEBUG)
logging.getLogger().setLevel(logging.INFO)

random_seed = 31415927
max_steps = 100
Expand Down Expand Up @@ -64,8 +64,13 @@
tuner.run()

tuning_experiment = load_experiment(tuner.name)
print(tuning_experiment)

# Print the best configuration found from experiment-results
print(f"best result found: {tuning_experiment.best_config()}")

# Plot the best value found over time
tuning_experiment.plot()

# Print the best configuration found from the tuner and retrain it
trial_id, best_config = tuner.best_config()
tuner.trial_backend.start_trial(config=best_config)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe plot again, and hopefully show improvement? Or consider splitting out into a separate retraining example?

Otherwise it feels a bit random - why train again and then do nothing with it after?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One use-case could be to run with a larger budget, I do not have one use-case personally but I know some people asks for this so it probably have an example showing how it can be done.

49 changes: 11 additions & 38 deletions syne_tune/experiments/experiment_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
ST_TUNER_TIME,
)
from syne_tune.try_import import try_import_aws_message, try_import_visual_message
from syne_tune.util import experiment_path, s3_experiment_path
from syne_tune.util import experiment_path, s3_experiment_path, metric_name_mode

try:
import boto3
Expand Down Expand Up @@ -98,7 +98,7 @@ def plot_hypervolume(
len(metrics_to_plot) > 1
), "Only one metric defined, cannot compute hypervolume"

metrics, metric_names, metric_modes = zip(
metric_names, metric_modes = zip(
*[self._metric_name_mode(metric) for metric in metrics_to_plot]
)
assert np.all(
Expand Down Expand Up @@ -139,9 +139,7 @@ def plot(
If None, the figure is shown
:param plt_kwargs: Arguments to :func:`matplotlib.pyplot.plot`
"""
metric, metric_name, metric_mode = self._metric_name_mode(
metric_to_plot, verbose=True
)
metric_name, metric_mode = self._metric_name_mode(metric_to_plot)

df = self.results
if df is not None and len(df) > 0:
Expand Down Expand Up @@ -178,7 +176,7 @@ def best_config(self, metric: Union[str, int] = 0) -> Dict[str, Any]:
default to 0 - first metric defined in the Scheduler
:return: Configuration corresponding to best metric value
"""
metric, metric_name, metric_mode = self._metric_name_mode(metric, verbose=True)
metric_name, metric_mode = self._metric_name_mode(metric)

# locate best result
if metric_mode == "min":
Expand All @@ -190,41 +188,16 @@ def best_config(self, metric: Union[str, int] = 0) -> Dict[str, Any]:
# Don't include internal fields
return {k: v for k, v in res.items() if not k.startswith("st_")}

def _metric_name_mode(
self, metric: Union[str, int], verbose: bool = False
) -> Tuple[int, str, str]:
def _metric_name_mode(self, metric: Union[str, int]) -> Tuple[str, str]:
"""
Determine the metric, name and its mode given ambiguous input.
Determine the name and its mode given ambiguous input.
:param metric: Index or name of the selected metric
:param verbose: If True, prints a warning message when only one metric is selected from many
"""
if isinstance(metric, str):
assert (
metric in self.metric_names()
), f"Attepted to use {metric} while available metrics are {self.metric_names()}"
metric_name = metric
metric = self.metric_names().index(metric)
elif isinstance(metric, int):
assert metric < len(
self.metric_names()
), f"Attepted to use metric index={metric} with {len(self.metric_names())} availale metrics"
metric_name = self.metric_names()[metric]
else:
raise AttributeError(
f"metic must be <int> or <str> but {type(metric)} was provided"
)

if len(self.metric_names()) > 1 and verbose:
logger.warning(
"Several metrics exists, this will "
f"use metric={metric_name} (index={metric}) out of {self.metric_names()}."
)

metric_mode = self.metric_mode()
if isinstance(metric_mode, list):
metric_mode = metric_mode[metric]

return metric, metric_name, metric_mode
return metric_name_mode(
metric_names=self.metric_names(),
metric_mode=self.metric_mode(),
metric=metric,
)


def download_single_experiment(
Expand Down
Empty file added syne_tune/experiments/util.py
Empty file.
30 changes: 29 additions & 1 deletion syne_tune/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import time
from collections import OrderedDict
from pathlib import Path
from typing import Callable, Dict, List, Optional, Set, Tuple, Any
from typing import Callable, Dict, List, Optional, Set, Tuple, Any, Union

import dill as dill

Expand Down Expand Up @@ -44,6 +44,7 @@
experiment_path,
name_from_base,
dump_json_with_numpy,
metric_name_mode,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -683,3 +684,30 @@ def _default_callback():
:return: Default callback to store results
"""
return StoreResultsCallback()

def best_config(
self, metric: Optional[Union[str, int]] = 0
) -> Tuple[int, Dict[str, Any]]:
"""
:param metric: Indicates which metric to use, can be the index or a name of the metric.
default to 0 - first metric defined in the Scheduler
:return: the best configuration found while tuning for the metric given and the associated trial-id
"""
metric_name, metric_mode = metric_name_mode(
metric_names=self.scheduler.metric_names(),
metric_mode=self.scheduler.metric_mode(),
metric=metric,
)
trial_id, best_metric = print_best_metric_found(
self.tuning_status, metric_names=[metric_name], mode=metric_mode
)
config = self.trial_backend._trial_dict[trial_id].config

logger.info(
f"If you want to retrain the best configuration found, you can run: \n"
f"```tuner.trial_backend.start_trial(config={config})``` to start training from scratch\n"
f"or\n"
f"```tuner.trial_backend.start_trial(config={config}, checkpoint_trial_id={trial_id})``` to start from "
f"last checkpoint (your script should have stored a checkpoint)"
Comment on lines +710 to +711
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here or in the FAQ entry, would it make sense to explain when you would use best_config() versus a checkpoint?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not really versus, you can only restart from a checkpoint if your script supports checkpointing which may not be the case.

I do not think it would make sense to explain checkpointing there as it has its own set of FAQ items, for instance:

)
return trial_id, config
44 changes: 44 additions & 0 deletions syne_tune/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from typing import Optional, List, Union, Dict, Any, Iterable
from time import perf_counter
from contextlib import contextmanager
from typing import Tuple, Union, List
import logging

import numpy as np

Expand All @@ -31,6 +33,8 @@
)
from syne_tune.try_import import try_import_aws_message

logger = logging.getLogger(__name__)

try:
import sagemaker
except ImportError:
Expand Down Expand Up @@ -319,3 +323,43 @@ def find_first_of_type(a: Iterable[Any], typ) -> Optional[Any]:
return next(x for x in a if isinstance(x, typ))
except StopIteration:
return None


def metric_name_mode(
metric_names: List[str], metric_mode: Union[str, List[str]], metric: Union[str, int]
) -> Tuple[str, str]:
"""
Retrieve the metric mode given a metric queried by either index or name.
:param metric_names: metrics names defined in a scheduler
:param metric_mode: metric mode or modes of a scheduler
:param metric: Index or name of the selected metric
:return the name and the mode of the queried metric
"""
if isinstance(metric, str):
assert (
metric in metric_names
), f"Attempted to use {metric} while available metrics are {metric_names}"
metric_name = metric
elif isinstance(metric, int):
assert metric < len(
metric_names
), f"Attempted to use metric index={metric} with {len(metric_names)} available metrics"
metric_name = metric_names[metric]
else:
raise AttributeError(
f"metric must be <int> or <str> but {type(metric)} was provided"
)

if len(metric_names) > 1:
logger.warning(
"Several metrics exists, this will "
geoalgo marked this conversation as resolved.
Show resolved Hide resolved
f"use metric={metric_name} (index={metric}) out of {metric_names}."
)

if isinstance(metric_mode, list):
metric_index = (
metric_names.index(metric_name) if isinstance(metric, str) else metric
)
metric_mode = metric_mode[metric_index]

return metric_name, metric_mode
26 changes: 26 additions & 0 deletions tst/experiments/test_metric_name_mode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import pytest

from syne_tune.util import metric_name_mode

metric_names = ["m1", "m2", "m3"]


@pytest.mark.parametrize(
"metric_mode, query_metric, expected_metric, expected_mode,",
[
("max", "m2", "m2", "max"),
("min", "m2", "m2", "min"),
(["max", "min", "max"], "m2", "m2", "min"),
(["max", "min", "max"], "m3", "m3", "max"),
("max", 1, "m2", "max"),
("min", 1, "m2", "min"),
(["max", "min", "max"], 1, "m2", "min"),
(["max", "min", "max"], 2, "m3", "max"),
],
)
def test_metric_name_mode(metric_mode, query_metric, expected_metric, expected_mode):
metric_name, metric_mode = metric_name_mode(
metric_names=metric_names, metric_mode=metric_mode, metric=query_metric
)
assert metric_name == expected_metric
assert metric_mode == expected_mode
Loading