diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index b1f8d45241..f655645ffd 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -74,9 +74,7 @@ get_config_value, name_from_base, ) -from sagemaker.workflow.entities import Expression -from sagemaker.workflow.parameters import Parameter -from sagemaker.workflow.properties import Properties +from sagemaker.workflow.entities import PipelineVariable logger = logging.getLogger(__name__) @@ -602,7 +600,7 @@ def _json_encode_hyperparameters(hyperparameters: Dict[str, Any]) -> Dict[str, A current_hyperparameters = hyperparameters if current_hyperparameters is not None: hyperparameters = { - str(k): (v if isinstance(v, (Parameter, Expression, Properties)) else json.dumps(v)) + str(k): (v.to_string() if isinstance(v, PipelineVariable) else json.dumps(v)) for (k, v) in current_hyperparameters.items() } return hyperparameters @@ -1813,7 +1811,7 @@ def _get_train_args(cls, estimator, inputs, experiment_config): current_hyperparameters = estimator.hyperparameters() if current_hyperparameters is not None: hyperparameters = { - str(k): (v if isinstance(v, (Parameter, Expression, Properties)) else str(v)) + str(k): (v.to_string() if isinstance(v, PipelineVariable) else str(v)) for (k, v) in current_hyperparameters.items() } diff --git a/src/sagemaker/parameter.py b/src/sagemaker/parameter.py index a7f8440f3d..8780b7e93b 100644 --- a/src/sagemaker/parameter.py +++ b/src/sagemaker/parameter.py @@ -14,9 +14,8 @@ from __future__ import absolute_import import json -from sagemaker.workflow.parameters import Parameter as PipelineParameter -from sagemaker.workflow.functions import JsonGet as PipelineJsonGet -from sagemaker.workflow.functions import Join as PipelineJoin + +from sagemaker.workflow.entities import PipelineVariable class ParameterRange(object): @@ -73,11 +72,11 @@ def as_tuning_range(self, name): return { "Name": name, "MinValue": str(self.min_value) - if not isinstance(self.min_value, (PipelineParameter, PipelineJsonGet, PipelineJoin)) - else self.min_value, + if not isinstance(self.min_value, PipelineVariable) + else self.min_value.to_string(), "MaxValue": str(self.max_value) - if not isinstance(self.max_value, (PipelineParameter, PipelineJsonGet, PipelineJoin)) - else self.max_value, + if not isinstance(self.max_value, PipelineVariable) + else self.max_value.to_string(), "ScalingType": self.scaling_type, } @@ -112,8 +111,7 @@ def __init__(self, values): # pylint: disable=super-init-not-called """ values = values if isinstance(values, list) else [values] self.values = [ - str(v) if not isinstance(v, (PipelineParameter, PipelineJsonGet, PipelineJoin)) else v - for v in values + str(v) if not isinstance(v, PipelineVariable) else v.to_string() for v in values ] def as_tuning_range(self, name): diff --git a/src/sagemaker/tuner.py b/src/sagemaker/tuner.py index f661e26e04..85cf811d64 100644 --- a/src/sagemaker/tuner.py +++ b/src/sagemaker/tuner.py @@ -38,6 +38,7 @@ IntegerParameter, ParameterRange, ) +from sagemaker.workflow.entities import PipelineVariable from sagemaker.workflow.parameters import Parameter as PipelineParameter from sagemaker.workflow.functions import JsonGet as PipelineJsonGet from sagemaker.workflow.functions import Join as PipelineJoin @@ -376,9 +377,7 @@ def _prepare_static_hyperparameters( """Prepare static hyperparameters for one estimator before tuning.""" # Remove any hyperparameter that will be tuned static_hyperparameters = { - str(k): str(v) - if not isinstance(v, (PipelineParameter, PipelineJsonGet, PipelineJoin)) - else v + str(k): str(v) if not isinstance(v, PipelineVariable) else v.to_string() for (k, v) in estimator.hyperparameters().items() } for hyperparameter_name in hyperparameter_ranges.keys(): diff --git a/src/sagemaker/workflow/entities.py b/src/sagemaker/workflow/entities.py index 1f85444454..f530c9b669 100644 --- a/src/sagemaker/workflow/entities.py +++ b/src/sagemaker/workflow/entities.py @@ -16,7 +16,7 @@ import abc from enum import EnumMeta -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Union, Optional PrimitiveType = Union[str, int, bool, float, None] RequestType = Union[Dict[str, Any], List[Dict[str, Any]]] @@ -57,3 +57,80 @@ class Expression(abc.ABC): @abc.abstractmethod def expr(self) -> RequestType: """Get the expression structure for workflow service calls.""" + + +class PipelineVariable(Expression): + """Base object for pipeline variables + + PipelineVariables must implement the expr property. + """ + + def __add__(self, other: Union[Expression, PrimitiveType]): + """Add function for PipelineVariable + + Args: + other (Union[Expression, PrimitiveType]): The other object to be concatenated. + + Always raise an error since pipeline variables do not support concatenation + """ + + raise TypeError("Pipeline variables do not support concatenation.") + + def __str__(self): + """Override built-in String function for PipelineVariable""" + raise TypeError("Pipeline variables do not support __str__ operation.") + + def __int__(self): + """Override built-in Integer function for PipelineVariable""" + raise TypeError("Pipeline variables do not support __int__ operation.") + + def __float__(self): + """Override built-in Float function for PipelineVariable""" + raise TypeError("Pipeline variables do not support __float__ operation.") + + def to_string(self): + """Prompt the pipeline to convert the pipeline variable to String in runtime""" + from sagemaker.workflow.functions import Join + + return Join(on="", values=[self]) + + @property + @abc.abstractmethod + def expr(self) -> RequestType: + """Get the expression structure for workflow service calls.""" + + def startswith( + self, + prefix: Union[str, tuple], # pylint: disable=unused-argument + start: Optional[int] = None, # pylint: disable=unused-argument + end: Optional[int] = None, # pylint: disable=unused-argument + ) -> bool: + """Simulate the Python string's built-in method: startswith + + Args: + prefix (str, tuple): The (tuple of) string to be checked. + start (int): To set the start index of the matching boundary (default: None). + end (int): To set the end index of the matching boundary (default: None). + + Return: + bool: Always return False as Pipeline variables are parsed during execution runtime + """ + return False + + def endswith( + self, + suffix: Union[str, tuple], # pylint: disable=unused-argument + start: Optional[int] = None, # pylint: disable=unused-argument + end: Optional[int] = None, # pylint: disable=unused-argument + ) -> bool: + """Simulate the Python string's built-in method: endswith + + Args: + suffix (str, tuple): The (tuple of) string to be checked. + start (int): To set the start index of the matching boundary (default: None). + end (int): To set the end index of the matching boundary (default: None). + + Return: + bool: Always return False as Pipeline variables are parsed during execution runtime + """ + return False diff --git a/src/sagemaker/workflow/execution_variables.py b/src/sagemaker/workflow/execution_variables.py index d81b359bde..22474c8856 100644 --- a/src/sagemaker/workflow/execution_variables.py +++ b/src/sagemaker/workflow/execution_variables.py @@ -14,12 +14,12 @@ from __future__ import absolute_import from sagemaker.workflow.entities import ( - Expression, RequestType, + PipelineVariable, ) -class ExecutionVariable(Expression): +class ExecutionVariable(PipelineVariable): """Pipeline execution variables for workflow.""" def __init__(self, name: str): @@ -30,6 +30,13 @@ def __init__(self, name: str): """ self.name = name + def to_string(self) -> PipelineVariable: + """Prompt the pipeline to convert the pipeline variable to String in runtime + + As ExecutionVariable is treated as String in runtime, no extra actions are needed. + """ + return self + @property def expr(self) -> RequestType: """The 'Get' expression dict for an `ExecutionVariable`.""" diff --git a/src/sagemaker/workflow/functions.py b/src/sagemaker/workflow/functions.py index e0076322de..36bd69fbff 100644 --- a/src/sagemaker/workflow/functions.py +++ b/src/sagemaker/workflow/functions.py @@ -17,12 +17,12 @@ import attr -from sagemaker.workflow.entities import Expression +from sagemaker.workflow.entities import PipelineVariable from sagemaker.workflow.properties import PropertyFile @attr.s -class Join(Expression): +class Join(PipelineVariable): """Join together properties. Examples: @@ -38,15 +38,23 @@ class Join(Expression): Attributes: values (List[Union[PrimitiveType, Parameter, Expression]]): The primitive type values, parameters, step properties, expressions to join. - on_str (str): The string to join the values on (Defaults to ""). + on (str): The string to join the values on (Defaults to ""). """ on: str = attr.ib(factory=str) values: List = attr.ib(factory=list) + def to_string(self) -> PipelineVariable: + """Prompt the pipeline to convert the pipeline variable to String in runtime + + As Join is treated as String in runtime, no extra actions are needed. + """ + return self + @property def expr(self): """The expression dict for a `Join` function.""" + return { "Std:Join": { "On": self.on, @@ -58,7 +66,7 @@ def expr(self): @attr.s -class JsonGet(Expression): +class JsonGet(PipelineVariable): """Get JSON properties from PropertyFiles. Attributes: diff --git a/src/sagemaker/workflow/parameters.py b/src/sagemaker/workflow/parameters.py index 40ad70b014..f88e5c2097 100644 --- a/src/sagemaker/workflow/parameters.py +++ b/src/sagemaker/workflow/parameters.py @@ -24,6 +24,7 @@ Entity, PrimitiveType, RequestType, + PipelineVariable, ) @@ -48,7 +49,7 @@ def python_type(self) -> Type: @attr.s -class Parameter(Entity): +class Parameter(PipelineVariable, Entity): """Pipeline parameter for workflow. Attributes: @@ -170,6 +171,13 @@ def __hash__(self): """Hash function for parameter types""" return hash(tuple(self.to_request())) + def to_string(self) -> PipelineVariable: + """Prompt the pipeline to convert the pipeline variable to String in runtime + + As ParameterString is treated as String in runtime, no extra actions are needed. + """ + return self + def to_request(self) -> RequestType: """Get the request structure for workflow service calls.""" request_dict = super(ParameterString, self).to_request() diff --git a/src/sagemaker/workflow/properties.py b/src/sagemaker/workflow/properties.py index 6e9aba4408..480fddada1 100644 --- a/src/sagemaker/workflow/properties.py +++ b/src/sagemaker/workflow/properties.py @@ -13,16 +13,17 @@ """The properties definitions for workflow.""" from __future__ import absolute_import +from abc import ABCMeta from typing import Dict, Union, List import attr import botocore.loaders -from sagemaker.workflow.entities import Expression +from sagemaker.workflow.entities import Expression, PipelineVariable -class PropertiesMeta(type): +class PropertiesMeta(ABCMeta): """Load an internal shapes attribute from the botocore service model for sagemaker and emr service. @@ -44,7 +45,7 @@ def __new__(mcs, *args, **kwargs): return super().__new__(mcs, *args, **kwargs) -class Properties(metaclass=PropertiesMeta): +class Properties(PipelineVariable, metaclass=PropertiesMeta): """Properties for use in workflow expressions.""" def __init__( diff --git a/tests/integ/sagemaker/workflow/test_pipeline_var_behaviors.py b/tests/integ/sagemaker/workflow/test_pipeline_var_behaviors.py new file mode 100644 index 0000000000..f161901691 --- /dev/null +++ b/tests/integ/sagemaker/workflow/test_pipeline_var_behaviors.py @@ -0,0 +1,119 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest +from botocore.exceptions import WaiterError + +from sagemaker import get_execution_role, utils +from sagemaker.workflow.condition_step import ConditionStep +from sagemaker.workflow.conditions import ConditionGreaterThan +from sagemaker.workflow.fail_step import FailStep +from sagemaker.workflow.functions import Join +from sagemaker.workflow.parameters import ParameterString, ParameterInteger +from sagemaker.workflow.pipeline import Pipeline + + +@pytest.fixture +def role(sagemaker_session): + return get_execution_role(sagemaker_session) + + +@pytest.fixture +def pipeline_name(): + return utils.unique_name_from_base("my-pipeline-vars") + + +def test_ppl_var_to_string_and_add(sagemaker_session, role, pipeline_name): + param_str = ParameterString(name="MyString", default_value="1") + param_int = ParameterInteger(name="MyInteger", default_value=3) + + cond = ConditionGreaterThan(left=param_str, right=param_int.to_string()) + step_cond = ConditionStep( + name="CondStep", + conditions=[cond], + if_steps=[], + else_steps=[], + ) + join_fn1 = Join( + on=" ", + values=[ + "condition greater than check return:", + step_cond.properties.Outcome.to_string(), + "and left side param str is", + param_str, + "and right side param int is", + param_int, + ], + ) + + step_fail = FailStep( + name="FailStep", + error_message=join_fn1, + ) + pipeline = Pipeline( + name=pipeline_name, + parameters=[param_str, param_int], + steps=[step_cond, step_fail], + sagemaker_session=sagemaker_session, + ) + + try: + response = pipeline.create(role) + pipeline_arn = response["PipelineArn"] + execution = pipeline.start() + response = execution.describe() + assert response["PipelineArn"] == pipeline_arn + + try: + execution.wait(delay=30, max_attempts=60) + except WaiterError: + pass + execution_steps = execution.list_steps() + + assert len(execution_steps) == 2 + for execution_step in execution_steps: + if execution_step["StepName"] == "CondStep": + assert execution_step["StepStatus"] == "Succeeded" + continue + assert execution_step["StepName"] == "FailStep" + assert execution_step["StepStatus"] == "Failed" + assert ( + execution_step["FailureReason"] == "condition greater than check return: false " + "and left side param str is 1 and right side param int is 3" + ) + + # Update int param to update cond step outcome + execution = pipeline.start(parameters={"MyInteger": 0}) + try: + execution.wait(delay=30, max_attempts=60) + except WaiterError: + pass + execution_steps = execution.list_steps() + + assert len(execution_steps) == 2 + for execution_step in execution_steps: + if execution_step["StepName"] == "CondStep": + assert execution_step["StepStatus"] == "Succeeded" + continue + assert execution_step["StepName"] == "FailStep" + assert execution_step["StepStatus"] == "Failed" + assert ( + execution_step["FailureReason"] == "condition greater than check return: true " + "and left side param str is 1 and right side param int is 0" + ) + finally: + try: + pipeline.delete() + except Exception: + pass diff --git a/tests/integ/sagemaker/workflow/test_tuning_steps.py b/tests/integ/sagemaker/workflow/test_tuning_steps.py index 7cfb542cb6..347420c7e0 100644 --- a/tests/integ/sagemaker/workflow/test_tuning_steps.py +++ b/tests/integ/sagemaker/workflow/test_tuning_steps.py @@ -44,7 +44,7 @@ def role(sagemaker_session): @pytest.fixture def pipeline_name(): - return utils.unique_name_from_base("my-pipeline-training") + return utils.unique_name_from_base("my-pipeline-tuning") @pytest.fixture @@ -105,8 +105,8 @@ def test_tuning_single_algo( max_retry_attempts=3, ) - min_batch_size = ParameterString(name="MinBatchSize", default_value="64") - max_batch_size = ParameterString(name="MaxBatchSize", default_value="128") + min_batch_size = ParameterInteger(name="MinBatchSize", default_value=64) + max_batch_size = ParameterInteger(name="MaxBatchSize", default_value=128) hyperparameter_ranges = { "batch-size": IntegerParameter(min_batch_size, max_batch_size), } diff --git a/tests/unit/sagemaker/workflow/test_entities.py b/tests/unit/sagemaker/workflow/test_entities.py index d431d7235f..03c4fd22a1 100644 --- a/tests/unit/sagemaker/workflow/test_entities.py +++ b/tests/unit/sagemaker/workflow/test_entities.py @@ -13,14 +13,26 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import +import json + import pytest from enum import Enum +from mock.mock import Mock, PropertyMock + +import sagemaker +from sagemaker.workflow.condition_step import ConditionStep +from sagemaker.workflow.conditions import ConditionGreaterThan from sagemaker.workflow.entities import ( DefaultEnumMeta, Entity, ) +from sagemaker.workflow.fail_step import FailStep +from sagemaker.workflow.functions import Join, JsonGet +from sagemaker.workflow.parameters import ParameterString, ParameterInteger +from sagemaker.workflow.pipeline import Pipeline +from sagemaker.workflow.properties import PropertyFile, Properties class CustomEntity(Entity): @@ -46,6 +58,46 @@ def custom_entity_list(): return [CustomEntity(1), CustomEntity(2)] +@pytest.fixture +def boto_session(): + role_mock = Mock() + type(role_mock).arn = PropertyMock(return_value="role") + + resource_mock = Mock() + resource_mock.Role.return_value = role_mock + + session_mock = Mock(region_name="us-west-2") + session_mock.resource.return_value = resource_mock + + return session_mock + + +@pytest.fixture +def client(): + """Mock client. + + Considerations when appropriate: + + * utilize botocore.stub.Stubber + * separate runtime client from client + """ + client_mock = Mock() + client_mock._client_config.user_agent = ( + "Boto3/1.14.24 Python/3.8.5 Linux/5.4.0-42-generic Botocore/1.17.24 Resource" + ) + return client_mock + + +@pytest.fixture +def sagemaker_session(boto_session, client): + return sagemaker.session.Session( + boto_session=boto_session, + sagemaker_client=client, + sagemaker_runtime_client=client, + default_bucket="my-bucket", + ) + + def test_entity(custom_entity): request_struct = {"foo": 1} assert custom_entity.to_request() == request_struct @@ -53,3 +105,115 @@ def test_entity(custom_entity): def test_default_enum_meta(): assert CustomEnum().value == 1 + + +def test_pipeline_variable_in_pipeline_definition(sagemaker_session): + param_str = ParameterString(name="MyString", default_value="1") + param_int = ParameterInteger(name="MyInteger", default_value=3) + + property_file = PropertyFile( + name="name", + output_name="result", + path="output", + ) + json_get_func2 = JsonGet( + step_name="my-step", + property_file=property_file, + json_path="my-json-path", + ) + prop = Properties("Steps.MyStep", "DescribeProcessingJobResponse") + + cond = ConditionGreaterThan(left=param_str, right=param_int.to_string()) + step_fail = FailStep( + name="MyFailStep", + error_message=Join( + on=" ", + values=[ + "Execution failed due to condition check fails, see:", + json_get_func2.to_string(), + prop.ProcessingOutputConfig.Outputs["MyOutputName"].S3Output.S3Uri.to_string(), + param_int, + ], + ), + ) + step_cond = ConditionStep( + name="MyCondStep", + conditions=[cond], + if_steps=[], + else_steps=[step_fail], + ) + pipeline = Pipeline( + name="MyPipeline", + parameters=[param_str, param_int], + steps=[step_cond], + sagemaker_session=sagemaker_session, + ) + + dsl = json.loads(pipeline.definition()) + assert dsl["Parameters"] == [ + {"Name": "MyString", "Type": "String", "DefaultValue": "1"}, + {"Name": "MyInteger", "Type": "Integer", "DefaultValue": 3}, + ] + assert len(dsl["Steps"]) == 1 + assert dsl["Steps"][0] == { + "Name": "MyCondStep", + "Type": "Condition", + "Arguments": { + "Conditions": [ + { + "Type": "GreaterThan", + "LeftValue": {"Get": "Parameters.MyString"}, + "RightValue": { + "Std:Join": { + "On": "", + "Values": [{"Get": "Parameters.MyInteger"}], + }, + }, + }, + ], + "IfSteps": [], + "ElseSteps": [ + { + "Name": "MyFailStep", + "Type": "Fail", + "Arguments": { + "ErrorMessage": { + "Std:Join": { + "On": " ", + "Values": [ + "Execution failed due to condition check fails, see:", + { + "Std:Join": { + "On": "", + "Values": [ + { + "Std:JsonGet": { + "PropertyFile": { + "Get": "Steps.my-step.PropertyFiles.name" + }, + "Path": "my-json-path", + } + }, + ], + }, + }, + { + "Std:Join": { + "On": "", + "Values": [ + { + "Get": "Steps.MyStep.ProcessingOutputConfig." + + "Outputs['MyOutputName'].S3Output.S3Uri" + }, + ], + }, + }, + {"Get": "Parameters.MyInteger"}, + ], + } + } + }, + } + ], + }, + } diff --git a/tests/unit/sagemaker/workflow/test_execution_variables.py b/tests/unit/sagemaker/workflow/test_execution_variables.py index 1acce9e10c..23a0484d01 100644 --- a/tests/unit/sagemaker/workflow/test_execution_variables.py +++ b/tests/unit/sagemaker/workflow/test_execution_variables.py @@ -10,12 +10,52 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -# language governing permissions and limitations under the License. from __future__ import absolute_import +import pytest + from sagemaker.workflow.execution_variables import ExecutionVariables def test_execution_variable(): var = ExecutionVariables.START_DATETIME assert var.expr == {"Get": "Execution.StartDateTime"} + + +def test_to_string(): + var = ExecutionVariables.START_DATETIME + + assert var.to_string() == var + + +def test_implicit_value(): + var = ExecutionVariables.START_DATETIME + + with pytest.raises(TypeError) as error: + str(var) + assert str(error.value) == "Pipeline variables do not support __str__ operation." + + with pytest.raises(TypeError) as error: + int(var) + assert str(error.value) == "Pipeline variables do not support __int__ operation." + + with pytest.raises(TypeError) as error: + float(var) + assert str(error.value) == "Pipeline variables do not support __float__ operation." + + +def test_string_builtin_funcs_that_return_bool(): + prop = ExecutionVariables.PIPELINE_NAME + # The execution var will only be parsed in runtime (Pipeline backend) so not able to tell in SDK + assert not prop.startswith("MyPipeline") + assert not prop.endswith("MyPipeline") + + +def test_add_func(): + var_start_datetime = ExecutionVariables.START_DATETIME + var_current_datetime = ExecutionVariables.CURRENT_DATETIME + + with pytest.raises(TypeError) as error: + var_start_datetime + var_current_datetime + + assert str(error.value) == "Pipeline variables do not support concatenation." diff --git a/tests/unit/sagemaker/workflow/test_functions.py b/tests/unit/sagemaker/workflow/test_functions.py index 9b07a41d09..a9f03d7c6d 100644 --- a/tests/unit/sagemaker/workflow/test_functions.py +++ b/tests/unit/sagemaker/workflow/test_functions.py @@ -70,8 +70,46 @@ def test_join_expressions(): } -def test_json_get_expressions(): +def test_to_string_on_join(): + func = Join(values=[1, "a", False, 1.1]) + + assert func.to_string() == func + + +def test_implicit_value_on_join(): + func = Join(values=[1, "a", False, 1.1]) + + with pytest.raises(TypeError) as error: + str(func) + assert str(error.value) == "Pipeline variables do not support __str__ operation." + + with pytest.raises(TypeError) as error: + int(func) + assert str(error.value) == "Pipeline variables do not support __int__ operation." + + with pytest.raises(TypeError) as error: + float(func) + assert str(error.value) == "Pipeline variables do not support __float__ operation." + +def test_string_builtin_funcs_that_return_bool_on_join(): + func = Join(on=",", values=["s3:/", "my-bucket", "a"]) + # The func will only be parsed in runtime (Pipeline backend) so not able to tell in SDK + assert not func.startswith("s3") + assert not func.endswith("s3") + + +def test_add_func_of_join(): + func_join1 = Join(values=[1, "a"]) + param = ParameterInteger(name="MyInteger", default_value=3) + + with pytest.raises(TypeError) as error: + func_join1 + param + + assert str(error.value) == "Pipeline variables do not support concatenation." + + +def test_json_get_expressions(): assert JsonGet( step_name="my-step", property_file="my-property-file", @@ -88,7 +126,6 @@ def test_json_get_expressions(): output_name="result", path="output", ) - assert JsonGet( step_name="my-step", property_file=property_file, @@ -119,3 +156,80 @@ def test_json_get_expressions_with_invalid_step_name(): ).expr assert "Please give a valid step name as a string" in str(err.value) + + +def test_to_string_on_json_get(): + func = JsonGet( + step_name="my-step", + property_file="my-property-file", + json_path="my-json-path", + ) + + assert func.to_string().expr == { + "Std:Join": { + "On": "", + "Values": [ + { + "Std:JsonGet": { + "Path": "my-json-path", + "PropertyFile": {"Get": "Steps.my-step.PropertyFiles.my-property-file"}, + } + } + ], + }, + } + + +def test_implicit_value_on_json_get(): + func = JsonGet( + step_name="my-step", + property_file="my-property-file", + json_path="my-json-path", + ) + + with pytest.raises(TypeError) as error: + str(func) + assert str(error.value) == "Pipeline variables do not support __str__ operation." + + with pytest.raises(TypeError) as error: + int(func) + assert str(error.value) == "Pipeline variables do not support __int__ operation." + + with pytest.raises(TypeError) as error: + float(func) + assert str(error.value) == "Pipeline variables do not support __float__ operation." + + +def test_string_builtin_funcs_that_return_bool_on_json_get(): + func = JsonGet( + step_name="my-step", + property_file="my-property-file", + json_path="my-json-path", + ) + # The func will only be parsed in runtime (Pipeline backend) so not able to tell in SDK + assert not func.startswith("s3") + assert not func.endswith("s3") + + +def test_add_func_of_json_get(): + json_get_func1 = JsonGet( + step_name="my-step", + property_file="my-property-file", + json_path="my-json-path", + ) + + property_file = PropertyFile( + name="name", + output_name="result", + path="output", + ) + json_get_func2 = JsonGet( + step_name="my-step", + property_file=property_file, + json_path="my-json-path", + ) + + with pytest.raises(TypeError) as error: + json_get_func1 + json_get_func2 + + assert str(error.value) == "Pipeline variables do not support concatenation." diff --git a/tests/unit/sagemaker/workflow/test_parameters.py b/tests/unit/sagemaker/workflow/test_parameters.py index 552465ba0d..c9aa6cb983 100644 --- a/tests/unit/sagemaker/workflow/test_parameters.py +++ b/tests/unit/sagemaker/workflow/test_parameters.py @@ -35,16 +35,23 @@ def test_parameter(): def test_parameter_with_default(): param = ParameterFloat(name="MyFloat", default_value=1.2) assert param.to_request() == {"Name": "MyFloat", "Type": "Float", "DefaultValue": 1.2} + assert param.expr == {"Get": "Parameters.MyFloat"} + assert param.parameter_type.python_type == float def test_parameter_with_default_value_zero(): param = ParameterInteger(name="MyInteger", default_value=0) assert param.to_request() == {"Name": "MyInteger", "Type": "Integer", "DefaultValue": 0} + assert param.expr == {"Get": "Parameters.MyInteger"} + assert param.parameter_type.python_type == int def test_parameter_string_with_enum_values(): param = ParameterString("MyString", enum_values=["a", "b"]) assert param.to_request() == {"Name": "MyString", "Type": "String", "EnumValues": ["a", "b"]} + assert param.expr == {"Get": "Parameters.MyString"} + assert param.parameter_type.python_type == str + param = ParameterString("MyString", default_value="a", enum_values=["a", "b"]) assert param.to_request() == { "Name": "MyString", @@ -52,6 +59,8 @@ def test_parameter_string_with_enum_values(): "DefaultValue": "a", "EnumValues": ["a", "b"], } + assert param.expr == {"Get": "Parameters.MyString"} + assert param.parameter_type.python_type == str def test_parameter_with_invalid_default(): @@ -59,41 +68,65 @@ def test_parameter_with_invalid_default(): ParameterFloat(name="MyFloat", default_value="abc") -def test_parameter_string_implicit_value(): - param = ParameterString("MyString") - assert param.__str__() == "" - param1 = ParameterString("MyString", "1") - assert param1.__str__() == "1" - param2 = ParameterString("MyString", default_value="2") - assert param2.__str__() == "2" - param3 = ParameterString(name="MyString", default_value="3") - assert param3.__str__() == "3" - param3 = ParameterString(name="MyString", default_value="3", enum_values=["3"]) - assert param3.__str__() == "3" +def test_parameter_to_string_and_string_implicit_value(): + param = ParameterString("MyString", "1") + + assert param.to_string() == param + + with pytest.raises(TypeError) as error: + str(param) + + assert str(error.value) == "Pipeline variables do not support __str__ operation." def test_parameter_integer_implicit_value(): - param = ParameterInteger("MyInteger") - assert param.__int__() == 0 - param1 = ParameterInteger("MyInteger", 1) - assert param1.__int__() == 1 - param2 = ParameterInteger("MyInteger", default_value=2) - assert param2.__int__() == 2 - param3 = ParameterInteger(name="MyInteger", default_value=3) - assert param3.__int__() == 3 + param = ParameterInteger("MyInteger", 1) + + with pytest.raises(TypeError) as error: + int(param) + + assert str(error.value) == "Pipeline variables do not support __int__ operation." def test_parameter_float_implicit_value(): - param = ParameterFloat("MyFloat") - assert param.__float__() == 0.0 - param1 = ParameterFloat("MyFloat", 1.1) - assert param1.__float__() == 1.1 - param2 = ParameterFloat("MyFloat", default_value=2.1) - assert param2.__float__() == 2.1 - param3 = ParameterFloat(name="MyFloat", default_value=3.1) - assert param3.__float__() == 3.1 + param = ParameterFloat("MyFloat", 1.1) + + with pytest.raises(TypeError) as error: + float(param) + + assert str(error.value) == "Pipeline variables do not support __float__ operation." def test_parsable_parameter_string(): param = ParameterString("MyString", default_value="s3://foo/bar/baz.csv") assert urlparse(param).scheme == "s3" + + +def test_string_builtin_funcs_that_return_bool_on_parameter_string(): + param = ParameterString("MyString", default_value="s3://foo/bar/baz.csv") + # The param will only be parsed in runtime (Pipeline backend) so not able to tell in SDK + assert not param.startswith("s3") + assert not param.endswith("s3") + + +def test_add_func(): + param_str = ParameterString(name="MyString", default_value="s3://foo/bar/baz.csv") + param_int = ParameterInteger(name="MyInteger", default_value=3) + param_float = ParameterFloat(name="MyFloat", default_value=1.5) + param_bool = ParameterBoolean(name="MyBool") + + with pytest.raises(TypeError) as error: + param_str + param_int + assert str(error.value) == "Pipeline variables do not support concatenation." + + with pytest.raises(TypeError) as error: + param_int + param_float + assert str(error.value) == "Pipeline variables do not support concatenation." + + with pytest.raises(TypeError) as error: + param_float + param_bool + assert str(error.value) == "Pipeline variables do not support concatenation." + + with pytest.raises(TypeError) as error: + param_bool + param_str + assert str(error.value) == "Pipeline variables do not support concatenation." diff --git a/tests/unit/sagemaker/workflow/test_properties.py b/tests/unit/sagemaker/workflow/test_properties.py index 405de5c0b2..eac6168646 100644 --- a/tests/unit/sagemaker/workflow/test_properties.py +++ b/tests/unit/sagemaker/workflow/test_properties.py @@ -10,9 +10,10 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -# language governing permissions and limitations under the License. from __future__ import absolute_import +import pytest + from sagemaker.workflow.properties import Properties @@ -92,3 +93,47 @@ def test_properties_describe_model_package_output(): assert prop.ValidationSpecification.ValidationRole.expr == { "Get": "Steps.MyStep.ValidationSpecification.ValidationRole" } + + +def test_to_string(): + prop = Properties("Steps.MyStep", "DescribeTrainingJobResponse") + + assert prop.CreationTime.to_string().expr == { + "Std:Join": { + "On": "", + "Values": [{"Get": "Steps.MyStep.CreationTime"}], + }, + } + + +def test_implicit_value(): + prop = Properties("Steps.MyStep", "DescribeTrainingJobResponse") + + with pytest.raises(TypeError) as error: + str(prop.CreationTime) + assert str(error.value) == "Pipeline variables do not support __str__ operation." + + with pytest.raises(TypeError) as error: + int(prop.CreationTime) + assert str(error.value) == "Pipeline variables do not support __int__ operation." + + with pytest.raises(TypeError) as error: + float(prop.CreationTime) + assert str(error.value) == "Pipeline variables do not support __float__ operation." + + +def test_string_builtin_funcs_that_return_bool(): + prop = Properties("Steps.MyStep", "DescribeModelPackageOutput") + # The prop will only be parsed in runtime (Pipeline backend) so not able to tell in SDK + assert not prop.startswith("s3") + assert not prop.endswith("s3") + + +def test_add_func(): + prop_train = Properties("Steps.MyStepTrain", "DescribeTrainingJobResponse") + prop_model = Properties("Steps.MyStepModel", "DescribeModelPackageOutput") + + with pytest.raises(TypeError) as error: + prop_train + prop_model + + assert str(error.value) == "Pipeline variables do not support concatenation." diff --git a/tests/unit/sagemaker/workflow/test_steps.py b/tests/unit/sagemaker/workflow/test_steps.py index fd3bd7d0b9..64da232c9e 100644 --- a/tests/unit/sagemaker/workflow/test_steps.py +++ b/tests/unit/sagemaker/workflow/test_steps.py @@ -320,7 +320,19 @@ def test_training_step_base_estimator(sagemaker_session): cache_config=cache_config, ) step.add_depends_on(["AnotherTestStep"]) - assert step.to_request() == { + pipeline = Pipeline( + name="MyPipeline", + parameters=[ + instance_type_parameter, + instance_count_parameter, + data_source_uri_parameter, + training_epochs_parameter, + training_batch_size_parameter, + ], + steps=[step], + sagemaker_session=sagemaker_session, + ) + assert json.loads(pipeline.definition())["Steps"][0] == { "Name": "MyTrainingStep", "Type": "Training", "Description": "TrainingStep description", @@ -329,8 +341,18 @@ def test_training_step_base_estimator(sagemaker_session): "Arguments": { "AlgorithmSpecification": {"TrainingImage": IMAGE_URI, "TrainingInputMode": "File"}, "HyperParameters": { - "batch-size": training_batch_size_parameter, - "epochs": training_epochs_parameter, + "batch-size": { + "Std:Join": { + "On": "", + "Values": [{"Get": "Parameters.TrainingBatchSize"}], + }, + }, + "epochs": { + "Std:Join": { + "On": "", + "Values": [{"Get": "Parameters.TrainingEpochs"}], + }, + }, }, "InputDataConfig": [ { @@ -339,15 +361,15 @@ def test_training_step_base_estimator(sagemaker_session): "S3DataSource": { "S3DataDistributionType": "FullyReplicated", "S3DataType": "S3Prefix", - "S3Uri": data_source_uri_parameter, + "S3Uri": {"Get": "Parameters.DataSourceS3Uri"}, } }, } ], "OutputDataConfig": {"S3OutputPath": f"s3://{BUCKET}/"}, "ResourceConfig": { - "InstanceCount": instance_count_parameter, - "InstanceType": instance_type_parameter, + "InstanceCount": {"Get": "Parameters.InstanceCount"}, + "InstanceType": {"Get": "Parameters.InstanceType"}, "VolumeSizeInGB": 30, }, "RoleArn": ROLE, @@ -398,10 +420,22 @@ def test_training_step_tensorflow(sagemaker_session): step = TrainingStep( name="MyTrainingStep", estimator=estimator, inputs=inputs, cache_config=cache_config ) - step_request = step.to_request() - step_request["Arguments"]["HyperParameters"].pop("sagemaker_program", None) - step_request["Arguments"].pop("ProfilerRuleConfigurations", None) - assert step_request == { + pipeline = Pipeline( + name="MyPipeline", + parameters=[ + instance_type_parameter, + instance_count_parameter, + data_source_uri_parameter, + training_epochs_parameter, + training_batch_size_parameter, + ], + steps=[step], + sagemaker_session=sagemaker_session, + ) + dsl = json.loads(pipeline.definition())["Steps"][0] + dsl["Arguments"]["HyperParameters"].pop("sagemaker_program", None) + dsl["Arguments"].pop("ProfilerRuleConfigurations", None) + assert dsl == { "Name": "MyTrainingStep", "Type": "Training", "Arguments": { @@ -413,8 +447,8 @@ def test_training_step_tensorflow(sagemaker_session): "OutputDataConfig": {"S3OutputPath": "s3://my-bucket/"}, "StoppingCondition": {"MaxRuntimeInSeconds": 86400}, "ResourceConfig": { - "InstanceCount": instance_count_parameter, - "InstanceType": instance_type_parameter, + "InstanceCount": {"Get": "Parameters.InstanceCount"}, + "InstanceType": {"Get": "Parameters.InstanceType"}, "VolumeSizeInGB": 30, }, "RoleArn": "DummyRole", @@ -423,7 +457,7 @@ def test_training_step_tensorflow(sagemaker_session): "DataSource": { "S3DataSource": { "S3DataType": "S3Prefix", - "S3Uri": data_source_uri_parameter, + "S3Uri": {"Get": "Parameters.DataSourceS3Uri"}, "S3DataDistributionType": "FullyReplicated", } }, @@ -431,13 +465,17 @@ def test_training_step_tensorflow(sagemaker_session): } ], "HyperParameters": { - "batch-size": training_batch_size_parameter, - "epochs": training_epochs_parameter, + "batch-size": { + "Std:Join": {"On": "", "Values": [{"Get": "Parameters.TrainingBatchSize"}]} + }, + "epochs": { + "Std:Join": {"On": "", "Values": [{"Get": "Parameters.TrainingEpochs"}]} + }, "sagemaker_submit_directory": '"s3://mybucket/source"', "sagemaker_container_log_level": "20", "sagemaker_region": '"us-west-2"', "sagemaker_distributed_dataparallel_enabled": "true", - "sagemaker_instance_type": instance_type_parameter, + "sagemaker_instance_type": {"Get": "Parameters.InstanceType"}, "sagemaker_distributed_dataparallel_custom_mpi_options": '""', }, "ProfilerConfig": {"S3OutputPath": "s3://my-bucket/"}, @@ -1153,7 +1191,16 @@ def test_multi_algo_tuning_step(sagemaker_session): }, ) - assert tuning_step.to_request() == { + pipeline = Pipeline( + name="MyPipeline", + parameters=[data_source_uri_parameter, instance_count, initial_lr_param], + steps=[tuning_step], + sagemaker_session=sagemaker_session, + ) + + dsl = json.loads(pipeline.definition()) + + assert dsl["Steps"][0] == { "Name": "MyTuningStep", "Type": "Tuning", "Arguments": { @@ -1179,7 +1226,7 @@ def test_multi_algo_tuning_step(sagemaker_session): "RoleArn": "DummyRole", "OutputDataConfig": {"S3OutputPath": "s3://my-bucket/"}, "ResourceConfig": { - "InstanceCount": 1, + "InstanceCount": {"Get": "Parameters.InstanceCount"}, "InstanceType": "ml.c5.4xlarge", "VolumeSizeInGB": 30, }, @@ -1193,7 +1240,7 @@ def test_multi_algo_tuning_step(sagemaker_session): "DataSource": { "S3DataSource": { "S3DataType": "S3Prefix", - "S3Uri": data_source_uri_parameter, + "S3Uri": {"Get": "Parameters.DataSourceS3Uri"}, "S3DataDistributionType": "FullyReplicated", } }, @@ -1206,7 +1253,7 @@ def test_multi_algo_tuning_step(sagemaker_session): "ContinuousParameterRanges": [ { "Name": "learning_rate", - "MinValue": initial_lr_param, + "MinValue": {"Get": "Parameters.InitialLR"}, "MaxValue": "0.05", "ScalingType": "Auto", }, @@ -1246,7 +1293,7 @@ def test_multi_algo_tuning_step(sagemaker_session): "RoleArn": "DummyRole", "OutputDataConfig": {"S3OutputPath": "s3://my-bucket/"}, "ResourceConfig": { - "InstanceCount": 1, + "InstanceCount": {"Get": "Parameters.InstanceCount"}, "InstanceType": "ml.c5.4xlarge", "VolumeSizeInGB": 30, }, @@ -1260,7 +1307,7 @@ def test_multi_algo_tuning_step(sagemaker_session): "DataSource": { "S3DataSource": { "S3DataType": "S3Prefix", - "S3Uri": data_source_uri_parameter, + "S3Uri": {"Get": "Parameters.DataSourceS3Uri"}, "S3DataDistributionType": "FullyReplicated", } }, @@ -1273,7 +1320,7 @@ def test_multi_algo_tuning_step(sagemaker_session): "ContinuousParameterRanges": [ { "Name": "learning_rate", - "MinValue": initial_lr_param, + "MinValue": {"Get": "Parameters.InitialLR"}, "MaxValue": "0.05", "ScalingType": "Auto", }, diff --git a/tests/unit/test_tuner.py b/tests/unit/test_tuner.py index d6f1f5a648..acfbe7b2db 100644 --- a/tests/unit/test_tuner.py +++ b/tests/unit/test_tuner.py @@ -86,7 +86,19 @@ def test_prepare_for_training(tuner): assert tuner._current_job_name.startswith(IMAGE_NAME) assert len(tuner.static_hyperparameters) == 3 assert tuner.static_hyperparameters["another_one"] == "0" - assert tuner.static_hyperparameters["hp1"] == hp1 + assert tuner.static_hyperparameters["hp1"].expr == { + "Std:Join": { + "On": "", + "Values": [ + { + "Std:JsonGet": { + "PropertyFile": {"Get": "Steps.stepname.PropertyFiles.pf"}, + "Path": "jp", + }, + }, + ], + } + } assert tuner.static_hyperparameters["hp2"] == hp2 @@ -1177,8 +1189,27 @@ def test_integer_parameter_ranges_with_pipeline_parameter(): assert len(ranges.keys()) == 4 assert ranges["Name"] == "some" - assert ranges["MinValue"] == min - assert ranges["MaxValue"] == max + assert ranges["MinValue"].expr == { + "Std:Join": { + "On": "", + "Values": [ + {"Get": "Parameters.p"}, + ], + } + } + assert ranges["MaxValue"].expr == { + "Std:Join": { + "On": "", + "Values": [ + { + "Std:JsonGet": { + "PropertyFile": {"Get": "Steps.sn.PropertyFiles.pf"}, + "Path": "jp", + } + } + ], + } + } assert ranges["ScalingType"] == scale