diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 939a360e1b..e80412d665 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -52,6 +52,9 @@ _region_supports_profiler, get_mp_parameters, ) +from sagemaker.workflow.properties import Properties +from sagemaker.workflow.parameters import Parameter +from sagemaker.workflow.entities import Expression from sagemaker.inputs import TrainingInput from sagemaker.job import _Job from sagemaker.local import LocalSession @@ -1456,7 +1459,10 @@ def _get_train_args(cls, estimator, inputs, experiment_config): current_hyperparameters = estimator.hyperparameters() if current_hyperparameters is not None: - hyperparameters = {str(k): str(v) for (k, v) in current_hyperparameters.items()} + hyperparameters = { + str(k): (v if isinstance(v, (Parameter, Expression, Properties)) else str(v)) + for (k, v) in current_hyperparameters.items() + } train_args = config.copy() train_args["input_mode"] = estimator.input_mode diff --git a/src/sagemaker/processing.py b/src/sagemaker/processing.py index 699743b51b..11be13fc88 100644 --- a/src/sagemaker/processing.py +++ b/src/sagemaker/processing.py @@ -32,6 +32,7 @@ from sagemaker.session import Session from sagemaker.network import NetworkConfig # noqa: F401 # pylint: disable=unused-import from sagemaker.workflow.properties import Properties +from sagemaker.workflow.parameters import Parameter from sagemaker.workflow.entities import Expression from sagemaker.dataset_definition.inputs import S3Input, DatasetDefinition from sagemaker.apiutils._base_types import ApiObject @@ -292,7 +293,9 @@ def _normalize_inputs(self, inputs=None, kms_key=None): if isinstance(file_input.source, Properties) or file_input.dataset_definition: normalized_inputs.append(file_input) continue - + if isinstance(file_input.s3_input.s3_uri, (Parameter, Expression, Properties)): + normalized_inputs.append(file_input) + continue # If the source is a local path, upload it to S3 # and save the S3 uri in the ProcessingInput source. parse_result = urlparse(file_input.s3_input.s3_uri) @@ -340,8 +343,7 @@ def _normalize_outputs(self, outputs=None): # Generate a name for the ProcessingOutput if it doesn't have one. if output.output_name is None: output.output_name = "output-{}".format(count) - # if the output's destination is a workflow expression, do no normalization - if isinstance(output.destination, Expression): + if isinstance(output.destination, (Parameter, Expression, Properties)): normalized_outputs.append(output) continue # If the output's destination is not an s3_uri, create one. @@ -1099,7 +1101,7 @@ def _create_s3_input(self): self.s3_data_type = self.s3_input.s3_data_type self.s3_input_mode = self.s3_input.s3_input_mode self.s3_data_distribution_type = self.s3_input.s3_data_distribution_type - elif self.source and self.destination: + elif self.source is not None and self.destination is not None: self.s3_input = S3Input( s3_uri=self.source, local_path=self.destination, diff --git a/src/sagemaker/workflow/pipeline.py b/src/sagemaker/workflow/pipeline.py index d1d60fa39c..d3d42e1b49 100644 --- a/src/sagemaker/workflow/pipeline.py +++ b/src/sagemaker/workflow/pipeline.py @@ -83,7 +83,7 @@ def create( Args: role_arn (str): The role arn that is assumed by the pipeline to create step artifacts. - pipeline_description (str): A description of the pipeline. + description (str): A description of the pipeline. experiment_name (str): The name of the experiment. tags (List[Dict[str, str]]): A list of {"Key": "string", "Value": "string"} dicts as tags. diff --git a/tests/unit/sagemaker/workflow/test_steps.py b/tests/unit/sagemaker/workflow/test_steps.py index 02bb1545e6..1a1cff6815 100644 --- a/tests/unit/sagemaker/workflow/test_steps.py +++ b/tests/unit/sagemaker/workflow/test_steps.py @@ -35,6 +35,7 @@ from sagemaker.network import NetworkConfig from sagemaker.transformer import Transformer from sagemaker.workflow.properties import Properties +from sagemaker.workflow.parameters import ParameterString, ParameterInteger from sagemaker.workflow.steps import ( ProcessingStep, Step, @@ -112,16 +113,27 @@ def test_custom_step(): def test_training_step(sagemaker_session): + instance_type_parameter = ParameterString(name="InstanceType", default_value="c4.4xlarge") + instance_count_parameter = ParameterInteger(name="InstanceCount", default_value=1) + data_source_uri_parameter = ParameterString( + name="DataSourceS3Uri", default_value=f"s3://{BUCKET}/train_manifest" + ) + training_epochs_parameter = ParameterInteger(name="TrainingEpochs", default_value=5) + training_batch_size_parameter = ParameterInteger(name="TrainingBatchSize", default_value=500) estimator = Estimator( image_uri=IMAGE_URI, role=ROLE, - instance_count=1, - instance_type="c4.4xlarge", + instance_count=instance_count_parameter, + instance_type=instance_type_parameter, profiler_config=ProfilerConfig(system_monitor_interval_millis=500), + hyperparameters={ + "batch-size": training_batch_size_parameter, + "epochs": training_epochs_parameter, + }, rules=[], sagemaker_session=sagemaker_session, ) - inputs = TrainingInput(f"s3://{BUCKET}/train_manifest") + inputs = TrainingInput(s3_data=data_source_uri_parameter) cache_config = CacheConfig(enable_caching=True, expire_after="PT1H") step = TrainingStep( name="MyTrainingStep", estimator=estimator, inputs=inputs, cache_config=cache_config @@ -131,6 +143,10 @@ def test_training_step(sagemaker_session): "Type": "Training", "Arguments": { "AlgorithmSpecification": {"TrainingImage": IMAGE_URI, "TrainingInputMode": "File"}, + "HyperParameters": { + "batch-size": training_batch_size_parameter, + "epochs": training_epochs_parameter, + }, "InputDataConfig": [ { "ChannelName": "training", @@ -138,15 +154,15 @@ def test_training_step(sagemaker_session): "S3DataSource": { "S3DataDistributionType": "FullyReplicated", "S3DataType": "S3Prefix", - "S3Uri": f"s3://{BUCKET}/train_manifest", + "S3Uri": data_source_uri_parameter, } }, } ], "OutputDataConfig": {"S3OutputPath": f"s3://{BUCKET}/"}, "ResourceConfig": { - "InstanceCount": 1, - "InstanceType": "c4.4xlarge", + "InstanceCount": instance_count_parameter, + "InstanceType": instance_type_parameter, "VolumeSizeInGB": 30, }, "RoleArn": ROLE, @@ -162,16 +178,21 @@ def test_training_step(sagemaker_session): def test_processing_step(sagemaker_session): + processing_input_data_uri_parameter = ParameterString( + name="ProcessingInputDataUri", default_value=f"s3://{BUCKET}/processing_manifest" + ) + instance_type_parameter = ParameterString(name="InstanceType", default_value="ml.m4.4xlarge") + instance_count_parameter = ParameterInteger(name="InstanceCount", default_value=1) processor = Processor( image_uri=IMAGE_URI, role=ROLE, - instance_count=1, - instance_type="ml.m4.4xlarge", + instance_count=instance_count_parameter, + instance_type=instance_type_parameter, sagemaker_session=sagemaker_session, ) inputs = [ ProcessingInput( - source=f"s3://{BUCKET}/processing_manifest", + source=processing_input_data_uri_parameter, destination="processing_manifest", ) ] @@ -198,14 +219,14 @@ def test_processing_step(sagemaker_session): "S3DataDistributionType": "FullyReplicated", "S3DataType": "S3Prefix", "S3InputMode": "File", - "S3Uri": "s3://my-bucket/processing_manifest", + "S3Uri": processing_input_data_uri_parameter, }, } ], "ProcessingResources": { "ClusterConfig": { - "InstanceCount": 1, - "InstanceType": "ml.m4.4xlarge", + "InstanceCount": instance_count_parameter, + "InstanceType": instance_type_parameter, "VolumeSizeInGB": 30, } },