diff --git a/src/sagemaker/workflow/steps.py b/src/sagemaker/workflow/steps.py index f9915ce6be..e39a08323b 100644 --- a/src/sagemaker/workflow/steps.py +++ b/src/sagemaker/workflow/steps.py @@ -14,6 +14,7 @@ from __future__ import absolute_import import abc +import logging import warnings from enum import Enum @@ -54,6 +55,14 @@ if TYPE_CHECKING: from sagemaker.workflow.step_collections import StepCollection +logger = logging.getLogger(__name__) + +JOB_ARG_IGNORE_WARN_MSG_TEMPLATE = ( + "The job specific arguments (%s) supplied to the step will be ignored, " + "because `step_args` is presented. These job specific arguments should be supplied " + "in %s to generate the `step_args`." +) + class StepTypeEnum(Enum, metaclass=DefaultEnumMeta): """Enum of `Step` types.""" @@ -424,6 +433,9 @@ def __init__( error_message="The step_args of TrainingStep must be obtained from estimator.fit().", ) + if inputs: + logger.warning(JOB_ARG_IGNORE_WARN_MSG_TEMPLATE, "`inputs`", "estimator.fit()") + self.step_args = step_args self.estimator = estimator self.inputs = inputs @@ -680,6 +692,11 @@ def __init__( "from transformer.transform().", ) + if inputs: + logger.warning( + JOB_ARG_IGNORE_WARN_MSG_TEMPLATE, "`inputs`", "transformer.transform()" + ) + self.step_args = step_args self.transformer = transformer self.inputs = inputs @@ -817,6 +834,13 @@ def __init__( error_message="The step_args of ProcessingStep must be obtained from processor.run().", ) + if job_arguments or inputs or outputs or code or kms_key: + logger.warning( + JOB_ARG_IGNORE_WARN_MSG_TEMPLATE, + "`job_arguments`, `inputs`, `outputs`, `code` and `kms_key`", + "processor.run()", + ) + self.step_args = step_args self.processor = processor self.inputs = inputs @@ -997,6 +1021,9 @@ def __init__( error_message="The step_args of TuningStep must be obtained from tuner.fit().", ) + if inputs: + logger.warning(JOB_ARG_IGNORE_WARN_MSG_TEMPLATE, "`inputs`", "tuner.fit()") + self.step_args = step_args self.tuner = tuner self.inputs = inputs diff --git a/src/sagemaker/workflow/utilities.py b/src/sagemaker/workflow/utilities.py index bdeca6190b..3b553fe7e2 100644 --- a/src/sagemaker/workflow/utilities.py +++ b/src/sagemaker/workflow/utilities.py @@ -363,8 +363,8 @@ def override_pipeline_parameter_var(func): We should remove this decorator after the grace period. """ warning_msg_template = ( - "The input argument %s of function (%s) is a pipeline variable (%s), which is not allowed. " - "The default_value of this Parameter object will be used to override it. " + "The input argument %s of function (%s) is a pipeline variable (%s), which will not work. " + "Thus the default_value of this Parameter object will be used to override it. " "Please make sure the default_value is valid." ) diff --git a/tests/unit/sagemaker/workflow/test_processing_step.py b/tests/unit/sagemaker/workflow/test_processing_step.py index 4ce6e5302c..72acf4b50e 100644 --- a/tests/unit/sagemaker/workflow/test_processing_step.py +++ b/tests/unit/sagemaker/workflow/test_processing_step.py @@ -1115,3 +1115,44 @@ def test_processor_with_role_as_pipeline_parameter( step_def = json.loads(pipeline.definition())["Steps"][0] assert step_def["Arguments"]["RoleArn"] == {"Get": f"Parameters.{_PARAM_ROLE_NAME}"} + + +def test_processing_step_with_extra_job_args(pipeline_session, processing_input): + processor = Processor( + image_uri=IMAGE_URI, + role=ROLE, + instance_count=1, + instance_type=INSTANCE_TYPE, + sagemaker_session=pipeline_session, + ) + + step_args = processor.run(inputs=processing_input) + + ignored_code = f"s3://{BUCKET}/my-code-to-be-ignored" + ignored_kms_key = "ignored_kms_key" + step = ProcessingStep( + name="MyProcessingStep", + step_args=step_args, + description="ProcessingStep description", + display_name="MyProcessingStep", + code=ignored_code, + kms_key=ignored_kms_key, + ) + + pipeline = Pipeline( + name="MyPipeline", + steps=[step], + sagemaker_session=pipeline_session, + ) + step_args = get_step_args_helper(step_args, "Processing") + pipeline_def = pipeline.definition() + step_def = json.loads(pipeline_def)["Steps"][0] + assert step_def == { + "Name": "MyProcessingStep", + "Description": "ProcessingStep description", + "DisplayName": "MyProcessingStep", + "Type": "Processing", + "Arguments": step_args, + } + assert ignored_code not in pipeline_def + assert ignored_kms_key not in pipeline_def diff --git a/tests/unit/sagemaker/workflow/test_training_step.py b/tests/unit/sagemaker/workflow/test_training_step.py index b76aca2d96..c1436c13f8 100644 --- a/tests/unit/sagemaker/workflow/test_training_step.py +++ b/tests/unit/sagemaker/workflow/test_training_step.py @@ -801,3 +801,41 @@ def test_insert_wrong_step_args_into_training_step(inputs, pipeline_session): ) assert "The step_args of TrainingStep must be obtained from estimator.fit()" in str(error.value) + + +def test_training_step_estimator_with_extra_job_args(pipeline_session, training_input): + estimator = Estimator( + role=ROLE, + instance_count=1, + instance_type=INSTANCE_TYPE, + sagemaker_session=pipeline_session, + image_uri=IMAGE_URI, + ) + + step_args = estimator.fit(inputs=training_input) + + ignored_train_input = f"s3://{BUCKET}/my-training-input-to-be-ignored" + step = TrainingStep( + name="MyTrainingStep", + step_args=step_args, + description="TrainingStep description", + display_name="MyTrainingStep", + inputs=ignored_train_input, + ) + + pipeline = Pipeline( + name="MyPipeline", + steps=[step], + sagemaker_session=pipeline_session, + ) + step_args = get_step_args_helper(step_args, "Training") + pipeline_def = pipeline.definition() + step_def = json.loads(pipeline_def)["Steps"][0] + assert step_def == { + "Name": "MyTrainingStep", + "Description": "TrainingStep description", + "DisplayName": "MyTrainingStep", + "Type": "Training", + "Arguments": step_args, + } + assert ignored_train_input not in pipeline_def diff --git a/tests/unit/sagemaker/workflow/test_transform_step.py b/tests/unit/sagemaker/workflow/test_transform_step.py index 30960e0c4a..b28152b8ab 100644 --- a/tests/unit/sagemaker/workflow/test_transform_step.py +++ b/tests/unit/sagemaker/workflow/test_transform_step.py @@ -287,3 +287,47 @@ def test_insert_wrong_step_args_into_transform_step(inputs, pipeline_session): assert "The step_args of TransformStep must be obtained from transformer.transform()" in str( error.value ) + + +def test_transform_step_with_extra_job_args(pipeline_session): + transformer = Transformer( + model_name="my_model", + instance_type="ml.m5.xlarge", + instance_count=1, + output_path="s3://my-bucket/my-output-path", + sagemaker_session=pipeline_session, + ) + transform_inputs = TransformInput(data="s3://my-bucket/my-data") + + step_args = transformer.transform( + data=transform_inputs.data, + data_type=transform_inputs.data_type, + content_type=transform_inputs.content_type, + compression_type=transform_inputs.compression_type, + split_type=transform_inputs.split_type, + input_filter=transform_inputs.input_filter, + output_filter=transform_inputs.output_filter, + join_source=transform_inputs.join_source, + model_client_config=transform_inputs.model_client_config, + ) + + ignored_data_path = "s3://my-bucket/my-data-to-be-ignored" + step = TransformStep( + name="MyTransformStep", step_args=step_args, inputs=TransformInput(data=ignored_data_path) + ) + + pipeline = Pipeline( + name="MyPipeline", + steps=[step], + sagemaker_session=pipeline_session, + ) + + step_args = get_step_args_helper(step_args, "Transform") + pipeline_def = pipeline.definition() + step_def = json.loads(pipeline_def)["Steps"][0] + assert step_def == { + "Name": "MyTransformStep", + "Type": "Transform", + "Arguments": step_args, + } + assert ignored_data_path not in pipeline_def diff --git a/tests/unit/sagemaker/workflow/test_tuning_step.py b/tests/unit/sagemaker/workflow/test_tuning_step.py index 52721bc15d..caf4c7aedc 100644 --- a/tests/unit/sagemaker/workflow/test_tuning_step.py +++ b/tests/unit/sagemaker/workflow/test_tuning_step.py @@ -301,3 +301,70 @@ def test_insert_wrong_step_args_into_tuning_step(inputs, pipeline_session): ) assert "The step_args of TuningStep must be obtained from tuner.fit()" in str(error.value) + + +def test_tuning_step_with_extra_job_args(pipeline_session, entry_point): + pytorch_estimator = PyTorch( + entry_point=entry_point, + role=ROLE, + framework_version="1.5.0", + py_version="py3", + instance_count=1, + instance_type="ml.m5.xlarge", + sagemaker_session=pipeline_session, + enable_sagemaker_metrics=True, + max_retry_attempts=3, + ) + + hyperparameter_ranges = { + "batch-size": IntegerParameter(64, 128), + } + + tuner = HyperparameterTuner( + estimator=pytorch_estimator, + objective_metric_name="test:acc", + objective_type="Maximize", + hyperparameter_ranges=hyperparameter_ranges, + metric_definitions=[{"Name": "test:acc", "Regex": "Overall test accuracy: (.*?);"}], + max_jobs=2, + max_parallel_jobs=2, + ) + + step_args = tuner.fit(inputs=TrainingInput(s3_data="s3://my-bucket/my-training-input")) + + ignored_input = "s3://my-bucket/my-input-to-be-ignored" + step = TuningStep( + name="MyTuningStep", + step_args=step_args, + inputs=ignored_input, + ) + + pipeline = Pipeline( + name="MyPipeline", + steps=[step], + sagemaker_session=pipeline_session, + ) + + step_args = get_step_args_helper(step_args, "HyperParameterTuning") + pipeline_def = pipeline.definition() + step_def = json.loads(pipeline_def)["Steps"][0] + + # delete sagemaker_job_name b/c of timestamp collision + del step_args["TrainingJobDefinition"]["StaticHyperParameters"]["sagemaker_job_name"] + del step_def["Arguments"]["TrainingJobDefinition"]["StaticHyperParameters"][ + "sagemaker_job_name" + ] + + # delete S3 path assertions for now because job name is included with timestamp. These will be re-enabled after + # caching improvements phase 2. + del step_args["TrainingJobDefinition"]["StaticHyperParameters"]["sagemaker_submit_directory"] + del step_def["Arguments"]["TrainingJobDefinition"]["StaticHyperParameters"][ + "sagemaker_submit_directory" + ] + + assert step_def == { + "Name": "MyTuningStep", + "Type": "Tuning", + "Arguments": step_args, + } + assert ignored_input not in pipeline_def