Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2512,6 +2512,8 @@ def tune( # noqa: C901
random_seed=None,
environment=None,
hpo_resource_config=None,
autotune=False,
auto_parameters=None,
):
"""Create an Amazon SageMaker hyperparameter tuning job.

Expand Down Expand Up @@ -2617,6 +2619,11 @@ def tune( # noqa: C901
* volume_kms_key_id: The AWS Key Management Service (AWS KMS) key
that Amazon SageMaker uses to encrypt data on the storage
volume attached to the ML compute instance(s) that run the training job.
autotune (bool): Whether the parameter ranges or other unset settings of a tuning job
should be chosen automatically (default: False).
auto_parameters (dict[str, str]): Dictionary of auto parameters. The keys are names
of auto parameters and values are example values of auto parameters
(default: ``None``).
"""

tune_request = {
Expand All @@ -2633,6 +2640,7 @@ def tune( # noqa: C901
random_seed=random_seed,
strategy_config=strategy_config,
completion_criteria_config=completion_criteria_config,
auto_parameters=auto_parameters,
),
"TrainingJobDefinition": self._map_training_config(
static_hyperparameters=static_hyperparameters,
Expand All @@ -2659,6 +2667,9 @@ def tune( # noqa: C901
if warm_start_config is not None:
tune_request["WarmStartConfig"] = warm_start_config

if autotune:
tune_request["Autotune"] = {"Mode": "Enabled"}

tags = _append_project_tags(tags)
if tags is not None:
tune_request["Tags"] = tags
Expand All @@ -2675,6 +2686,7 @@ def create_tuning_job(
training_config_list=None,
warm_start_config=None,
tags=None,
autotune=False,
):
"""Create an Amazon SageMaker hyperparameter tuning job.

Expand All @@ -2694,6 +2706,8 @@ def create_tuning_job(
other required configurations.
tags (list[dict]): List of tags for labeling the tuning job. For more, see
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
autotune (bool): Whether the parameter ranges or other unset settings of a tuning job
should be chosen automatically.
"""

if training_config is None and training_config_list is None:
Expand All @@ -2710,6 +2724,7 @@ def create_tuning_job(
training_config_list=training_config_list,
warm_start_config=warm_start_config,
tags=tags,
autotune=autotune,
)

def submit(request):
Expand All @@ -2727,6 +2742,7 @@ def _get_tuning_request(
training_config_list=None,
warm_start_config=None,
tags=None,
autotune=False,
):
"""Construct CreateHyperParameterTuningJob request

Expand All @@ -2742,13 +2758,17 @@ def _get_tuning_request(
other required configurations.
tags (list[dict]): List of tags for labeling the tuning job. For more, see
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
autotune (bool): Whether the parameter ranges or other unset settings of a tuning job
should be chosen automatically.
Returns:
dict: A dictionary for CreateHyperParameterTuningJob request
"""
tune_request = {
"HyperParameterTuningJobName": job_name,
"HyperParameterTuningJobConfig": self._map_tuning_config(**tuning_config),
}
if autotune:
tune_request["Autotune"] = {"Mode": "Enabled"}

if training_config is not None:
tune_request["TrainingJobDefinition"] = self._map_training_config(**training_config)
Expand Down Expand Up @@ -2794,6 +2814,7 @@ def _map_tuning_config(
random_seed=None,
strategy_config=None,
completion_criteria_config=None,
auto_parameters=None,
):
"""Construct tuning job configuration dictionary.

Expand All @@ -2820,6 +2841,8 @@ def _map_tuning_config(
strategy.
completion_criteria_config (dict): A configuration
for the completion criteria.
auto_parameters (dict): Dictionary of auto parameters. The keys are names of auto
parameters and valeus are example values of auto parameters.

Returns:
A dictionary of tuning job configuration. For format details, please refer to
Expand Down Expand Up @@ -2849,6 +2872,13 @@ def _map_tuning_config(
if parameter_ranges is not None:
tuning_config["ParameterRanges"] = parameter_ranges

if auto_parameters is not None:
if parameter_ranges is None:
tuning_config["ParameterRanges"] = {}
tuning_config["ParameterRanges"]["AutoParameters"] = [
{"Name": name, "ValueHint": value} for name, value in auto_parameters.items()
]

if strategy_config is not None:
tuning_config["StrategyConfig"] = strategy_config

Expand Down Expand Up @@ -2910,6 +2940,7 @@ def _map_training_config(
checkpoint_local_path=None,
max_retry_attempts=None,
environment=None,
auto_parameters=None,
):
"""Construct a dictionary of training job configuration from the arguments.

Expand Down Expand Up @@ -3030,6 +3061,13 @@ def _map_training_config(
if parameter_ranges is not None:
training_job_definition["HyperParameterRanges"] = parameter_ranges

if auto_parameters is not None:
if parameter_ranges is None:
training_job_definition["HyperParameterRanges"] = {}
training_job_definition["HyperParameterRanges"]["AutoParameters"] = [
{"Name": name, "ValueHint": value} for name, value in auto_parameters.items()
]

if max_retry_attempts is not None:
training_job_definition["RetryStrategy"] = {"MaximumRetryAttempts": max_retry_attempts}

Expand Down
Loading