diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 4cf48a498b..a022aeec61 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -4206,7 +4206,7 @@ def get_model_package_args( description (str): Model Package description (default: None). tags (List[dict[str, str]]): A list of dictionaries containing key-value pairs (default: None). - container_def_list (list): A list of container defintiions (default: None). + container_def_list (list): A list of container definitions (default: None). drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None). customer_metadata_properties (dict[str, str]): A dictionary of key-value paired metadata properties (default: None). diff --git a/src/sagemaker/workflow/_repack_model.py b/src/sagemaker/workflow/_repack_model.py index 3cfa6760b3..af3fbb7d62 100644 --- a/src/sagemaker/workflow/_repack_model.py +++ b/src/sagemaker/workflow/_repack_model.py @@ -32,9 +32,15 @@ # we'll go ahead and use the copy_tree function anyways because this # repacking is some short-lived hackery, right?? from distutils.dir_util import copy_tree +from typing import Optional -def repack(inference_script, model_archive, dependencies=None, source_dir=None): # pragma: no cover +def repack( + inference_script: str, + model_archive: str, + dependencies: Optional[str] = None, + source_dir: Optional[str] = None, +): # pragma: no cover """Repack custom dependencies and code into an existing model TAR archive Args: diff --git a/src/sagemaker/workflow/_utils.py b/src/sagemaker/workflow/_utils.py index c8dbfc610d..418832a61f 100644 --- a/src/sagemaker/workflow/_utils.py +++ b/src/sagemaker/workflow/_utils.py @@ -17,7 +17,8 @@ import shutil import tarfile import tempfile -from typing import List, Union +from typing import List, Union, Dict, Optional +from sagemaker.session import Session from sagemaker import image_uris from sagemaker.inputs import TrainingInput from sagemaker.estimator import EstimatorBase @@ -33,6 +34,9 @@ ) from sagemaker.utils import _save_model, download_file_from_url from sagemaker.workflow.retry import RetryPolicy +from sagemaker.model_metrics import ModelMetrics +from sagemaker.metadata_properties import MetadataProperties +from sagemaker.drift_check_baselines import DriftCheckBaselines FRAMEWORK_VERSION = "0.23-1" INSTANCE_TYPE = "ml.m5.large" @@ -49,18 +53,18 @@ class _RepackModelStep(TrainingStep): def __init__( self, name: str, - sagemaker_session, - role, + sagemaker_session: Session, + role: str, model_data: str, entry_point: str, - display_name: str = None, - description: str = None, - source_dir: str = None, - dependencies: List = None, - depends_on: Union[List[str], List[Step]] = None, - retry_policies: List[RetryPolicy] = None, - subnets=None, - security_group_ids=None, + display_name: Optional[str] = None, + description: Optional[str] = None, + source_dir: Optional[str] = None, + dependencies: Optional[List] = None, + depends_on: Optional[Union[List[str], List[Step]]] = None, + retry_policies: Optional[List[RetryPolicy]] = None, + subnets: Optional[List[str]] = None, + security_group_ids: Optional[List[str]] = None, **kwargs, ): """Base class initializer. @@ -237,7 +241,7 @@ def _inject_repack_script(self): def arguments(self) -> RequestType: """The arguments dict that are used to call `create_training_job`. - This first prepares the source bundle for repackinglby placing artifacts + This first prepares the source bundle for repacking by placing artifacts in locations which the training container will make available to the repacking script and then gets the arguments for the training job. """ @@ -278,26 +282,26 @@ class _RegisterModelStep(ConfigurableRetryStep): def __init__( self, name: str, - content_types, - response_types, - inference_instances, - transform_instances, - estimator: EstimatorBase = None, - model_data=None, - model_package_group_name=None, - model_metrics=None, - metadata_properties=None, - approval_status="PendingManualApproval", - image_uri=None, - compile_model_family=None, - display_name: str = None, - description=None, - depends_on: Union[List[str], List[Step]] = None, - retry_policies: List[RetryPolicy] = None, - tags=None, - container_def_list=None, - drift_check_baselines=None, - customer_metadata_properties=None, + content_types: List, + response_types: List, + inference_instances: List, + transform_instances: List, + estimator: Optional[EstimatorBase] = None, + model_data: Optional[str] = None, + model_package_group_name: Optional[str] = None, + model_metrics: Optional[ModelMetrics] = None, + metadata_properties: Optional[MetadataProperties] = None, + approval_status: str = "PendingManualApproval", + image_uri: Optional[str] = None, + compile_model_family: Optional[str] = None, + display_name: Optional[str] = None, + description: Optional[str] = None, + depends_on: Optional[Union[List[str], List[Step]]] = None, + retry_policies: Optional[List[RetryPolicy]] = None, + tags: Optional[List[Dict[str, str]]] = None, + container_def_list: Optional[List] = None, + drift_check_baselines: Optional[DriftCheckBaselines] = None, + customer_metadata_properties: Optional[Dict[str, str]] = None, **kwargs, ): """Constructor of a register model step. @@ -334,6 +338,9 @@ def __init__( depends_on (List[str] or List[Step]): A list of step names or instances this step depends on retry_policies (List[RetryPolicy]): The list of retry policies for the current step + tags (List[dict[str, str]]): A list of dictionaries containing key-value pairs + (default: None). + container_def_list (list): A list of container definitions (default: None). drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None). customer_metadata_properties (dict[str, str]): A dictionary of key-value paired metadata properties (default: None). diff --git a/src/sagemaker/workflow/callback_step.py b/src/sagemaker/workflow/callback_step.py index f88b56c9f5..d29a2a4ec2 100644 --- a/src/sagemaker/workflow/callback_step.py +++ b/src/sagemaker/workflow/callback_step.py @@ -10,10 +10,10 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -"""The step definitions for workflow.""" +"""The CallbackStep definitions for workflow.""" from __future__ import absolute_import -from typing import List, Dict, Union +from typing import List, Dict, Union, Optional from enum import Enum import attr @@ -81,12 +81,12 @@ def __init__( self, name: str, sqs_queue_url: str, - inputs: dict, + inputs: Dict, outputs: List[CallbackOutput], - display_name: str = None, - description: str = None, - cache_config: CacheConfig = None, - depends_on: Union[List[str], List[Step]] = None, + display_name: Optional[str] = None, + description: Optional[str] = None, + cache_config: Optional[CacheConfig] = None, + depends_on: Optional[Union[List[str], List[Step]]] = None, ): """Constructs a CallbackStep. @@ -99,7 +99,7 @@ def __init__( display_name (str): The display name of the callback step. description (str): The description of the callback step. cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance. - depends_on (List[str] or List[Step]): A list of step names or step instances + depends_on (Union[List[str], List[Step]]): A list of step names or step instances this `sagemaker.workflow.steps.CallbackStep` depends on """ super(CallbackStep, self).__init__( diff --git a/src/sagemaker/workflow/check_job_config.py b/src/sagemaker/workflow/check_job_config.py index eaba149823..3b25ff6727 100644 --- a/src/sagemaker/workflow/check_job_config.py +++ b/src/sagemaker/workflow/check_job_config.py @@ -14,9 +14,10 @@ from __future__ import absolute_import import logging -from typing import Optional +from typing import Optional, Dict, List from sagemaker import Session +from sagemaker.network import NetworkConfig from sagemaker.model_monitor import ( ModelMonitor, DefaultModelMonitor, @@ -31,18 +32,18 @@ class CheckJobConfig: def __init__( self, - role, - instance_count=1, - instance_type="ml.m5.xlarge", - volume_size_in_gb=30, - volume_kms_key=None, - output_kms_key=None, - max_runtime_in_seconds=None, - base_job_name=None, - sagemaker_session=None, - env=None, - tags=None, - network_config=None, + role: str, + instance_count: int = 1, + instance_type: str = "ml.m5.xlarge", + volume_size_in_gb: int = 30, + volume_kms_key: str = None, + output_kms_key: str = None, + max_runtime_in_seconds: int = None, + base_job_name: str = None, + sagemaker_session: Session = None, + env: Dict[str, str] = None, + tags: List[Dict[str, str]] = None, + network_config: NetworkConfig = None, ): """Constructs a CheckJobConfig instance. @@ -65,8 +66,9 @@ def __init__( manages interactions with Amazon SageMaker APIs and any other AWS services needed (default: None). If not specified, one is created using the default AWS configuration chain. - env (dict): Environment variables to be passed to the job (default: None). - tags ([dict]): List of tags to be passed to the job (default: None). + env (Dict): Environment variables to be passed to the job (default: None). + tags (List[Dict[str, str]]): A list of dictionaries containing key-value pairs + (default: None). network_config (sagemaker.network.NetworkConfig): A NetworkConfig object that configures network isolation, encryption of inter-container traffic, security group IDs, and subnets (default: None). diff --git a/src/sagemaker/workflow/clarify_check_step.py b/src/sagemaker/workflow/clarify_check_step.py index 5921e5099a..9f9b21ef81 100644 --- a/src/sagemaker/workflow/clarify_check_step.py +++ b/src/sagemaker/workflow/clarify_check_step.py @@ -10,7 +10,7 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -"""The step definitions for workflow.""" +"""The ClarifyCheckStep definitions for workflow.""" from __future__ import absolute_import import copy @@ -180,7 +180,7 @@ def __init__( description (str): The description of the ClarifyCheckStep step (default: None). cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance (default: None). - depends_on (List[str] or List[Step]): A list of step names or step instances + depends_on (Union[List[str], List[Step]]): A list of step names or step instances this `sagemaker.workflow.steps.ClarifyCheckStep` depends on (default: None). """ if ( diff --git a/src/sagemaker/workflow/condition_step.py b/src/sagemaker/workflow/condition_step.py index bb40ca05f1..d99a614b18 100644 --- a/src/sagemaker/workflow/condition_step.py +++ b/src/sagemaker/workflow/condition_step.py @@ -10,10 +10,10 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -"""The step definitions for workflow.""" +"""The ConditionStep definitions for workflow.""" from __future__ import absolute_import -from typing import List, Union +from typing import List, Union, Optional import attr @@ -41,12 +41,12 @@ class ConditionStep(Step): def __init__( self, name: str, - depends_on: Union[List[str], List[Step]] = None, - display_name: str = None, - description: str = None, - conditions: List[Condition] = None, - if_steps: List[Union[Step, StepCollection]] = None, - else_steps: List[Union[Step, StepCollection]] = None, + display_name: Optional[str] = None, + description: Optional[str] = None, + conditions: Optional[List[Condition]] = None, + depends_on: Optional[Union[List[str], List[Step]]] = None, + if_steps: Optional[List[Union[Step, StepCollection]]] = None, + else_steps: Optional[List[Union[Step, StepCollection]]] = None, ): """Construct a ConditionStep for pipelines to support conditional branching. @@ -60,6 +60,8 @@ def __init__( description (str): The description of the condition step. conditions (List[Condition]): A list of `sagemaker.workflow.conditions.Condition` instances. + depends_on (List[str] or List[Step]): A list of step names or step instances + this `sagemaker.workflow.steps.ConditionStep` depends on (default: None). if_steps (List[Union[Step, StepCollection]]): A list of `sagemaker.workflow.steps.Step` or `sagemaker.workflow.step_collections.StepCollection` instances that are marked as ready for execution if the list of conditions evaluates to True. diff --git a/src/sagemaker/workflow/emr_step.py b/src/sagemaker/workflow/emr_step.py index 8b244c78f2..8a5f2c920e 100644 --- a/src/sagemaker/workflow/emr_step.py +++ b/src/sagemaker/workflow/emr_step.py @@ -10,10 +10,10 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -"""The step definitions for workflow.""" +"""The EMRStep definitions for workflow.""" from __future__ import absolute_import -from typing import List +from typing import List, Union from sagemaker.workflow.entities import ( RequestType, @@ -28,17 +28,21 @@ class EMRStepConfig: """Config for a Hadoop Jar step.""" def __init__( - self, jar, args: List[str] = None, main_class: str = None, properties: List[dict] = None + self, + jar: str, + args: List[str] = None, + main_class: str = None, + properties: List[dict] = None, ): """Create a definition for input data used by an EMR cluster(job flow) step. See AWS documentation on the ``StepConfig`` API for more details on the parameters. Args: + jar(str): A path to a JAR file run during the step. args(List[str]): A list of command line arguments passed to the JAR file's main function when executed. - jar(str): A path to a JAR file run during the step. main_class(str): The name of the main class in the specified Java file. properties(List(dict)): A list of key-value pairs that are set when the step runs. """ @@ -70,7 +74,7 @@ def __init__( description: str, cluster_id: str, step_config: EMRStepConfig, - depends_on: List[str] = None, + depends_on: Union[List[str], List[Step]] = None, cache_config: CacheConfig = None, ): """Constructs a EMRStep. @@ -81,7 +85,7 @@ def __init__( description(str): The description of the EMR step. cluster_id(str): The ID of the running EMR cluster. step_config(EMRStepConfig): One StepConfig to be executed by the job flow. - depends_on(List[str]): + depends_on(Union[List[str], List[Step]]): A list of step names this `sagemaker.workflow.steps.EMRStep` depends on cache_config(CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance. diff --git a/src/sagemaker/workflow/fail_step.py b/src/sagemaker/workflow/fail_step.py index cc908a2a2a..06d60a3eab 100644 --- a/src/sagemaker/workflow/fail_step.py +++ b/src/sagemaker/workflow/fail_step.py @@ -10,7 +10,7 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -"""The `Step` definitions for SageMaker Pipelines Workflows.""" +"""The FailStep definitions for workflow.""" from __future__ import absolute_import from typing import List, Union @@ -45,9 +45,8 @@ def __init__( display_name (str): The display name of the `FailStep`. The display name provides better UI readability. (default: None). description (str): The description of the `FailStep` (default: None). - depends_on (List[str] or List[Step]): A list of `Step` names or `Step` instances - that this `FailStep` depends on. - If a listed `Step` name does not exist, an error is returned (default: None). + depends_on (Union[List[str], List[Step]]): A list of step names or step instances + this `sagemaker.workflow.steps.FailStep` depends on """ super(FailStep, self).__init__( name, display_name, description, StepTypeEnum.FAIL, depends_on diff --git a/src/sagemaker/workflow/functions.py b/src/sagemaker/workflow/functions.py index 36bd69fbff..b5d776528a 100644 --- a/src/sagemaker/workflow/functions.py +++ b/src/sagemaker/workflow/functions.py @@ -10,7 +10,7 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -"""The step definitions for workflow.""" +"""The functions for workflow.""" from __future__ import absolute_import from typing import List, Union @@ -36,13 +36,13 @@ class Join(PipelineVariable): content_type="text/csv") Attributes: - values (List[Union[PrimitiveType, Parameter, Expression]]): - The primitive type values, parameters, step properties, expressions to join. on (str): The string to join the values on (Defaults to ""). + values (List[PipelineVariable]): + The PipelineVariable(s) to join. """ on: str = attr.ib(factory=str) - values: List = attr.ib(factory=list) + values: List[PipelineVariable] = attr.ib(factory=list) def to_string(self) -> PipelineVariable: """Prompt the pipeline to convert the pipeline variable to String in runtime diff --git a/src/sagemaker/workflow/lambda_step.py b/src/sagemaker/workflow/lambda_step.py index 96f8de3a3b..257fb89f85 100644 --- a/src/sagemaker/workflow/lambda_step.py +++ b/src/sagemaker/workflow/lambda_step.py @@ -10,10 +10,10 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -"""The step definitions for workflow.""" +"""The LambdaStep definitions for workflow.""" from __future__ import absolute_import -from typing import List, Dict +from typing import List, Dict, Union from enum import Enum import attr @@ -84,25 +84,25 @@ def __init__( lambda_func: Lambda, display_name: str = None, description: str = None, - inputs: dict = None, + inputs: Dict = None, outputs: List[LambdaOutput] = None, cache_config: CacheConfig = None, - depends_on: List[str] = None, + depends_on: Union[List[str], List[Step]] = None, ): """Constructs a LambdaStep. Args: name (str): The name of the lambda step. - display_name (str): The display name of the Lambda step. - description (str): The description of the Lambda step. - lambda_func (str): An instance of sagemaker.lambda_helper.Lambda. + lambda_func (Lambda): An instance of sagemaker.lambda_helper.Lambda. If lambda arn is specified in the instance, LambdaStep just invokes the function, else lambda function will be created while creating the pipeline. - inputs (dict): Input arguments that will be provided + display_name (str): The display name of the Lambda step. + description (str): The description of the Lambda step. + inputs (Dict): Input arguments that will be provided to the lambda function. outputs (List[LambdaOutput]): List of outputs from the lambda function. cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance. - depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.LambdaStep` + depends_on (Union[List[str], List[Step]]): A list of step names this `sagemaker.workflow.steps.LambdaStep` depends on """ super(LambdaStep, self).__init__( diff --git a/src/sagemaker/workflow/pipeline.py b/src/sagemaker/workflow/pipeline.py index e1ad50e6cf..04c0ffc5cb 100644 --- a/src/sagemaker/workflow/pipeline.py +++ b/src/sagemaker/workflow/pipeline.py @@ -177,8 +177,8 @@ def describe(self) -> Dict[str, Any]: def update( self, role_arn: str, - description: str = None, - parallelism_config: ParallelismConfiguration = None, + description: Optional[str] = None, + parallelism_config: Optional[ParallelismConfiguration] = None, ) -> Dict[str, Any]: """Updates a Pipeline in the Workflow service. @@ -393,7 +393,7 @@ def _map_callback_outputs(steps: List[Step]): """Iterate over the provided steps, building a map of callback output parameters to step names. Args: - step (List[Step]): The steps list. + steps (List[Step]): The steps list. """ callback_output_map = {} @@ -410,7 +410,7 @@ def _map_lambda_outputs(steps: List[Step]): """Iterate over the provided steps, building a map of lambda output parameters to step names. Args: - step (List[Step]): The steps list. + steps (List[Step]): The steps list. """ lambda_output_map = {} @@ -429,7 +429,7 @@ def update_args(args: Dict[str, Any], **kwargs): This handles the case when the service API doesn't like NoneTypes for argument values. Args: - request_args (Dict[str, Any]): The request arguments dict. + args (Dict[str, Any]): The request arguments dict. kwargs: key, value pairs to update the args dict with. """ for key, value in kwargs.items(): diff --git a/src/sagemaker/workflow/properties.py b/src/sagemaker/workflow/properties.py index 480fddada1..a90891c195 100644 --- a/src/sagemaker/workflow/properties.py +++ b/src/sagemaker/workflow/properties.py @@ -14,7 +14,7 @@ from __future__ import absolute_import from abc import ABCMeta -from typing import Dict, Union, List +from typing import Dict, Union, List, Optional import attr @@ -51,9 +51,9 @@ class Properties(PipelineVariable, metaclass=PropertiesMeta): def __init__( self, path: str, - shape_name: str = None, - shape_names: List[str] = None, - service_name: str = "sagemaker", + shape_name: Optional[str] = None, + shape_names: Optional[List[str]] = None, + service_name: Optional[str] = "sagemaker", ): """Create a Properties instance representing the given shape. @@ -61,6 +61,7 @@ def __init__( path (str): The parent path of the Properties instance. shape_name (str): The botocore service model shape name. shape_names (str): A List of the botocore service model shape name. + service_name (str): The botocore service name. """ self._path = path shape_names = [] if shape_names is None else shape_names @@ -98,7 +99,9 @@ def expr(self): class PropertiesList(Properties): """PropertiesList for use in workflow expressions.""" - def __init__(self, path: str, shape_name: str = None, service_name: str = "sagemaker"): + def __init__( + self, path: str, shape_name: Optional[str] = None, service_name: Optional[str] = "sagemaker" + ): """Create a PropertiesList instance representing the given shape. Args: @@ -132,7 +135,9 @@ def __getitem__(self, item: Union[int, str]): class PropertiesMap(Properties): """PropertiesMap for use in workflow expressions.""" - def __init__(self, path: str, shape_name: str = None, service_name: str = "sagemaker"): + def __init__( + self, path: str, shape_name: Optional[str] = None, service_name: Optional[str] = "sagemaker" + ): """Create a PropertiesMap instance representing the given shape. Args: @@ -168,9 +173,9 @@ class PropertyFile(Expression): """Provides a property file struct. Attributes: - name: The name of the property file for reference with `JsonGet` functions. - output_name: The name of the processing job output channel. - path: The path to the file at the output channel location. + name (str): The name of the property file for reference with `JsonGet` functions. + output_name (str): The name of the processing job output channel. + path (str): The path to the file at the output channel location. """ name: str = attr.ib() diff --git a/src/sagemaker/workflow/quality_check_step.py b/src/sagemaker/workflow/quality_check_step.py index 76b9f5f022..b47d0039d0 100644 --- a/src/sagemaker/workflow/quality_check_step.py +++ b/src/sagemaker/workflow/quality_check_step.py @@ -10,11 +10,11 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -"""The step definitions for workflow.""" +"""The QualityCheckStep definitions for workflow.""" from __future__ import absolute_import from abc import ABC -from typing import List, Union +from typing import List, Union, Optional import os import pathlib import attr @@ -119,13 +119,13 @@ def __init__( check_job_config: CheckJobConfig, skip_check: Union[bool, PipelineNonPrimitiveInputTypes] = False, register_new_baseline: Union[bool, PipelineNonPrimitiveInputTypes] = False, - model_package_group_name: Union[str, PipelineNonPrimitiveInputTypes] = None, - supplied_baseline_statistics: Union[str, PipelineNonPrimitiveInputTypes] = None, - supplied_baseline_constraints: Union[str, PipelineNonPrimitiveInputTypes] = None, - display_name: str = None, - description: str = None, - cache_config: CacheConfig = None, - depends_on: Union[List[str], List[Step]] = None, + model_package_group_name: Optional[Union[str, PipelineNonPrimitiveInputTypes]] = None, + supplied_baseline_statistics: Optional[Union[str, PipelineNonPrimitiveInputTypes]] = None, + supplied_baseline_constraints: Optional[Union[str, PipelineNonPrimitiveInputTypes]] = None, + display_name: Optional[str] = None, + description: Optional[str] = None, + cache_config: Optional[CacheConfig] = None, + depends_on: Optional[Union[List[str], List[Step]]] = None, ): """Constructs a QualityCheckStep. diff --git a/src/sagemaker/workflow/retry.py b/src/sagemaker/workflow/retry.py index 177e13e3d4..5a7cb70de3 100644 --- a/src/sagemaker/workflow/retry.py +++ b/src/sagemaker/workflow/retry.py @@ -10,11 +10,11 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -"""Pipeline parameters and conditions for workflow.""" +"""Pipeline step configurable retry strategy.""" from __future__ import absolute_import from enum import Enum -from typing import List +from typing import List, Optional import attr from sagemaker.workflow.entities import Entity, DefaultEnumMeta, RequestType @@ -133,8 +133,8 @@ def __init__( exception_types: List[StepExceptionTypeEnum], backoff_rate: float = 2.0, interval_seconds: int = 1, - max_attempts: int = None, - expire_after_mins: int = None, + max_attempts: Optional[int] = None, + expire_after_mins: Optional[int] = None, ): super().__init__(backoff_rate, interval_seconds, max_attempts, expire_after_mins) for exception_type in exception_types: @@ -155,7 +155,7 @@ class SageMakerJobStepRetryPolicy(RetryPolicy): Attributes: exception_types (List[SageMakerJobExceptionTypeEnum]): The SageMaker exception to match for this policy. The SageMaker exceptions - captured here are the exceptions thrown by synchronously + captured here are the exceptions thrown synchronously by creating the job. For instance the resource limit exception. failure_reason_types (List[SageMakerJobExceptionTypeEnum]): the SageMaker failure reason types to match for this policy. The failure reason type @@ -173,12 +173,12 @@ class SageMakerJobStepRetryPolicy(RetryPolicy): def __init__( self, - exception_types: List[SageMakerJobExceptionTypeEnum] = None, - failure_reason_types: List[SageMakerJobExceptionTypeEnum] = None, + exception_types: Optional[List[SageMakerJobExceptionTypeEnum]] = None, + failure_reason_types: Optional[List[SageMakerJobExceptionTypeEnum]] = None, backoff_rate: float = 2.0, interval_seconds: int = 1, - max_attempts: int = None, - expire_after_mins: int = None, + max_attempts: Optional[int] = None, + expire_after_mins: Optional[int] = None, ): super().__init__(backoff_rate, interval_seconds, max_attempts, expire_after_mins) diff --git a/src/sagemaker/workflow/step_collections.py b/src/sagemaker/workflow/step_collections.py index 1280637006..7f5f8ee2c8 100644 --- a/src/sagemaker/workflow/step_collections.py +++ b/src/sagemaker/workflow/step_collections.py @@ -10,13 +10,14 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -"""The step definitions for workflow.""" +"""The StepCollections definitions for workflow.""" from __future__ import absolute_import -from typing import List, Union +from typing import List, Union, Dict, Callable, Optional import attr +from sagemaker.session import Session from sagemaker.estimator import EstimatorBase from sagemaker.model import Model from sagemaker import PipelineModel @@ -33,6 +34,9 @@ _RepackModelStep, ) from sagemaker.workflow.retry import RetryPolicy +from sagemaker.model_metrics import ModelMetrics +from sagemaker.drift_check_baselines import DriftCheckBaselines +from sagemaker.inputs import CreateModelInput, TransformInput @attr.s @@ -56,26 +60,26 @@ class RegisterModel(StepCollection): def __init__( self, name: str, - content_types, - response_types, - inference_instances, - transform_instances, - estimator: EstimatorBase = None, - model_data=None, - depends_on: Union[List[str], List[Step]] = None, - repack_model_step_retry_policies: List[RetryPolicy] = None, - register_model_step_retry_policies: List[RetryPolicy] = None, - model_package_group_name=None, - model_metrics=None, - approval_status=None, - image_uri=None, - compile_model_family=None, - display_name=None, - description=None, - tags=None, - model: Union[Model, PipelineModel] = None, - drift_check_baselines=None, - customer_metadata_properties=None, + content_types: List[str], + response_types: List[str], + inference_instances: List[str], + transform_instances: List[str], + estimator: Optional[EstimatorBase] = None, + model_data: Optional[str] = None, + depends_on: Optional[Union[List[str], List[Step]]] = None, + repack_model_step_retry_policies: Optional[List[RetryPolicy]] = None, + register_model_step_retry_policies: Optional[List[RetryPolicy]] = None, + model_package_group_name: Optional[str] = None, + model_metrics: Optional[ModelMetrics] = None, + approval_status: Optional[str] = None, + image_uri: Optional[str] = None, + compile_model_family: Optional[str] = None, + display_name: Optional[str] = None, + description: Optional[str] = None, + tags: Optional[List[Dict[str, str]]] = None, + model: Optional[Union[Model, PipelineModel]] = None, + drift_check_baselines: Optional[DriftCheckBaselines] = None, + customer_metadata_properties: Optional[Dict[str, str]] = None, **kwargs, ): """Construct steps `_RepackModelStep` and `_RegisterModelStep` based on the estimator. @@ -90,7 +94,7 @@ def __init__( generate inferences in real-time (default: None). transform_instances (list): A list of the instance types on which a transformation job can be run or on which an endpoint can be deployed (default: None). - depends_on (List[str] or List[Step]): The list of step names or step instances + depends_on (Union[List[str], List[Step]]): The list of step names or step instances the first step in the collection depends on repack_model_step_retry_policies (List[RetryPolicy]): The list of retry policies for the repack model step @@ -250,32 +254,32 @@ def __init__( self, name: str, estimator: EstimatorBase, - model_data, - model_inputs, - instance_count, - instance_type, - transform_inputs, - description: str = None, - display_name: str = None, + model_data: str, + model_inputs: CreateModelInput, + instance_count: int, + instance_type: str, + transform_inputs: TransformInput, + description: Optional[str] = None, + display_name: Optional[str] = None, # model arguments - image_uri=None, - predictor_cls=None, - env=None, + image_uri: Optional[str] = None, + predictor_cls: Optional[Callable[[str, Session], Predictor]] = None, + env: Optional[Dict[str, str]] = None, # transformer arguments - strategy=None, - assemble_with=None, - output_path=None, - output_kms_key=None, - accept=None, - max_concurrent_transforms=None, - max_payload=None, - tags=None, - volume_kms_key=None, - depends_on: Union[List[str], List[Step]] = None, + strategy: Optional[str] = None, + assemble_with: Optional[str] = None, + output_path: Optional[str] = None, + output_kms_key: Optional[str] = None, + accept: Optional[str] = None, + max_concurrent_transforms: Optional[int] = None, + max_payload: Optional[int] = None, + tags: Optional[List[Dict[str, str]]] = None, + volume_kms_key: Optional[str] = None, + depends_on: Optional[Union[List[str], List[Step]]] = None, # step retry policies - repack_model_step_retry_policies: List[RetryPolicy] = None, - model_step_retry_policies: List[RetryPolicy] = None, - transform_step_retry_policies: List[RetryPolicy] = None, + repack_model_step_retry_policies: Optional[List[RetryPolicy]] = None, + model_step_retry_policies: Optional[List[RetryPolicy]] = None, + transform_step_retry_policies: Optional[List[RetryPolicy]] = None, **kwargs, ): """Construct steps required for a Transformer step collection: @@ -295,9 +299,22 @@ def __init__( Args: name (str): The name of the Transform Step. - estimator: The estimator instance. + estimator (EstimatorBase): The estimator instance. + model_data (str): The S3 location of a SageMaker model data + ``.tar.gz`` file (default: None). + model_inputs (CreateModelInput): The create model input. instance_count (int): The number of EC2 instances to use. instance_type (str): The type of EC2 instance to use. + transform_inputs (TransformInput): The transform inputs for the transform step. + description (str): The description of the steps in the collection. (default: None). + display_name (str): The display name of the steps in the collection. (default: None). + image_uri (str): A Docker image URI for model. + predictor_cls (Callable[[str, Session], Predictor]): A + function to call to create a predictor (default: None). If not + None, ``deploy`` will return the result of invoking this + function on the created endpoint name. + env (dict[str, str]): Environment variables to run with ``image_uri`` + when hosted in SageMaker (default: None). strategy (str): The strategy used to decide how to batch records in a single request (default: None). Valid values: 'MultiRecord' and 'SingleRecord'. @@ -310,8 +327,16 @@ def __init__( accept (str): The accept header passed by the client to the inference endpoint. If it is supported by the endpoint, it will be the format of the batch transform output. - env (dict): The Environment variables to be set for use during the - transform job (default: None). + max_concurrent_transforms (int): The maximum number of HTTP requests + to be made to each individual transform container at one time. + max_payload (int): Maximum size of the payload in a single HTTP + request to the container in MB. + tags (List[dict[str, str]]): The list of tags to attach to the model package group. Note + that tags will only be applied to newly created model package groups; if the + name of an existing group is passed to "model_package_group_name", + tags will not be applied. + volume_kms_key (str): KMS key ID for encrypting the volume + attached to the ML compute instance (default: None). depends_on (List[str] or List[Step]): The list of step names or step instances the first step in the collection depends on repack_model_step_retry_policies (List[RetryPolicy]): The list of retry policies diff --git a/src/sagemaker/workflow/steps.py b/src/sagemaker/workflow/steps.py index 1ef63ef915..413ad753d7 100644 --- a/src/sagemaker/workflow/steps.py +++ b/src/sagemaker/workflow/steps.py @@ -17,7 +17,7 @@ import warnings from enum import Enum -from typing import Dict, List, Union +from typing import Dict, List, Union, Optional from urllib.parse import urlparse import attr @@ -186,11 +186,22 @@ def __init__( self, name: str, step_type: StepTypeEnum, - display_name: str = None, - description: str = None, - depends_on: Union[List[str], List[Step]] = None, - retry_policies: List[RetryPolicy] = None, + display_name: Optional[str] = None, + description: Optional[str] = None, + depends_on: Optional[Union[List[str], List[Step]]] = None, + retry_policies: Optional[List[RetryPolicy]] = None, ): + """Constructor of a configurable retry step. + + Args: + name (str): The name of the step. + step_type (StepTypeEnum): The type of the step. + display_name (str): The display name of the step. + description (str): The description of the step. + depends_on (Union[List[str], List[Step]]): A list of `Step` names or `Step` instances + this step depends on. + retry_policies (List[RetryPolicy]): A list of retry policies. + """ super().__init__( name=name, display_name=display_name, @@ -229,12 +240,12 @@ def __init__( self, name: str, estimator: EstimatorBase, - display_name: str = None, - description: str = None, - inputs: Union[TrainingInput, dict, str, FileSystemInput] = None, - cache_config: CacheConfig = None, - depends_on: Union[List[str], List[Step]] = None, - retry_policies: List[RetryPolicy] = None, + display_name: Optional[str] = None, + description: Optional[str] = None, + inputs: Optional[Union[TrainingInput, dict, str, FileSystemInput]] = None, + cache_config: Optional[CacheConfig] = None, + depends_on: Optional[Union[List[str], List[Step]]] = None, + retry_policies: Optional[List[RetryPolicy]] = None, ): """Construct a `TrainingStep`, given an `EstimatorBase` instance. @@ -328,10 +339,10 @@ def __init__( name: str, model: Union[Model, PipelineModel], inputs: CreateModelInput, - depends_on: Union[List[str], List[Step]] = None, - retry_policies: List[RetryPolicy] = None, - display_name: str = None, - description: str = None, + depends_on: Optional[Union[List[str], List[Step]]] = None, + retry_policies: Optional[List[RetryPolicy]] = None, + display_name: Optional[str] = None, + description: Optional[str] = None, ): """Construct a `CreateModelStep`, given an `sagemaker.model.Model` instance. @@ -403,11 +414,11 @@ def __init__( name: str, transformer: Transformer, inputs: TransformInput, - display_name: str = None, - description: str = None, - cache_config: CacheConfig = None, - depends_on: Union[List[str], List[Step]] = None, - retry_policies: List[RetryPolicy] = None, + display_name: Optional[str] = None, + description: Optional[str] = None, + cache_config: Optional[CacheConfig] = None, + depends_on: Optional[Union[List[str], List[Step]]] = None, + retry_policies: Optional[List[RetryPolicy]] = None, ): """Constructs a `TransformStep`, given a `Transformer` instance. diff --git a/src/sagemaker/workflow/utilities.py b/src/sagemaker/workflow/utilities.py index 16a832a14f..9a965f0714 100644 --- a/src/sagemaker/workflow/utilities.py +++ b/src/sagemaker/workflow/utilities.py @@ -28,7 +28,7 @@ def list_to_request(entities: Sequence[Union[Entity, StepCollection]]) -> List[R """Get the request structure for list of entities. Args: - entities (Sequence[Entity]): A list of entities. + entities (Sequence[Entity, StepCollection]): A list of entities or StepCollection Returns: list: A request structure for a workflow service call. """