Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions src/sagemaker/workflow/steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from __future__ import absolute_import

import abc
import logging
import warnings

from enum import Enum
Expand Down Expand Up @@ -54,6 +55,14 @@
if TYPE_CHECKING:
from sagemaker.workflow.step_collections import StepCollection

logger = logging.getLogger(__name__)

JOB_ARG_IGNORE_WARN_MSG_TEMPLATE = (
"The job specific arguments (%s) supplied to the step will be ignored, "
"because `step_args` is presented. These job specific arguments should be supplied "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: typo in presented. Should be replaced with present

"in %s to generate the `step_args`."
)


class StepTypeEnum(Enum, metaclass=DefaultEnumMeta):
"""Enum of `Step` types."""
Expand Down Expand Up @@ -424,6 +433,9 @@ def __init__(
error_message="The step_args of TrainingStep must be obtained from estimator.fit().",
)

if inputs:
logger.warning(JOB_ARG_IGNORE_WARN_MSG_TEMPLATE, "`inputs`", "estimator.fit()")

self.step_args = step_args
self.estimator = estimator
self.inputs = inputs
Expand Down Expand Up @@ -680,6 +692,11 @@ def __init__(
"from transformer.transform().",
)

if inputs:
logger.warning(
JOB_ARG_IGNORE_WARN_MSG_TEMPLATE, "`inputs`", "transformer.transform()"
)

self.step_args = step_args
self.transformer = transformer
self.inputs = inputs
Expand Down Expand Up @@ -817,6 +834,13 @@ def __init__(
error_message="The step_args of ProcessingStep must be obtained from processor.run().",
)

if job_arguments or inputs or outputs or code or kms_key:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we check for each individually? Otherwise the error message might get confusing

logger.warning(
JOB_ARG_IGNORE_WARN_MSG_TEMPLATE,
"`job_arguments`, `inputs`, `outputs`, `code` and `kms_key`",
"processor.run()",
)

self.step_args = step_args
self.processor = processor
self.inputs = inputs
Expand Down Expand Up @@ -997,6 +1021,9 @@ def __init__(
error_message="The step_args of TuningStep must be obtained from tuner.fit().",
)

if inputs:
logger.warning(JOB_ARG_IGNORE_WARN_MSG_TEMPLATE, "`inputs`", "tuner.fit()")

self.step_args = step_args
self.tuner = tuner
self.inputs = inputs
Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker/workflow/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,8 +363,8 @@ def override_pipeline_parameter_var(func):
We should remove this decorator after the grace period.
"""
warning_msg_template = (
"The input argument %s of function (%s) is a pipeline variable (%s), which is not allowed. "
"The default_value of this Parameter object will be used to override it. "
"The input argument %s of function (%s) is a pipeline variable (%s), which will not work. "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The obvious question here for any user will be: "why will it not work?"

We should get this reviewed by a doc writer to get the least confusing wording

"Thus the default_value of this Parameter object will be used to override it. "
"Please make sure the default_value is valid."
)

Expand Down
41 changes: 41 additions & 0 deletions tests/unit/sagemaker/workflow/test_processing_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -1115,3 +1115,44 @@ def test_processor_with_role_as_pipeline_parameter(

step_def = json.loads(pipeline.definition())["Steps"][0]
assert step_def["Arguments"]["RoleArn"] == {"Get": f"Parameters.{_PARAM_ROLE_NAME}"}


def test_processing_step_with_extra_job_args(pipeline_session, processing_input):
processor = Processor(
image_uri=IMAGE_URI,
role=ROLE,
instance_count=1,
instance_type=INSTANCE_TYPE,
sagemaker_session=pipeline_session,
)

step_args = processor.run(inputs=processing_input)

ignored_code = f"s3://{BUCKET}/my-code-to-be-ignored"
ignored_kms_key = "ignored_kms_key"
step = ProcessingStep(
name="MyProcessingStep",
step_args=step_args,
description="ProcessingStep description",
display_name="MyProcessingStep",
code=ignored_code,
kms_key=ignored_kms_key,
)

pipeline = Pipeline(
name="MyPipeline",
steps=[step],
sagemaker_session=pipeline_session,
)
step_args = get_step_args_helper(step_args, "Processing")
pipeline_def = pipeline.definition()
step_def = json.loads(pipeline_def)["Steps"][0]
assert step_def == {
"Name": "MyProcessingStep",
"Description": "ProcessingStep description",
"DisplayName": "MyProcessingStep",
"Type": "Processing",
"Arguments": step_args,
}
assert ignored_code not in pipeline_def
assert ignored_kms_key not in pipeline_def
38 changes: 38 additions & 0 deletions tests/unit/sagemaker/workflow/test_training_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,3 +801,41 @@ def test_insert_wrong_step_args_into_training_step(inputs, pipeline_session):
)

assert "The step_args of TrainingStep must be obtained from estimator.fit()" in str(error.value)


def test_training_step_estimator_with_extra_job_args(pipeline_session, training_input):
estimator = Estimator(
role=ROLE,
instance_count=1,
instance_type=INSTANCE_TYPE,
sagemaker_session=pipeline_session,
image_uri=IMAGE_URI,
)

step_args = estimator.fit(inputs=training_input)

ignored_train_input = f"s3://{BUCKET}/my-training-input-to-be-ignored"
step = TrainingStep(
name="MyTrainingStep",
step_args=step_args,
description="TrainingStep description",
display_name="MyTrainingStep",
inputs=ignored_train_input,
)

pipeline = Pipeline(
name="MyPipeline",
steps=[step],
sagemaker_session=pipeline_session,
)
step_args = get_step_args_helper(step_args, "Training")
pipeline_def = pipeline.definition()
step_def = json.loads(pipeline_def)["Steps"][0]
assert step_def == {
"Name": "MyTrainingStep",
"Description": "TrainingStep description",
"DisplayName": "MyTrainingStep",
"Type": "Training",
"Arguments": step_args,
}
assert ignored_train_input not in pipeline_def
44 changes: 44 additions & 0 deletions tests/unit/sagemaker/workflow/test_transform_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,3 +287,47 @@ def test_insert_wrong_step_args_into_transform_step(inputs, pipeline_session):
assert "The step_args of TransformStep must be obtained from transformer.transform()" in str(
error.value
)


def test_transform_step_with_extra_job_args(pipeline_session):
transformer = Transformer(
model_name="my_model",
instance_type="ml.m5.xlarge",
instance_count=1,
output_path="s3://my-bucket/my-output-path",
sagemaker_session=pipeline_session,
)
transform_inputs = TransformInput(data="s3://my-bucket/my-data")

step_args = transformer.transform(
data=transform_inputs.data,
data_type=transform_inputs.data_type,
content_type=transform_inputs.content_type,
compression_type=transform_inputs.compression_type,
split_type=transform_inputs.split_type,
input_filter=transform_inputs.input_filter,
output_filter=transform_inputs.output_filter,
join_source=transform_inputs.join_source,
model_client_config=transform_inputs.model_client_config,
)

ignored_data_path = "s3://my-bucket/my-data-to-be-ignored"
step = TransformStep(
name="MyTransformStep", step_args=step_args, inputs=TransformInput(data=ignored_data_path)
)

pipeline = Pipeline(
name="MyPipeline",
steps=[step],
sagemaker_session=pipeline_session,
)

step_args = get_step_args_helper(step_args, "Transform")
pipeline_def = pipeline.definition()
step_def = json.loads(pipeline_def)["Steps"][0]
assert step_def == {
"Name": "MyTransformStep",
"Type": "Transform",
"Arguments": step_args,
}
assert ignored_data_path not in pipeline_def
67 changes: 67 additions & 0 deletions tests/unit/sagemaker/workflow/test_tuning_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,3 +301,70 @@ def test_insert_wrong_step_args_into_tuning_step(inputs, pipeline_session):
)

assert "The step_args of TuningStep must be obtained from tuner.fit()" in str(error.value)


def test_tuning_step_with_extra_job_args(pipeline_session, entry_point):
pytorch_estimator = PyTorch(
entry_point=entry_point,
role=ROLE,
framework_version="1.5.0",
py_version="py3",
instance_count=1,
instance_type="ml.m5.xlarge",
sagemaker_session=pipeline_session,
enable_sagemaker_metrics=True,
max_retry_attempts=3,
)

hyperparameter_ranges = {
"batch-size": IntegerParameter(64, 128),
}

tuner = HyperparameterTuner(
estimator=pytorch_estimator,
objective_metric_name="test:acc",
objective_type="Maximize",
hyperparameter_ranges=hyperparameter_ranges,
metric_definitions=[{"Name": "test:acc", "Regex": "Overall test accuracy: (.*?);"}],
max_jobs=2,
max_parallel_jobs=2,
)

step_args = tuner.fit(inputs=TrainingInput(s3_data="s3://my-bucket/my-training-input"))

ignored_input = "s3://my-bucket/my-input-to-be-ignored"
step = TuningStep(
name="MyTuningStep",
step_args=step_args,
inputs=ignored_input,
)

pipeline = Pipeline(
name="MyPipeline",
steps=[step],
sagemaker_session=pipeline_session,
)

step_args = get_step_args_helper(step_args, "HyperParameterTuning")
pipeline_def = pipeline.definition()
step_def = json.loads(pipeline_def)["Steps"][0]

# delete sagemaker_job_name b/c of timestamp collision
del step_args["TrainingJobDefinition"]["StaticHyperParameters"]["sagemaker_job_name"]
del step_def["Arguments"]["TrainingJobDefinition"]["StaticHyperParameters"][
"sagemaker_job_name"
]

# delete S3 path assertions for now because job name is included with timestamp. These will be re-enabled after
# caching improvements phase 2.
del step_args["TrainingJobDefinition"]["StaticHyperParameters"]["sagemaker_submit_directory"]
del step_def["Arguments"]["TrainingJobDefinition"]["StaticHyperParameters"][
"sagemaker_submit_directory"
]

assert step_def == {
"Name": "MyTuningStep",
"Type": "Tuning",
"Arguments": step_args,
}
assert ignored_input not in pipeline_def