-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Description
Describe the bug
When trying to specify the TuningJobCompletionCriteriaConfig
and using the to_input_req
method on the instantiated class object (which the HPO tuner does when the fit() function is called), it throws an error, which prevents from using it. This is because in the to_input_req
values are assigned to non-existent keys of a dictionary.
Additionally, the TuningJobCompletionCriteriaConfig
is not documentation in the Python SDK: https://sagemaker.readthedocs.io/en/stable/api/training/tuner.html
To reproduce
A clear, step-by-step set of instructions to reproduce the bug.
from sagemaker.tuner import (
TuningJobCompletionCriteriaConfig,
)
test = TuningJobCompletionCriteriaConfig(
max_number_of_training_jobs_not_improving=10,
complete_on_convergence=True,
target_objective_metric_value=0,
)
test.to_input_req()
Results in:
When we try to do it only for the target_objective_metric_value
it works as expected. This is because there is just one level of the dictionary for this argument in the completion_criteria_config
in the to_input_req
which confirms my assumption in the description of the issue.
Expected behavior
Given this input:
from sagemaker.tuner import (
TuningJobCompletionCriteriaConfig,
)
test = TuningJobCompletionCriteriaConfig(
max_number_of_training_jobs_not_improving=10,
complete_on_convergence=True,
target_objective_metric_value=0,
)
test.to_input_req()
I would expect the function to return:
{'BestObjectiveNotImproving': {'MaxNumberOfTrainingJobsNotImproving': 10},
'TargetObjectiveMetricValue': 0,
'ConvergenceDetected': {'CompleteOnConvergence': 'Enabled'}}
This can be achieved by modifying the to_input_req
of the TuningJobCompletionCriteriaConfig
class to:
def to_input_req(self):
"""Converts the ``self`` instance to the desired input request format.
Examples:
>>> completion_criteria_config = TuningJobCompletionCriteriaConfig(
max_number_of_training_jobs_not_improving=5
complete_on_convergence = True,
target_objective_metric_value = 0.42
)
>>> completion_criteria_config.to_input_req()
{
"BestObjectiveNotImproving": {
"MaxNumberOfTrainingJobsNotImproving":5
},
"ConvergenceDetected": {
"CompleteOnConvergence": "Enabled",
},
"TargetObjectiveMetricValue": 0.42
}
Returns:
dict: Containing the completion criteria configurations.
"""
completion_criteria_config = {}
if self.max_number_of_training_jobs_not_improving is not None:
completion_criteria_config[BEST_OBJECTIVE_NOT_IMPROVING] = {} ##### <---- CHANGED
completion_criteria_config[BEST_OBJECTIVE_NOT_IMPROVING][
MAX_NUMBER_OF_TRAINING_JOBS_NOT_IMPROVING
] = self.max_number_of_training_jobs_not_improving
if self.target_objective_metric_value is not None:
completion_criteria_config[
TARGET_OBJECTIVE_METRIC_VALUE
] = self.target_objective_metric_value
if self.complete_on_convergence is not None:
completion_criteria_config[CONVERGENCE_DETECTED] = {} ##### <---- CHANGED
completion_criteria_config[CONVERGENCE_DETECTED][COMPLETE_ON_CONVERGENCE_DETECTED] = (
"Enabled" if self.complete_on_convergence else "Disabled"
)
return completion_criteria_config
System information
A description of your system. Please provide:
- SageMaker Python SDK version: '2.140.1'
- Python version: 3.10.9
- CPU or GPU: CPU
- Custom Docker image (Y/N): N