Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
geoalgo committed Oct 10, 2023
1 parent b10f4a5 commit b19940e
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 46 deletions.
2 changes: 1 addition & 1 deletion examples/launch_plot_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,5 +72,5 @@
tuning_experiment.plot()

# Print the best configuration found from the tuner and retrain it
best_config, trial_id = tuner.best_config()
trial_id, best_config = tuner.best_config()
tuner.trial_backend.start_trial(config=best_config)
51 changes: 15 additions & 36 deletions syne_tune/experiments/experiment_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import numpy as np
import pandas as pd

from syne_tune import Tuner
from syne_tune.constants import (
ST_METADATA_FILENAME,
ST_RESULTS_DATAFRAME_FILENAME,
Expand All @@ -29,7 +30,7 @@
ST_TUNER_TIME,
)
from syne_tune.try_import import try_import_aws_message, try_import_visual_message
from syne_tune.util import experiment_path, s3_experiment_path
from syne_tune.util import experiment_path, s3_experiment_path, metric_name_mode

try:
import boto3
Expand Down Expand Up @@ -60,7 +61,7 @@ class ExperimentResult:
name: str
results: pd.DataFrame
metadata: Dict[str, Any]
tuner: "Tuner"
tuner: Tuner
path: Path

def __str__(self):
Expand Down Expand Up @@ -97,7 +98,7 @@ def plot_hypervolume(
len(metrics_to_plot) > 1
), "Only one metric defined, cannot compute hypervolume"

metrics, metric_names, metric_modes = zip(
metric_names, metric_modes = zip(
*[self._metric_name_mode(metric) for metric in metrics_to_plot]
)
assert np.all(
Expand Down Expand Up @@ -138,8 +139,8 @@ def plot(
If None, the figure is shown
:param plt_kwargs: Arguments to :func:`matplotlib.pyplot.plot`
"""
metric, metric_name, metric_mode = self._metric_name_mode(
metric_to_plot, verbose=True
metric_name, metric_mode = self._metric_name_mode(
metric_to_plot
)

df = self.results
Expand Down Expand Up @@ -177,7 +178,7 @@ def best_config(self, metric: Union[str, int] = 0) -> Dict[str, Any]:
default to 0 - first metric defined in the Scheduler
:return: Configuration corresponding to best metric value
"""
metric, metric_name, metric_mode = self._metric_name_mode(metric, verbose=True)
metric_name, metric_mode = self._metric_name_mode(metric)

# locate best result
if metric_mode == "min":
Expand All @@ -190,40 +191,18 @@ def best_config(self, metric: Union[str, int] = 0) -> Dict[str, Any]:
return {k: v for k, v in res.items() if not k.startswith("st_")}

def _metric_name_mode(
self, metric: Union[str, int], verbose: bool = False
) -> Tuple[int, str, str]:
self, metric: Union[str, int]
) -> Tuple[str, str]:
"""
Determine the metric, name and its mode given ambiguous input.
Determine the name and its mode given ambiguous input.
:param metric: Index or name of the selected metric
:param verbose: If True, prints a warning message when only one metric is selected from many
"""
if isinstance(metric, str):
assert (
metric in self.metric_names()
), f"Attempted to use {metric} while available metrics are {self.metric_names()}"
metric_name = metric
metric = self.metric_names().index(metric)
elif isinstance(metric, int):
assert metric < len(
self.metric_names()
), f"Attempted to use metric index={metric} with {len(self.metric_names())} available metrics"
metric_name = self.metric_names()[metric]
else:
raise AttributeError(
f"metric must be <int> or <str> but {type(metric)} was provided"
)

if len(self.metric_names()) > 1 and verbose:
logger.warning(
"Several metrics exists, this will "
f"use metric={metric_name} (index={metric}) out of {self.metric_names()}."
)

metric_mode = self.metric_mode()
if isinstance(metric_mode, list):
metric_mode = metric_mode[metric]
return metric_name_mode(
metric_names=self.metric_names(),
metric_mode=self.metric_mode(),
metric=metric,
)

return metric, metric_name, metric_mode


def download_single_experiment(
Expand Down
Empty file added syne_tune/experiments/util.py
Empty file.
21 changes: 12 additions & 9 deletions syne_tune/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
ST_TUNER_DILL_FILENAME,
TUNER_DEFAULT_SLEEP_TIME,
)
from syne_tune.experiments import load_experiment
from syne_tune.optimizer.scheduler import SchedulerDecision, TrialScheduler
from syne_tune.optimizer.schedulers.remove_checkpoints import (
RemoveCheckpointsSchedulerMixin,
Expand All @@ -44,7 +43,7 @@
check_valid_sagemaker_name,
experiment_path,
name_from_base,
dump_json_with_numpy,
dump_json_with_numpy, metric_name_mode,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -685,17 +684,21 @@ def _default_callback():
"""
return StoreResultsCallback()

def best_config(self, metric: Optional[Union[str, int]] = 0) -> Dict[str, Any]:
def best_config(self, metric: Optional[Union[str, int]] = 0) -> Tuple[int, Dict[str, Any]]:
"""
:param metric: Indicates which metric to use, can be the index or a name of the metric.
default to 0 - first metric defined in the Scheduler
:return: the best configuration found while tuning for the metric given and the associated trial-id
"""
tuning_experiment = load_experiment(self.name)

best_config_info = tuning_experiment.best_config(metric)
trial_id = best_config_info["trial_id"]
config = {k.replace("config_", ""): v for k, v in best_config_info.items() if k.startswith("config_")}
metric_name, metric_mode = metric_name_mode(
metric_names=self.scheduler.metric_names(),
metric_mode=self.scheduler.metric_mode(),
metric=metric,
)
trial_id, best_metric = print_best_metric_found(
self.tuning_status, metric_names=[metric_name], mode=metric_mode
)
config = self.trial_backend._trial_dict[trial_id].config

logger.info(
f"If you want to retrain the best configuration found, you can run: \n"
Expand All @@ -704,4 +707,4 @@ def best_config(self, metric: Optional[Union[str, int]] = 0) -> Dict[str, Any]:
f"```tuner.trial_backend.start_trial(config={config}, checkpoint_trial_id={trial_id})``` to start from "
f"last checkpoint (your script should have stored a checkpoint)"
)
return config, trial_id
return trial_id, config
41 changes: 41 additions & 0 deletions syne_tune/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from typing import Optional, List, Union, Dict, Any, Iterable
from time import perf_counter
from contextlib import contextmanager
from typing import Tuple, Union, List
import logging

import numpy as np

Expand All @@ -31,6 +33,8 @@
)
from syne_tune.try_import import try_import_aws_message

logger = logging.getLogger(__name__)

try:
import sagemaker
except ImportError:
Expand Down Expand Up @@ -319,3 +323,40 @@ def find_first_of_type(a: Iterable[Any], typ) -> Optional[Any]:
return next(x for x in a if isinstance(x, typ))
except StopIteration:
return None

def metric_name_mode(
metric_names: List[str], metric_mode: Union[str, List[str]], metric: Union[str, int]
) -> Tuple[str, str]:
"""
Retrieve the metric mode given a metric queried by either index or name.
:param metric_names: metrics names defined in a scheduler
:param metric_mode: metric mode or modes of a scheduler
:param metric: Index or name of the selected metric
:return the name and the mode of the queried metric
"""
if isinstance(metric, str):
assert (
metric in metric_names
), f"Attempted to use {metric} while available metrics are {metric_names}"
metric_name = metric
elif isinstance(metric, int):
assert metric < len(
metric_names
), f"Attempted to use metric index={metric} with {len(metric_names)} available metrics"
metric_name = metric_names[metric]
else:
raise AttributeError(
f"metric must be <int> or <str> but {type(metric)} was provided"
)

if len(metric_names) > 1:
logger.warning(
"Several metrics exists, this will "
f"use metric={metric_name} (index={metric}) out of {metric_names}."
)

if isinstance(metric_mode, list):
metric_index = metric_names.index(metric_name) if isinstance(metric, str) else metric
metric_mode = metric_mode[metric_index]

return metric_name, metric_mode
25 changes: 25 additions & 0 deletions tst/experiments/test_metric_name_mode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import pytest

from syne_tune.util import metric_name_mode

metric_names = ["m1", "m2", "m3"]
@pytest.mark.parametrize(
"metric_mode, query_metric, expected_metric, expected_mode,",
[
("max", "m2", "m2", "max"),
("min", "m2", "m2", "min"),
(["max", "min", "max"], "m2", "m2", "min"),
(["max", "min", "max"], "m3", "m3", "max"),
("max", 1, "m2", "max"),
("min", 1, "m2", "min"),
(["max", "min", "max"], 1, "m2", "min"),
(["max", "min", "max"], 2, "m3", "max"),
],
)
def test_metric_name_mode(metric_mode, query_metric, expected_metric, expected_mode):
metric_name, metric_mode = metric_name_mode(metric_names=metric_names, metric_mode=metric_mode, metric=query_metric)
assert metric_name == expected_metric
assert metric_mode == expected_mode


0 comments on commit b19940e

Please sign in to comment.