From ee9654b259fcb149d2910f86e0cccf47cb2f2bf9 Mon Sep 17 00:00:00 2001 From: Carolyn Nguyen Date: Thu, 21 Oct 2021 18:43:41 -0700 Subject: [PATCH] feat: Support placeholders for TrainingStep --- src/stepfunctions/steps/sagemaker.py | 38 ++++---- tests/integ/test_sagemaker_steps.py | 58 +++++++++++++ tests/unit/test_sagemaker_steps.py | 124 +++++++++++++++++---------- 3 files changed, 158 insertions(+), 62 deletions(-) diff --git a/src/stepfunctions/steps/sagemaker.py b/src/stepfunctions/steps/sagemaker.py index 6d8ccc1..7595e97 100644 --- a/src/stepfunctions/steps/sagemaker.py +++ b/src/stepfunctions/steps/sagemaker.py @@ -74,11 +74,13 @@ def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=Non If there are duplicate entries, the value provided through this property will be used. (Default: Hyperparameters specified in the estimator.) * (Placeholder, optional) - The TrainingStep will use the hyperparameters specified by the Placeholder's value instead of the hyperparameters specified in the estimator. mini_batch_size (int): Specify this argument only when estimator is a built-in estimator of an Amazon algorithm. For other estimators, batch size should be specified in the estimator. - experiment_config (dict, optional): Specify the experiment config for the training. (Default: None) + experiment_config (dict or Placeholder, optional): Specify the experiment config for the training. (Default: None) wait_for_completion (bool, optional): Boolean value set to `True` if the Task state should wait for the training job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the training job and proceed to the next step. (default: True) - tags (list[dict], optional): `List to tags `_ to associate with the resource. + tags (list[dict] or Placeholder, optional): `List to tags `_ to associate with the resource. output_data_config_path (str or Placeholder, optional): S3 location for saving the training result (model artifacts and output files). If specified, it overrides the `output_path` property of `estimator`. + parameters(dict, optional): The value of this field is merged with other arguments to become the request payload for SageMaker `CreateTrainingJob`_. (Default: None) + You can use `parameters` to override the value provided by other arguments and specify any field's value dynamically using `Placeholders`_. """ self.estimator = estimator self.job_name = job_name @@ -105,44 +107,48 @@ def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=Non data = data.to_jsonpath() if isinstance(job_name, str): - parameters = training_config(estimator=estimator, inputs=data, job_name=job_name, mini_batch_size=mini_batch_size) + training_parameters = training_config(estimator=estimator, inputs=data, job_name=job_name, mini_batch_size=mini_batch_size) else: - parameters = training_config(estimator=estimator, inputs=data, mini_batch_size=mini_batch_size) + training_parameters = training_config(estimator=estimator, inputs=data, mini_batch_size=mini_batch_size) if estimator.debugger_hook_config != None and estimator.debugger_hook_config is not False: - parameters['DebugHookConfig'] = estimator.debugger_hook_config._to_request_dict() + training_parameters['DebugHookConfig'] = estimator.debugger_hook_config._to_request_dict() if estimator.rules != None: - parameters['DebugRuleConfigurations'] = [rule.to_debugger_rule_config_dict() for rule in estimator.rules] + training_parameters['DebugRuleConfigurations'] = [rule.to_debugger_rule_config_dict() for rule in estimator.rules] if isinstance(job_name, Placeholder): - parameters['TrainingJobName'] = job_name + training_parameters['TrainingJobName'] = job_name if output_data_config_path is not None: - parameters['OutputDataConfig']['S3OutputPath'] = output_data_config_path + training_parameters['OutputDataConfig']['S3OutputPath'] = output_data_config_path if data is not None and is_data_placeholder: # Replace the 'S3Uri' key with one that supports JSONpath value. # Support for uri str only: The list will only contain 1 element - data_uri = parameters['InputDataConfig'][0]['DataSource']['S3DataSource'].pop('S3Uri', None) - parameters['InputDataConfig'][0]['DataSource']['S3DataSource']['S3Uri.$'] = data_uri + data_uri = training_parameters['InputDataConfig'][0]['DataSource']['S3DataSource'].pop('S3Uri', None) + training_parameters['InputDataConfig'][0]['DataSource']['S3DataSource']['S3Uri.$'] = data_uri if hyperparameters is not None: if not isinstance(hyperparameters, Placeholder): if estimator.hyperparameters() is not None: hyperparameters = self.__merge_hyperparameters(hyperparameters, estimator.hyperparameters()) - parameters['HyperParameters'] = hyperparameters + training_parameters['HyperParameters'] = hyperparameters if experiment_config is not None: - parameters['ExperimentConfig'] = experiment_config + training_parameters['ExperimentConfig'] = experiment_config - if 'S3Operations' in parameters: - del parameters['S3Operations'] + if 'S3Operations' in training_parameters: + del training_parameters['S3Operations'] if tags: - parameters['Tags'] = tags_dict_to_kv_list(tags) + training_parameters['Tags'] = tags if isinstance(tags, Placeholder) else tags_dict_to_kv_list(tags) - kwargs[Field.Parameters.value] = parameters + if Field.Parameters.value in kwargs and isinstance(kwargs[Field.Parameters.value], dict): + # Update training parameters with input parameters + merge_dicts(training_parameters, kwargs[Field.Parameters.value]) + + kwargs[Field.Parameters.value] = training_parameters super(TrainingStep, self).__init__(state_id, **kwargs) def get_expected_model(self, model_name=None): diff --git a/tests/integ/test_sagemaker_steps.py b/tests/integ/test_sagemaker_steps.py index a04f110..4024a45 100644 --- a/tests/integ/test_sagemaker_steps.py +++ b/tests/integ/test_sagemaker_steps.py @@ -107,6 +107,64 @@ def test_training_step(pca_estimator_fixture, record_set_fixture, sfn_client, sf # End of Cleanup +def test_training_step_with_placeholders(pca_estimator_fixture, record_set_fixture, sfn_client, sfn_role_arn): + execution_input = ExecutionInput(schema={ + 'JobName': str, + 'HyperParameters': str, + 'InstanceCount': int, + 'InstanceType': str, + 'MaxRun': int + }) + + parameters = { + 'HyperParameters': execution_input['HyperParameters'], + 'ResourceConfig': { + 'InstanceCount': execution_input['InstanceCount'], + 'InstanceType': execution_input['InstanceType'] + }, + 'StoppingCondition': { + 'MaxRuntimeInSeconds': execution_input['MaxRun'] + } + } + + training_step = TrainingStep('create_training_job_step', estimator=pca_estimator_fixture, + job_name=execution_input['JobName'], data=record_set_fixture, mini_batch_size=200, + parameters=parameters) + training_step.add_retry(SAGEMAKER_RETRY_STRATEGY) + workflow_graph = Chain([training_step]) + + with timeout(minutes=DEFAULT_TIMEOUT_MINUTES): + # Create workflow and check definition + workflow = create_workflow_and_check_definition( + workflow_graph=workflow_graph, + workflow_name=unique_name_from_base("integ-test-training-step-workflow"), + sfn_client=sfn_client, + sfn_role_arn=sfn_role_arn + ) + + inputs = { + 'JobName': generate_job_name(), + 'HyperParameters': { + "num_components": "48", + "feature_dim": "784", + "mini_batch_size": "250" + }, + 'InstanceCount': INSTANCE_COUNT, + 'InstanceType': INSTANCE_TYPE, + 'MaxRun': 100000 + } + + # Execute workflow + execution = workflow.execute(inputs=inputs) + execution_output = execution.get_output(wait=True) + + # Check workflow output + assert execution_output.get("TrainingJobStatus") == "Completed" + + # Cleanup + state_machine_delete_wait(sfn_client, workflow.state_machine_arn) + + def test_model_step(trained_estimator, sfn_client, sagemaker_session, sfn_role_arn): # Build workflow definition model_name = generate_job_name() diff --git a/tests/unit/test_sagemaker_steps.py b/tests/unit/test_sagemaker_steps.py index cc17bae..4c1000d 100644 --- a/tests/unit/test_sagemaker_steps.py +++ b/tests/unit/test_sagemaker_steps.py @@ -335,68 +335,100 @@ def test_training_step_creation_with_placeholders(pca_estimator): execution_input = ExecutionInput(schema={ 'Data': str, 'OutputPath': str, - 'HyperParameters': str + 'HyperParameters': str, + 'ExperimentConfig': str, + 'Tags': str, + 'InstanceCount': int, + 'InstanceType': str, + 'MaxRun': int, + 'MetricDefinitions': str, + 'MaxWait': int, + 'CheckpointS3Uri': str, + 'CheckpointLocalPath': str, + 'EnableSagemakerMetrics': bool, + 'EnableNetworkIsolation': bool, + 'Environment': str }) step_input = StepInput(schema={ 'JobName': str, }) - step = TrainingStep('Training', - estimator=pca_estimator, - job_name=step_input['JobName'], - data=execution_input['Data'], - output_data_config_path=execution_input['OutputPath'], - experiment_config={ - 'ExperimentName': 'pca_experiment', - 'TrialName': 'pca_trial', - 'TrialComponentDisplayName': 'Training' - }, - tags=DEFAULT_TAGS, - hyperparameters=execution_input['HyperParameters'] - ) - assert step.to_dict() == { - 'Type': 'Task', - 'Parameters': { + parameters = { 'AlgorithmSpecification': { 'TrainingImage': PCA_IMAGE, - 'TrainingInputMode': 'File' + 'TrainingInputMode': 'File', + 'MetricDefinitions': execution_input['MetricDefinitions'], + 'EnableSageMakerMetricsTimeSeries': execution_input['EnableSagemakerMetrics'] }, - 'OutputDataConfig': { - 'S3OutputPath.$': "$$.Execution.Input['OutputPath']" + 'CheckpointConfig': { + 'S3Uri': execution_input['CheckpointS3Uri'], + 'LocalPath': execution_input['CheckpointLocalPath'] }, + 'EnableNetworkIsolation': execution_input['EnableNetworkIsolation'], 'StoppingCondition': { - 'MaxRuntimeInSeconds': 86400 + 'MaxRuntimeInSeconds': execution_input['MaxRun'], + 'MaxWaitTimeInSeconds': execution_input['MaxWait'] }, 'ResourceConfig': { - 'InstanceCount': 1, - 'InstanceType': 'ml.c4.xlarge', - 'VolumeSizeInGB': 30 + 'InstanceCount': execution_input['InstanceCount'], + 'InstanceType': execution_input['InstanceType'] }, - 'RoleArn': EXECUTION_ROLE, - 'HyperParameters.$': "$$.Execution.Input['HyperParameters']", - 'InputDataConfig': [ - { - 'ChannelName': 'training', - 'DataSource': { - 'S3DataSource': { - 'S3DataDistributionType': 'FullyReplicated', - 'S3DataType': 'S3Prefix', - 'S3Uri.$': "$$.Execution.Input['Data']" - } + 'Environment': execution_input['Environment'], + 'ExperimentConfig': execution_input['ExperimentConfig'] + } + + step = TrainingStep('Training', + estimator=pca_estimator, + job_name=step_input['JobName'], + data=execution_input['Data'], + output_data_config_path=execution_input['OutputPath'], + experiment_config=execution_input['ExperimentConfig'], + tags=execution_input['Tags'], + mini_batch_size=1000, + hyperparameters=execution_input['HyperParameters'], + parameters=parameters + ) + assert step.to_dict()['Parameters'] == { + 'AlgorithmSpecification': { + 'EnableSageMakerMetricsTimeSeries.$': "$$.Execution.Input['EnableSagemakerMetrics']", + 'MetricDefinitions.$': "$$.Execution.Input['MetricDefinitions']", + 'TrainingImage': PCA_IMAGE, + 'TrainingInputMode': 'File' + }, + 'CheckpointConfig': {'LocalPath.$': "$$.Execution.Input['CheckpointLocalPath']", + 'S3Uri.$': "$$.Execution.Input['CheckpointS3Uri']"}, + 'EnableNetworkIsolation.$': "$$.Execution.Input['EnableNetworkIsolation']", + 'Environment.$': "$$.Execution.Input['Environment']", + 'OutputDataConfig': { + 'S3OutputPath.$': "$$.Execution.Input['OutputPath']" + }, + 'StoppingCondition': { + 'MaxRuntimeInSeconds.$': "$$.Execution.Input['MaxRun']", + 'MaxWaitTimeInSeconds.$': "$$.Execution.Input['MaxWait']" + }, + 'ResourceConfig': { + 'InstanceCount.$': "$$.Execution.Input['InstanceCount']", + 'InstanceType.$': "$$.Execution.Input['InstanceType']", + 'VolumeSizeInGB': 30 + }, + 'RoleArn': EXECUTION_ROLE, + 'HyperParameters.$': "$$.Execution.Input['HyperParameters']", + 'InputDataConfig': [ + { + 'ChannelName': 'training', + 'DataSource': { + 'S3DataSource': { + 'S3DataDistributionType': 'FullyReplicated', + 'S3DataType': 'S3Prefix', + 'S3Uri.$': "$$.Execution.Input['Data']" } } - ], - 'ExperimentConfig': { - 'ExperimentName': 'pca_experiment', - 'TrialName': 'pca_trial', - 'TrialComponentDisplayName': 'Training' - }, - 'TrainingJobName.$': "$['JobName']", - 'Tags': DEFAULT_TAGS_LIST - }, - 'Resource': 'arn:aws:states:::sagemaker:createTrainingJob.sync', - 'End': True + } + ], + 'ExperimentConfig.$': "$$.Execution.Input['ExperimentConfig']", + 'TrainingJobName.$': "$['JobName']", + 'Tags.$': "$$.Execution.Input['Tags']" }