Skip to content

Commit 1504394

Browse files
committed
change: Add warning for ignored job args in pipeline steps
1 parent 8ac6ca8 commit 1504394

File tree

6 files changed

+219
-2
lines changed

6 files changed

+219
-2
lines changed

src/sagemaker/workflow/steps.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
import abc
17+
import logging
1718
import warnings
1819

1920
from enum import Enum
@@ -54,6 +55,14 @@
5455
if TYPE_CHECKING:
5556
from sagemaker.workflow.step_collections import StepCollection
5657

58+
logger = logging.getLogger(__name__)
59+
60+
JOB_ARG_IGNORE_WARN_MSG_TEMPLATE = (
61+
"The job specific arguments (%s) supplied to the step will be ignored, "
62+
"because `step_args` is presented. These job specific arguments should be supplied "
63+
"in %s to generate the `step_args`."
64+
)
65+
5766

5867
class StepTypeEnum(Enum, metaclass=DefaultEnumMeta):
5968
"""Enum of `Step` types."""
@@ -424,6 +433,9 @@ def __init__(
424433
error_message="The step_args of TrainingStep must be obtained from estimator.fit().",
425434
)
426435

436+
if inputs:
437+
logger.warning(JOB_ARG_IGNORE_WARN_MSG_TEMPLATE, "`inputs`", "estimator.fit()")
438+
427439
self.step_args = step_args
428440
self.estimator = estimator
429441
self.inputs = inputs
@@ -680,6 +692,11 @@ def __init__(
680692
"from transformer.transform().",
681693
)
682694

695+
if inputs:
696+
logger.warning(
697+
JOB_ARG_IGNORE_WARN_MSG_TEMPLATE, "`inputs`", "transformer.transform()"
698+
)
699+
683700
self.step_args = step_args
684701
self.transformer = transformer
685702
self.inputs = inputs
@@ -817,6 +834,13 @@ def __init__(
817834
error_message="The step_args of ProcessingStep must be obtained from processor.run().",
818835
)
819836

837+
if job_arguments or inputs or outputs or code or kms_key:
838+
logger.warning(
839+
JOB_ARG_IGNORE_WARN_MSG_TEMPLATE,
840+
"`job_arguments`, `inputs`, `outputs`, `code` and `kms_key`",
841+
"processor.run()",
842+
)
843+
820844
self.step_args = step_args
821845
self.processor = processor
822846
self.inputs = inputs
@@ -997,6 +1021,9 @@ def __init__(
9971021
error_message="The step_args of TuningStep must be obtained from tuner.fit().",
9981022
)
9991023

1024+
if inputs:
1025+
logger.warning(JOB_ARG_IGNORE_WARN_MSG_TEMPLATE, "`inputs`", "tuner.fit()")
1026+
10001027
self.step_args = step_args
10011028
self.tuner = tuner
10021029
self.inputs = inputs

src/sagemaker/workflow/utilities.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -363,8 +363,8 @@ def override_pipeline_parameter_var(func):
363363
We should remove this decorator after the grace period.
364364
"""
365365
warning_msg_template = (
366-
"The input argument %s of function (%s) is a pipeline variable (%s), which is not allowed. "
367-
"The default_value of this Parameter object will be used to override it. "
366+
"The input argument %s of function (%s) is a pipeline variable (%s), which will not work. "
367+
"Thus the default_value of this Parameter object will be used to override it. "
368368
"Please make sure the default_value is valid."
369369
)
370370

tests/unit/sagemaker/workflow/test_processing_step.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1115,3 +1115,44 @@ def test_processor_with_role_as_pipeline_parameter(
11151115

11161116
step_def = json.loads(pipeline.definition())["Steps"][0]
11171117
assert step_def["Arguments"]["RoleArn"] == {"Get": f"Parameters.{_PARAM_ROLE_NAME}"}
1118+
1119+
1120+
def test_processing_step_with_extra_job_args(pipeline_session, processing_input):
1121+
processor = Processor(
1122+
image_uri=IMAGE_URI,
1123+
role=ROLE,
1124+
instance_count=1,
1125+
instance_type=INSTANCE_TYPE,
1126+
sagemaker_session=pipeline_session,
1127+
)
1128+
1129+
step_args = processor.run(inputs=processing_input)
1130+
1131+
ignored_code = f"s3://{BUCKET}/my-code-to-be-ignored"
1132+
ignored_kms_key = "ignored_kms_key"
1133+
step = ProcessingStep(
1134+
name="MyProcessingStep",
1135+
step_args=step_args,
1136+
description="ProcessingStep description",
1137+
display_name="MyProcessingStep",
1138+
code=ignored_code,
1139+
kms_key=ignored_kms_key,
1140+
)
1141+
1142+
pipeline = Pipeline(
1143+
name="MyPipeline",
1144+
steps=[step],
1145+
sagemaker_session=pipeline_session,
1146+
)
1147+
step_args = get_step_args_helper(step_args, "Processing")
1148+
pipeline_def = pipeline.definition()
1149+
step_def = json.loads(pipeline_def)["Steps"][0]
1150+
assert step_def == {
1151+
"Name": "MyProcessingStep",
1152+
"Description": "ProcessingStep description",
1153+
"DisplayName": "MyProcessingStep",
1154+
"Type": "Processing",
1155+
"Arguments": step_args,
1156+
}
1157+
assert ignored_code not in pipeline_def
1158+
assert ignored_kms_key not in pipeline_def

tests/unit/sagemaker/workflow/test_training_step.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -801,3 +801,41 @@ def test_insert_wrong_step_args_into_training_step(inputs, pipeline_session):
801801
)
802802

803803
assert "The step_args of TrainingStep must be obtained from estimator.fit()" in str(error.value)
804+
805+
806+
def test_training_step_estimator_with_extra_job_args(pipeline_session, training_input):
807+
estimator = Estimator(
808+
role=ROLE,
809+
instance_count=1,
810+
instance_type=INSTANCE_TYPE,
811+
sagemaker_session=pipeline_session,
812+
image_uri=IMAGE_URI,
813+
)
814+
815+
step_args = estimator.fit(inputs=training_input)
816+
817+
ignored_train_input = f"s3://{BUCKET}/my-training-input-to-be-ignored"
818+
step = TrainingStep(
819+
name="MyTrainingStep",
820+
step_args=step_args,
821+
description="TrainingStep description",
822+
display_name="MyTrainingStep",
823+
inputs=ignored_train_input,
824+
)
825+
826+
pipeline = Pipeline(
827+
name="MyPipeline",
828+
steps=[step],
829+
sagemaker_session=pipeline_session,
830+
)
831+
step_args = get_step_args_helper(step_args, "Training")
832+
pipeline_def = pipeline.definition()
833+
step_def = json.loads(pipeline_def)["Steps"][0]
834+
assert step_def == {
835+
"Name": "MyTrainingStep",
836+
"Description": "TrainingStep description",
837+
"DisplayName": "MyTrainingStep",
838+
"Type": "Training",
839+
"Arguments": step_args,
840+
}
841+
assert ignored_train_input not in pipeline_def

tests/unit/sagemaker/workflow/test_transform_step.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,3 +287,47 @@ def test_insert_wrong_step_args_into_transform_step(inputs, pipeline_session):
287287
assert "The step_args of TransformStep must be obtained from transformer.transform()" in str(
288288
error.value
289289
)
290+
291+
292+
def test_transform_step_with_extra_job_args(pipeline_session):
293+
transformer = Transformer(
294+
model_name="my_model",
295+
instance_type="ml.m5.xlarge",
296+
instance_count=1,
297+
output_path="s3://my-bucket/my-output-path",
298+
sagemaker_session=pipeline_session,
299+
)
300+
transform_inputs = TransformInput(data="s3://my-bucket/my-data")
301+
302+
step_args = transformer.transform(
303+
data=transform_inputs.data,
304+
data_type=transform_inputs.data_type,
305+
content_type=transform_inputs.content_type,
306+
compression_type=transform_inputs.compression_type,
307+
split_type=transform_inputs.split_type,
308+
input_filter=transform_inputs.input_filter,
309+
output_filter=transform_inputs.output_filter,
310+
join_source=transform_inputs.join_source,
311+
model_client_config=transform_inputs.model_client_config,
312+
)
313+
314+
ignored_data_path = "s3://my-bucket/my-data-to-be-ignored"
315+
step = TransformStep(
316+
name="MyTransformStep", step_args=step_args, inputs=TransformInput(data=ignored_data_path)
317+
)
318+
319+
pipeline = Pipeline(
320+
name="MyPipeline",
321+
steps=[step],
322+
sagemaker_session=pipeline_session,
323+
)
324+
325+
step_args = get_step_args_helper(step_args, "Transform")
326+
pipeline_def = pipeline.definition()
327+
step_def = json.loads(pipeline_def)["Steps"][0]
328+
assert step_def == {
329+
"Name": "MyTransformStep",
330+
"Type": "Transform",
331+
"Arguments": step_args,
332+
}
333+
assert ignored_data_path not in pipeline_def

tests/unit/sagemaker/workflow/test_tuning_step.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,3 +301,70 @@ def test_insert_wrong_step_args_into_tuning_step(inputs, pipeline_session):
301301
)
302302

303303
assert "The step_args of TuningStep must be obtained from tuner.fit()" in str(error.value)
304+
305+
306+
def test_tuning_step_with_extra_job_args(pipeline_session, entry_point):
307+
pytorch_estimator = PyTorch(
308+
entry_point=entry_point,
309+
role=ROLE,
310+
framework_version="1.5.0",
311+
py_version="py3",
312+
instance_count=1,
313+
instance_type="ml.m5.xlarge",
314+
sagemaker_session=pipeline_session,
315+
enable_sagemaker_metrics=True,
316+
max_retry_attempts=3,
317+
)
318+
319+
hyperparameter_ranges = {
320+
"batch-size": IntegerParameter(64, 128),
321+
}
322+
323+
tuner = HyperparameterTuner(
324+
estimator=pytorch_estimator,
325+
objective_metric_name="test:acc",
326+
objective_type="Maximize",
327+
hyperparameter_ranges=hyperparameter_ranges,
328+
metric_definitions=[{"Name": "test:acc", "Regex": "Overall test accuracy: (.*?);"}],
329+
max_jobs=2,
330+
max_parallel_jobs=2,
331+
)
332+
333+
step_args = tuner.fit(inputs=TrainingInput(s3_data="s3://my-bucket/my-training-input"))
334+
335+
ignored_input = "s3://my-bucket/my-input-to-be-ignored"
336+
step = TuningStep(
337+
name="MyTuningStep",
338+
step_args=step_args,
339+
inputs=ignored_input,
340+
)
341+
342+
pipeline = Pipeline(
343+
name="MyPipeline",
344+
steps=[step],
345+
sagemaker_session=pipeline_session,
346+
)
347+
348+
step_args = get_step_args_helper(step_args, "HyperParameterTuning")
349+
pipeline_def = pipeline.definition()
350+
step_def = json.loads(pipeline_def)["Steps"][0]
351+
352+
# delete sagemaker_job_name b/c of timestamp collision
353+
del step_args["TrainingJobDefinition"]["StaticHyperParameters"]["sagemaker_job_name"]
354+
del step_def["Arguments"]["TrainingJobDefinition"]["StaticHyperParameters"][
355+
"sagemaker_job_name"
356+
]
357+
358+
# delete S3 path assertions for now because job name is included with timestamp. These will be re-enabled after
359+
# caching improvements phase 2.
360+
del step_args["TrainingJobDefinition"]["StaticHyperParameters"]["sagemaker_submit_directory"]
361+
del step_def["Arguments"]["TrainingJobDefinition"]["StaticHyperParameters"][
362+
"sagemaker_submit_directory"
363+
]
364+
365+
assert step_def == {
366+
"Name": "MyTuningStep",
367+
"Type": "Tuning",
368+
"Arguments": step_args,
369+
}
370+
assert ignored_input not in pipeline_def

0 commit comments

Comments
 (0)