From f1d16b6663374a951a8a2fef5ac1dc754205993e Mon Sep 17 00:00:00 2001 From: Dewen Qi Date: Wed, 22 Jun 2022 11:03:17 -0700 Subject: [PATCH 1/6] change: Add Pipeline annotation in model base class and tensorflow estimator Model annotate update change: Add PipelineVariable annotation to composite argument of training --- src/sagemaker/amazon/amazon_estimator.py | 9 +- src/sagemaker/debugger/debugger.py | 45 ++++--- src/sagemaker/debugger/profiler_config.py | 9 +- src/sagemaker/drift_check_baselines.py | 22 +-- src/sagemaker/estimator.py | 10 +- .../huggingface/training_compiler/config.py | 6 +- src/sagemaker/inputs.py | 25 ++-- src/sagemaker/metadata_properties.py | 12 +- src/sagemaker/model.py | 126 +++++++++--------- src/sagemaker/model_metrics.py | 32 +++-- .../serverless/serverless_inference_config.py | 4 +- src/sagemaker/session.py | 4 +- src/sagemaker/tensorflow/estimator.py | 16 ++- 13 files changed, 189 insertions(+), 131 deletions(-) diff --git a/src/sagemaker/amazon/amazon_estimator.py b/src/sagemaker/amazon/amazon_estimator.py index 09e77d612a..eaf4644da6 100644 --- a/src/sagemaker/amazon/amazon_estimator.py +++ b/src/sagemaker/amazon/amazon_estimator.py @@ -16,6 +16,7 @@ import json import logging import tempfile +from typing import Union from six.moves.urllib.parse import urlparse @@ -27,6 +28,7 @@ from sagemaker.estimator import EstimatorBase, _TrainingJob from sagemaker.inputs import FileSystemInput, TrainingInput from sagemaker.utils import sagemaker_timestamp +from sagemaker.workflow.entities import PipelineVariable from sagemaker.workflow.pipeline_context import runnable_by_pipeline logger = logging.getLogger(__name__) @@ -304,7 +306,12 @@ class RecordSet(object): """Placeholder docstring""" def __init__( - self, s3_data, num_records, feature_dim, s3_data_type="ManifestFile", channel="train" + self, + s3_data: Union[str, PipelineVariable], + num_records: int, + feature_dim: int, + s3_data_type: Union[str, PipelineVariable] = "ManifestFile", + channel: Union[str, PipelineVariable] = "train", ): """A collection of Amazon :class:~`Record` objects serialized and stored in S3. diff --git a/src/sagemaker/debugger/debugger.py b/src/sagemaker/debugger/debugger.py index d2d53547f1..23f7b651a3 100644 --- a/src/sagemaker/debugger/debugger.py +++ b/src/sagemaker/debugger/debugger.py @@ -24,12 +24,15 @@ from abc import ABC +from typing import Union, Optional, List, Dict + import attr import smdebug_rulesconfig as rule_configs from sagemaker import image_uris from sagemaker.utils import build_dict +from sagemaker.workflow.entities import PipelineVariable framework_name = "debugger" DEBUGGER_FLAG = "USE_SMDEBUG" @@ -311,17 +314,17 @@ def sagemaker( @classmethod def custom( cls, - name, - image_uri, - instance_type, - volume_size_in_gb, - source=None, - rule_to_invoke=None, - container_local_output_path=None, - s3_output_path=None, - other_trials_s3_input_paths=None, - rule_parameters=None, - collections_to_save=None, + name: str, + image_uri: Union[str, PipelineVariable], + instance_type: Union[str, PipelineVariable], + volume_size_in_gb: Union[int, PipelineVariable], + source: Optional[str] = None, + rule_to_invoke: Optional[Union[str, PipelineVariable]] = None, + container_local_output_path: Optional[Union[str, PipelineVariable]] = None, + s3_output_path: Optional[Union[str, PipelineVariable]] = None, + other_trials_s3_input_paths: Optional[List[Union[str, PipelineVariable]]] = None, + rule_parameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + collections_to_save: Optional[List["CollectionConfig"]] = None, actions=None, ): """Initialize a ``Rule`` object for a *custom* debugging rule. @@ -610,10 +613,10 @@ class DebuggerHookConfig(object): def __init__( self, - s3_output_path=None, - container_local_output_path=None, - hook_parameters=None, - collection_configs=None, + s3_output_path: Optional[Union[str, PipelineVariable]] = None, + container_local_output_path: Optional[Union[str, PipelineVariable]] = None, + hook_parameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + collection_configs: Optional[List["CollectionConfig"]] = None, ): """Initialize the DebuggerHookConfig instance. @@ -679,7 +682,11 @@ def _to_request_dict(self): class TensorBoardOutputConfig(object): """Create a tensor ouput configuration object for debugging visualizations on TensorBoard.""" - def __init__(self, s3_output_path, container_local_output_path=None): + def __init__( + self, + s3_output_path: Union[str, PipelineVariable], + container_local_output_path: Optional[Union[str, PipelineVariable]] = None, + ): """Initialize the TensorBoardOutputConfig instance. Args: @@ -708,7 +715,11 @@ def _to_request_dict(self): class CollectionConfig(object): """Creates tensor collections for SageMaker Debugger.""" - def __init__(self, name, parameters=None): + def __init__( + self, + name: Union[str, PipelineVariable], + parameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + ): """Constructor for collection configuration. Args: diff --git a/src/sagemaker/debugger/profiler_config.py b/src/sagemaker/debugger/profiler_config.py index 371d161bbe..807ba91e79 100644 --- a/src/sagemaker/debugger/profiler_config.py +++ b/src/sagemaker/debugger/profiler_config.py @@ -13,7 +13,10 @@ """Configuration for collecting system and framework metrics in SageMaker training jobs.""" from __future__ import absolute_import +from typing import Optional, Union + from sagemaker.debugger.framework_profile import FrameworkProfile +from sagemaker.workflow.entities import PipelineVariable class ProfilerConfig(object): @@ -26,9 +29,9 @@ class ProfilerConfig(object): def __init__( self, - s3_output_path=None, - system_monitor_interval_millis=None, - framework_profile_params=None, + s3_output_path: Optional[Union[str, PipelineVariable]] = None, + system_monitor_interval_millis: Optional[Union[int, PipelineVariable]] = None, + framework_profile_params: Optional[FrameworkProfile] = None, ): """Initialize a ``ProfilerConfig`` instance. diff --git a/src/sagemaker/drift_check_baselines.py b/src/sagemaker/drift_check_baselines.py index 24aa4787d0..9c3b8dbd57 100644 --- a/src/sagemaker/drift_check_baselines.py +++ b/src/sagemaker/drift_check_baselines.py @@ -13,21 +13,25 @@ """This file contains code related to drift check baselines""" from __future__ import absolute_import +from typing import Optional + +from sagemaker.model_metrics import MetricsSource, FileSource + class DriftCheckBaselines(object): """Accepts drift check baselines parameters for conversion to request dict.""" def __init__( self, - model_statistics=None, - model_constraints=None, - model_data_statistics=None, - model_data_constraints=None, - bias_config_file=None, - bias_pre_training_constraints=None, - bias_post_training_constraints=None, - explainability_constraints=None, - explainability_config_file=None, + model_statistics: Optional[MetricsSource] = None, + model_constraints: Optional[MetricsSource] = None, + model_data_statistics: Optional[MetricsSource] = None, + model_data_constraints: Optional[MetricsSource] = None, + bias_config_file: Optional[FileSource] = None, + bias_pre_training_constraints: Optional[MetricsSource] = None, + bias_post_training_constraints: Optional[MetricsSource] = None, + explainability_constraints: Optional[MetricsSource] = None, + explainability_config_file: Optional[FileSource] = None, ): """Initialize a ``DriftCheckBaselines`` instance and turn parameters into dict. diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 9d0c30ff27..fe87186646 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -2874,7 +2874,15 @@ def _validate_and_set_debugger_configs(self): # Disable debugger if checkpointing is enabled by the customer if self.checkpoint_s3_uri and self.checkpoint_local_path and self.debugger_hook_config: if self._framework_name in {"mxnet", "pytorch", "tensorflow"}: - if self.instance_count > 1 or ( + if is_pipeline_variable(self.instance_count): + logger.warning( + "SMDebug does not currently support distributed training jobs " + "with checkpointing enabled. Therefore, to allow parameterized " + "instance_count and allow to change it to any values in execution time, " + "the debugger_hook_config is disabled." + ) + self.debugger_hook_config = False + elif self.instance_count > 1 or ( hasattr(self, "distribution") and self.distribution is not None # pylint: disable=no-member ): diff --git a/src/sagemaker/huggingface/training_compiler/config.py b/src/sagemaker/huggingface/training_compiler/config.py index 07a3bcf9b7..b19fb2be2b 100644 --- a/src/sagemaker/huggingface/training_compiler/config.py +++ b/src/sagemaker/huggingface/training_compiler/config.py @@ -13,8 +13,10 @@ """Configuration for the SageMaker Training Compiler.""" from __future__ import absolute_import import logging +from typing import Union from sagemaker.training_compiler.config import TrainingCompilerConfig as BaseConfig +from sagemaker.workflow.entities import PipelineVariable logger = logging.getLogger(__name__) @@ -26,8 +28,8 @@ class TrainingCompilerConfig(BaseConfig): def __init__( self, - enabled=True, - debug=False, + enabled: Union[bool, PipelineVariable] = True, + debug: Union[bool, PipelineVariable] = False, ): """This class initializes a ``TrainingCompilerConfig`` instance. diff --git a/src/sagemaker/inputs.py b/src/sagemaker/inputs.py index 3481c138bd..0fca307a97 100644 --- a/src/sagemaker/inputs.py +++ b/src/sagemaker/inputs.py @@ -13,8 +13,11 @@ """Amazon SageMaker channel configurations for S3 data sources and file system data sources""" from __future__ import absolute_import, print_function +from typing import Union, Optional, List import attr +from sagemaker.workflow.entities import PipelineVariable + FILE_SYSTEM_TYPES = ["FSxLustre", "EFS"] FILE_SYSTEM_ACCESS_MODES = ["ro", "rw"] @@ -29,17 +32,17 @@ class TrainingInput(object): def __init__( self, - s3_data, - distribution=None, - compression=None, - content_type=None, - record_wrapping=None, - s3_data_type="S3Prefix", - instance_groups=None, - input_mode=None, - attribute_names=None, - target_attribute_name=None, - shuffle_config=None, + s3_data: Union[str, PipelineVariable], + distribution: Optional[Union[str, PipelineVariable]] = None, + compression: Optional[Union[str, PipelineVariable]] = None, + content_type: Optional[Union[str, PipelineVariable]] = None, + record_wrapping: Optional[Union[str, PipelineVariable]] = None, + s3_data_type: Union[str, PipelineVariable] = "S3Prefix", + instance_groups: Optional[List[Union[str, PipelineVariable]]] = None, + input_mode: Optional[Union[str, PipelineVariable]] = None, + attribute_names: Optional[List[Union[str, PipelineVariable]]] = None, + target_attribute_name: Optional[Union[str, PipelineVariable]] = None, + shuffle_config: Optional["ShuffleConfig"] = None, ): r"""Create a definition for input data used by an SageMaker training job. diff --git a/src/sagemaker/metadata_properties.py b/src/sagemaker/metadata_properties.py index 4bc77ed0ee..b25aff9168 100644 --- a/src/sagemaker/metadata_properties.py +++ b/src/sagemaker/metadata_properties.py @@ -13,16 +13,20 @@ """This file contains code related to metadata properties.""" from __future__ import absolute_import +from typing import Optional, Union + +from sagemaker.workflow.entities import PipelineVariable + class MetadataProperties(object): """Accepts metadata properties parameters for conversion to request dict.""" def __init__( self, - commit_id=None, - repository=None, - generated_by=None, - project_id=None, + commit_id: Optional[Union[str, PipelineVariable]] = None, + repository: Optional[Union[str, PipelineVariable]] = None, + generated_by: Optional[Union[str, PipelineVariable]] = None, + project_id: Optional[Union[str, PipelineVariable]] = None, ): """Initialize a ``MetadataProperties`` instance and turn parameters into dict. diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index a2c6da4bb7..8772fa724f 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -18,7 +18,7 @@ import logging import os import copy -from typing import List, Dict +from typing import List, Dict, Optional, Union import sagemaker from sagemaker import ( @@ -29,7 +29,11 @@ utils, git_utils, ) +from sagemaker.session import Session +from sagemaker.model_metrics import ModelMetrics from sagemaker.deprecations import removed_kwargs +from sagemaker.drift_check_baselines import DriftCheckBaselines +from sagemaker.metadata_properties import MetadataProperties from sagemaker.predictor import PredictorBase from sagemaker.serverless import ServerlessInferenceConfig from sagemaker.transformer import Transformer @@ -37,10 +41,12 @@ from sagemaker.utils import ( unique_name_from_base, update_container_with_inference_params, + to_string, ) from sagemaker.async_inference import AsyncInferenceConfig from sagemaker.predictor_async import AsyncPredictor from sagemaker.workflow import is_pipeline_variable +from sagemaker.workflow.entities import PipelineVariable from sagemaker.workflow.pipeline_context import runnable_by_pipeline, PipelineSession LOGGER = logging.getLogger("sagemaker") @@ -82,23 +88,23 @@ class Model(ModelBase): def __init__( self, - image_uri, - model_data=None, - role=None, - predictor_cls=None, - env=None, - name=None, - vpc_config=None, - sagemaker_session=None, - enable_network_isolation=False, - model_kms_key=None, - image_config=None, - source_dir=None, - code_location=None, - entry_point=None, - container_log_level=logging.INFO, - dependencies=None, - git_config=None, + image_uri: Union[str, PipelineVariable], + model_data: Optional[Union[str, PipelineVariable]] = None, + role: Optional[str] = None, + predictor_cls: Optional[callable] = None, + env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + name: Optional[str] = None, + vpc_config: Optional[Dict[str, List[Union[str, PipelineVariable]]]] = None, + sagemaker_session: Optional[Session] = None, + enable_network_isolation: Union[bool, PipelineVariable] = False, + model_kms_key: Optional[str] = None, + image_config: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + source_dir: Optional[str] = None, + code_location: Optional[str] = None, + entry_point: Optional[str] = None, + container_log_level: Union[int, PipelineVariable] = logging.INFO, + dependencies: Optional[List[str]] = None, + git_config: Optional[Dict[str, str]] = None, ): """Initialize an SageMaker ``Model``. @@ -298,28 +304,28 @@ def __init__( @runnable_by_pipeline def register( self, - content_types, - response_types, - inference_instances=None, - transform_instances=None, - model_package_name=None, - model_package_group_name=None, - image_uri=None, - model_metrics=None, - metadata_properties=None, - marketplace_cert=False, - approval_status=None, - description=None, - drift_check_baselines=None, - customer_metadata_properties=None, - validation_specification=None, - domain=None, - task=None, - sample_payload_url=None, - framework=None, - framework_version=None, - nearest_model_name=None, - data_input_configuration=None, + content_types: List[Union[str, PipelineVariable]], + response_types: List[Union[str, PipelineVariable]], + inference_instances: Optional[List[Union[str, PipelineVariable]]] = None, + transform_instances: Optional[List[Union[str, PipelineVariable]]] = None, + model_package_name: Optional[Union[str, PipelineVariable]] = None, + model_package_group_name: Optional[Union[str, PipelineVariable]] = None, + image_uri: Optional[Union[str, PipelineVariable]] = None, + model_metrics: Optional[ModelMetrics] = None, + metadata_properties: Optional[MetadataProperties] = None, + marketplace_cert: bool = False, + approval_status: Optional[Union[str, PipelineVariable]] = None, + description: Optional[str] = None, + drift_check_baselines: Optional[DriftCheckBaselines] = None, + customer_metadata_properties: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + validation_specification: Optional[Union[str, PipelineVariable]] = None, + domain: Optional[Union[str, PipelineVariable]] = None, + task: Optional[Union[str, PipelineVariable]] = None, + sample_payload_url: Optional[Union[str, PipelineVariable]] = None, + framework: Optional[Union[str, PipelineVariable]] = None, + framework_version: Optional[Union[str, PipelineVariable]] = None, + nearest_model_name: Optional[Union[str, PipelineVariable]] = None, + data_input_configuration: Optional[Union[str, PipelineVariable]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -349,11 +355,11 @@ def register( metadata properties (default: None). domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING", "MACHINE_LEARNING" (default: None). - sample_payload_url (str): The S3 path where the sample payload is stored - (default: None). task (str): Task values which are supported by Inference Recommender are "FILL_MASK", "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION", "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). + sample_payload_url (str): The S3 path where the sample payload is stored + (default: None). framework (str): Machine learning framework of the model package container image (default: None). framework_version (str): Framework version of the Model Package Container Image @@ -421,10 +427,10 @@ def register( @runnable_by_pipeline def create( self, - instance_type: str = None, - accelerator_type: str = None, - serverless_inference_config: ServerlessInferenceConfig = None, - tags: List[Dict[str, str]] = None, + instance_type: Optional[str] = None, + accelerator_type: Optional[str] = None, + serverless_inference_config: Optional[ServerlessInferenceConfig] = None, + tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, ): """Create a SageMaker Model Entity @@ -608,7 +614,7 @@ def _script_mode_env_vars(self): return { SCRIPT_PARAM_NAME.upper(): script_name or str(), DIR_PARAM_NAME.upper(): dir_name or str(), - CONTAINER_LOG_LEVEL_PARAM_NAME.upper(): str(self.container_log_level), + CONTAINER_LOG_LEVEL_PARAM_NAME.upper(): to_string(self.container_log_level), SAGEMAKER_REGION_PARAM_NAME.upper(): self.sagemaker_session.boto_region_name, } @@ -1286,19 +1292,19 @@ class FrameworkModel(Model): def __init__( self, - model_data, - image_uri, - role, - entry_point, - source_dir=None, - predictor_cls=None, - env=None, - name=None, - container_log_level=logging.INFO, - code_location=None, - sagemaker_session=None, - dependencies=None, - git_config=None, + model_data: Union[str, PipelineVariable], + image_uri: Union[str, PipelineVariable], + role: str, + entry_point: str, + source_dir: Optional[str] = None, + predictor_cls: Optional[callable] = None, + env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + name: Optional[str] = None, + container_log_level: Union[int, PipelineVariable] = logging.INFO, + code_location: Optional[str] = None, + sagemaker_session: Optional[Session] = None, + dependencies: Optional[List[str]] = None, + git_config: Optional[Dict[str, str]] = None, **kwargs, ): """Initialize a ``FrameworkModel``. diff --git a/src/sagemaker/model_metrics.py b/src/sagemaker/model_metrics.py index acce4e13c9..83a43d3f18 100644 --- a/src/sagemaker/model_metrics.py +++ b/src/sagemaker/model_metrics.py @@ -13,20 +13,24 @@ """This file contains code related to model metrics, including metric source and file source.""" from __future__ import absolute_import +from typing import Optional, Union + +from sagemaker.workflow.entities import PipelineVariable + class ModelMetrics(object): """Accepts model metrics parameters for conversion to request dict.""" def __init__( self, - model_statistics=None, - model_constraints=None, - model_data_statistics=None, - model_data_constraints=None, - bias=None, - explainability=None, - bias_pre_training=None, - bias_post_training=None, + model_statistics: Optional["MetricsSource"] = None, + model_constraints: Optional["MetricsSource"] = None, + model_data_statistics: Optional["MetricsSource"] = None, + model_data_constraints: Optional["MetricsSource"] = None, + bias: Optional["MetricsSource"] = None, + explainability: Optional["MetricsSource"] = None, + bias_pre_training: Optional["MetricsSource"] = None, + bias_post_training: Optional["MetricsSource"] = None, ): """Initialize a ``ModelMetrics`` instance and turn parameters into dict. @@ -99,9 +103,9 @@ class MetricsSource(object): def __init__( self, - content_type, - s3_uri, - content_digest=None, + content_type: Union[str, PipelineVariable], + s3_uri: Union[str, PipelineVariable], + content_digest: Optional[Union[str, PipelineVariable]] = None, ): """Initialize a ``MetricsSource`` instance and turn parameters into dict. @@ -127,9 +131,9 @@ class FileSource(object): def __init__( self, - s3_uri, - content_digest=None, - content_type=None, + s3_uri: Union[str, PipelineVariable], + content_digest: Optional[Union[str, PipelineVariable]] = None, + content_type: Optional[Union[str, PipelineVariable]] = None, ): """Initialize a ``FileSource`` instance and turn parameters into dict. diff --git a/src/sagemaker/serverless/serverless_inference_config.py b/src/sagemaker/serverless/serverless_inference_config.py index 39950f4f84..adc98a319a 100644 --- a/src/sagemaker/serverless/serverless_inference_config.py +++ b/src/sagemaker/serverless/serverless_inference_config.py @@ -27,8 +27,8 @@ class ServerlessInferenceConfig(object): def __init__( self, - memory_size_in_mb=2048, - max_concurrency=5, + memory_size_in_mb: int = 2048, + max_concurrency: int = 5, ): """Initialize a ServerlessInferenceConfig object for serverless inference configuration. diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 145bf41cbe..221434d7db 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -2633,7 +2633,9 @@ def _create_model_request( request["VpcConfig"] = vpc_config if enable_network_isolation: - request["EnableNetworkIsolation"] = True + # enable_network_isolation may be a pipeline variable which is + # parsed in execution time + request["EnableNetworkIsolation"] = enable_network_isolation return request diff --git a/src/sagemaker/tensorflow/estimator.py b/src/sagemaker/tensorflow/estimator.py index 4db647e140..9533f475a1 100644 --- a/src/sagemaker/tensorflow/estimator.py +++ b/src/sagemaker/tensorflow/estimator.py @@ -14,6 +14,7 @@ from __future__ import absolute_import import logging +from typing import Optional, Union, Dict from packaging import version @@ -27,6 +28,7 @@ from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT from sagemaker.workflow import is_pipeline_variable from sagemaker.tensorflow.training_compiler.config import TrainingCompilerConfig +from sagemaker.workflow.entities import PipelineVariable logger = logging.getLogger("sagemaker") @@ -41,12 +43,12 @@ class TensorFlow(Framework): def __init__( self, - py_version=None, - framework_version=None, - model_dir=None, - image_uri=None, - distribution=None, - compiler_config=None, + py_version: Optional[str] = None, + framework_version: Optional[str] = None, + model_dir: Optional[Union[str, PipelineVariable]] = None, + image_uri: Optional[Union[str, PipelineVariable]] = None, + distribution: Optional[Dict[str, str]] = None, + compiler_config: Optional[TrainingCompilerConfig] = None, **kwargs, ): """Initialize a ``TensorFlow`` estimator. @@ -251,6 +253,8 @@ def _only_legacy_mode_supported(self): def _only_python_3_supported(self): """Placeholder docstring""" + if not self.framework_version: + return False return version.Version(self.framework_version) > self._HIGHEST_PYTHON_2_VERSION @classmethod From 3560d7003ca817749684ccb00876c7f7ae39b0d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?jerrypeng73=F0=9F=98=8E?= Date: Wed, 15 Jun 2022 15:14:19 -0700 Subject: [PATCH 2/6] change: Add PipelineVariable annotation for the rest of processors and estimators annotations for processors + estimators / test mechanism change remove debug print reformatting --- src/sagemaker/algorithm.py | 70 +++--- src/sagemaker/amazon/amazon_estimator.py | 36 ++- .../amazon/factorization_machines.py | 106 ++++---- src/sagemaker/amazon/hyperparameter.py | 1 + src/sagemaker/amazon/ipinsights.py | 48 ++-- src/sagemaker/amazon/kmeans.py | 56 +++-- src/sagemaker/amazon/knn.py | 47 ++-- src/sagemaker/amazon/lda.py | 32 +-- src/sagemaker/amazon/linear_learner.py | 205 +++++++++------- src/sagemaker/amazon/ntm.py | 66 ++--- src/sagemaker/amazon/object2vec.py | 156 ++++++------ src/sagemaker/amazon/pca.py | 30 +-- src/sagemaker/amazon/randomcutforest.py | 28 +-- src/sagemaker/chainer/estimator.py | 28 ++- src/sagemaker/clarify.py | 229 ++++++++++-------- src/sagemaker/estimator.py | 1 + src/sagemaker/fw_utils.py | 2 +- src/sagemaker/huggingface/estimator.py | 12 +- src/sagemaker/huggingface/processing.py | 44 ++-- src/sagemaker/mxnet/estimator.py | 2 + src/sagemaker/mxnet/processing.py | 39 +-- src/sagemaker/processing.py | 10 +- src/sagemaker/pytorch/estimator.py | 12 +- src/sagemaker/pytorch/processing.py | 39 +-- src/sagemaker/rl/estimator.py | 14 +- src/sagemaker/sklearn/estimator.py | 25 +- src/sagemaker/sklearn/processing.py | 32 +-- src/sagemaker/tensorflow/processing.py | 39 +-- src/sagemaker/training_compiler/config.py | 15 ++ src/sagemaker/utils.py | 2 - src/sagemaker/workflow/step_collections.py | 1 + src/sagemaker/xgboost/estimator.py | 14 +- src/sagemaker/xgboost/processing.py | 39 +-- .../workflow/test_processing_step.py | 4 +- .../sagemaker/workflow/test_training_step.py | 28 ++- 35 files changed, 850 insertions(+), 662 deletions(-) diff --git a/src/sagemaker/algorithm.py b/src/sagemaker/algorithm.py index 1ab5ee3bcf..0ddf9fae78 100644 --- a/src/sagemaker/algorithm.py +++ b/src/sagemaker/algorithm.py @@ -13,15 +13,22 @@ """Test docstring""" from __future__ import absolute_import +from typing import Optional, Union, Dict, List + import sagemaker import sagemaker.parameter from sagemaker import vpc_utils from sagemaker.deserializers import BytesDeserializer from sagemaker.deprecations import removed_kwargs from sagemaker.estimator import EstimatorBase +from sagemaker.inputs import TrainingInput, FileSystemInput from sagemaker.serializers import IdentitySerializer from sagemaker.transformer import Transformer from sagemaker.predictor import Predictor +from sagemaker.session import Session +from sagemaker.workflow.entities import PipelineVariable + +from sagemaker.workflow import is_pipeline_variable class AlgorithmEstimator(EstimatorBase): @@ -37,28 +44,28 @@ class AlgorithmEstimator(EstimatorBase): def __init__( self, - algorithm_arn, - role, - instance_count, - instance_type, - volume_size=30, - volume_kms_key=None, - max_run=24 * 60 * 60, - input_mode="File", - output_path=None, - output_kms_key=None, - base_job_name=None, - sagemaker_session=None, - hyperparameters=None, - tags=None, - subnets=None, - security_group_ids=None, - model_uri=None, - model_channel_name="model", - metric_definitions=None, - encrypt_inter_container_traffic=False, - use_spot_instances=False, - max_wait=None, + algorithm_arn: str, + role: str, + instance_count: Optional[Union[int, PipelineVariable]] = None, + instance_type: Optional[Union[str, PipelineVariable]] = None, + volume_size: Union[int, PipelineVariable] = 30, + volume_kms_key: Optional[Union[str, PipelineVariable]] = None, + max_run: Union[int, PipelineVariable] = 24 * 60 * 60, + input_mode: Union[str, PipelineVariable] = "File", + output_path: Optional[Union[str, PipelineVariable]] = None, + output_kms_key: Optional[Union[str, PipelineVariable]] = None, + base_job_name: Optional[str] = None, + sagemaker_session: Optional[Session] = None, + hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, + subnets: Optional[List[Union[str, PipelineVariable]]] = None, + security_group_ids: Optional[List[Union[str, PipelineVariable]]] = None, + model_uri: Optional[str] = None, + model_channel_name: Union[str, PipelineVariable] = "model", + metric_definitions: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, + encrypt_inter_container_traffic: Union[bool, PipelineVariable] = False, + use_spot_instances: Union[bool, PipelineVariable] = False, + max_wait: Union[int, PipelineVariable] = None, **kwargs # pylint: disable=W0613 ): """Initialize an ``AlgorithmEstimator`` instance. @@ -186,7 +193,7 @@ def validate_train_spec(self): # Check that the input mode provided is compatible with the training input modes for the # algorithm. input_modes = self._algorithm_training_input_modes(train_spec["TrainingChannels"]) - if self.input_mode not in input_modes: + if not is_pipeline_variable(self.input_mode) and self.input_mode not in input_modes: raise ValueError( "Invalid input mode: %s. %s only supports: %s" % (self.input_mode, algorithm_name, input_modes) @@ -194,14 +201,17 @@ def validate_train_spec(self): # Check that the training instance type is compatible with the algorithm. supported_instances = train_spec["SupportedTrainingInstanceTypes"] - if self.instance_type not in supported_instances: + if ( + not is_pipeline_variable(self.instance_type) + and self.instance_type not in supported_instances + ): raise ValueError( "Invalid instance_type: %s. %s supports the following instance types: %s" % (self.instance_type, algorithm_name, supported_instances) ) # Verify if distributed training is supported by the algorithm - if ( + if not is_pipeline_variable(self.instance_count) and ( self.instance_count > 1 and "SupportsDistributedTraining" in train_spec and not train_spec["SupportsDistributedTraining"] @@ -414,12 +424,18 @@ def _prepare_for_training(self, job_name=None): super(AlgorithmEstimator, self)._prepare_for_training(job_name) - def fit(self, inputs=None, wait=True, logs=True, job_name=None): + def fit( + self, + inputs: Optional[Union[str, Dict, TrainingInput, FileSystemInput]] = None, + wait: bool = True, + logs: bool = True, + job_name: Optional[str] = None, + ): """Placeholder docstring""" if inputs: self._validate_input_channels(inputs) - super(AlgorithmEstimator, self).fit(inputs, wait, logs, job_name) + return super(AlgorithmEstimator, self).fit(inputs, wait, logs, job_name) def _validate_input_channels(self, channels): """Placeholder docstring""" diff --git a/src/sagemaker/amazon/amazon_estimator.py b/src/sagemaker/amazon/amazon_estimator.py index eaf4644da6..71c29a74c1 100644 --- a/src/sagemaker/amazon/amazon_estimator.py +++ b/src/sagemaker/amazon/amazon_estimator.py @@ -13,6 +13,8 @@ """Placeholder docstring""" from __future__ import absolute_import +from typing import Optional, Union, Dict + import json import logging import tempfile @@ -30,6 +32,9 @@ from sagemaker.utils import sagemaker_timestamp from sagemaker.workflow.entities import PipelineVariable from sagemaker.workflow.pipeline_context import runnable_by_pipeline +from sagemaker.workflow.entities import PipelineVariable +from sagemaker.workflow.parameters import ParameterBoolean +from sagemaker.workflow import is_pipeline_variable logger = logging.getLogger(__name__) @@ -42,16 +47,16 @@ class AmazonAlgorithmEstimatorBase(EstimatorBase): feature_dim = hp("feature_dim", validation.gt(0), data_type=int) mini_batch_size = hp("mini_batch_size", validation.gt(0), data_type=int) - repo_name = None - repo_version = None + repo_name: Optional[str] = None + repo_version: Optional[str] = None def __init__( self, - role, - instance_count=None, - instance_type=None, - data_location=None, - enable_network_isolation=False, + role: str, + instance_count: Optional[Union[int]] = None, + instance_type: Optional[Union[str, PipelineVariable]] = None, + data_location: Optional[str] = None, + enable_network_isolation: Union[bool, ParameterBoolean] = False, **kwargs ): """Initialize an AmazonAlgorithmEstimatorBase. @@ -115,6 +120,11 @@ def data_location(self): @data_location.setter def data_location(self, data_location): """Placeholder docstring""" + if is_pipeline_variable(data_location): + raise ValueError( + "data_location argument has to be an integer " + "rather than a pipeline variable" + ) + if not data_location.startswith("s3://"): raise ValueError( 'Expecting an S3 URL beginning with "s3://". Got "{}"'.format(data_location) @@ -198,12 +208,12 @@ def _prepare_for_training(self, records, mini_batch_size=None, job_name=None): @runnable_by_pipeline def fit( self, - records, - mini_batch_size=None, - wait=True, - logs=True, - job_name=None, - experiment_config=None, + records: "RecordSet", + mini_batch_size: Optional[int] = None, + wait: bool = True, + logs: bool = True, + job_name: Optional[str] = None, + experiment_config: Optional[Dict[str, str]] = None, ): """Fit this Estimator on serialized Record objects, stored in S3. diff --git a/src/sagemaker/amazon/factorization_machines.py b/src/sagemaker/amazon/factorization_machines.py index 5e9c2098b9..0d3f570e60 100644 --- a/src/sagemaker/amazon/factorization_machines.py +++ b/src/sagemaker/amazon/factorization_machines.py @@ -37,83 +37,83 @@ class FactorizationMachines(AmazonAlgorithmEstimatorBase): sparse datasets economically. """ - repo_name = "factorization-machines" - repo_version = 1 + repo_name: str = "factorization-machines" + repo_version: int = 1 - num_factors = hp("num_factors", gt(0), "An integer greater than zero", int) - predictor_type = hp( + num_factors: hp = hp("num_factors", gt(0), "An integer greater than zero", int) + predictor_type: hp = hp( "predictor_type", isin("binary_classifier", "regressor"), 'Value "binary_classifier" or "regressor"', str, ) - epochs = hp("epochs", gt(0), "An integer greater than 0", int) - clip_gradient = hp("clip_gradient", (), "A float value", float) - eps = hp("eps", (), "A float value", float) - rescale_grad = hp("rescale_grad", (), "A float value", float) - bias_lr = hp("bias_lr", ge(0), "A non-negative float", float) - linear_lr = hp("linear_lr", ge(0), "A non-negative float", float) - factors_lr = hp("factors_lr", ge(0), "A non-negative float", float) - bias_wd = hp("bias_wd", ge(0), "A non-negative float", float) - linear_wd = hp("linear_wd", ge(0), "A non-negative float", float) - factors_wd = hp("factors_wd", ge(0), "A non-negative float", float) - bias_init_method = hp( + epochs: hp = hp("epochs", gt(0), "An integer greater than 0", int) + clip_gradient: hp = hp("clip_gradient", (), "A float value", float) + eps: hp = hp("eps", (), "A float value", float) + rescale_grad: hp = hp("rescale_grad", (), "A float value", float) + bias_lr: hp = hp("bias_lr", ge(0), "A non-negative float", float) + linear_lr: hp = hp("linear_lr", ge(0), "A non-negative float", float) + factors_lr: hp = hp("factors_lr", ge(0), "A non-negative float", float) + bias_wd: hp = hp("bias_wd", ge(0), "A non-negative float", float) + linear_wd: hp = hp("linear_wd", ge(0), "A non-negative float", float) + factors_wd: hp = hp("factors_wd", ge(0), "A non-negative float", float) + bias_init_method: hp = hp( "bias_init_method", isin("normal", "uniform", "constant"), 'Value "normal", "uniform" or "constant"', str, ) - bias_init_scale = hp("bias_init_scale", ge(0), "A non-negative float", float) - bias_init_sigma = hp("bias_init_sigma", ge(0), "A non-negative float", float) - bias_init_value = hp("bias_init_value", (), "A float value", float) - linear_init_method = hp( + bias_init_scale: hp = hp("bias_init_scale", ge(0), "A non-negative float", float) + bias_init_sigma: hp = hp("bias_init_sigma", ge(0), "A non-negative float", float) + bias_init_value: hp = hp("bias_init_value", (), "A float value", float) + linear_init_method: hp = hp( "linear_init_method", isin("normal", "uniform", "constant"), 'Value "normal", "uniform" or "constant"', str, ) - linear_init_scale = hp("linear_init_scale", ge(0), "A non-negative float", float) - linear_init_sigma = hp("linear_init_sigma", ge(0), "A non-negative float", float) - linear_init_value = hp("linear_init_value", (), "A float value", float) - factors_init_method = hp( + linear_init_scale: hp = hp("linear_init_scale", ge(0), "A non-negative float", float) + linear_init_sigma: hp = hp("linear_init_sigma", ge(0), "A non-negative float", float) + linear_init_value: hp = hp("linear_init_value", (), "A float value", float) + factors_init_method: hp = hp( "factors_init_method", isin("normal", "uniform", "constant"), 'Value "normal", "uniform" or "constant"', str, ) - factors_init_scale = hp("factors_init_scale", ge(0), "A non-negative float", float) - factors_init_sigma = hp("factors_init_sigma", ge(0), "A non-negative float", float) - factors_init_value = hp("factors_init_value", (), "A float value", float) + factors_init_scale: hp = hp("factors_init_scale", ge(0), "A non-negative float", float) + factors_init_sigma: hp = hp("factors_init_sigma", ge(0), "A non-negative float", float) + factors_init_value: hp = hp("factors_init_value", (), "A float value", float) def __init__( self, - role, - instance_count=None, - instance_type=None, - num_factors=None, - predictor_type=None, - epochs=None, - clip_gradient=None, - eps=None, - rescale_grad=None, - bias_lr=None, - linear_lr=None, - factors_lr=None, - bias_wd=None, - linear_wd=None, - factors_wd=None, - bias_init_method=None, - bias_init_scale=None, - bias_init_sigma=None, - bias_init_value=None, - linear_init_method=None, - linear_init_scale=None, - linear_init_sigma=None, - linear_init_value=None, - factors_init_method=None, - factors_init_scale=None, - factors_init_sigma=None, - factors_init_value=None, + role: str, + instance_count: Optional[Union[int, PipelineVariable]] = None, + instance_type: Optional[Union[str, PipelineVariable]] = None, + num_factors: Optional[int] = None, + predictor_type: Optional[str] = None, + epochs: Optional[int] = None, + clip_gradient: Optional[float] = None, + eps: Optional[float] = None, + rescale_grad: Optional[float] = None, + bias_lr: Optional[float] = None, + linear_lr: Optional[float] = None, + factors_lr: Optional[float] = None, + bias_wd: Optional[float] = None, + linear_wd: Optional[float] = None, + factors_wd: Optional[float] = None, + bias_init_method: Optional[str] = None, + bias_init_scale: Optional[float] = None, + bias_init_sigma: Optional[float] = None, + bias_init_value: Optional[float] = None, + linear_init_method: Optional[str] = None, + linear_init_scale: Optional[float] = None, + linear_init_sigma: Optional[float] = None, + linear_init_value: Optional[float] = None, + factors_init_method: Optional[str] = None, + factors_init_scale: Optional[float] = None, + factors_init_sigma: Optional[float] = None, + factors_init_value: Optional[float] = None, **kwargs ): """Factorization Machines is :class:`Estimator` for general-purpose supervised learning. diff --git a/src/sagemaker/amazon/hyperparameter.py b/src/sagemaker/amazon/hyperparameter.py index 856927cb13..36ba8019da 100644 --- a/src/sagemaker/amazon/hyperparameter.py +++ b/src/sagemaker/amazon/hyperparameter.py @@ -14,6 +14,7 @@ from __future__ import absolute_import import json +from sagemaker.workflow import is_pipeline_variable from sagemaker.workflow import is_pipeline_variable diff --git a/src/sagemaker/amazon/ipinsights.py b/src/sagemaker/amazon/ipinsights.py index 097f6b45dc..50f6f03566 100644 --- a/src/sagemaker/amazon/ipinsights.py +++ b/src/sagemaker/amazon/ipinsights.py @@ -36,45 +36,45 @@ class IPInsights(AmazonAlgorithmEstimatorBase): as user IDs or account numbers. """ - repo_name = "ipinsights" - repo_version = 1 - MINI_BATCH_SIZE = 10000 + repo_name: str = "ipinsights" + repo_version: int = 1 + MINI_BATCH_SIZE: int = 10000 - num_entity_vectors = hp( + num_entity_vectors: hp = hp( "num_entity_vectors", (ge(1), le(250000000)), "An integer in [1, 250000000]", int ) - vector_dim = hp("vector_dim", (ge(4), le(4096)), "An integer in [4, 4096]", int) + vector_dim: hp = hp("vector_dim", (ge(4), le(4096)), "An integer in [4, 4096]", int) - batch_metrics_publish_interval = hp( + batch_metrics_publish_interval: hp = hp( "batch_metrics_publish_interval", (ge(1)), "An integer greater than 0", int ) - epochs = hp("epochs", (ge(1)), "An integer greater than 0", int) - learning_rate = hp("learning_rate", (ge(1e-6), le(10.0)), "A float in [1e-6, 10.0]", float) - num_ip_encoder_layers = hp( + epochs: hp = hp("epochs", (ge(1)), "An integer greater than 0", int) + learning_rate: hp = hp("learning_rate", (ge(1e-6), le(10.0)), "A float in [1e-6, 10.0]", float) + num_ip_encoder_layers: hp = hp( "num_ip_encoder_layers", (ge(0), le(100)), "An integer in [0, 100]", int ) - random_negative_sampling_rate = hp( + random_negative_sampling_rate: hp = hp( "random_negative_sampling_rate", (ge(0), le(500)), "An integer in [0, 500]", int ) - shuffled_negative_sampling_rate = hp( + shuffled_negative_sampling_rate: hp = hp( "shuffled_negative_sampling_rate", (ge(0), le(500)), "An integer in [0, 500]", int ) - weight_decay = hp("weight_decay", (ge(0.0), le(10.0)), "A float in [0.0, 10.0]", float) + weight_decay: hp = hp("weight_decay", (ge(0.0), le(10.0)), "A float in [0.0, 10.0]", float) def __init__( self, - role, - instance_count=None, - instance_type=None, - num_entity_vectors=None, - vector_dim=None, - batch_metrics_publish_interval=None, - epochs=None, - learning_rate=None, - num_ip_encoder_layers=None, - random_negative_sampling_rate=None, - shuffled_negative_sampling_rate=None, - weight_decay=None, + role: str, + instance_count: Optional[Union[int, PipelineVariable]] = None, + instance_type: Optional[Union[str, PipelineVariable]] = None, + num_entity_vectors: Optional[int] = None, + vector_dim: Optional[int] = None, + batch_metrics_publish_interval: Optional[int] = None, + epochs: Optional[int] = None, + learning_rate: Optional[float] = None, + num_ip_encoder_layers: Optional[int] = None, + random_negative_sampling_rate: Optional[int] = None, + shuffled_negative_sampling_rate: Optional[int] = None, + weight_decay: Optional[float] = None, **kwargs ): """This estimator is for IP Insights. diff --git a/src/sagemaker/amazon/kmeans.py b/src/sagemaker/amazon/kmeans.py index 581e93e02a..18c976b99d 100644 --- a/src/sagemaker/amazon/kmeans.py +++ b/src/sagemaker/amazon/kmeans.py @@ -13,7 +13,7 @@ """Placeholder docstring""" from __future__ import absolute_import -from typing import Union, Optional +from typing import Union, Optional, List from sagemaker import image_uris from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase @@ -27,6 +27,8 @@ from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT from sagemaker.workflow.entities import PipelineVariable +from sagemaker.workflow.entities import PipelineVariable + class KMeans(AmazonAlgorithmEstimatorBase): """An unsupervised learning algorithm that attempts to find discrete groupings within data. @@ -36,23 +38,25 @@ class KMeans(AmazonAlgorithmEstimatorBase): the algorithm to use to determine similarity. """ - repo_name = "kmeans" - repo_version = 1 + repo_name: str = "kmeans" + repo_version: str = 1 - k = hp("k", gt(1), "An integer greater-than 1", int) - init_method = hp("init_method", isin("random", "kmeans++"), 'One of "random", "kmeans++"', str) - max_iterations = hp("local_lloyd_max_iter", gt(0), "An integer greater-than 0", int) - tol = hp("local_lloyd_tol", (ge(0), le(1)), "An float in [0, 1]", float) - num_trials = hp("local_lloyd_num_trials", gt(0), "An integer greater-than 0", int) - local_init_method = hp( + k: hp = hp("k", gt(1), "An integer greater-than 1", int) + init_method: hp = hp( + "init_method", isin("random", "kmeans++"), 'One of "random", "kmeans++"', str + ) + max_iterations: hp = hp("local_lloyd_max_iter", gt(0), "An integer greater-than 0", int) + tol: hp = hp("local_lloyd_tol", (ge(0), le(1)), "An float in [0, 1]", float) + num_trials: hp = hp("local_lloyd_num_trials", gt(0), "An integer greater-than 0", int) + local_init_method: hp = hp( "local_lloyd_init_method", isin("random", "kmeans++"), 'One of "random", "kmeans++"', str ) - half_life_time_size = hp( + half_life_time_size: hp = hp( "half_life_time_size", ge(0), "An integer greater-than-or-equal-to 0", int ) - epochs = hp("epochs", gt(0), "An integer greater-than 0", int) - center_factor = hp("extra_center_factor", gt(0), "An integer greater-than 0", int) - eval_metrics = hp( + epochs: hp = hp("epochs", gt(0), "An integer greater-than 0", int) + center_factor: hp = hp("extra_center_factor", gt(0), "An integer greater-than 0", int) + eval_metrics: hp = hp( name="eval_metrics", validation_message='A comma separated list of "msd" or "ssd"', data_type=list, @@ -60,19 +64,19 @@ class KMeans(AmazonAlgorithmEstimatorBase): def __init__( self, - role, - instance_count=None, - instance_type=None, - k=None, - init_method=None, - max_iterations=None, - tol=None, - num_trials=None, - local_init_method=None, - half_life_time_size=None, - epochs=None, - center_factor=None, - eval_metrics=None, + role: str, + instance_count: Optional[Union[int, PipelineVariable]] = None, + instance_type: Optional[Union[str, PipelineVariable]] = None, + k: Optional[int] = None, + init_method: Optional[str] = None, + max_iterations: Optional[int] = None, + tol: Optional[float] = None, + num_trials: Optional[int] = None, + local_init_method: Optional[str] = None, + half_life_time_size: Optional[int] = None, + epochs: Optional[int] = None, + center_factor: Optional[int] = None, + eval_metrics: Optional[List[Union[str, PipelineVariable]]] = None, **kwargs ): """A k-means clustering class :class:`~sagemaker.amazon.AmazonAlgorithmEstimatorBase`. diff --git a/src/sagemaker/amazon/knn.py b/src/sagemaker/amazon/knn.py index 14ba404ebf..2054c36483 100644 --- a/src/sagemaker/amazon/knn.py +++ b/src/sagemaker/amazon/knn.py @@ -37,54 +37,54 @@ class KNN(AmazonAlgorithmEstimatorBase): the average of their feature values as the predicted value. """ - repo_name = "knn" - repo_version = 1 + repo_name: str = "knn" + repo_version: int = 1 - k = hp("k", (ge(1)), "An integer greater than 0", int) - sample_size = hp("sample_size", (ge(1)), "An integer greater than 0", int) - predictor_type = hp( + k: hp = hp("k", (ge(1)), "An integer greater than 0", int) + sample_size: hp = hp("sample_size", (ge(1)), "An integer greater than 0", int) + predictor_type: hp = hp( "predictor_type", isin("classifier", "regressor"), 'One of "classifier" or "regressor"', str ) - dimension_reduction_target = hp( + dimension_reduction_target: hp = hp( "dimension_reduction_target", (ge(1)), "An integer greater than 0 and less than feature_dim", int, ) - dimension_reduction_type = hp( + dimension_reduction_type: hp = hp( "dimension_reduction_type", isin("sign", "fjlt"), 'One of "sign" or "fjlt"', str ) - index_metric = hp( + index_metric: hp = hp( "index_metric", isin("COSINE", "INNER_PRODUCT", "L2"), 'One of "COSINE", "INNER_PRODUCT", "L2"', str, ) - index_type = hp( + index_type: hp = hp( "index_type", isin("faiss.Flat", "faiss.IVFFlat", "faiss.IVFPQ"), 'One of "faiss.Flat", "faiss.IVFFlat", "faiss.IVFPQ"', str, ) - faiss_index_ivf_nlists = hp( + faiss_index_ivf_nlists: hp = hp( "faiss_index_ivf_nlists", (), '"auto" or an integer greater than 0', str ) - faiss_index_pq_m = hp("faiss_index_pq_m", (ge(1)), "An integer greater than 0", int) + faiss_index_pq_m: hp = hp("faiss_index_pq_m", (ge(1)), "An integer greater than 0", int) def __init__( self, - role, - instance_count=None, - instance_type=None, - k=None, - sample_size=None, - predictor_type=None, - dimension_reduction_type=None, - dimension_reduction_target=None, - index_type=None, - index_metric=None, - faiss_index_ivf_nlists=None, - faiss_index_pq_m=None, + role: str, + instance_count: Optional[Union[int, PipelineVariable]] = None, + instance_type: Optional[Union[str, PipelineVariable]] = None, + k: Optional[int] = None, + sample_size: Optional[int] = None, + predictor_type: Optional[str] = None, + dimension_reduction_type: Optional[str] = None, + dimension_reduction_target: Optional[int] = None, + index_type: Optional[str] = None, + index_metric: Optional[str] = None, + faiss_index_ivf_nlists: Optional[str] = None, + faiss_index_pq_m: Optional[int] = None, **kwargs ): """k-nearest neighbors (KNN) is :class:`Estimator` used for classification and regression. @@ -158,6 +158,7 @@ def __init__( self.index_metric = index_metric self.faiss_index_ivf_nlists = faiss_index_ivf_nlists self.faiss_index_pq_m = faiss_index_pq_m + if dimension_reduction_type and not dimension_reduction_target: raise ValueError( '"dimension_reduction_target" is required when "dimension_reduction_type" is set.' diff --git a/src/sagemaker/amazon/lda.py b/src/sagemaker/amazon/lda.py index 4158b6cc27..56b621dcb0 100644 --- a/src/sagemaker/amazon/lda.py +++ b/src/sagemaker/amazon/lda.py @@ -26,6 +26,7 @@ from sagemaker.utils import pop_out_unused_kwarg from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT from sagemaker.workflow.entities import PipelineVariable +from sagemaker.workflow import is_pipeline_variable class LDA(AmazonAlgorithmEstimatorBase): @@ -37,24 +38,24 @@ class LDA(AmazonAlgorithmEstimatorBase): word, and the categories are the topics. """ - repo_name = "lda" - repo_version = 1 + repo_name: str = "lda" + repo_version: int = 1 - num_topics = hp("num_topics", gt(0), "An integer greater than zero", int) - alpha0 = hp("alpha0", gt(0), "A positive float", float) - max_restarts = hp("max_restarts", gt(0), "An integer greater than zero", int) - max_iterations = hp("max_iterations", gt(0), "An integer greater than zero", int) - tol = hp("tol", gt(0), "A positive float", float) + num_topics: hp = hp("num_topics", gt(0), "An integer greater than zero", int) + alpha0: hp = hp("alpha0", gt(0), "A positive float", float) + max_restarts: hp = hp("max_restarts", gt(0), "An integer greater than zero", int) + max_iterations: hp = hp("max_iterations", gt(0), "An integer greater than zero", int) + tol: hp = hp("tol", gt(0), "A positive float", float) def __init__( self, - role, - instance_type=None, - num_topics=None, - alpha0=None, - max_restarts=None, - max_iterations=None, - tol=None, + role: str, + instance_type: Optional[Union[str, PipelineVariable]] = None, + num_topics: Optional[int] = None, + alpha0: Optional[float] = None, + max_restarts: Optional[int] = None, + max_iterations: Optional[int] = None, + tol: Optional[float] = None, **kwargs ): """Latent Dirichlet Allocation (LDA) is :class:`Estimator` used for unsupervised learning. @@ -124,7 +125,8 @@ def __init__( :class:`~sagemaker.estimator.EstimatorBase`. """ # this algorithm only supports single instance training - if kwargs.pop("instance_count", 1) != 1: + instance_count = kwargs.pop("instance_count", 1) + if is_pipeline_variable(instance_count) or instance_count != 1: print( "LDA only supports single instance training. Defaulting to 1 {}.".format( instance_type diff --git a/src/sagemaker/amazon/linear_learner.py b/src/sagemaker/amazon/linear_learner.py index d02ed2875f..91bae84b36 100644 --- a/src/sagemaker/amazon/linear_learner.py +++ b/src/sagemaker/amazon/linear_learner.py @@ -26,6 +26,7 @@ from sagemaker.utils import pop_out_unused_kwarg from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT from sagemaker.workflow.entities import PipelineVariable +from sagemaker.workflow import is_pipeline_variable class LinearLearner(AmazonAlgorithmEstimatorBase): @@ -39,12 +40,12 @@ class LinearLearner(AmazonAlgorithmEstimatorBase): of the label y. """ - repo_name = "linear-learner" - repo_version = 1 + repo_name: str = "linear-learner" + repo_version: int = 1 - DEFAULT_MINI_BATCH_SIZE = 1000 + DEFAULT_MINI_BATCH_SIZE: int = 1000 - binary_classifier_model_selection_criteria = hp( + binary_classifier_model_selection_criteria: hp = hp( "binary_classifier_model_selection_criteria", isin( "accuracy", @@ -57,32 +58,36 @@ class LinearLearner(AmazonAlgorithmEstimatorBase): ), data_type=str, ) - target_recall = hp("target_recall", (gt(0), lt(1)), "A float in (0,1)", float) - target_precision = hp("target_precision", (gt(0), lt(1)), "A float in (0,1)", float) - positive_example_weight_mult = hp( + target_recall: hp = hp("target_recall", (gt(0), lt(1)), "A float in (0,1)", float) + target_precision: hp = hp("target_precision", (gt(0), lt(1)), "A float in (0,1)", float) + positive_example_weight_mult: hp = hp( "positive_example_weight_mult", (), "A float greater than 0 or 'auto' or 'balanced'", str ) - epochs = hp("epochs", gt(0), "An integer greater-than 0", int) - predictor_type = hp( + epochs: hp = hp("epochs", gt(0), "An integer greater-than 0", int) + predictor_type: hp = hp( "predictor_type", isin("binary_classifier", "regressor", "multiclass_classifier"), 'One of "binary_classifier" or "multiclass_classifier" or "regressor"', str, ) - use_bias = hp("use_bias", (), "Either True or False", bool) - num_models = hp("num_models", gt(0), "An integer greater-than 0", int) - num_calibration_samples = hp("num_calibration_samples", gt(0), "An integer greater-than 0", int) - init_method = hp("init_method", isin("uniform", "normal"), 'One of "uniform" or "normal"', str) - init_scale = hp("init_scale", gt(0), "A float greater-than 0", float) - init_sigma = hp("init_sigma", gt(0), "A float greater-than 0", float) - init_bias = hp("init_bias", (), "A number", float) - optimizer = hp( + use_bias: hp = hp("use_bias", (), "Either True or False", bool) + num_models: hp = hp("num_models", gt(0), "An integer greater-than 0", int) + num_calibration_samples: hp = hp( + "num_calibration_samples", gt(0), "An integer greater-than 0", int + ) + init_method: hp = hp( + "init_method", isin("uniform", "normal"), 'One of "uniform" or "normal"', str + ) + init_scale: hp = hp("init_scale", gt(0), "A float greater-than 0", float) + init_sigma: hp = hp("init_sigma", gt(0), "A float greater-than 0", float) + init_bias: hp = hp("init_bias", (), "A number", float) + optimizer: hp = hp( "optimizer", isin("sgd", "adam", "rmsprop", "auto"), 'One of "sgd", "adam", "rmsprop" or "auto', str, ) - loss = hp( + loss: hp = hp( "loss", isin( "logistic", @@ -100,83 +105,89 @@ class LinearLearner(AmazonAlgorithmEstimatorBase): ' "eps_insensitive_absolute_loss", "quantile_loss", "huber_loss", "softmax_loss" or "auto"', str, ) - wd = hp("wd", ge(0), "A float greater-than or equal to 0", float) - l1 = hp("l1", ge(0), "A float greater-than or equal to 0", float) - momentum = hp("momentum", (ge(0), lt(1)), "A float in [0,1)", float) - learning_rate = hp("learning_rate", gt(0), "A float greater-than 0", float) - beta_1 = hp("beta_1", (ge(0), lt(1)), "A float in [0,1)", float) - beta_2 = hp("beta_2", (ge(0), lt(1)), "A float in [0,1)", float) - bias_lr_mult = hp("bias_lr_mult", gt(0), "A float greater-than 0", float) - bias_wd_mult = hp("bias_wd_mult", ge(0), "A float greater-than or equal to 0", float) - use_lr_scheduler = hp("use_lr_scheduler", (), "A boolean", bool) - lr_scheduler_step = hp("lr_scheduler_step", gt(0), "An integer greater-than 0", int) - lr_scheduler_factor = hp("lr_scheduler_factor", (gt(0), lt(1)), "A float in (0,1)", float) - lr_scheduler_minimum_lr = hp("lr_scheduler_minimum_lr", gt(0), "A float greater-than 0", float) - normalize_data = hp("normalize_data", (), "A boolean", bool) - normalize_label = hp("normalize_label", (), "A boolean", bool) - unbias_data = hp("unbias_data", (), "A boolean", bool) - unbias_label = hp("unbias_label", (), "A boolean", bool) - num_point_for_scaler = hp("num_point_for_scaler", gt(0), "An integer greater-than 0", int) - margin = hp("margin", ge(0), "A float greater-than or equal to 0", float) - quantile = hp("quantile", (gt(0), lt(1)), "A float in (0,1)", float) - loss_insensitivity = hp("loss_insensitivity", gt(0), "A float greater-than 0", float) - huber_delta = hp("huber_delta", ge(0), "A float greater-than or equal to 0", float) - early_stopping_patience = hp("early_stopping_patience", gt(0), "An integer greater-than 0", int) - early_stopping_tolerance = hp( + wd: hp = hp("wd", ge(0), "A float greater-than or equal to 0", float) + l1: hp = hp("l1", ge(0), "A float greater-than or equal to 0", float) + momentum: hp = hp("momentum", (ge(0), lt(1)), "A float in [0,1)", float) + learning_rate: hp = hp("learning_rate", gt(0), "A float greater-than 0", float) + beta_1: hp = hp("beta_1", (ge(0), lt(1)), "A float in [0,1)", float) + beta_2: hp = hp("beta_2", (ge(0), lt(1)), "A float in [0,1)", float) + bias_lr_mult: hp = hp("bias_lr_mult", gt(0), "A float greater-than 0", float) + bias_wd_mult: hp = hp("bias_wd_mult", ge(0), "A float greater-than or equal to 0", float) + use_lr_scheduler: hp = hp("use_lr_scheduler", (), "A boolean", bool) + lr_scheduler_step: hp = hp("lr_scheduler_step", gt(0), "An integer greater-than 0", int) + lr_scheduler_factor: hp = hp("lr_scheduler_factor", (gt(0), lt(1)), "A float in (0,1)", float) + lr_scheduler_minimum_lr: hp = hp( + "lr_scheduler_minimum_lr", gt(0), "A float greater-than 0", float + ) + normalize_data: hp = hp("normalize_data", (), "A boolean", bool) + normalize_label: hp = hp("normalize_label", (), "A boolean", bool) + unbias_data: hp = hp("unbias_data", (), "A boolean", bool) + unbias_label: hp = hp("unbias_label", (), "A boolean", bool) + num_point_for_scaler: hp = hp("num_point_for_scaler", gt(0), "An integer greater-than 0", int) + margin: hp = hp("margin", ge(0), "A float greater-than or equal to 0", float) + quantile: hp = hp("quantile", (gt(0), lt(1)), "A float in (0,1)", float) + loss_insensitivity: hp = hp("loss_insensitivity", gt(0), "A float greater-than 0", float) + huber_delta: hp = hp("huber_delta", ge(0), "A float greater-than or equal to 0", float) + early_stopping_patience: hp = hp( + "early_stopping_patience", gt(0), "An integer greater-than 0", int + ) + early_stopping_tolerance: hp = hp( "early_stopping_tolerance", gt(0), "A float greater-than 0", float ) - num_classes = hp("num_classes", (gt(0), le(1000000)), "An integer in [1,1000000]", int) - accuracy_top_k = hp("accuracy_top_k", (gt(0), le(1000000)), "An integer in [1,1000000]", int) - f_beta = hp("f_beta", gt(0), "A float greater-than 0", float) - balance_multiclass_weights = hp("balance_multiclass_weights", (), "A boolean", bool) + num_classes: hp = hp("num_classes", (gt(0), le(1000000)), "An integer in [1,1000000]", int) + accuracy_top_k: hp = hp( + "accuracy_top_k", (gt(0), le(1000000)), "An integer in [1,1000000]", int + ) + f_beta: hp = hp("f_beta", gt(0), "A float greater-than 0", float) + balance_multiclass_weights: hp = hp("balance_multiclass_weights", (), "A boolean", bool) def __init__( self, - role, - instance_count=None, - instance_type=None, - predictor_type=None, - binary_classifier_model_selection_criteria=None, - target_recall=None, - target_precision=None, - positive_example_weight_mult=None, - epochs=None, - use_bias=None, - num_models=None, - num_calibration_samples=None, - init_method=None, - init_scale=None, - init_sigma=None, - init_bias=None, - optimizer=None, - loss=None, - wd=None, - l1=None, - momentum=None, - learning_rate=None, - beta_1=None, - beta_2=None, - bias_lr_mult=None, - bias_wd_mult=None, - use_lr_scheduler=None, - lr_scheduler_step=None, - lr_scheduler_factor=None, - lr_scheduler_minimum_lr=None, - normalize_data=None, - normalize_label=None, - unbias_data=None, - unbias_label=None, - num_point_for_scaler=None, - margin=None, - quantile=None, - loss_insensitivity=None, - huber_delta=None, - early_stopping_patience=None, - early_stopping_tolerance=None, - num_classes=None, - accuracy_top_k=None, - f_beta=None, - balance_multiclass_weights=None, + role: str, + instance_count: Optional[Union[int, PipelineVariable]] = None, + instance_type: Optional[Union[str, PipelineVariable]] = None, + predictor_type: Optional[str] = None, + binary_classifier_model_selection_criteria: Optional[str] = None, + target_recall: Optional[float] = None, + target_precision: Optional[float] = None, + positive_example_weight_mult: Optional[str] = None, + epochs: Optional[int] = None, + use_bias: Optional[bool] = None, + num_models: Optional[int] = None, + num_calibration_samples: Optional[int] = None, + init_method: Optional[str] = None, + init_scale: Optional[float] = None, + init_sigma: Optional[float] = None, + init_bias: Optional[float] = None, + optimizer: Optional[str] = None, + loss: Optional[str] = None, + wd: Optional[float] = None, + l1: Optional[float] = None, + momentum: Optional[float] = None, + learning_rate: Optional[float] = None, + beta_1: Optional[float] = None, + beta_2: Optional[float] = None, + bias_lr_mult: Optional[float] = None, + bias_wd_mult: Optional[float] = None, + use_lr_scheduler: Optional[bool] = None, + lr_scheduler_step: Optional[int] = None, + lr_scheduler_factor: Optional[float] = None, + lr_scheduler_minimum_lr: Optional[float] = None, + normalize_data: Optional[bool] = None, + normalize_label: Optional[bool] = None, + unbias_data: Optional[bool] = None, + unbias_label: Optional[bool] = None, + num_point_for_scaler: Optional[int] = None, + margin: Optional[float] = None, + quantile: Optional[float] = None, + loss_insensitivity: Optional[float] = None, + huber_delta: Optional[float] = None, + early_stopping_patience: Optional[int] = None, + early_stopping_tolerance: Optional[float] = None, + num_classes: Optional[int] = None, + accuracy_top_k: Optional[int] = None, + f_beta: Optional[float] = None, + balance_multiclass_weights: Optional[bool] = None, **kwargs ): """An :class:`Estimator` for binary classification and regression. @@ -424,10 +435,16 @@ def _prepare_for_training(self, records, mini_batch_size=None, job_name=None): num_records = records.num_records # mini_batch_size can't be greater than number of records or training job fails - default_mini_batch_size = min( - self.DEFAULT_MINI_BATCH_SIZE, max(1, int(num_records / self.instance_count)) - ) - mini_batch_size = mini_batch_size or default_mini_batch_size + if not mini_batch_size: + if is_pipeline_variable(self.instance_count): + raise ValueError( + "instance_count can not be a pipeline variable when mini_batch_size is not given." + ) + else: + mini_batch_size = min( + self.DEFAULT_MINI_BATCH_SIZE, max(1, int(num_records / self.instance_count)) + ) + super(LinearLearner, self)._prepare_for_training( records, mini_batch_size=mini_batch_size, job_name=job_name ) diff --git a/src/sagemaker/amazon/ntm.py b/src/sagemaker/amazon/ntm.py index 83c2f97348..8dccb0d079 100644 --- a/src/sagemaker/amazon/ntm.py +++ b/src/sagemaker/amazon/ntm.py @@ -13,7 +13,7 @@ """Placeholder docstring""" from __future__ import absolute_import -from typing import Union, Optional +from typing import Optional, Union, List from sagemaker import image_uris from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase @@ -36,53 +36,59 @@ class NTM(AmazonAlgorithmEstimatorBase): "mileage", and "speed" are likely to share a topic on "transportation" for example. """ - repo_name = "ntm" - repo_version = 1 + repo_name: str = "ntm" + repo_version: int = 1 - num_topics = hp("num_topics", (ge(2), le(1000)), "An integer in [2, 1000]", int) - encoder_layers = hp( + num_topics: hp = hp("num_topics", (ge(2), le(1000)), "An integer in [2, 1000]", int) + encoder_layers: hp = hp( name="encoder_layers", validation_message="A comma separated list of " "positive integers", data_type=list, ) - epochs = hp("epochs", (ge(1), le(100)), "An integer in [1, 100]", int) - encoder_layers_activation = hp( + epochs: hp = hp("epochs", (ge(1), le(100)), "An integer in [1, 100]", int) + encoder_layers_activation: hp = hp( "encoder_layers_activation", isin("sigmoid", "tanh", "relu"), 'One of "sigmoid", "tanh" or "relu"', str, ) - optimizer = hp( + optimizer: hp = hp( "optimizer", isin("adagrad", "adam", "rmsprop", "sgd", "adadelta"), 'One of "adagrad", "adam", "rmsprop", "sgd" and "adadelta"', str, ) - tolerance = hp("tolerance", (ge(1e-6), le(0.1)), "A float in [1e-6, 0.1]", float) - num_patience_epochs = hp("num_patience_epochs", (ge(1), le(10)), "An integer in [1, 10]", int) - batch_norm = hp(name="batch_norm", validation_message="Value must be a boolean", data_type=bool) - rescale_gradient = hp("rescale_gradient", (ge(1e-3), le(1.0)), "A float in [1e-3, 1.0]", float) - clip_gradient = hp("clip_gradient", ge(1e-3), "A float greater equal to 1e-3", float) - weight_decay = hp("weight_decay", (ge(0.0), le(1.0)), "A float in [0.0, 1.0]", float) - learning_rate = hp("learning_rate", (ge(1e-6), le(1.0)), "A float in [1e-6, 1.0]", float) + tolerance: hp = hp("tolerance", (ge(1e-6), le(0.1)), "A float in [1e-6, 0.1]", float) + num_patience_epochs: hp = hp( + "num_patience_epochs", (ge(1), le(10)), "An integer in [1, 10]", int + ) + batch_norm: hp = hp( + name="batch_norm", validation_message="Value must be a boolean", data_type=bool + ) + rescale_gradient: hp = hp( + "rescale_gradient", (ge(1e-3), le(1.0)), "A float in [1e-3, 1.0]", float + ) + clip_gradient: hp = hp("clip_gradient", ge(1e-3), "A float greater equal to 1e-3", float) + weight_decay: hp = hp("weight_decay", (ge(0.0), le(1.0)), "A float in [0.0, 1.0]", float) + learning_rate: hp = hp("learning_rate", (ge(1e-6), le(1.0)), "A float in [1e-6, 1.0]", float) def __init__( self, - role, - instance_count=None, - instance_type=None, - num_topics=None, - encoder_layers=None, - epochs=None, - encoder_layers_activation=None, - optimizer=None, - tolerance=None, - num_patience_epochs=None, - batch_norm=None, - rescale_gradient=None, - clip_gradient=None, - weight_decay=None, - learning_rate=None, + role: str, + instance_count: Optional[Union[int, PipelineVariable]] = None, + instance_type: Optional[Union[str, PipelineVariable]] = None, + num_topics: Optional[int] = None, + encoder_layers: Optional[List] = None, + epochs: Optional[int] = None, + encoder_layers_activation: Optional[str] = None, + optimizer: Optional[str] = None, + tolerance: Optional[float] = None, + num_patience_epochs: Optional[int] = None, + batch_norm: Optional[bool] = None, + rescale_gradient: Optional[float] = None, + clip_gradient: Optional[float] = None, + weight_decay: Optional[float] = None, + learning_rate: Optional[float] = None, **kwargs ): """Neural Topic Model (NTM) is :class:`Estimator` used for unsupervised learning. diff --git a/src/sagemaker/amazon/object2vec.py b/src/sagemaker/amazon/object2vec.py index 1fbd846cbf..492f24f230 100644 --- a/src/sagemaker/amazon/object2vec.py +++ b/src/sagemaker/amazon/object2vec.py @@ -53,132 +53,142 @@ class Object2Vec(AmazonAlgorithmEstimatorBase): objects in the original space in the embedding space. """ - repo_name = "object2vec" - repo_version = 1 - MINI_BATCH_SIZE = 32 - - enc_dim = hp("enc_dim", (ge(4), le(10000)), "An integer in [4, 10000]", int) - mini_batch_size = hp("mini_batch_size", (ge(1), le(10000)), "An integer in [1, 10000]", int) - epochs = hp("epochs", (ge(1), le(100)), "An integer in [1, 100]", int) - early_stopping_patience = hp( + repo_name: str = "object2vec" + repo_version: int = 1 + MINI_BATCH_SIZE: int = 32 + + enc_dim: hp = hp("enc_dim", (ge(4), le(10000)), "An integer in [4, 10000]", int) + mini_batch_size: hp = hp("mini_batch_size", (ge(1), le(10000)), "An integer in [1, 10000]", int) + epochs: hp = hp("epochs", (ge(1), le(100)), "An integer in [1, 100]", int) + early_stopping_patience: hp = hp( "early_stopping_patience", (ge(1), le(5)), "An integer in [1, 5]", int ) - early_stopping_tolerance = hp( + early_stopping_tolerance: hp = hp( "early_stopping_tolerance", (ge(1e-06), le(0.1)), "A float in [1e-06, 0.1]", float ) - dropout = hp("dropout", (ge(0.0), le(1.0)), "A float in [0.0, 1.0]", float) - weight_decay = hp("weight_decay", (ge(0.0), le(10000.0)), "A float in [0.0, 10000.0]", float) - bucket_width = hp("bucket_width", (ge(0), le(100)), "An integer in [0, 100]", int) - num_classes = hp("num_classes", (ge(2), le(30)), "An integer in [2, 30]", int) - mlp_layers = hp("mlp_layers", (ge(1), le(10)), "An integer in [1, 10]", int) - mlp_dim = hp("mlp_dim", (ge(2), le(10000)), "An integer in [2, 10000]", int) - mlp_activation = hp( + dropout: hp = hp("dropout", (ge(0.0), le(1.0)), "A float in [0.0, 1.0]", float) + weight_decay: hp = hp( + "weight_decay", (ge(0.0), le(10000.0)), "A float in [0.0, 10000.0]", float + ) + bucket_width: hp = hp("bucket_width", (ge(0), le(100)), "An integer in [0, 100]", int) + num_classes: hp = hp("num_classes", (ge(2), le(30)), "An integer in [2, 30]", int) + mlp_layers: hp = hp("mlp_layers", (ge(1), le(10)), "An integer in [1, 10]", int) + mlp_dim: hp = hp("mlp_dim", (ge(2), le(10000)), "An integer in [2, 10000]", int) + mlp_activation: hp = hp( "mlp_activation", isin("tanh", "relu", "linear"), 'One of "tanh", "relu", "linear"', str ) - output_layer = hp( + output_layer: hp = hp( "output_layer", isin("softmax", "mean_squared_error"), 'One of "softmax", "mean_squared_error"', str, ) - optimizer = hp( + optimizer: hp = hp( "optimizer", isin("adagrad", "adam", "rmsprop", "sgd", "adadelta"), 'One of "adagrad", "adam", "rmsprop", "sgd", "adadelta"', str, ) - learning_rate = hp("learning_rate", (ge(1e-06), le(1.0)), "A float in [1e-06, 1.0]", float) + learning_rate: hp = hp("learning_rate", (ge(1e-06), le(1.0)), "A float in [1e-06, 1.0]", float) - negative_sampling_rate = hp( + negative_sampling_rate: hp = hp( "negative_sampling_rate", (ge(0), le(100)), "An integer in [0, 100]", int ) - comparator_list = hp( + comparator_list: hp = hp( "comparator_list", _list_check_subset(["hadamard", "concat", "abs_diff"]), 'Comma-separated of hadamard, concat, abs_diff. E.g. "hadamard,abs_diff"', str, ) - tied_token_embedding_weight = hp( + tied_token_embedding_weight: hp = hp( "tied_token_embedding_weight", (), "Either True or False", bool ) - token_embedding_storage_type = hp( + token_embedding_storage_type: hp = hp( "token_embedding_storage_type", isin("dense", "row_sparse"), 'One of "dense", "row_sparse"', str, ) - enc0_network = hp( + enc0_network: hp = hp( "enc0_network", isin("hcnn", "bilstm", "pooled_embedding"), 'One of "hcnn", "bilstm", "pooled_embedding"', str, ) - enc1_network = hp( + enc1_network: hp = hp( "enc1_network", isin("hcnn", "bilstm", "pooled_embedding", "enc0"), 'One of "hcnn", "bilstm", "pooled_embedding", "enc0"', str, ) - enc0_cnn_filter_width = hp("enc0_cnn_filter_width", (ge(1), le(9)), "An integer in [1, 9]", int) - enc1_cnn_filter_width = hp("enc1_cnn_filter_width", (ge(1), le(9)), "An integer in [1, 9]", int) - enc0_max_seq_len = hp("enc0_max_seq_len", (ge(1), le(5000)), "An integer in [1, 5000]", int) - enc1_max_seq_len = hp("enc1_max_seq_len", (ge(1), le(5000)), "An integer in [1, 5000]", int) - enc0_token_embedding_dim = hp( + enc0_cnn_filter_width: hp = hp( + "enc0_cnn_filter_width", (ge(1), le(9)), "An integer in [1, 9]", int + ) + enc1_cnn_filter_width: hp = hp( + "enc1_cnn_filter_width", (ge(1), le(9)), "An integer in [1, 9]", int + ) + enc0_max_seq_len: hp = hp("enc0_max_seq_len", (ge(1), le(5000)), "An integer in [1, 5000]", int) + enc1_max_seq_len: hp = hp("enc1_max_seq_len", (ge(1), le(5000)), "An integer in [1, 5000]", int) + enc0_token_embedding_dim: hp = hp( "enc0_token_embedding_dim", (ge(2), le(1000)), "An integer in [2, 1000]", int ) - enc1_token_embedding_dim = hp( + enc1_token_embedding_dim: hp = hp( "enc1_token_embedding_dim", (ge(2), le(1000)), "An integer in [2, 1000]", int ) - enc0_vocab_size = hp("enc0_vocab_size", (ge(2), le(3000000)), "An integer in [2, 3000000]", int) - enc1_vocab_size = hp("enc1_vocab_size", (ge(2), le(3000000)), "An integer in [2, 3000000]", int) - enc0_layers = hp("enc0_layers", (ge(1), le(4)), "An integer in [1, 4]", int) - enc1_layers = hp("enc1_layers", (ge(1), le(4)), "An integer in [1, 4]", int) - enc0_freeze_pretrained_embedding = hp( + enc0_vocab_size: hp = hp( + "enc0_vocab_size", (ge(2), le(3000000)), "An integer in [2, 3000000]", int + ) + enc1_vocab_size: hp = hp( + "enc1_vocab_size", (ge(2), le(3000000)), "An integer in [2, 3000000]", int + ) + enc0_layers: hp = hp("enc0_layers", (ge(1), le(4)), "An integer in [1, 4]", int) + enc1_layers: hp = hp("enc1_layers", (ge(1), le(4)), "An integer in [1, 4]", int) + enc0_freeze_pretrained_embedding: hp = hp( "enc0_freeze_pretrained_embedding", (), "Either True or False", bool ) - enc1_freeze_pretrained_embedding = hp( + enc1_freeze_pretrained_embedding: hp = hp( "enc1_freeze_pretrained_embedding", (), "Either True or False", bool ) def __init__( self, - role, - instance_count=None, - instance_type=None, - epochs=None, - enc0_max_seq_len=None, - enc0_vocab_size=None, - enc_dim=None, - mini_batch_size=None, - early_stopping_patience=None, - early_stopping_tolerance=None, - dropout=None, - weight_decay=None, - bucket_width=None, - num_classes=None, - mlp_layers=None, - mlp_dim=None, - mlp_activation=None, - output_layer=None, - optimizer=None, - learning_rate=None, - negative_sampling_rate=None, - comparator_list=None, - tied_token_embedding_weight=None, - token_embedding_storage_type=None, - enc0_network=None, - enc1_network=None, - enc0_cnn_filter_width=None, - enc1_cnn_filter_width=None, - enc1_max_seq_len=None, - enc0_token_embedding_dim=None, - enc1_token_embedding_dim=None, - enc1_vocab_size=None, - enc0_layers=None, - enc1_layers=None, - enc0_freeze_pretrained_embedding=None, - enc1_freeze_pretrained_embedding=None, + role: str, + instance_count: Optional[Union[int, PipelineVariable]] = None, + instance_type: Optional[Union[str, PipelineVariable]] = None, + epochs: Optional[int] = None, + enc0_max_seq_len: Optional[int] = None, + enc0_vocab_size: Optional[int] = None, + enc_dim: Optional[int] = None, + mini_batch_size: Optional[int] = None, + early_stopping_patience: Optional[int] = None, + early_stopping_tolerance: Optional[float] = None, + dropout: Optional[float] = None, + weight_decay: Optional[float] = None, + bucket_width: Optional[int] = None, + num_classes: Optional[int] = None, + mlp_layers: Optional[int] = None, + mlp_dim: Optional[int] = None, + mlp_activation: Optional[str] = None, + output_layer: Optional[str] = None, + optimizer: Optional[str] = None, + learning_rate: Optional[float] = None, + negative_sampling_rate: Optional[float] = None, + comparator_list: Optional[str] = None, + tied_token_embedding_weight: Optional[float] = None, + token_embedding_storage_type: Optional[str] = None, + enc0_network: Optional[str] = None, + enc1_network: Optional[str] = None, + enc0_cnn_filter_width: Optional[int] = None, + enc1_cnn_filter_width: Optional[int] = None, + enc1_max_seq_len: Optional[int] = None, + enc0_token_embedding_dim: Optional[int] = None, + enc1_token_embedding_dim: Optional[int] = None, + enc1_vocab_size: Optional[int] = None, + enc0_layers: Optional[int] = None, + enc1_layers: Optional[int] = None, + enc0_freeze_pretrained_embedding: Optional[bool] = None, + enc1_freeze_pretrained_embedding: Optional[bool] = None, **kwargs ): """Object2Vec is :class:`Estimator` used for anomaly detection. diff --git a/src/sagemaker/amazon/pca.py b/src/sagemaker/amazon/pca.py index e3127fd7a1..13440533d5 100644 --- a/src/sagemaker/amazon/pca.py +++ b/src/sagemaker/amazon/pca.py @@ -35,22 +35,24 @@ class PCA(AmazonAlgorithmEstimatorBase): retain as much information as possible. """ - repo_name = "pca" - repo_version = 1 + repo_name: str = "pca" + repo_version: int = 1 - DEFAULT_MINI_BATCH_SIZE = 500 + DEFAULT_MINI_BATCH_SIZE: int = 500 - num_components = hp("num_components", gt(0), "Value must be an integer greater than zero", int) - algorithm_mode = hp( + num_components: hp = hp( + "num_components", gt(0), "Value must be an integer greater than zero", int + ) + algorithm_mode: hp = hp( "algorithm_mode", isin("regular", "randomized"), 'Value must be one of "regular" and "randomized"', str, ) - subtract_mean = hp( + subtract_mean: hp = hp( name="subtract_mean", validation_message="Value must be a boolean", data_type=bool ) - extra_components = hp( + extra_components: hp = hp( name="extra_components", validation_message="Value must be an integer greater than or equal to 0, or -1.", data_type=int, @@ -58,13 +60,13 @@ class PCA(AmazonAlgorithmEstimatorBase): def __init__( self, - role, - instance_count=None, - instance_type=None, - num_components=None, - algorithm_mode=None, - subtract_mean=None, - extra_components=None, + role: str, + instance_count: Optional[int] = None, + instance_type: Optional[Union[str, PipelineVariable]] = None, + num_components: Optional[int] = None, + algorithm_mode: Optional[str] = None, + subtract_mean: Optional[bool] = None, + extra_components: Optional[int] = None, **kwargs ): """A Principal Components Analysis (PCA) diff --git a/src/sagemaker/amazon/randomcutforest.py b/src/sagemaker/amazon/randomcutforest.py index c38d75e3e4..16e4a09955 100644 --- a/src/sagemaker/amazon/randomcutforest.py +++ b/src/sagemaker/amazon/randomcutforest.py @@ -13,7 +13,7 @@ """Placeholder docstring""" from __future__ import absolute_import -from typing import Optional, Union +from typing import Optional, Union, List from sagemaker import image_uris from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase @@ -36,30 +36,30 @@ class RandomCutForest(AmazonAlgorithmEstimatorBase): or unclassifiable data points. """ - repo_name = "randomcutforest" - repo_version = 1 - MINI_BATCH_SIZE = 1000 + repo_name: str = "randomcutforest" + repo_version: int = 1 + MINI_BATCH_SIZE: int = 1000 - eval_metrics = hp( + eval_metrics: hp = hp( name="eval_metrics", validation_message='A comma separated list of "accuracy" or "precision_recall_fscore"', data_type=list, ) - num_trees = hp("num_trees", (ge(50), le(1000)), "An integer in [50, 1000]", int) - num_samples_per_tree = hp( + num_trees: hp = hp("num_trees", (ge(50), le(1000)), "An integer in [50, 1000]", int) + num_samples_per_tree: hp = hp( "num_samples_per_tree", (ge(1), le(2048)), "An integer in [1, 2048]", int ) - feature_dim = hp("feature_dim", (ge(1), le(10000)), "An integer in [1, 10000]", int) + feature_dim: hp = hp("feature_dim", (ge(1), le(10000)), "An integer in [1, 10000]", int) def __init__( self, - role, - instance_count=None, - instance_type=None, - num_samples_per_tree=None, - num_trees=None, - eval_metrics=None, + role: str, + instance_count: Optional[Union[int, PipelineVariable]] = None, + instance_type: Optional[Union[str, PipelineVariable]] = None, + num_samples_per_tree: Optional[int] = None, + num_trees: Optional[int] = None, + eval_metrics: Optional[List] = None, **kwargs ): """An `Estimator` class implementing a Random Cut Forest. diff --git a/src/sagemaker/chainer/estimator.py b/src/sagemaker/chainer/estimator.py index 12c22eae91..4719396f62 100644 --- a/src/sagemaker/chainer/estimator.py +++ b/src/sagemaker/chainer/estimator.py @@ -13,6 +13,8 @@ """Placeholder docstring""" from __future__ import absolute_import +from typing import Optional, Union, Dict + import logging from typing import Union, Optional @@ -34,26 +36,26 @@ class Chainer(Framework): """Handle end-to-end training and deployment of custom Chainer code.""" - _framework_name = "chainer" + _framework_name: str = "chainer" # Hyperparameters - _use_mpi = "sagemaker_use_mpi" - _num_processes = "sagemaker_num_processes" - _process_slots_per_host = "sagemaker_process_slots_per_host" - _additional_mpi_options = "sagemaker_additional_mpi_options" + _use_mpi: str = "sagemaker_use_mpi" + _num_processes: str = "sagemaker_num_processes" + _process_slots_per_host: str = "sagemaker_process_slots_per_host" + _additional_mpi_options: str = "sagemaker_additional_mpi_options" def __init__( self, entry_point: Union[str, PipelineVariable], - use_mpi=None, - num_processes=None, - process_slots_per_host=None, - additional_mpi_options=None, + use_mpi: Optional[Union[bool, PipelineVariable]] = None, + num_processes: Optional[Union[int, PipelineVariable]] = None, + process_slots_per_host: Optional[Union[int, PipelineVariable]] = None, + additional_mpi_options: Optional[Union[str, PipelineVariable]] = None, source_dir: Optional[Union[str, PipelineVariable]] = None, - hyperparameters=None, - framework_version=None, - py_version=None, - image_uri=None, + hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + framework_version: Optional[str] = None, + py_version: Optional[str] = None, + image_uri: Optional[Union[str, PipelineVariable]] = None, **kwargs ): """This ``Estimator`` executes an Chainer script in a managed execution environment. diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index 6590d30514..7ebf28e3af 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -17,6 +17,8 @@ """ from __future__ import absolute_import, print_function +from typing import Union, List, Optional, Dict, Any + import copy import json import logging @@ -26,25 +28,29 @@ import tempfile from abc import ABC, abstractmethod from sagemaker import image_uris, s3, utils +from sagemaker.session import Session +from sagemaker.network import NetworkConfig from sagemaker.processing import ProcessingInput, ProcessingOutput, Processor +from sagemaker.workflow.entities import PipelineVariable +from sagemaker.workflow import is_pipeline_variable logger = logging.getLogger(__name__) -class DataConfig: +class DataConfig: #TODO: add PipelineVariable to rest of fields """Config object related to configurations of the input and output dataset.""" def __init__( self, - s3_data_input_path, - s3_output_path, - s3_analysis_config_output_path=None, - label=None, - headers=None, - features=None, - dataset_type="text/csv", - s3_compression_type="None", - joinsource=None, + s3_data_input_path: Union[str, PipelineVariable], + s3_output_path: Union[str, PipelineVariable], + s3_analysis_config_output_path: Optional[str] = None, + label: Optional[str] = None, + headers: Optional[List[str]] = None, + features: Optional[List[str]] = None, + dataset_type: str = "text/csv", + s3_compression_type: str = "None", + joinsource: str = None, facet_dataset_uri=None, facet_headers=None, predicted_label_dataset_uri=None, @@ -186,10 +192,10 @@ class BiasConfig: def __init__( self, - label_values_or_threshold, - facet_name, - facet_values_or_threshold=None, - group_name=None, + label_values_or_threshold: Union[int, float, str], + facet_name: Union[str, int, List[str], List[int]], + facet_values_or_threshold: Optional[Union[int, float, str]] = None, + group_name: Optional[str] = None, ): """Initializes a configuration of the sensitive groups in the dataset. @@ -265,21 +271,21 @@ def get_config(self): return copy.deepcopy(self.analysis_config) -class ModelConfig: +class ModelConfig: # TODO add pipeline annotation """Config object related to a model and its endpoint to be created.""" def __init__( self, - model_name, - instance_count, - instance_type, - accept_type=None, - content_type=None, - content_template=None, - custom_attributes=None, - accelerator_type=None, - endpoint_name_prefix=None, - target_model=None, + model_name: str, + instance_count: int, + instance_type: str, + accept_type: Optional[str] = None, + content_type: Optional[str] = None, + content_template: Optional[str] = None, + custom_attributes: Optional[str] = None, + accelerator_type: Optional[str] = None, + endpoint_name_prefix: Optional[str] = None, + target_model: Optional[str] = None, ): r"""Initializes a configuration of a model and the endpoint to be created for it. @@ -378,10 +384,10 @@ class ModelPredictedLabelConfig: def __init__( self, - label=None, - probability=None, - probability_threshold=None, - label_headers=None, + label: Optional[Union[str, int]] = None, + probability: Optional[Union[str, int]] = None, + probability_threshold: Optional[float] = None, + label_headers: Optional[List[str]] = None, ): """Initializes a model output config to extract the predicted label or predicted score(s). @@ -473,7 +479,9 @@ class PDPConfig(ExplainabilityConfig): and the corresponding values are included in the analysis output. """ # noqa E501 - def __init__(self, features=None, grid_resolution=15, top_k_features=10): + def __init__( + self, features: Optional[List] = None, grid_resolution: int = 15, top_k_features: int = 10 + ): """Initializes PDP config. Args: @@ -641,8 +649,8 @@ class TextConfig: def __init__( self, - granularity, - language, + granularity: str, + language: str, ): """Initializes a text configuration. @@ -697,13 +705,13 @@ class ImageConfig: def __init__( self, - model_type, - num_segments=None, - feature_extraction_method=None, - segment_compactness=None, - max_objects=None, - iou_threshold=None, - context=None, + model_type: str, + num_segments: Optional[int] = None, + feature_extraction_method: Optional[str] = None, + segment_compactness: Optional[float] = None, + max_objects: Optional[int] = None, + iou_threshold: Optional[float] = None, + context: Optional[float] = None, ): """Initializes a config object for Computer Vision (CV) Image explainability. @@ -778,15 +786,15 @@ class SHAPConfig(ExplainabilityConfig): def __init__( self, - baseline=None, - num_samples=None, - agg_method=None, - use_logit=False, - save_local_shap_values=True, - seed=None, - num_clusters=None, - text_config=None, - image_config=None, + baseline: Optional[Union[str, List]] = None, + num_samples: Optional[int] = None, + agg_method: Optional[str] = None, + use_logit: Optional[bool] = None, + save_local_shap_values: Optional[bool] = None, + seed: Optional[int] = None, + num_clusters: Optional[int] = None, + text_config: Optional[TextConfig] = None, + image_config: Optional[ImageConfig] = None, ): """Initializes config for SHAP analysis. @@ -866,19 +874,19 @@ class SageMakerClarifyProcessor(Processor): def __init__( self, - role, - instance_count, - instance_type, - volume_size_in_gb=30, - volume_kms_key=None, - output_kms_key=None, - max_runtime_in_seconds=None, - sagemaker_session=None, - env=None, - tags=None, - network_config=None, - job_name_prefix=None, - version=None, + role: str, + instance_count: Union[int, PipelineVariable], + instance_type: Union[str, PipelineVariable], + volume_size_in_gb: Union[int, PipelineVariable] = 30, + volume_kms_key: Optional[Union[str, PipelineVariable]] = None, + output_kms_key: Optional[Union[str, PipelineVariable]] = None, + max_runtime_in_seconds: Optional[Union[int, PipelineVariable]] = None, + sagemaker_session: Optional[Session] = None, + env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, + network_config: Optional[NetworkConfig] = None, + job_name_prefix: Optional[str] = None, + version: Optional[str] = None, ): """Initializes a SageMakerClarifyProcessor to compute bias metrics and model explanations. @@ -949,13 +957,13 @@ def run(self, **_): def _run( self, - data_config, - analysis_config, - wait, - logs, - job_name, - kms_key, - experiment_config, + data_config: DataConfig, + analysis_config: Dict[str, Any], + wait: bool, + logs: bool, + job_name: str, + kms_key: str, + experiment_config: Dict[str, str], ): """Runs a :class:`~sagemaker.processing.ProcessingJob` with the SageMaker Clarify container @@ -991,6 +999,15 @@ def _run( analysis_config_file = os.path.join(tmpdirname, "analysis_config.json") with open(analysis_config_file, "w") as f: json.dump(analysis_config, f) + + if ( + is_pipeline_variable(data_config.s3_data_input_path) + and not data_config.s3_analysis_config_output_path + ): + raise ValueError( + "If s3_data_input_path for DataConfig is a pipeline variable, " + "s3_analysis_config_output_path must not be null" + ) s3_analysis_config_file = _upload_analysis_config( analysis_config_file, data_config.s3_analysis_config_output_path or data_config.s3_output_path, @@ -1033,14 +1050,14 @@ def _run( def run_pre_training_bias( self, - data_config, - data_bias_config, - methods="all", - wait=True, - logs=True, - job_name=None, - kms_key=None, - experiment_config=None, + data_config: DataConfig, + data_bias_config: BiasConfig, + methods: str = "all", + wait: bool = True, + logs: bool = True, + job_name: str = None, + kms_key: str = None, + experiment_config: Dict[str, str] = None, ): """Runs a :class:`~sagemaker.processing.ProcessingJob` to compute pre-training bias methods @@ -1103,16 +1120,16 @@ def run_pre_training_bias( def run_post_training_bias( self, - data_config, - data_bias_config, - model_config, - model_predicted_label_config, - methods="all", - wait=True, - logs=True, - job_name=None, - kms_key=None, - experiment_config=None, + data_config: DataConfig, + data_bias_config: BiasConfig, + model_config: ModelConfig, + model_predicted_label_config: ModelPredictedLabelConfig, + methods: str = "all", + wait: bool = True, + logs: bool = True, + job_name: str = None, + kms_key: str = None, + experiment_config: Dict[str, str] = None, ): """Runs a :class:`~sagemaker.processing.ProcessingJob` to compute posttraining bias @@ -1192,17 +1209,17 @@ def run_post_training_bias( def run_bias( self, - data_config, - bias_config, - model_config, - model_predicted_label_config=None, - pre_training_methods="all", - post_training_methods="all", - wait=True, - logs=True, - job_name=None, - kms_key=None, - experiment_config=None, + data_config: DataConfig, + bias_config: BiasConfig, + model_config: ModelConfig, + model_predicted_label_config: ModelPredictedLabelConfig = None, + pre_training_methods: str = "all", + post_training_methods: str = "all", + wait: bool = True, + logs: bool = True, + job_name: str = None, + kms_key: str = None, + experiment_config: Dict[str, str] = None, ): """Runs a :class:`~sagemaker.processing.ProcessingJob` to compute the requested bias methods @@ -1298,15 +1315,15 @@ def run_bias( def run_explainability( self, - data_config, - model_config, - explainability_config, - model_scores=None, - wait=True, - logs=True, - job_name=None, - kms_key=None, - experiment_config=None, + data_config: DataConfig, + model_config: ModelConfig, + explainability_config: Union[ExplainabilityConfig, List], + model_scores: Union[int, ModelPredictedLabelConfig] = None, + wait: bool = True, + logs: bool = True, + job_name: str = None, + kms_key: str = None, + experiment_config: Dict[str, str] = None, ): """Runs a :class:`~sagemaker.processing.ProcessingJob` computing feature attributions. diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index fe87186646..b5358baadf 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -56,6 +56,7 @@ get_jumpstart_base_name_if_jumpstart_model, update_inference_tags_with_jumpstart_training_tags, ) +from sagemaker.debugger import RuleBase from sagemaker.local import LocalSession from sagemaker.model import ( CONTAINER_LOG_LEVEL_PARAM_NAME, diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index 613bbd3742..66d1ed80f3 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -555,7 +555,7 @@ def validate_smdistributed( if "smdistributed" not in distribution: # Distribution strategy other than smdistributed is selected return - if is_pipeline_variable(instance_type): + if is_pipeline_variable(instance_type) or (image_uri and is_pipeline_variable(image_uri)): # The instance_type is not available in compile time. # Rather, it's given in Pipeline execution time return diff --git a/src/sagemaker/huggingface/estimator.py b/src/sagemaker/huggingface/estimator.py index 628c14dc8e..fdc6ef5868 100644 --- a/src/sagemaker/huggingface/estimator.py +++ b/src/sagemaker/huggingface/estimator.py @@ -13,6 +13,8 @@ """Placeholder docstring""" from __future__ import absolute_import +from typing import Optional, Union, Dict + import logging import re from typing import Optional, Union, Dict @@ -202,8 +204,14 @@ def __init__( f"Instead got {type(compiler_config)}" ) raise ValueError(error_string) - if compiler_config: - compiler_config.validate(self) + + compiler_config.validate( + image_uri=image_uri, + instance_type=instance_type, + distribution=distribution, + ) + + self.distribution = distribution or {} self.compiler_config = compiler_config def _validate_args(self, image_uri): diff --git a/src/sagemaker/huggingface/processing.py b/src/sagemaker/huggingface/processing.py index 3f3813b778..63810b0eb9 100644 --- a/src/sagemaker/huggingface/processing.py +++ b/src/sagemaker/huggingface/processing.py @@ -17,9 +17,15 @@ """ from __future__ import absolute_import +from typing import Union, Optional, List, Dict + +from sagemaker.session import Session +from sagemaker.network import NetworkConfig from sagemaker.processing import FrameworkProcessor from sagemaker.huggingface.estimator import HuggingFace +from sagemaker.workflow.entities import PipelineVariable + class HuggingFaceProcessor(FrameworkProcessor): """Handles Amazon SageMaker processing tasks for jobs using HuggingFace containers.""" @@ -28,25 +34,25 @@ class HuggingFaceProcessor(FrameworkProcessor): def __init__( self, - role, - instance_count, - instance_type, - transformers_version=None, - tensorflow_version=None, - pytorch_version=None, - py_version="py36", - image_uri=None, - command=None, - volume_size_in_gb=30, - volume_kms_key=None, - output_kms_key=None, - code_location=None, - max_runtime_in_seconds=None, - base_job_name=None, - sagemaker_session=None, - env=None, - tags=None, - network_config=None, + role: str, + instance_count: Union[int, PipelineVariable], + instance_type: Union[str, PipelineVariable], + transformers_version: Optional[str] = None, + tensorflow_version: Optional[str] = None, + pytorch_version: Optional[str] = None, + py_version: str = "py36", + image_uri: Optional[Union[str, PipelineVariable]] = None, + command: Optional[List[str]] = None, + volume_size_in_gb: Union[int, PipelineVariable] = 30, + volume_kms_key: Optional[Union[str, PipelineVariable]] = None, + output_kms_key: Optional[Union[str, PipelineVariable]] = None, + code_location: Optional[str] = None, + max_runtime_in_seconds: Optional[Union[int, PipelineVariable]] = None, + base_job_name: Optional[str] = None, + sagemaker_session: Optional[Session] = None, + env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, + network_config: Optional[NetworkConfig] = None, ): """This processor executes a Python script in a HuggingFace execution environment. diff --git a/src/sagemaker/mxnet/estimator.py b/src/sagemaker/mxnet/estimator.py index 3f0c054929..5784b676aa 100644 --- a/src/sagemaker/mxnet/estimator.py +++ b/src/sagemaker/mxnet/estimator.py @@ -13,6 +13,8 @@ """Placeholder docstring""" from __future__ import absolute_import +from typing import Optional, Dict, Union + import logging from typing import Union, Optional, Dict diff --git a/src/sagemaker/mxnet/processing.py b/src/sagemaker/mxnet/processing.py index 663b08be4b..71bce7cdff 100644 --- a/src/sagemaker/mxnet/processing.py +++ b/src/sagemaker/mxnet/processing.py @@ -17,8 +17,13 @@ """ from __future__ import absolute_import +from typing import Union, Optional, List, Dict + +from sagemaker.session import Session +from sagemaker.network import NetworkConfig from sagemaker.mxnet.estimator import MXNet from sagemaker.processing import FrameworkProcessor +from sagemaker.workflow.entities import PipelineVariable class MXNetProcessor(FrameworkProcessor): @@ -28,23 +33,23 @@ class MXNetProcessor(FrameworkProcessor): def __init__( self, - framework_version, # New arg - role, - instance_count, - instance_type, - py_version="py3", # New kwarg - image_uri=None, - command=None, - volume_size_in_gb=30, - volume_kms_key=None, - output_kms_key=None, - code_location=None, # New arg - max_runtime_in_seconds=None, - base_job_name=None, - sagemaker_session=None, - env=None, - tags=None, - network_config=None, + framework_version: str, # New arg + role: str, + instance_count: Union[int, PipelineVariable], + instance_type: Union[str, PipelineVariable], + py_version: str = "py3", # New kwarg + image_uri: Optional[Union[str, PipelineVariable]] = None, + command: Optional[List[str]] = None, + volume_size_in_gb: Union[int, PipelineVariable] = 30, + volume_kms_key: Optional[Union[str, PipelineVariable]] = None, + output_kms_key: Optional[Union[str, PipelineVariable]] = None, + code_location: Optional[str] = None, # New arg + max_runtime_in_seconds: Optional[Union[int, PipelineVariable]] = None, + base_job_name: Optional[str] = None, + sagemaker_session: Optional[Session] = None, + env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, + network_config: Optional[NetworkConfig] = None, ): """This processor executes a Python script in a managed MXNet execution environment. diff --git a/src/sagemaker/processing.py b/src/sagemaker/processing.py index 8da6e04768..1b57cd35d3 100644 --- a/src/sagemaker/processing.py +++ b/src/sagemaker/processing.py @@ -40,6 +40,8 @@ from sagemaker.dataset_definition.inputs import S3Input, DatasetDefinition from sagemaker.apiutils._base_types import ApiObject from sagemaker.s3 import S3Uploader +from sagemaker.workflow.entities import PipelineVariable + logger = logging.getLogger(__name__) @@ -773,10 +775,10 @@ def start_new(cls, processor, inputs, outputs, experiment_config): process_args = cls._get_process_args(processor, inputs, outputs, experiment_config) # Print the job name and the user's inputs and outputs as lists of dictionaries. - print() - print("Job Name: ", process_args["job_name"]) - print("Inputs: ", process_args["inputs"]) - print("Outputs: ", process_args["output_config"]["Outputs"]) + # print() + # print("Job Name: ", process_args["job_name"]) + # print("Inputs: ", process_args["inputs"]) + # print("Outputs: ", process_args["output_config"]["Outputs"]) # Call sagemaker_session.process using the arguments dictionary. processor.sagemaker_session.process(**process_args) diff --git a/src/sagemaker/pytorch/estimator.py b/src/sagemaker/pytorch/estimator.py index 622e79084c..42326baea9 100644 --- a/src/sagemaker/pytorch/estimator.py +++ b/src/sagemaker/pytorch/estimator.py @@ -13,6 +13,8 @@ """Placeholder docstring""" from __future__ import absolute_import +from typing import Optional, Union, Dict + import logging from typing import Union, Optional @@ -45,12 +47,12 @@ class PyTorch(Framework): def __init__( self, entry_point: Union[str, PipelineVariable], - framework_version=None, - py_version=None, + framework_version: Optional[str] = None, + py_version: Optional[str] = None, source_dir: Optional[Union[str, PipelineVariable]] = None, - hyperparameters=None, - image_uri=None, - distribution=None, + hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + image_uri: Optional[Union[str, PipelineVariable]] = None, + distribution: Dict = None, **kwargs ): """This ``Estimator`` executes a PyTorch script in a managed PyTorch execution environment. diff --git a/src/sagemaker/pytorch/processing.py b/src/sagemaker/pytorch/processing.py index a6581efac6..3717fa9ccd 100644 --- a/src/sagemaker/pytorch/processing.py +++ b/src/sagemaker/pytorch/processing.py @@ -17,8 +17,13 @@ """ from __future__ import absolute_import +from typing import Union, Optional, List, Dict + +from sagemaker.session import Session +from sagemaker.network import NetworkConfig from sagemaker.processing import FrameworkProcessor from sagemaker.pytorch.estimator import PyTorch +from sagemaker.workflow.entities import PipelineVariable class PyTorchProcessor(FrameworkProcessor): @@ -28,23 +33,23 @@ class PyTorchProcessor(FrameworkProcessor): def __init__( self, - framework_version, # New arg - role, - instance_count, - instance_type, - py_version="py3", # New kwarg - image_uri=None, - command=None, - volume_size_in_gb=30, - volume_kms_key=None, - output_kms_key=None, - code_location=None, # New arg - max_runtime_in_seconds=None, - base_job_name=None, - sagemaker_session=None, - env=None, - tags=None, - network_config=None, + framework_version: str, # New arg + role: str, + instance_count: Union[int, PipelineVariable], + instance_type: Union[str, PipelineVariable], + py_version: str = "py3", # New kwarg + image_uri: Optional[Union[str, PipelineVariable]] = None, + command: Optional[List[str]] = None, + volume_size_in_gb: Union[int, PipelineVariable] = 30, + volume_kms_key: Optional[Union[str, PipelineVariable]] = None, + output_kms_key: Optional[Union[str, PipelineVariable]] = None, + code_location: Optional[str] = None, + max_runtime_in_seconds: Optional[Union[int, PipelineVariable]] = None, + base_job_name: Optional[str] = None, + sagemaker_session: Optional[Session] = None, + env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, + network_config: Optional[NetworkConfig] = None, ): """This processor executes a Python script in a PyTorch execution environment. diff --git a/src/sagemaker/rl/estimator.py b/src/sagemaker/rl/estimator.py index b004dd87b8..650bba1c74 100644 --- a/src/sagemaker/rl/estimator.py +++ b/src/sagemaker/rl/estimator.py @@ -13,6 +13,8 @@ """Placeholder docstring""" from __future__ import absolute_import +from typing import Optional, Union, List, Dict + import enum import logging import re @@ -77,13 +79,13 @@ class RLEstimator(Framework): def __init__( self, entry_point: Union[str, PipelineVariable], - toolkit=None, - toolkit_version=None, - framework=None, + toolkit: Optional[RLToolkit] = None, + toolkit_version: Optional[str] = None, + framework: Optional[Framework] = None, source_dir: Optional[Union[str, PipelineVariable]] = None, - hyperparameters=None, - image_uri=None, - metric_definitions=None, + hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + image_uri: Optional[Union[str, PipelineVariable]] = None, + metric_definitions: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, **kwargs ): """Creates an RLEstimator for managed Reinforcement Learning (RL). diff --git a/src/sagemaker/sklearn/estimator.py b/src/sagemaker/sklearn/estimator.py index e13fbb764c..7680a122fe 100644 --- a/src/sagemaker/sklearn/estimator.py +++ b/src/sagemaker/sklearn/estimator.py @@ -13,6 +13,8 @@ """Placeholder docstring""" from __future__ import absolute_import +from typing import Optional, Union, Dict + import logging from typing import Union, Optional @@ -28,6 +30,7 @@ from sagemaker.sklearn.model import SKLearnModel from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT from sagemaker.workflow.entities import PipelineVariable +from sagemaker.workflow import is_pipeline_variable logger = logging.getLogger("sagemaker") @@ -40,12 +43,12 @@ class SKLearn(Framework): def __init__( self, entry_point: Union[str, PipelineVariable], - framework_version=None, - py_version="py3", + framework_version: Optional[str] = None, + py_version: str = "py3", source_dir: Optional[Union[str, PipelineVariable]] = None, - hyperparameters=None, - image_uri=None, - image_uri_region=None, + hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + image_uri: Optional[Union[str, PipelineVariable]] = None, + image_uri_region: Optional[str] = None, **kwargs ): """Creates a SKLearn Estimator for Scikit-learn environment. @@ -128,11 +131,17 @@ def __init__( self.framework_version = framework_version self.py_version = py_version + if instance_type: + if is_pipeline_variable(instance_type): + raise ValueError("instance_count argument cannot be a pipeline variable") # SciKit-Learn does not support distributed training or training on GPU instance types. # Fail fast. _validate_not_gpu_instance_type(instance_type) if instance_count: + if is_pipeline_variable(instance_count): + raise ValueError("instance_count argument cannot be a pipeline variable") + if instance_count != 1: raise AttributeError( "Scikit-Learn does not support distributed training. Please remove the " @@ -148,6 +157,12 @@ def __init__( ) if image_uri is None: + + if is_pipeline_variable(instance_type): + raise ValueError( + "instance_type argument cannot be a pipeline variable when image_uri is not given." + ) + self.image_uri = image_uris.retrieve( SKLearn._framework_name, image_uri_region or self.sagemaker_session.boto_region_name, diff --git a/src/sagemaker/sklearn/processing.py b/src/sagemaker/sklearn/processing.py index c5445e31f4..7a4a953a3d 100644 --- a/src/sagemaker/sklearn/processing.py +++ b/src/sagemaker/sklearn/processing.py @@ -17,9 +17,13 @@ """ from __future__ import absolute_import +from typing import Union, List, Dict, Optional + +from sagemaker.network import NetworkConfig from sagemaker import image_uris, Session from sagemaker.processing import ScriptProcessor from sagemaker.sklearn import defaults +from sagemaker.workflow.entities import PipelineVariable class SKLearnProcessor(ScriptProcessor): @@ -27,20 +31,20 @@ class SKLearnProcessor(ScriptProcessor): def __init__( self, - framework_version, - role, - instance_type, - instance_count, - command=None, - volume_size_in_gb=30, - volume_kms_key=None, - output_kms_key=None, - max_runtime_in_seconds=None, - base_job_name=None, - sagemaker_session=None, - env=None, - tags=None, - network_config=None, + framework_version: str, # New arg + role: str, + instance_count: Union[int, PipelineVariable], + instance_type: str, + command: Optional[List[str]] = None, + volume_size_in_gb: Union[int, PipelineVariable] = 30, + volume_kms_key: Optional[Union[str, PipelineVariable]] = None, + output_kms_key: Optional[Union[str, PipelineVariable]] = None, + max_runtime_in_seconds: Optional[Union[int, PipelineVariable]] = None, + base_job_name: Optional[str] = None, + sagemaker_session: Optional[Session] = None, + env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, + network_config: Optional[NetworkConfig] = None, ): """Initialize an ``SKLearnProcessor`` instance. diff --git a/src/sagemaker/tensorflow/processing.py b/src/sagemaker/tensorflow/processing.py index 38bf784b7c..aa489e6b8c 100644 --- a/src/sagemaker/tensorflow/processing.py +++ b/src/sagemaker/tensorflow/processing.py @@ -17,8 +17,13 @@ """ from __future__ import absolute_import +from typing import Union, List, Dict, Optional + +from sagemaker.session import Session +from sagemaker.network import NetworkConfig from sagemaker.processing import FrameworkProcessor from sagemaker.tensorflow.estimator import TensorFlow +from sagemaker.workflow.entities import PipelineVariable class TensorFlowProcessor(FrameworkProcessor): @@ -28,23 +33,23 @@ class TensorFlowProcessor(FrameworkProcessor): def __init__( self, - framework_version, # New arg - role, - instance_count, - instance_type, - py_version="py3", # New kwarg - image_uri=None, - command=None, - volume_size_in_gb=30, - volume_kms_key=None, - output_kms_key=None, - code_location=None, # New arg - max_runtime_in_seconds=None, - base_job_name=None, - sagemaker_session=None, - env=None, - tags=None, - network_config=None, + framework_version: str, # New arg + role: str, + instance_count: Union[int, PipelineVariable], + instance_type: Union[str, PipelineVariable], + py_version: str = "py3", # New kwarg + image_uri: Optional[Union[str, PipelineVariable]] = None, + command: Optional[List[str]] = None, + volume_size_in_gb: Union[int, PipelineVariable] = 30, + volume_kms_key: Optional[Union[str, PipelineVariable]] = None, + output_kms_key: Optional[Union[str, PipelineVariable]] = None, + code_location: Optional[str] = None, # New arg + max_runtime_in_seconds: Optional[Union[int, PipelineVariable]] = None, + base_job_name: Optional[str] = None, + sagemaker_session: Session = None, + env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, + network_config: Optional[NetworkConfig] = None, ): """This processor executes a Python script in a TensorFlow execution environment. diff --git a/src/sagemaker/training_compiler/config.py b/src/sagemaker/training_compiler/config.py index cd1fbd5957..5e0c312242 100644 --- a/src/sagemaker/training_compiler/config.py +++ b/src/sagemaker/training_compiler/config.py @@ -14,6 +14,8 @@ from __future__ import absolute_import import logging +from sagemaker.workflow import is_pipeline_variable + logger = logging.getLogger(__name__) @@ -132,6 +134,19 @@ def validate( ValueError: Raised if the requested configuration is not compatible with SageMaker Training Compiler. """ + if estimator.image_uri: + error_helper_string = ( + "Overriding the image URI is currently not supported " + "for SageMaker Training Compiler." + "Specify the following parameters to run the Hugging Face training job " + "with SageMaker Training Compiler enabled: " + "transformer_version, tensorflow_version or pytorch_version, and compiler_config." + ) + raise ValueError(error_helper_string) + + if is_pipeline_variable(estimator.instance_type): + # skip the validation if either instance type is a pipeline variable + return if "local" not in estimator.instance_type: requested_instance_class = estimator.instance_type.split(".")[ diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index d71b8e1433..fce0a73a02 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -36,7 +36,6 @@ from sagemaker.session_settings import SessionSettings from sagemaker.workflow import is_pipeline_variable, is_pipeline_parameter_string - ECR_URI_PATTERN = r"^(\d+)(\.)dkr(\.)ecr(\.)(.+)(\.)(.*)(/)(.*:.*)$" MAX_BUCKET_PATHS_COUNT = 5 S3_PREFIX = "s3://" @@ -94,7 +93,6 @@ def unique_name_from_base(base, max_length=63): def base_name_from_image(image, default_base_name=None): """Extract the base name of the image to use as the 'algorithm name' for the job. - Args: image (str): Image name. default_base_name (str): The default base name diff --git a/src/sagemaker/workflow/step_collections.py b/src/sagemaker/workflow/step_collections.py index 1f85d56442..eae09f3b42 100644 --- a/src/sagemaker/workflow/step_collections.py +++ b/src/sagemaker/workflow/step_collections.py @@ -14,6 +14,7 @@ from __future__ import absolute_import import warnings +from sagemaker.deprecations import deprecated_class from typing import List, Union, Optional import attr diff --git a/src/sagemaker/xgboost/estimator.py b/src/sagemaker/xgboost/estimator.py index f6f0005f1f..252c211562 100644 --- a/src/sagemaker/xgboost/estimator.py +++ b/src/sagemaker/xgboost/estimator.py @@ -13,6 +13,8 @@ """Placeholder docstring""" from __future__ import absolute_import +from typing import Optional, Union, Dict + import logging from typing import Union, Optional @@ -31,6 +33,8 @@ from sagemaker.xgboost.model import XGBoostModel from sagemaker.xgboost.utils import validate_py_version, validate_framework_version +from sagemaker.workflow.entities import PipelineVariable + logger = logging.getLogger("sagemaker") @@ -45,12 +49,12 @@ class XGBoost(Framework): def __init__( self, entry_point: Union[str, PipelineVariable], - framework_version, + framework_version: str, source_dir: Optional[Union[str, PipelineVariable]] = None, - hyperparameters=None, - py_version="py3", - image_uri=None, - image_uri_region=None, + hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + py_version: str = "py3", + image_uri: Optional[Union[str, PipelineVariable]] = None, + image_uri_region: Optional[str] = None, **kwargs ): """An estimator that executes an XGBoost-based SageMaker Training Job. diff --git a/src/sagemaker/xgboost/processing.py b/src/sagemaker/xgboost/processing.py index 3ca4b2361f..41b557b731 100644 --- a/src/sagemaker/xgboost/processing.py +++ b/src/sagemaker/xgboost/processing.py @@ -17,8 +17,13 @@ """ from __future__ import absolute_import +from typing import Union, List, Dict, Optional + +from sagemaker.session import Session +from sagemaker.network import NetworkConfig from sagemaker.processing import FrameworkProcessor from sagemaker.xgboost.estimator import XGBoost +from sagemaker.workflow.entities import PipelineVariable class XGBoostProcessor(FrameworkProcessor): @@ -28,23 +33,23 @@ class XGBoostProcessor(FrameworkProcessor): def __init__( self, - framework_version, # New arg - role, - instance_count, - instance_type, - py_version="py3", # New kwarg - image_uri=None, - command=None, - volume_size_in_gb=30, - volume_kms_key=None, - output_kms_key=None, - code_location=None, # New arg - max_runtime_in_seconds=None, - base_job_name=None, - sagemaker_session=None, - env=None, - tags=None, - network_config=None, + framework_version: str, # New arg + role: str, + instance_count: Union[int, PipelineVariable], + instance_type: Union[str, PipelineVariable], + py_version: str = "py3", # New kwarg + image_uri: Optional[Union[str, PipelineVariable]] = None, + command: Optional[List[str]] = None, + volume_size_in_gb: Union[int, PipelineVariable] = 30, + volume_kms_key: Optional[Union[str, PipelineVariable]] = None, + output_kms_key: Optional[Union[str, PipelineVariable]] = None, + code_location: Optional[str] = None, # New arg + max_runtime_in_seconds: Optional[Union[int, PipelineVariable]] = None, + base_job_name: Optional[str] = None, + sagemaker_session: Optional[Session] = None, + env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, + network_config: Optional[NetworkConfig] = None, ): """This processor executes a Python script in an XGBoost execution environment. diff --git a/tests/unit/sagemaker/workflow/test_processing_step.py b/tests/unit/sagemaker/workflow/test_processing_step.py index 262d0eb558..e320ac7ebb 100644 --- a/tests/unit/sagemaker/workflow/test_processing_step.py +++ b/tests/unit/sagemaker/workflow/test_processing_step.py @@ -356,7 +356,6 @@ def test_processing_step_with_script_processor(pipeline_session, processing_inpu processor = ScriptProcessor( role=ROLE, image_uri=IMAGE_URI, - command=["python3"], instance_type=INSTANCE_TYPE, instance_count=1, volume_size_in_gb=100, @@ -430,6 +429,9 @@ def test_processing_step_with_framework_processor( == processing_output.destination ) + del step_args["AppSpecification"]["ContainerEntrypoint"] + del step_def["Arguments"]["AppSpecification"]["ContainerEntrypoint"] + del step_args["ProcessingInputs"][0]["S3Input"]["S3Uri"] del step_def["Arguments"]["ProcessingInputs"][0]["S3Input"]["S3Uri"] diff --git a/tests/unit/sagemaker/workflow/test_training_step.py b/tests/unit/sagemaker/workflow/test_training_step.py index f043048095..4fe36dec00 100644 --- a/tests/unit/sagemaker/workflow/test_training_step.py +++ b/tests/unit/sagemaker/workflow/test_training_step.py @@ -56,6 +56,8 @@ from sagemaker.inputs import TrainingInput from tests.unit.sagemaker.workflow.helpers import CustomStep, ordered +from sagemaker.workflow.condition_step import ConditionStep +from sagemaker.workflow.conditions import ConditionGreaterThanOrEqualTo REGION = "us-west-2" BUCKET = "my-bucket" @@ -232,18 +234,28 @@ def test_training_step_with_estimator(pipeline_session, training_input, hyperpar assert "Running within a PipelineSession" in str(w[-1].message) with warnings.catch_warnings(record=True) as w: - step = TrainingStep( + step_train = TrainingStep( name="MyTrainingStep", step_args=step_args, description="TrainingStep description", display_name="MyTrainingStep", - depends_on=["TestStep", "SecondTestStep"], + depends_on=["TestStep"], ) assert len(w) == 0 + step_condition = ConditionStep( + name="MyConditionStep", + conditions=[ + ConditionGreaterThanOrEqualTo( + left=step_train.properties.FinalMetricDataList["val:acc"].Value, right=0.95 + ) + ], + if_steps=[custom_step2], + ) + pipeline = Pipeline( name="MyPipeline", - steps=[step, custom_step1, custom_step2], + steps=[step_train, step_condition, custom_step1, custom_step2], parameters=[enable_network_isolation, encrypt_container_traffic], sagemaker_session=pipeline_session, ) @@ -251,15 +263,18 @@ def test_training_step_with_estimator(pipeline_session, training_input, hyperpar "Get": "Parameters.encrypt_container_traffic" } step_args.args["EnableNetworkIsolation"] = {"Get": "Parameters.encrypt_container_traffic"} + assert step_condition.conditions[0].left.expr == { + "Get": "Steps.MyTrainingStep.FinalMetricDataList['val:acc'].Value" + } assert json.loads(pipeline.definition())["Steps"][0] == { "Name": "MyTrainingStep", "Description": "TrainingStep description", "DisplayName": "MyTrainingStep", "Type": "Training", - "DependsOn": ["TestStep", "SecondTestStep"], + "DependsOn": ["TestStep"], "Arguments": step_args.args, } - assert step.properties.TrainingJobName.expr == {"Get": "Steps.MyTrainingStep.TrainingJobName"} + assert step_train.properties.TrainingJobName.expr == {"Get": "Steps.MyTrainingStep.TrainingJobName"} adjacency_list = PipelineGraph.from_pipeline(pipeline).adjacency_list assert ordered(adjacency_list) == ordered( {"MyTrainingStep": [], "SecondTestStep": ["MyTrainingStep"], "TestStep": ["MyTrainingStep"]} @@ -337,6 +352,9 @@ def test_training_step_with_framework_estimator( estimator.sagemaker_session = pipeline_session step_args = estimator.fit(inputs=TrainingInput(s3_data=training_input)) + from sagemaker.workflow.retry import SageMakerJobStepRetryPolicy, SageMakerJobExceptionTypeEnum + from sagemaker.workflow.parameters import ParameterInteger + step = TrainingStep( name="MyTrainingStep", step_args=step_args, From 03cd5ad4e4023a13dd3ce78c060de7ae663b57f4 Mon Sep 17 00:00:00 2001 From: Dewen Qi Date: Wed, 22 Jun 2022 12:37:12 -0700 Subject: [PATCH 3/6] change: Implement test mechanism for Pipeline variables annotations for processors + estimators / test mechanism change annotations for processors + estimators / test mechanism change remove debug print reformatting update TM and resolve all AIs and all untested subclasses Add ppl var annotation to all composite object for training Adjust tm for recent model changes --- src/sagemaker/amazon/amazon_estimator.py | 2 - src/sagemaker/amazon/hyperparameter.py | 2 - src/sagemaker/amazon/kmeans.py | 4 +- src/sagemaker/amazon/linear_learner.py | 12 +- src/sagemaker/chainer/estimator.py | 1 - src/sagemaker/clarify.py | 4 +- src/sagemaker/estimator.py | 1 - src/sagemaker/huggingface/estimator.py | 11 +- src/sagemaker/inputs.py | 1 + src/sagemaker/mxnet/estimator.py | 1 - src/sagemaker/processing.py | 1 - src/sagemaker/pytorch/estimator.py | 1 - src/sagemaker/rl/estimator.py | 1 - src/sagemaker/sklearn/estimator.py | 7 - .../tensorflow/training_compiler/config.py | 7 +- src/sagemaker/utils.py | 1 + src/sagemaker/workflow/step_collections.py | 1 - src/sagemaker/xgboost/estimator.py | 2 - src/sagemaker/xgboost/processing.py | 6 +- .../test_mechanism/test_code/__init__.py | 755 ++++++++++++++++++ .../test_code/parameter_skip_checker.py | 382 +++++++++ ...est_pipeline_var_compatibility_template.py | 586 ++++++++++++++ .../test_mechanism/test_code/utilities.py | 267 +++++++ ...eline_var_compatibility_with_estimators.py | 440 ++++++++++ ..._pipeline_var_compatibility_with_models.py | 550 +++++++++++++ ...eline_var_compatibility_with_processors.py | 363 +++++++++ ...line_var_compatibility_with_transformer.py | 41 + ...t_pipeline_var_compatibility_with_tuner.py | 47 ++ .../sagemaker/workflow/test_training_step.py | 7 +- 29 files changed, 3459 insertions(+), 45 deletions(-) create mode 100644 tests/unit/sagemaker/workflow/test_mechanism/test_code/__init__.py create mode 100644 tests/unit/sagemaker/workflow/test_mechanism/test_code/parameter_skip_checker.py create mode 100644 tests/unit/sagemaker/workflow/test_mechanism/test_code/test_pipeline_var_compatibility_template.py create mode 100644 tests/unit/sagemaker/workflow/test_mechanism/test_code/utilities.py create mode 100644 tests/unit/sagemaker/workflow/test_mechanism/test_entries/test_pipeline_var_compatibility_with_estimators.py create mode 100644 tests/unit/sagemaker/workflow/test_mechanism/test_entries/test_pipeline_var_compatibility_with_models.py create mode 100644 tests/unit/sagemaker/workflow/test_mechanism/test_entries/test_pipeline_var_compatibility_with_processors.py create mode 100644 tests/unit/sagemaker/workflow/test_mechanism/test_entries/test_pipeline_var_compatibility_with_transformer.py create mode 100644 tests/unit/sagemaker/workflow/test_mechanism/test_entries/test_pipeline_var_compatibility_with_tuner.py diff --git a/src/sagemaker/amazon/amazon_estimator.py b/src/sagemaker/amazon/amazon_estimator.py index 71c29a74c1..dad5d54dcd 100644 --- a/src/sagemaker/amazon/amazon_estimator.py +++ b/src/sagemaker/amazon/amazon_estimator.py @@ -18,7 +18,6 @@ import json import logging import tempfile -from typing import Union from six.moves.urllib.parse import urlparse @@ -32,7 +31,6 @@ from sagemaker.utils import sagemaker_timestamp from sagemaker.workflow.entities import PipelineVariable from sagemaker.workflow.pipeline_context import runnable_by_pipeline -from sagemaker.workflow.entities import PipelineVariable from sagemaker.workflow.parameters import ParameterBoolean from sagemaker.workflow import is_pipeline_variable diff --git a/src/sagemaker/amazon/hyperparameter.py b/src/sagemaker/amazon/hyperparameter.py index 36ba8019da..973668ed56 100644 --- a/src/sagemaker/amazon/hyperparameter.py +++ b/src/sagemaker/amazon/hyperparameter.py @@ -16,8 +16,6 @@ import json from sagemaker.workflow import is_pipeline_variable -from sagemaker.workflow import is_pipeline_variable - class Hyperparameter(object): """An algorithm hyperparameter with optional validation. diff --git a/src/sagemaker/amazon/kmeans.py b/src/sagemaker/amazon/kmeans.py index 18c976b99d..1b925af6e4 100644 --- a/src/sagemaker/amazon/kmeans.py +++ b/src/sagemaker/amazon/kmeans.py @@ -27,8 +27,6 @@ from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT from sagemaker.workflow.entities import PipelineVariable -from sagemaker.workflow.entities import PipelineVariable - class KMeans(AmazonAlgorithmEstimatorBase): """An unsupervised learning algorithm that attempts to find discrete groupings within data. @@ -76,7 +74,7 @@ def __init__( half_life_time_size: Optional[int] = None, epochs: Optional[int] = None, center_factor: Optional[int] = None, - eval_metrics: Optional[List[Union[str, PipelineVariable]]] = None, + eval_metrics: Optional[List[str]] = None, **kwargs ): """A k-means clustering class :class:`~sagemaker.amazon.AmazonAlgorithmEstimatorBase`. diff --git a/src/sagemaker/amazon/linear_learner.py b/src/sagemaker/amazon/linear_learner.py index 91bae84b36..50d2c33d03 100644 --- a/src/sagemaker/amazon/linear_learner.py +++ b/src/sagemaker/amazon/linear_learner.py @@ -13,6 +13,7 @@ """Placeholder docstring""" from __future__ import absolute_import +import logging from typing import Union, Optional from sagemaker import image_uris @@ -28,6 +29,8 @@ from sagemaker.workflow.entities import PipelineVariable from sagemaker.workflow import is_pipeline_variable +logger = logging.getLogger(__name__) + class LinearLearner(AmazonAlgorithmEstimatorBase): """A supervised learning algorithms used for solving classification or regression problems. @@ -437,9 +440,14 @@ def _prepare_for_training(self, records, mini_batch_size=None, job_name=None): # mini_batch_size can't be greater than number of records or training job fails if not mini_batch_size: if is_pipeline_variable(self.instance_count): - raise ValueError( - "instance_count can not be a pipeline variable when mini_batch_size is not given." + logger.warning( + "mini_batch_size is not given in .fit() and instance_count is a " + "pipeline variable (%s) which is only parsed in execution time. " + "Thus setting mini_batch_size to 1, as it can't be greater than " + "number of records per instance_count, otherwise the training job fails.", + type(self.instance_count), ) + mini_batch_size = 1 else: mini_batch_size = min( self.DEFAULT_MINI_BATCH_SIZE, max(1, int(num_records / self.instance_count)) diff --git a/src/sagemaker/chainer/estimator.py b/src/sagemaker/chainer/estimator.py index 4719396f62..112c67a4d8 100644 --- a/src/sagemaker/chainer/estimator.py +++ b/src/sagemaker/chainer/estimator.py @@ -16,7 +16,6 @@ from typing import Optional, Union, Dict import logging -from typing import Union, Optional from sagemaker.estimator import Framework, EstimatorBase from sagemaker.fw_utils import ( diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index 7ebf28e3af..be16ccfb9e 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -37,7 +37,7 @@ logger = logging.getLogger(__name__) -class DataConfig: #TODO: add PipelineVariable to rest of fields +class DataConfig: # TODO: add PipelineVariable to rest of fields """Config object related to configurations of the input and output dataset.""" def __init__( @@ -271,7 +271,7 @@ def get_config(self): return copy.deepcopy(self.analysis_config) -class ModelConfig: # TODO add pipeline annotation +class ModelConfig: # TODO add pipeline annotation """Config object related to a model and its endpoint to be created.""" def __init__( diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index b5358baadf..fe87186646 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -56,7 +56,6 @@ get_jumpstart_base_name_if_jumpstart_model, update_inference_tags_with_jumpstart_training_tags, ) -from sagemaker.debugger import RuleBase from sagemaker.local import LocalSession from sagemaker.model import ( CONTAINER_LOG_LEVEL_PARAM_NAME, diff --git a/src/sagemaker/huggingface/estimator.py b/src/sagemaker/huggingface/estimator.py index fdc6ef5868..4d8b409eb4 100644 --- a/src/sagemaker/huggingface/estimator.py +++ b/src/sagemaker/huggingface/estimator.py @@ -17,7 +17,6 @@ import logging import re -from typing import Optional, Union, Dict from sagemaker.deprecations import renamed_kwargs from sagemaker.estimator import Framework, EstimatorBase @@ -204,14 +203,8 @@ def __init__( f"Instead got {type(compiler_config)}" ) raise ValueError(error_string) - - compiler_config.validate( - image_uri=image_uri, - instance_type=instance_type, - distribution=distribution, - ) - - self.distribution = distribution or {} + if compiler_config: + compiler_config.validate(self) self.compiler_config = compiler_config def _validate_args(self, image_uri): diff --git a/src/sagemaker/inputs.py b/src/sagemaker/inputs.py index 0fca307a97..5bfcb0d672 100644 --- a/src/sagemaker/inputs.py +++ b/src/sagemaker/inputs.py @@ -14,6 +14,7 @@ from __future__ import absolute_import, print_function from typing import Union, Optional, List + import attr from sagemaker.workflow.entities import PipelineVariable diff --git a/src/sagemaker/mxnet/estimator.py b/src/sagemaker/mxnet/estimator.py index 5784b676aa..48974a3413 100644 --- a/src/sagemaker/mxnet/estimator.py +++ b/src/sagemaker/mxnet/estimator.py @@ -16,7 +16,6 @@ from typing import Optional, Dict, Union import logging -from typing import Union, Optional, Dict from packaging.version import Version diff --git a/src/sagemaker/processing.py b/src/sagemaker/processing.py index 1b57cd35d3..11272ccb63 100644 --- a/src/sagemaker/processing.py +++ b/src/sagemaker/processing.py @@ -40,7 +40,6 @@ from sagemaker.dataset_definition.inputs import S3Input, DatasetDefinition from sagemaker.apiutils._base_types import ApiObject from sagemaker.s3 import S3Uploader -from sagemaker.workflow.entities import PipelineVariable logger = logging.getLogger(__name__) diff --git a/src/sagemaker/pytorch/estimator.py b/src/sagemaker/pytorch/estimator.py index 42326baea9..55a4fff89d 100644 --- a/src/sagemaker/pytorch/estimator.py +++ b/src/sagemaker/pytorch/estimator.py @@ -16,7 +16,6 @@ from typing import Optional, Union, Dict import logging -from typing import Union, Optional from packaging.version import Version diff --git a/src/sagemaker/rl/estimator.py b/src/sagemaker/rl/estimator.py index 650bba1c74..b95f192ea8 100644 --- a/src/sagemaker/rl/estimator.py +++ b/src/sagemaker/rl/estimator.py @@ -18,7 +18,6 @@ import enum import logging import re -from typing import Union, Optional from sagemaker import image_uris, fw_utils from sagemaker.estimator import Framework, EstimatorBase diff --git a/src/sagemaker/sklearn/estimator.py b/src/sagemaker/sklearn/estimator.py index 7680a122fe..72372f602c 100644 --- a/src/sagemaker/sklearn/estimator.py +++ b/src/sagemaker/sklearn/estimator.py @@ -16,7 +16,6 @@ from typing import Optional, Union, Dict import logging -from typing import Union, Optional from sagemaker import image_uris from sagemaker.deprecations import renamed_kwargs @@ -157,12 +156,6 @@ def __init__( ) if image_uri is None: - - if is_pipeline_variable(instance_type): - raise ValueError( - "instance_type argument cannot be a pipeline variable when image_uri is not given." - ) - self.image_uri = image_uris.retrieve( SKLearn._framework_name, image_uri_region or self.sagemaker_session.boto_region_name, diff --git a/src/sagemaker/tensorflow/training_compiler/config.py b/src/sagemaker/tensorflow/training_compiler/config.py index d14cc3359b..35516477a1 100644 --- a/src/sagemaker/tensorflow/training_compiler/config.py +++ b/src/sagemaker/tensorflow/training_compiler/config.py @@ -13,10 +13,13 @@ """Configuration for the SageMaker Training Compiler.""" from __future__ import absolute_import import logging +from typing import Union + from packaging.specifiers import SpecifierSet from packaging.version import Version from sagemaker.training_compiler.config import TrainingCompilerConfig as BaseConfig +from sagemaker.workflow.entities import PipelineVariable logger = logging.getLogger(__name__) @@ -29,8 +32,8 @@ class TrainingCompilerConfig(BaseConfig): def __init__( self, - enabled=True, - debug=False, + enabled: Union[bool, PipelineVariable] = True, + debug: Union[bool, PipelineVariable] = False, ): """This class initializes a ``TrainingCompilerConfig`` instance. diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index fce0a73a02..a7f07963fc 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -93,6 +93,7 @@ def unique_name_from_base(base, max_length=63): def base_name_from_image(image, default_base_name=None): """Extract the base name of the image to use as the 'algorithm name' for the job. + Args: image (str): Image name. default_base_name (str): The default base name diff --git a/src/sagemaker/workflow/step_collections.py b/src/sagemaker/workflow/step_collections.py index eae09f3b42..1f85d56442 100644 --- a/src/sagemaker/workflow/step_collections.py +++ b/src/sagemaker/workflow/step_collections.py @@ -14,7 +14,6 @@ from __future__ import absolute_import import warnings -from sagemaker.deprecations import deprecated_class from typing import List, Union, Optional import attr diff --git a/src/sagemaker/xgboost/estimator.py b/src/sagemaker/xgboost/estimator.py index 252c211562..498d009dd0 100644 --- a/src/sagemaker/xgboost/estimator.py +++ b/src/sagemaker/xgboost/estimator.py @@ -16,7 +16,6 @@ from typing import Optional, Union, Dict import logging -from typing import Union, Optional from sagemaker import image_uris from sagemaker.deprecations import renamed_kwargs @@ -28,7 +27,6 @@ ) from sagemaker.session import Session from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT -from sagemaker.workflow.entities import PipelineVariable from sagemaker.xgboost import defaults from sagemaker.xgboost.model import XGBoostModel from sagemaker.xgboost.utils import validate_py_version, validate_framework_version diff --git a/src/sagemaker/xgboost/processing.py b/src/sagemaker/xgboost/processing.py index 41b557b731..f3e9c08f02 100644 --- a/src/sagemaker/xgboost/processing.py +++ b/src/sagemaker/xgboost/processing.py @@ -33,17 +33,17 @@ class XGBoostProcessor(FrameworkProcessor): def __init__( self, - framework_version: str, # New arg + framework_version: str, role: str, instance_count: Union[int, PipelineVariable], instance_type: Union[str, PipelineVariable], - py_version: str = "py3", # New kwarg + py_version: str = "py3", image_uri: Optional[Union[str, PipelineVariable]] = None, command: Optional[List[str]] = None, volume_size_in_gb: Union[int, PipelineVariable] = 30, volume_kms_key: Optional[Union[str, PipelineVariable]] = None, output_kms_key: Optional[Union[str, PipelineVariable]] = None, - code_location: Optional[str] = None, # New arg + code_location: Optional[str] = None, max_runtime_in_seconds: Optional[Union[int, PipelineVariable]] = None, base_job_name: Optional[str] = None, sagemaker_session: Optional[Session] = None, diff --git a/tests/unit/sagemaker/workflow/test_mechanism/test_code/__init__.py b/tests/unit/sagemaker/workflow/test_mechanism/test_code/__init__.py new file mode 100644 index 0000000000..4a9b2a37b0 --- /dev/null +++ b/tests/unit/sagemaker/workflow/test_mechanism/test_code/__init__.py @@ -0,0 +1,755 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import json +import os +from random import getrandbits +from typing import List + +from sagemaker import ModelMetrics, MetricsSource, FileSource, Predictor +from sagemaker.drift_check_baselines import DriftCheckBaselines +from sagemaker.metadata_properties import MetadataProperties +from sagemaker.model import FrameworkModel +from sagemaker.parameter import IntegerParameter +from sagemaker.serverless import ServerlessInferenceConfig +from sagemaker.sparkml import SparkMLModel +from sagemaker.tensorflow import TensorFlow +from sagemaker.tuner import WarmStartConfig, WarmStartTypes +from mock import Mock, PropertyMock +from sagemaker import TrainingInput +from sagemaker.debugger import ( + Rule, + DebuggerHookConfig, + TensorBoardOutputConfig, + ProfilerConfig, + CollectionConfig, +) +from sagemaker.clarify import DataConfig +from sagemaker.network import NetworkConfig +from sagemaker.processing import ProcessingInput, ProcessingOutput +from sagemaker.amazon.amazon_estimator import RecordSet +from sagemaker.rl.estimator import RLToolkit, RLFramework +from sagemaker.pytorch import PyTorch +from sagemaker.workflow.entities import PipelineVariable +from sagemaker.workflow.execution_variables import ExecutionVariables +from sagemaker.workflow.functions import Join, JsonGet +from sagemaker.workflow.model_step import ModelStep +from sagemaker.workflow.parameters import ( + ParameterString, + ParameterInteger, + ParameterFloat, + ParameterBoolean, +) +from sagemaker.workflow.pipeline_context import PipelineSession +from sagemaker.workflow.properties import PropertyFile +from sagemaker.workflow.steps import ProcessingStep, TrainingStep, TuningStep, TransformStep +from tests.unit import DATA_DIR + +STR_VAL = "MyString" +ROLE = "DummyRole" +INSTANCE_TYPE = "ml.m5.xlarge" +BUCKET = "my-bucket" +REGION = "us-west-2" +IMAGE_URI = "fakeimage" +DUMMY_S3_SCRIPT_PATH = "s3://dummy-s3/dummy_script.py" +TENSORFLOW_PATH = os.path.join(DATA_DIR, "tfs/tfs-test-entrypoint-and-dependencies") +TENSORFLOW_ENTRY_POINT = os.path.join(TENSORFLOW_PATH, "inference.py") +GIT_REPO = "https://github.com/aws/sagemaker-python-sdk.git" +BRANCH = "test-branch-git-config" +COMMIT = "ae15c9d7d5b97ea95ea451e4662ee43da3401d73" + +DEFAULT_VALUE = "default_value" +CLAZZ_ARGS = "clazz_args" +FUNC_ARGS = "func_args" +REQUIRED = "required" +OPTIONAL = "optional" +COMMON = "common" +INIT = "init" +TYPE = "type" + + +class PipelineVariableEncoder(json.JSONEncoder): + """The Json Encoder for PipelineVariable""" + + def default(self, obj): + """To return a serializable object for the input object if it's a PipelineVariable + + or call the base implementation if it's not + + Args: + obj (object): The input object to be handled. + """ + if isinstance(obj, PipelineVariable): + return obj.expr + return json.JSONEncoder.default(self, obj) + + +class MockProperties(PipelineVariable): + """A mock object or Pipeline Properties""" + + def __init__( + self, + step_name: str, + path: str = None, + shape_name: str = None, + shape_names: List[str] = None, + service_name: str = "sagemaker", + ): + """Initialize a MockProperties object""" + self.step_name = step_name + self.path = path + + @property + def expr(self): + """The 'Get' expression dict for a `Properties`.""" + return {"Get": f"Steps.{self.step_name}.Outcome"} + + @property + def _referenced_steps(self) -> List[str]: + """List of step names that this function depends on.""" + return [self.step_name] + + +def _generate_mock_pipeline_session(): + """Generate mock pipeline session""" + client_mock = Mock() + client_mock._client_config.user_agent = ( + "Boto3/1.14.24 Python/3.8.5 Linux/5.4.0-42-generic Botocore/1.17.24 Resource" + ) + client_mock.describe_algorithm.return_value = { + "TrainingSpecification": { + "TrainingChannels": [ + { + "SupportedContentTypes": ["application/json"], + "SupportedInputModes": ["File"], + "Name": "train", + } + ], + "SupportedTrainingInstanceTypes": ["ml.m5.xlarge", "ml.m4.xlarge"], + "SupportedHyperParameters": [ + { + "Name": "MyKey", + "Type": "FreeText", + } + ], + }, + "AlgorithmName": "algo-name", + } + + role_mock = Mock() + type(role_mock).arn = PropertyMock(return_value=ROLE) + resource_mock = Mock() + resource_mock.Role.return_value = role_mock + session_mock = Mock(region_name=REGION) + session_mock.resource.return_value = resource_mock + session_mock.client.return_value = client_mock + + return PipelineSession( + boto_session=session_mock, + sagemaker_client=client_mock, + default_bucket=BUCKET, + ) + + +def _generate_all_pipeline_vars() -> dict: + """Generate a dic with all kinds of Pipeline variables""" + # Parameter + ppl_param_str = ParameterString(name="MyString") + ppl_param_int = ParameterInteger(name="MyInt") + ppl_param_float = ParameterFloat(name="MyFloat") + ppl_param_bool = ParameterBoolean(name="MyBool") + + # Function + ppl_join = Join(on=" ", values=[ppl_param_int, ppl_param_float, 1, "test"]) + property_file = PropertyFile( + name="name", + output_name="result", + path="output", + ) + ppl_json_get = JsonGet( + step_name="my-step", + property_file=property_file, + json_path="my-json-path", + ) + + # Properties + ppl_prop = MockProperties(step_name="MyPreStep") + + # Execution Variables + ppl_exe_var = ExecutionVariables.PIPELINE_NAME + + return dict( + str=[ + ( + ppl_param_str, + dict(origin=ppl_param_str.expr, to_string=ppl_param_str.to_string().expr), + ), + (ppl_join, dict(origin=ppl_join.expr, to_string=ppl_join.to_string().expr)), + (ppl_json_get, dict(origin=ppl_json_get.expr, to_string=ppl_json_get.to_string().expr)), + (ppl_prop, dict(origin=ppl_prop.expr, to_string=ppl_prop.to_string().expr)), + (ppl_exe_var, dict(origin=ppl_exe_var.expr, to_string=ppl_exe_var.to_string().expr)), + ], + int=[ + ( + ppl_param_int, + dict(origin=ppl_param_int.expr, to_string=ppl_param_int.to_string().expr), + ), + (ppl_json_get, dict(origin=ppl_json_get.expr, to_string=ppl_json_get.to_string().expr)), + (ppl_prop, dict(origin=ppl_prop.expr, to_string=ppl_prop.to_string().expr)), + ], + float=[ + ( + ppl_param_float, + dict(origin=ppl_param_float.expr, to_string=ppl_param_float.to_string().expr), + ), + (ppl_json_get, dict(origin=ppl_json_get.expr, to_string=ppl_json_get.to_string().expr)), + ( + ppl_prop, + dict(origin=ppl_prop.expr, to_string=ppl_prop.to_string().expr), + ), + ], + bool=[ + ( + ppl_param_bool, + dict(origin=ppl_param_bool.expr, to_string=ppl_param_bool.to_string().expr), + ), + (ppl_json_get, dict(origin=ppl_json_get.expr, to_string=ppl_json_get.to_string().expr)), + ( + ppl_prop, + dict(origin=ppl_prop.expr, to_string=ppl_prop.to_string().expr), + ), + ], + ) + + +IS_TRUE = bool(getrandbits(1)) +PIPELINE_SESSION = _generate_mock_pipeline_session() +PIPELINE_VARIABLES = _generate_all_pipeline_vars() + +# TODO: need to recursively assign with Pipeline Variable in later changes +FIXED_ARGUMENTS = dict( + common=dict( + instance_type=INSTANCE_TYPE, + role=ROLE, + sagemaker_session=PIPELINE_SESSION, + source_dir=f"s3://{BUCKET}/source", + entry_point=TENSORFLOW_ENTRY_POINT, + dependencies=[os.path.join(TENSORFLOW_PATH, "dependency.py")], + code_location=f"s3://{BUCKET}/code", + predictor_cls=Predictor, + model_metrics=ModelMetrics( + model_statistics=MetricsSource( + content_type=ParameterString(name="model_statistics_content_type"), + s3_uri=ParameterString(name="model_statistics_s3_uri"), + content_digest=ParameterString(name="model_statistics_content_digest"), + ) + ), + metadata_properties=MetadataProperties( + commit_id=ParameterString(name="meta_properties_commit_id"), + repository=ParameterString(name="meta_properties_repository"), + generated_by=ParameterString(name="meta_properties_generated_by"), + project_id=ParameterString(name="meta_properties_project_id"), + ), + drift_check_baselines=DriftCheckBaselines( + model_constraints=MetricsSource( + content_type=ParameterString(name="drift_constraints_content_type"), + s3_uri=ParameterString(name="drift_constraints_s3_uri"), + content_digest=ParameterString(name="drift_constraints_content_digest"), + ), + bias_config_file=FileSource( + content_type=ParameterString(name="drift_bias_content_type"), + s3_uri=ParameterString(name="drift_bias_s3_uri"), + content_digest=ParameterString(name="drift_bias_content_digest"), + ), + ), + model_package_name="my-model-pkg" if IS_TRUE else None, + model_package_group_name="my-model-pkg-group" if not IS_TRUE else None, + inference_instances=["ml.t2.medium", "ml.m5.xlarge"], + transform_instances=["ml.t2.medium", "ml.m5.xlarge"], + content_types=["application/json"], + response_types=["application/json"], + ), + processor=dict( + estimator_cls=PyTorch, + code=f"s3://{BUCKET}/code", + spark_event_logs_s3_uri=f"s3://{BUCKET}/my-spark-output-path", + framework_version="1.8", + network_config=NetworkConfig( + subnets=[ParameterString(name="nw_cfg_subnets")], + security_group_ids=[ParameterString(name="nw_cfg_security_group_ids")], + enable_network_isolation=ParameterBoolean(name="nw_cfg_enable_nw_isolation"), + encrypt_inter_container_traffic=ParameterBoolean( + name="nw_cfg_encrypt_inter_container_traffic" + ), + ), + inputs=[ + ProcessingInput( + source=ParameterString(name="proc_input_source"), + destination=ParameterString(name="proc_input_dest"), + s3_data_type=ParameterString(name="proc_input_s3_data_type"), + app_managed=ParameterBoolean(name="proc_input_app_managed"), + ), + ], + outputs=[ + ProcessingOutput( + source=ParameterString(name="proc_output_source"), + destination=ParameterString(name="proc_output_dest"), + app_managed=ParameterBoolean(name="proc_output_app_managed"), + ), + ], + data_config=DataConfig( + s3_data_input_path=ParameterString(name="clarify_processor_input"), + s3_output_path=ParameterString(name="clarify_processor_output"), + s3_analysis_config_output_path="s3://analysis_config_output_path", + ), + data_bias_config=DataConfig( + s3_data_input_path=ParameterString(name="clarify_processor_input"), + s3_output_path=ParameterString(name="clarify_processor_output"), + s3_analysis_config_output_path="s3://analysis_config_output_path", + ), + ), + estimator=dict( + image_uri_region="us-west-2", + input_mode="File", + records=RecordSet( + s3_data=ParameterString(name="records_s3_data"), + num_records=1000, + feature_dim=128, + s3_data_type=ParameterString(name="records_s3_data_type"), + channel=ParameterString(name="records_channel"), + ), + disable_profiler=False, + vector_dim=128, + enc_dim=128, + momentum=1e-6, + beta_1=1e-4, + beta_2=1e-4, + mini_batch_size=1000, + dropout=0.25, + num_classes=10, + mlp_dim=512, + mlp_activation="relu", + output_layer="softmax", + comparator_list="hadamard,concat,abs_diff", + token_embedding_storage_type="dense", + enc0_network="bilstm", + enc1_network="bilstm", + enc0_token_embedding_dim=256, + enc1_token_embedding_dim=256, + enc0_vocab_size=512, + enc1_vocab_size=512, + bias_init_method="normal", + factors_init_method="normal", + predictor_type="regressor", + linear_init_method="uniform", + toolkit=RLToolkit.RAY, + toolkit_version="1.6.0", + framework=RLFramework.PYTORCH, + algorithm_mode="regular", + num_topics=6, + k=6, + init_method="kmeans++", + local_init_method="kmeans++", + eval_metrics="mds,ssd", + tol=1e-4, + dimension_reduction_type="sign", + index_type="faiss.Flat", + faiss_index_ivf_nlists="auto", + index_metric="COSINE", + binary_classifier_model_selection_criteria="f1", + positive_example_weight_mult="auto", + loss="logistic", + target_recall=0.1, + target_precision=0.8, + early_stopping_tolerance=1e-4, + encoder_layers_activation="relu", + optimizer="adam", + tolerance=1e-4, + rescale_gradient=1e-2, + weight_decay=1e-6, + learning_rate=1e-4, + num_trees=50, + source_dir=f"s3://{BUCKET}/source", + entry_point=os.path.join(TENSORFLOW_PATH, "inference.py"), + dependencies=[os.path.join(TENSORFLOW_PATH, "dependency.py")], + code_location=f"s3://{BUCKET}/code", + output_path=f"s3://{BUCKET}/output", + model_uri=f"s3://{BUCKET}/model", + py_version="py2", + framework_version="2.1.1", + rules=[ + Rule.custom( + name="CustomeRule", + image_uri=ParameterString(name="rules_image_uri"), + instance_type=ParameterString(name="rules_instance_type"), + volume_size_in_gb=ParameterInteger(name="rules_volume_size"), + source="path/to/my_custom_rule.py", + rule_to_invoke=ParameterString(name="rules_to_invoke"), + container_local_output_path=ParameterString(name="rules_local_output"), + s3_output_path=ParameterString(name="rules_to_s3_output_path"), + other_trials_s3_input_paths=[ParameterString(name="rules_other_s3_input")], + rule_parameters={"threshold": ParameterString(name="rules_param")}, + collections_to_save=[ + CollectionConfig( + name=ParameterString(name="rules_collections_name"), + parameters={"key1": ParameterString(name="rules_collections_param")}, + ) + ], + ) + ], + debugger_hook_config=DebuggerHookConfig( + s3_output_path=ParameterString(name="debugger_hook_s3_output"), + container_local_output_path=ParameterString(name="debugger_container_output"), + hook_parameters={"key1": ParameterString(name="debugger_hook_param")}, + collection_configs=[ + CollectionConfig( + name=ParameterString(name="debugger_collections_name"), + parameters={"key1": ParameterString(name="debugger_collections_param")}, + ) + ], + ), + tensorboard_output_config=TensorBoardOutputConfig( + s3_output_path=ParameterString(name="tensorboard_s3_output"), + container_local_output_path=ParameterString(name="tensorboard_container_output"), + ), + profiler_config=ProfilerConfig( + s3_output_path=ParameterString(name="profile_config_s3_output_path"), + system_monitor_interval_millis=ParameterInteger(name="profile_config_system_monitor"), + ), + inputs={ + "train": TrainingInput( + s3_data=ParameterString(name="train_inputs_s3_data"), + distribution=ParameterString(name="train_inputs_distribution"), + compression=ParameterString(name="train_inputs_compression"), + content_type=ParameterString(name="train_inputs_content_type"), + record_wrapping=ParameterString(name="train_inputs_record_wrapping"), + s3_data_type=ParameterString(name="train_inputs_s3_data_type"), + input_mode=ParameterString(name="train_inputs_input_mode"), + attribute_names=[ParameterString(name="train_inputs_attribute_name")], + target_attribute_name=ParameterString(name="train_inputs_target_attr_name"), + ), + }, + ), + transformer=dict( + data=f"s3://{BUCKET}/data", + ), + tuner=dict( + estimator=TensorFlow( + entry_point=TENSORFLOW_ENTRY_POINT, + role=ROLE, + framework_version="2.1.1", + py_version="py2", + instance_count=1, + instance_type="ml.m5.xlarge", + sagemaker_session=PIPELINE_SESSION, + enable_sagemaker_metrics=True, + max_retry_attempts=3, + hyperparameters={"static-hp": "hp1", "train_size": "1280"}, + ), + hyperparameter_ranges={ + "batch-size": IntegerParameter( + min_value=ParameterInteger(name="hyper_range_min_value"), + max_value=ParameterInteger(name="hyper_range_max_value"), + scaling_type=ParameterString(name="hyper_range_scaling_type"), + ), + }, + warm_start_config=WarmStartConfig( + warm_start_type=WarmStartTypes.IDENTICAL_DATA_AND_ALGORITHM, + parents={ParameterString(name="warm_start_cfg_parent")}, + ), + estimator_name="estimator-1", + inputs={ + "estimator-1": TrainingInput(s3_data=ParameterString(name="inputs_estimator_1")), + }, + include_cls_metadata={"estimator-1": IS_TRUE}, + ), + model=dict( + serverless_inference_config=ServerlessInferenceConfig(), + framework_version="1.11.0", + py_version="py3", + accelerator_type="ml.eia2.xlarge", + ), + pipelinemodel=dict( + models=[ + SparkMLModel( + name="MySparkMLModel", + model_data=f"s3://{BUCKET}", + role=ROLE, + sagemaker_session=PIPELINE_SESSION, + env={"SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT": "text/csv"}, + ), + FrameworkModel( + image_uri=IMAGE_URI, + model_data=f"s3://{BUCKET}/model.tar.gz", + role=ROLE, + sagemaker_session=PIPELINE_SESSION, + entry_point=f"{DATA_DIR}/dummy_script.py", + name="modelName", + vpc_config={"Subnets": ["abc", "def"], "SecurityGroupIds": ["123", "456"]}, + ), + ], + ), +) +STEP_CLASS = dict( + processor=ProcessingStep, + estimator=TrainingStep, + transformer=TransformStep, + tuner=TuningStep, + model=ModelStep, + pipelinemodel=ModelStep, +) + +# A dict for class __init__ parameters which are not used in +# request dict generated by a specific target function but are used in another +# For example, as for Model class constructor, the vpc_config is used only in +# model.create and is ignored in model.register +# Thus when testing on model.register we should skip replacing vpc_config +# with pipeline variables +CLASS_PARAMS_EXCLUDED_IN_TARGET_FUNC = dict( + model=dict( + register=dict( + common={"vpc_config", "enable_network_isolation"}, + ), + ), + pipelinemodel=dict( + register=dict( + common={"vpc_config", "enable_network_isolation"}, + ), + ), +) +# A dict for base class __init__ parameters which are not used in +# some specific subclasses. +# For example, TensorFlowModel uses **kwarg for duplicate parameters +# in base class (FrameworkModel/Model) but it ignores the "image_config" +# in target functions. +BASE_CLASS_PARAMS_EXCLUDED_IN_SUB_CLASS = dict( + model=dict( + TensorFlowModel={"image_config"}, + SKLearnModel={"image_config"}, + PyTorchModel={"image_config"}, + XGBoostModel={"image_config"}, + ChainerModel={"image_config"}, + HuggingFaceModel={"image_config"}, + MXNetModel={"image_config"}, + KNNModel={"image_uri"}, # it's overridden in __init__ with a fixed value + SparkMLModel={"image_uri"}, # it's overridden in __init__ with a fixed value + KMeansModel={"image_uri"}, # it's overridden in __init__ with a fixed value + PCAModel={"image_uri"}, # it's overridden in __init__ with a fixed value + LDAModel={"image_uri"}, # it's overridden in __init__ with a fixed value + NTMModel={"image_uri"}, # it's overridden in __init__ with a fixed value + Object2VecModel={"image_uri"}, # it's overridden in __init__ with a fixed value + FactorizationMachinesModel={"image_uri"}, # it's overridden in __init__ with a fixed value + IPInsightsModel={"image_uri"}, # it's overridden in __init__ with a fixed value + RandomCutForestModel={"image_uri"}, # it's overridden in __init__ with a fixed value + LinearLearnerModel={"image_uri"}, # it's overridden in __init__ with a fixed value + MultiDataModel={ + "model_data", # model_data is overridden in __init__ with model_data_prefix + "image_config", + "container_log_level", # it's simply ignored + }, + ), + estimator=dict( + AlgorithmEstimator={ # its kwargs is ignored so base class parameters are ignored + "enable_network_isolation", + "container_log_level", + "checkpoint_s3_uri", + "checkpoint_local_path", + "enable_sagemaker_metrics", + "environment", + "max_retry_attempts", + "source_dir", + "entry_point", + }, + SKLearn={ + "instance_count", + "instance_type", + }, + ), +) +# A dict to keep the optional arguments which should not be None according to the logic +# specific to the subclass. +PARAMS_SHOULD_NOT_BE_NONE = dict( + estimator=dict( + init=dict( + common={"instance_count", "instance_type"}, + LDA={"mini_batch_size"}, + ) + ), + model=dict( + register=dict( + common={}, + Model={"model_data"}, + HuggingFaceModel={"model_data"}, + ), + create=dict( + common={}, + Model={"role"}, + SparkMLModel={"role"}, + MultiDataModel={"role"}, + ), + ), +) +# A dict for parameters which should not be replaced with pipeline variables +# since they are bonded with other parameters with None value. For example: +# Case 1: if outputs (a parameter in FrameworkProcessor.run) is None, +# output_kms_key (a parameter in constructor) is omitted +# so don't need to replace it with pipeline variables +# Case 2: if image_uri is None, instance_type is not allowed to be pipeline variables, +# otherwise, the the class can fail to be initiated +UNSET_PARAM_BONDED_WITH_NONE = dict( + processor=dict( + init=dict( + common=dict(instance_type={"image_uri"}), + ), + run=dict( + common=dict(output_kms_key={"outputs"}), + ), + ), + estimator=dict( + init=dict( + common=dict( + # entry_point can only be parameterized when source_dir is given + # if source_dir is None, entry_point should be skipped to parameterize + entry_point={"source_dir"}, + instance_type={"image_uri"}, + ), + ), + fit=dict( + common=dict( + subnets={"security_group_ids"}, + security_group_ids={"subnets"}, + model_channel_name={"model_uri"}, + checkpoint_local_path={"checkpoint_s3_uri"}, + instance_type={"image_uri"}, + ), + ), + ), + model=dict( + register=dict( + common=dict( + env={"model_package_group_name"}, + image_config={"model_package_group_name"}, + model_server_workers={"model_package_group_name"}, + container_log_level={"model_package_group_name"}, + framework={"model_package_group_name"}, + framework_version={"model_package_group_name"}, + nearest_model_name={"model_package_group_name"}, + data_input_configuration={"model_package_group_name"}, + ) + ), + ), + pipelinemodel=dict( + register=dict( + common=dict( + image_uri={ + # model_package_name and model_package_group_name are mutual exclusive. + # If model_package_group_name is not None, image_uri will be ignored + "model_package_name" + }, + framework={"model_package_group_name"}, + framework_version={"model_package_group_name"}, + nearest_model_name={"model_package_group_name"}, + data_input_configuration={"model_package_group_name"}, + ), + ), + ), +) + +# A dict for parameters which should not be replaced with pipeline variables +# since they are bonded with other parameters with not None value. For example: +# 1. for any model subclass, if model_package_name is not None, model_package_group_name should be None +# and should skip to be replaced with a pipeline variable +# 2. for MultiDataModel, if if model is given, its kwargs including container_log_level will be ignored +# Note: for any mutual exclusive parameters (e.g. model_package_name, model_package_group_name), +# we can add an entry for each of them. +UNSET_PARAM_BONDED_WITH_NOT_NONE = dict( + model=dict( + register=dict( + common=dict( + model_package_name={"model_package_group_name"}, + model_package_group_name={"model_package_name"}, + ), + ), + ), + pipelinemodel=dict( + register=dict( + common=dict( + model_package_name={"model_package_group_name"}, + model_package_group_name={"model_package_name"}, + ), + ), + ), + estimator=dict( + init=dict( + common=dict(), + TensorFlow=dict( + image_uri={"compiler_config"}, + compiler_config={"image_uri"}, + ), + HuggingFace=dict( + image_uri={"compiler_config"}, + compiler_config={"image_uri"}, + ), + ) + ), +) + + +# A dict for parameters that should not be set to None since they are bonded with +# other parameters with None value. For example: +# if image_uri is None in TensorFlow, py_version should not be None +# since it's used as substitute argument to retrieve image_uri. +SET_PARAM_BONDED_WITH_NONE = dict( + estimator=dict( + init=dict( + common=dict(), + TensorFlow=dict( + py_version={"image_uri"}, + framework_version={"image_uri"}, + ), + HuggingFace=dict( + transformers_version={"image_uri"}, + tensorflow_version={"pytorch_version"}, + ), + ) + ), + model=dict( + register=dict( + common=dict( + inference_instances={"model_package_group_name"}, + transform_instances={"model_package_group_name"}, + ), + ) + ), + pipelinemodel=dict( + register=dict( + common=dict( + inference_instances={"model_package_group_name"}, + transform_instances={"model_package_group_name"}, + ), + ) + ), +) + +# A dict for parameters that should not be set to None since they are bonded with +# other parameters with not None value. Thus we can skip it. For example: +# dimension_reduction_target should not be none when dimension_reduction_type is set +SET_PARAM_BONDED_WITH_NOT_NONE = dict( + estimator=dict( + init=dict( + common=dict(), + KNN=dict(dimension_reduction_target={"dimension_reduction_type"}), + ), + ), +) diff --git a/tests/unit/sagemaker/workflow/test_mechanism/test_code/parameter_skip_checker.py b/tests/unit/sagemaker/workflow/test_mechanism/test_code/parameter_skip_checker.py new file mode 100644 index 0000000000..c26761b384 --- /dev/null +++ b/tests/unit/sagemaker/workflow/test_mechanism/test_code/parameter_skip_checker.py @@ -0,0 +1,382 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from typing import TYPE_CHECKING + +from tests.unit.sagemaker.workflow.test_mechanism.test_code import ( + UNSET_PARAM_BONDED_WITH_NONE, + UNSET_PARAM_BONDED_WITH_NOT_NONE, + SET_PARAM_BONDED_WITH_NOT_NONE, + SET_PARAM_BONDED_WITH_NONE, + BASE_CLASS_PARAMS_EXCLUDED_IN_SUB_CLASS, + CLASS_PARAMS_EXCLUDED_IN_TARGET_FUNC, + PARAMS_SHOULD_NOT_BE_NONE, + CLAZZ_ARGS, + FUNC_ARGS, + INIT, + COMMON, +) + +if TYPE_CHECKING: + from tests.unit.sagemaker.workflow.test_mechanism.test_code.test_pipeline_var_compatibility_template import ( + PipelineVarCompatiTestTemplate, + ) + + +class ParameterSkipChecker: + """Check if the current parameter can skipped for test""" + + def __init__(self, template: "PipelineVarCompatiTestTemplate"): + """Initialize a `ParameterSkipChecker` instance. + + Args: + template (PipelineVarCompatiTestTemplate): The template object to check + the compatibility between Pipeline variables and the given class and target method. + """ + self._template = template + + def skip_setting_param_to_none( + self, + param_name: str, + target_func: str, + ) -> bool: + """Check if to skip setting the parameter to none. + + Args: + param_name (str): The name of the parameter, which is to be verified that + if we can skip it in the loop. + target_func (str): The target function impacted by the check. + If target_func is init, it means the parameter affects initiating + the class object. + """ + return ( + self._is_param_should_not_be_none(param_name, target_func) + or self._need_set_param_bonded_with_none(param_name, target_func) + or self._need_set_param_bonded_with_not_none(param_name, target_func) + ) + + def skip_setting_clz_param_to_ppl_var( + self, + clz_param: str, + target_func: str, + param_with_none: str = None, + ) -> bool: + """Check if to skip setting the class parameter to pipeline variable + + Args: + clz_param (str): The name of the class __init__ parameter, which is to + be verified that if we can skip it in the loop. + target_func (str): The target function impacted by the check. + If target_func is init, it means the parameter affects initiating + the class object. + param_with_none (str): The name of the parameter with None value. + """ + if target_func == INIT: + return ( + self._need_unset_param_bonded_with_none(clz_param, INIT) + or self._is_base_clz_param_excluded_in_subclz(clz_param) + or self._need_unset_param_bonded_with_not_none(clz_param, INIT) + ) + else: + return ( + self._need_unset_param_bonded_with_none(clz_param, target_func) + or self._need_unset_param_bonded_with_not_none(clz_param, target_func) + or self._is_param_should_not_be_none(param_with_none, target_func) + or self._is_overridden_class_param(clz_param, target_func) + or self._is_clz_param_excluded_in_func(clz_param, target_func) + ) + + def skip_setting_func_param_to_ppl_var( + self, + func_param: str, + target_func: str, + param_with_none: str = None, + ) -> bool: + """Check if to skip setting the function parameter to pipeline variable + + Args: + func_param (str): The name of the function parameter, which is to + be verified that if we can skip it in the loop. + target_func (str): The target function impacted by the check. + param_with_none (str): The name of the parameter with None value. + """ + return ( + self._is_param_should_not_be_none(param_with_none, target_func) + or self._need_unset_param_bonded_with_none(func_param, target_func) + or self._need_unset_param_bonded_with_not_none(func_param, target_func) + ) + + def _need_unset_param_bonded_with_none( + self, + param_name: str, + target_func: str, + ) -> bool: + """Check if to skip testing with pipeline variables due to the bond relationship. + + I.e. the parameter (param_name) does not present in the definition json + or it is not allowed to be a pipeline variable if its boned parameter is None. + Then we can skip replacing the param_name with pipeline variables + + Args: + param_name (str): The name of the parameter, which is to be verified that + if we can skip setting it to be pipeline variables. + target_func (str): The target function impacted by the check. + If target_func is init, it means the bonded parameters affect initiating + the class object + """ + return self._is_param_bonded_with_none( + param_map=UNSET_PARAM_BONDED_WITH_NONE, + param_name=param_name, + target_func=target_func, + ) + + def _need_unset_param_bonded_with_not_none(self, param_name: str, target_func: str) -> bool: + """Check if to skip testing with pipeline variables due to the bond relationship. + + I.e. the parameter (param_name) does not present in the definition json or it should + not be presented or it is not allowed to be a pipeline variable if its boned parameter + is not None. Then we can skip replacing the param_name with pipeline variables + + Args: + param_name (str): The name of the parameter, which is to be verified that + if we can skip replacing it with pipeline variables. + target_func (str): The target function impacted by the check. + """ + return self._is_param_bonded_with_not_none( + param_map=UNSET_PARAM_BONDED_WITH_NOT_NONE, + param_name=param_name, + target_func=target_func, + ) + + def _need_set_param_bonded_with_none(self, param_name: str, target_func: str) -> bool: + """Check if to skip testing with None value due to the bond relationship. + + I.e. if a parameter (another_param) is None, its substitute parameter (param_name) + should not be None. Thus we can skip the test round which sets the param_name + to None under the target function (target_func). + + Args: + param_name (str): The name of the parameter, which is to be verified regarding + None value. + target_func (str): The target function impacted by this check. + """ + return self._is_param_bonded_with_none( + param_map=SET_PARAM_BONDED_WITH_NONE, + param_name=param_name, + target_func=target_func, + ) + + def _need_set_param_bonded_with_not_none(self, param_name: str, target_func: str) -> bool: + """Check if to skip testing with None value due to the bond relationship. + + I.e. if the parameter (another_param) is not None, its bonded parameter (param_name) + should not be None. Thus we can skip the test round which sets the param_name + to None under the target function (target_func). + + Args: + param_name (str): The name of the parameter, which is to be verified + regarding None value. + target_func (str): The target function impacted by this check. + """ + return self._is_param_bonded_with_not_none( + param_map=SET_PARAM_BONDED_WITH_NOT_NONE, + param_name=param_name, + target_func=target_func, + ) + + def _is_param_bonded_with_not_none( + self, + param_map: dict, + param_name: str, + target_func: str, + ) -> bool: + """Check if the parameter is bonded one with not None value. + + Args: + param_map (dict): The parameter map storing the bond relationship. + param_name (str): The name of the parameter to be verified. + target_func (str): The target function impacted by this check. + """ + template = self._template + + def _not_none_checker(func: str, params_dict: dict): + for another_param in params_dict: + if template.default_args[CLAZZ_ARGS].get(another_param, None): + return True + if func == INIT: + continue + if template.default_args[FUNC_ARGS][func].get(another_param, None): + return True + return False + + return self._is_param_bonded( + param_map=param_map, + param_name=param_name, + target_func=target_func, + checker_func=_not_none_checker, + ) + + def _is_param_bonded_with_none( + self, + param_map: dict, + param_name: str, + target_func: str, + ) -> bool: + """Check if the parameter is bonded with another one with None value. + + Args: + param_map (dict): The parameter map storing the bond relationship. + param_name (str): The name of the parameter to be verified. + target_func (str): The target function impacted by this check. + """ + template = self._template + + def _none_checker(func: str, params_dict: dict): + for another_param in params_dict: + if template.default_args[CLAZZ_ARGS].get(another_param, "N/A") is None: + return True + if func == INIT: + continue + if template.default_args[FUNC_ARGS][func].get(another_param, "N/A") is None: + return True + return False + + return self._is_param_bonded( + param_map=param_map, + param_name=param_name, + target_func=target_func, + checker_func=_none_checker, + ) + + def _is_param_bonded( + self, + param_map: dict, + param_name: str, + target_func: str, + checker_func: callable, + ) -> bool: + """Check if the parameter has a specific bond relationship. + + Args: + param_map (dict): The parameter map storing the bond relationship. + param_name (str): The name of the parameter to be verified. + target_func (str): The target function impacted by this check. + checker_func (callable): The checker function to check the specific bond relationship. + """ + template = self._template + if template.clazz_type not in param_map: + return False + if target_func not in param_map[template.clazz_type]: + return False + params_dict = param_map[template.clazz_type][target_func][COMMON].get(param_name, {}) + if not params_dict: + if template.clazz.__name__ not in param_map[template.clazz_type][target_func]: + return False + params_dict = param_map[template.clazz_type][target_func][template.clazz.__name__].get( + param_name, {} + ) + return checker_func(target_func, params_dict) + + def _is_base_clz_param_excluded_in_subclz(self, clz_param_name: str) -> bool: + """Check if to skip testing with pipeline variables on class parameter due to exclusion. + + I.e. the base class parameter (clz_param_name) should not be replaced with pipeline variables, + as it's not used in the subclass. + + Args: + clz_param_name (str): The name of the class parameter, which is to be verified. + """ + template = self._template + if template.clazz_type not in BASE_CLASS_PARAMS_EXCLUDED_IN_SUB_CLASS: + return False + if ( + template.clazz.__name__ + not in BASE_CLASS_PARAMS_EXCLUDED_IN_SUB_CLASS[template.clazz_type] + ): + return False + return ( + clz_param_name + in BASE_CLASS_PARAMS_EXCLUDED_IN_SUB_CLASS[template.clazz_type][template.clazz.__name__] + ) + + def _is_overridden_class_param(self, clz_param_name: str, target_func: str) -> bool: + """Check if to skip testing with pipeline variables on class parameter due to override. + + I.e. the class parameter (clz_param_name) should not be replaced with pipeline variables + and tested on the target function (target_func) because it's overridden by a + function parameter with the same name. + e.g. image_uri in model.create can override that in model constructor. + + Args: + clz_param_name (str): The name of the class parameter, which is to be verified. + target_func (str): The target function impacted by the check. + """ + template = self._template + return template.default_args[FUNC_ARGS][target_func].get(clz_param_name, None) is not None + + def _is_clz_param_excluded_in_func(self, clz_param_name: str, target_func: str) -> bool: + """Check if to skip testing with pipeline variables on class parameter due to exclusion. + + I.e. the class parameter (clz_param_name) should not be replaced with pipeline variables + and tested on the target function (target_func), as it's not used there. + + Args: + clz_param_name (str): The name of the class parameter, which is to be verified. + target_func (str): The target function impacted by the check. + """ + return self._is_param_included( + param_map=CLASS_PARAMS_EXCLUDED_IN_TARGET_FUNC, + param_name=clz_param_name, + target_func=target_func, + ) + + def _is_param_should_not_be_none(self, param_name: str, target_func: str) -> bool: + """Check if to skip testing due to the parameter should not be None. + + I.e. the parameter (param_name) is set to None in this round but it is not allowed + according to the logic. Thus we can skip this round of test. + + Args: + param_name (str): The name of the parameter, which is to be verified regarding None value. + target_func (str): The target function impacted by this check. + """ + return self._is_param_included( + param_map=PARAMS_SHOULD_NOT_BE_NONE, + param_name=param_name, + target_func=target_func, + ) + + def _is_param_included( + self, + param_map: dict, + param_name: str, + target_func: str, + ) -> bool: + """Check if the parameter is included in a specific relationship. + + Args: + param_map (dict): The parameter map storing the specific relationship. + param_name (str): The name of the parameter to be verified. + target_func (str): The target function impacted by this check. + """ + template = self._template + if template.clazz_type not in param_map: + return False + if target_func not in param_map[template.clazz_type]: + return False + if param_name in param_map[template.clazz_type][target_func][COMMON]: + return True + if template.clazz.__name__ not in param_map[template.clazz_type][target_func]: + return False + return param_name in param_map[template.clazz_type][target_func][template.clazz.__name__] diff --git a/tests/unit/sagemaker/workflow/test_mechanism/test_code/test_pipeline_var_compatibility_template.py b/tests/unit/sagemaker/workflow/test_mechanism/test_code/test_pipeline_var_compatibility_template.py new file mode 100644 index 0000000000..c9867a3020 --- /dev/null +++ b/tests/unit/sagemaker/workflow/test_mechanism/test_code/test_pipeline_var_compatibility_template.py @@ -0,0 +1,586 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import json + +from random import getrandbits +from typing import Optional +from typing_extensions import get_origin + +from sagemaker import Model, PipelineModel, AlgorithmEstimator +from sagemaker.estimator import EstimatorBase +from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase +from sagemaker.mxnet import MXNet +from sagemaker.processing import Processor +from sagemaker.clarify import SageMakerClarifyProcessor +from sagemaker.pytorch import PyTorch +from sagemaker.tensorflow import TensorFlow +from sagemaker.transformer import Transformer +from sagemaker.tuner import HyperparameterTuner +from sagemaker.workflow.step_collections import StepCollection +from tests.unit.sagemaker.workflow.test_mechanism.test_code import ( + STEP_CLASS, + FIXED_ARGUMENTS, + STR_VAL, + CLAZZ_ARGS, + FUNC_ARGS, + REQUIRED, + OPTIONAL, + INIT, + TYPE, + DEFAULT_VALUE, + COMMON, + PipelineVariableEncoder, +) +from tests.unit.sagemaker.workflow.test_mechanism.test_code.utilities import ( + support_pipeline_variable, + get_param_dict, + generate_pipeline_vars_per_type, + clean_up_types, +) +from tests.unit.sagemaker.workflow.test_mechanism.test_code.parameter_skip_checker import ( + ParameterSkipChecker, +) + + +class PipelineVarCompatiTestTemplate: + """Check the compatibility between Pipeline variables and the given class, target method""" + + def __init__(self, clazz: type, default_args: dict): + """Initialize a `PipelineVarCompatiTestTemplate` instance. + + Args: + clazz (type): The class to test the compatibility. + default_args (dict): The given default arguments for the class and its target method. + """ + self.clazz = clazz + self.clazz_type = self._get_clazz_type() + self._target_funcs = self._get_target_functions() + self._clz_params = get_param_dict(clazz.__init__, clazz) + self._func_params = dict() + for func in self._target_funcs: + self._func_params[func.__name__] = get_param_dict(func) + self._set_and_restructure_default_args(default_args) + self._skip_param_checker = ParameterSkipChecker(self) + + def _set_and_restructure_default_args(self, default_args: dict): + """Set and restructure the default_args + + Restructure the default_args[FUNC_ARGS] if it's missing the layer of target function name + + Args: + default_args (dict): The given default arguments for the class and its target method. + """ + self.default_args = default_args + # restructure the default_args[FUNC_ARGS] if it's missing the layer of target function name + if len(self._target_funcs) == 1: + target_func_name = self._target_funcs[0].__name__ + if target_func_name not in default_args[FUNC_ARGS]: + args = self.default_args.pop(FUNC_ARGS) + self.default_args[FUNC_ARGS] = dict() + self.default_args[FUNC_ARGS][target_func_name] = args + + self._check_or_fill_in_args( + params={**self._clz_params[REQUIRED], **self._clz_params[OPTIONAL]}, + default_args=self.default_args[CLAZZ_ARGS], + ) + for func in self._target_funcs: + func_name = func.__name__ + self._check_or_fill_in_args( + params={ + **self._func_params[func_name][REQUIRED], + **self._func_params[func_name][OPTIONAL], + }, + default_args=self.default_args[FUNC_ARGS][func_name], + ) + + def _get_clazz_type(self) -> str: + """Get the type (in str) of the downstream class""" + if issubclass(self.clazz, Processor): + return "processor" + if issubclass(self.clazz, EstimatorBase): + return "estimator" + if issubclass(self.clazz, Transformer): + return "transformer" + if issubclass(self.clazz, HyperparameterTuner): + return "tuner" + if issubclass(self.clazz, Model): + return "model" + if issubclass(self.clazz, PipelineModel): + return "pipelinemodel" + raise TypeError(f"Unsupported downstream class: {self.clazz}") + + def check_compatibility(self): + """The entry to check the compatibility""" + print( + "Starting to check Pipeline variable compatibility for class (%s) and target methods (%s)\n" + % (self.clazz.__name__, [func.__name__ for func in self._target_funcs]) + ) + + # Check the case when all args are assigned not-None values + print("## Starting to check the compatibility when all optional args are not None ##") + self._iterate_params_to_check_compatibility() + + # Check the case when one of the optional arg is None + print( + "## Starting to check the compatibility when one of the optional arg is None in each round ##" + ) + self._iterate_optional_params_to_check_compatibility() + + def _iterate_params_to_check_compatibility( + self, + param_with_none: Optional[str] = None, + test_func_for_none: Optional[str] = None, + ): + """Iterate each parameter and assign a pipeline var to it to test compatibility + + Args: + param_with_none (str): The name of the parameter with None value. + test_func_for_none (str): The name of the function which is being tested by + replacing optional parameters to None. + """ + self._iterate_clz_params_to_check_compatibility(param_with_none, test_func_for_none) + self._iterate_func_params_to_check_compatibility(param_with_none, test_func_for_none) + + def _iterate_optional_params_to_check_compatibility(self): + """Iterate each optional parameter and set it to none to test compatibility""" + self._iterate_class_optional_params() + self._iterate_func_optional_params() + + def _iterate_class_optional_params(self): + """Iterate each optional parameter in class __init__ and check compatibility""" + print("### Starting to iterate optional parameters in class __init__") + self._iterate_optional_params( + optional_params=self._clz_params[OPTIONAL], + default_args=self.default_args[CLAZZ_ARGS], + ) + + def _iterate_func_optional_params(self): + """Iterate each function parameter and check compatibility""" + for func in self._target_funcs: + print(f"### Starting to iterate optional parameters in function {func.__name__}") + self._iterate_optional_params( + optional_params=self._func_params[func.__name__][OPTIONAL], + default_args=self.default_args[FUNC_ARGS][func.__name__], + test_func_for_none=func.__name__, + ) + + def _iterate_optional_params( + self, + optional_params: dict, + default_args: dict, + test_func_for_none: Optional[str] = None, + ): + """Iterate each optional parameter and check compatibility + Args: + optional_params (dict): The dict containing the optional parameters of a class or method. + default_args (dict): The dict containing the default arguments of a class or method. + test_func_for_none (str): The name of the function which is being tested by + replacing optional parameters to None. + """ + for param_name in optional_params.keys(): + if self._skip_param_checker.skip_setting_param_to_none(param_name, INIT): + continue + origin_val = default_args[param_name] + default_args[param_name] = None + print("=== Parameter (%s) is None in this round ===" % param_name) + self._iterate_params_to_check_compatibility(param_name, test_func_for_none) + default_args[param_name] = origin_val + + def _iterate_clz_params_to_check_compatibility( + self, + param_with_none: Optional[str] = None, + test_func_for_none: Optional[str] = None, + ): + """Iterate each class parameter and assign a pipeline var to it to test compatibility + + Args: + param_with_none (str): The name of the parameter with None value. + test_func_for_none (str): The name of the function which is being tested by + replacing optional parameters to None. + """ + print( + f"#### Iterating parameters (with PipelineVariable annotation) " + f"in class {self.clazz.__name__} __init__ function" + ) + clz_params = {**self._clz_params[REQUIRED], **self._clz_params[OPTIONAL]} + # Iterate through each default arg + for clz_param_name, clz_default_arg in self.default_args[CLAZZ_ARGS].items(): + if clz_param_name == param_with_none: + continue + clz_param_type = clz_params[clz_param_name][TYPE] + if not support_pipeline_variable(clz_param_type): + continue + if self._skip_param_checker.skip_setting_clz_param_to_ppl_var( + clz_param=clz_param_name, target_func=INIT + ): + continue + + # For each arg which supports pipeline variables, + # Replace it with each one of generated pipeline variables + ppl_vars = generate_pipeline_vars_per_type(clz_param_name, clz_param_type) + for clz_ppl_var, expected_clz_expr in ppl_vars: + self.default_args[CLAZZ_ARGS][clz_param_name] = clz_ppl_var + obj = self.clazz(**self.default_args[CLAZZ_ARGS]) + for func in self._target_funcs: + func_name = func.__name__ + if test_func_for_none and test_func_for_none != func_name: + # Iterating optional parameters of a specific target function + # (test_func_for_none), which does not impact other target functions, + # so we can skip them + continue + if self._skip_param_checker._need_set_param_bonded_with_none( + param_with_none, func_name + ): # TODO: add to a public method + continue + if self._skip_param_checker.skip_setting_clz_param_to_ppl_var( + clz_param=clz_param_name, + target_func=func_name, + param_with_none=param_with_none, + ): + continue + # print( + # "Replacing class init arg (%s) with pipeline variable which is expected " + # "to be (%s). Testing with target function (%s)" + # % (clz_param_name, expected_clz_expr, func_name) + # ) + self._generate_and_verify_step_definition( + target_func=getattr(obj, func_name), + expected_expr=expected_clz_expr, + param_with_none=param_with_none, + ) + + # print("============================\n") + self.default_args[CLAZZ_ARGS][clz_param_name] = clz_default_arg + + def _iterate_func_params_to_check_compatibility( + self, + param_with_none: Optional[str] = None, + test_func_for_none: Optional[str] = None, + ): + """Iterate each target func parameter and assign a pipeline var to it + + Args: + param_with_none (str): The name of the parameter with None value. + test_func_for_none (str): The name of the function which is being tested by + replacing optional parameters to None. + """ + obj = self.clazz(**self.default_args[CLAZZ_ARGS]) + for func in self._target_funcs: + func_name = func.__name__ + if test_func_for_none and test_func_for_none != func_name: + # Iterating optional parameters of a specific target function (test_func_for_none), + # which does not impact other target functions, so we can skip them + continue + if self._skip_param_checker._need_set_param_bonded_with_none( + param_with_none, func_name + ): # TODO: add to a public method + continue + print( + f"#### Iterating parameters (with PipelineVariable annotation) in target function: {func_name}" + ) + func_params = { + **self._func_params[func_name][REQUIRED], + **self._func_params[func_name][OPTIONAL], + } + for func_param_name, func_default_arg in self.default_args[FUNC_ARGS][ + func_name + ].items(): + if func_param_name == param_with_none: + continue + if not support_pipeline_variable(func_params[func_param_name][TYPE]): + continue + if self._skip_param_checker.skip_setting_func_param_to_ppl_var( + func_param=func_param_name, + target_func=func_name, + param_with_none=param_with_none, + ): + continue + # For each arg which supports pipeline variables, + # Replace it with each one of generated pipeline variables + ppl_vars = generate_pipeline_vars_per_type( + func_param_name, func_params[func_param_name][TYPE] + ) + for func_ppl_var, expected_func_expr in ppl_vars: + # print( + # "Replacing func arg (%s) with pipeline variable which is expected to be (%s)" + # % (func_param_name, expected_func_expr) + # ) + self.default_args[FUNC_ARGS][func_name][func_param_name] = func_ppl_var + self._generate_and_verify_step_definition( + target_func=getattr(obj, func_name), + expected_expr=expected_func_expr, + param_with_none=param_with_none, + ) + + self.default_args[FUNC_ARGS][func_name][func_param_name] = func_default_arg + # print("-------------------------\n") + + def _generate_and_verify_step_definition( + self, + target_func: callable, + expected_expr: dict, + param_with_none: str, + ): + """Generate a pipeline and verify the pipeline definition + + Args: + target_func (callable): The function to generate step_args. + expected_expr (dict): The expected json expression of a class or method argument. + param_with_none (str): The name of the parameter with None value. + """ + args = dict( + name="MyStep", + step_args=target_func(**self.default_args[FUNC_ARGS][target_func.__name__]), + ) + step = STEP_CLASS[self.clazz_type](**args) + if isinstance(step, StepCollection): + request_dicts = step.request_dicts() + else: + request_dicts = [step.to_request()] + + step_dsl = json.dumps(request_dicts, cls=PipelineVariableEncoder) + step_dsl_obj = json.loads(step_dsl) + exp_origin = json.dumps(expected_expr["origin"]) + exp_to_str = json.dumps(expected_expr["to_string"]) + # if the testing arg is a dict, we may need to remove the outer {} of its expected expr + # to compare, since for HyperParameters, some other arguments are auto inserted to the dict + assert ( + exp_origin in step_dsl + or exp_to_str in step_dsl + or exp_origin[1:-1] in step_dsl + or exp_to_str[1:-1] in step_dsl + ) + self._verify_composite_object_against_pipeline_var(param_with_none, step_dsl, step_dsl_obj) + + def _verify_composite_object_against_pipeline_var( + self, + param_with_none: str, + step_dsl: str, + step_dsl_obj: object, + ): + """verify pipeline definition regarding composite objects against pipeline variables + + Args: + param_with_none (str): The name of the parameter with None value. + step_dsl (str): The step definition retrieved from the pipeline definition DSL. + step_dsl_obj (objet): The json load object of the step definition. + """ + # TODO: remove the following hard code assertion once recursive assignment is added + if issubclass(self.clazz, Processor): + if param_with_none != "network_config": + assert '{"Get": "Parameters.nw_cfg_subnets"}' in step_dsl + assert '{"Get": "Parameters.nw_cfg_security_group_ids"}' in step_dsl + assert '{"Get": "Parameters.nw_cfg_enable_nw_isolation"}' in step_dsl + if issubclass(self.clazz, SageMakerClarifyProcessor): + if param_with_none != "data_config": + assert '{"Get": "Parameters.clarify_processor_input"}' in step_dsl + assert '{"Get": "Parameters.clarify_processor_output"}' in step_dsl + else: + if param_with_none != "outputs": + assert '{"Get": "Parameters.proc_output_source"}' in step_dsl + assert '{"Get": "Parameters.proc_output_dest"}' in step_dsl + assert '{"Get": "Parameters.proc_output_app_managed"}' in step_dsl + if param_with_none != "inputs": + assert '{"Get": "Parameters.proc_input_source"}' in step_dsl + assert '{"Get": "Parameters.proc_input_dest"}' in step_dsl + assert '{"Get": "Parameters.proc_input_s3_data_type"}' in step_dsl + assert '{"Get": "Parameters.proc_input_app_managed"}' in step_dsl + elif issubclass(self.clazz, EstimatorBase): + if issubclass(self.clazz, AmazonAlgorithmEstimatorBase): + # AmazonAlgorithmEstimatorBase's input is records + if param_with_none != "records": + assert '{"Get": "Parameters.records_s3_data"}' in step_dsl + assert '{"Get": "Parameters.records_s3_data_type"}' in step_dsl + assert '{"Get": "Parameters.records_channel"}' in step_dsl + else: + if param_with_none != "inputs": + assert '{"Get": "Parameters.train_inputs_s3_data"}' in step_dsl + assert '{"Get": "Parameters.train_inputs_distribution"}' in step_dsl + assert '{"Get": "Parameters.train_inputs_compression"}' in step_dsl + assert '{"Get": "Parameters.train_inputs_content_type"}' in step_dsl + assert '{"Get": "Parameters.train_inputs_record_wrapping"}' in step_dsl + assert '{"Get": "Parameters.train_inputs_s3_data_type"}' in step_dsl + assert '{"Get": "Parameters.train_inputs_input_mode"}' in step_dsl + assert '{"Get": "Parameters.train_inputs_attribute_name"}' in step_dsl + assert '{"Get": "Parameters.train_inputs_target_attr_name"}' in step_dsl + if not issubclass(self.clazz, (TensorFlow, MXNet, PyTorch, AlgorithmEstimator)): + # debugger_hook_config may be disabled for these first 3 frameworks + # AlgorithmEstimator ignores the kwargs + if param_with_none != "debugger_hook_config": + assert '{"Get": "Parameters.debugger_hook_s3_output"}' in step_dsl + assert '{"Get": "Parameters.debugger_container_output"}' in step_dsl + assert '{"Get": "Parameters.debugger_hook_param"}' in step_dsl + assert '{"Get": "Parameters.debugger_collections_name"}' in step_dsl + assert '{"Get": "Parameters.debugger_collections_param"}' in step_dsl + if not issubclass(self.clazz, AlgorithmEstimator): + # AlgorithmEstimator ignores the kwargs + if param_with_none != "profiler_config": + assert '{"Get": "Parameters.profile_config_s3_output_path"}' in step_dsl + assert '{"Get": "Parameters.profile_config_system_monitor"}' in step_dsl + if param_with_none != "tensorboard_output_config": + assert '{"Get": "Parameters.tensorboard_s3_output"}' in step_dsl + assert '{"Get": "Parameters.tensorboard_container_output"}' in step_dsl + if param_with_none != "rules": + assert '{"Get": "Parameters.rules_image_uri"}' in step_dsl + assert '{"Get": "Parameters.rules_instance_type"}' in step_dsl + assert '{"Get": "Parameters.rules_volume_size"}' in step_dsl + assert '{"Get": "Parameters.rules_to_invoke"}' in step_dsl + assert '{"Get": "Parameters.rules_local_output"}' in step_dsl + assert '{"Get": "Parameters.rules_to_s3_output_path"}' in step_dsl + assert '{"Get": "Parameters.rules_other_s3_input"}' in step_dsl + assert '{"Get": "Parameters.rules_param"}' in step_dsl + if not issubclass(self.clazz, (TensorFlow, MXNet, PyTorch)): + # The collections_to_save is added to debugger rules, + # which may be disabled for some frameworks + assert '{"Get": "Parameters.rules_collections_name"}' in step_dsl + assert '{"Get": "Parameters.rules_collections_param"}' in step_dsl + elif issubclass(self.clazz, HyperparameterTuner): + if param_with_none != "inputs": + assert '{"Get": "Parameters.inputs_estimator_1"}' in step_dsl + if param_with_none != "warm_start_config": + assert '{"Get": "Parameters.warm_start_cfg_parent"}' in step_dsl + if param_with_none != "hyperparameter_ranges": + assert ( + json.dumps( + { + "Std:Join": { + "On": "", + "Values": [{"Get": "Parameters.hyper_range_min_value"}], + } + } + ) + in step_dsl + ) + assert ( + json.dumps( + { + "Std:Join": { + "On": "", + "Values": [{"Get": "Parameters.hyper_range_max_value"}], + } + } + ) + in step_dsl + ) + assert '{"Get": "Parameters.hyper_range_scaling_type"}' in step_dsl + elif issubclass(self.clazz, (Model, PipelineModel)): + if step_dsl_obj[-1]["Type"] == "Model": + return + if param_with_none != "model_metrics": + assert '{"Get": "Parameters.model_statistics_content_type"}' in step_dsl + assert '{"Get": "Parameters.model_statistics_s3_uri"}' in step_dsl + assert '{"Get": "Parameters.model_statistics_content_digest"}' in step_dsl + if param_with_none != "metadata_properties": + assert '{"Get": "Parameters.meta_properties_commit_id"}' in step_dsl + assert '{"Get": "Parameters.meta_properties_repository"}' in step_dsl + assert '{"Get": "Parameters.meta_properties_generated_by"}' in step_dsl + assert '{"Get": "Parameters.meta_properties_project_id"}' in step_dsl + if param_with_none != "drift_check_baselines": + assert '{"Get": "Parameters.drift_constraints_content_type"}' in step_dsl + assert '{"Get": "Parameters.drift_constraints_s3_uri"}' in step_dsl + assert '{"Get": "Parameters.drift_constraints_content_digest"}' in step_dsl + assert '{"Get": "Parameters.drift_bias_content_type"}' in step_dsl + assert '{"Get": "Parameters.drift_bias_s3_uri"}' in step_dsl + assert '{"Get": "Parameters.drift_bias_content_digest"}' in step_dsl + + def _get_non_pipeline_val(self, n: str, t: type) -> object: + """Get the value (not a Pipeline variable) based on parameter type and name + + Args: + n (str): The parameter name. If a parameter has a pre-defined value, + it will be returned directly. + t (type): The parameter type. If a parameter does not have a pre-defined value, + an arg will be auto-generated based on the type. + + Return: + object: A Python primitive value is returned. + """ + if n in FIXED_ARGUMENTS[COMMON]: + return FIXED_ARGUMENTS[COMMON][n] + if n in FIXED_ARGUMENTS[self.clazz_type]: + return FIXED_ARGUMENTS[self.clazz_type][n] + if t is str: + return STR_VAL + if t is int: + return 1 + if t is float: + return 1e-4 + if t is bool: + return bool(getrandbits(1)) + if t in [list, tuple, dict, set]: + return t() + + raise TypeError(f"Unable to parse type: {t}.") + + def _check_or_fill_in_args(self, params: dict, default_args: dict): + """Check if every args are provided and not None + + Otherwise fill in with some default values + + Args: + params (dict): The dict indicating the type of each parameter. + default_args (dict): The dict of args to be checked or filled in. + """ + for param_name, value in params.items(): + if param_name in default_args: + # User specified the default value + continue + if value[DEFAULT_VALUE]: + # The parameter has default value in method definition + default_args[param_name] = value[DEFAULT_VALUE] + continue + clean_type = clean_up_types(value[TYPE]) + origin_type = get_origin(clean_type) + if origin_type is None: + default_args[param_name] = self._get_non_pipeline_val(param_name, clean_type) + else: + default_args[param_name] = self._get_non_pipeline_val(param_name, origin_type) + + self._check_or_update_default_args(default_args) + + def _check_or_update_default_args(self, default_args: dict): + """To check if the default args are valid and update them if not + + Args: + default_args (dict): The dict of args to be checked or updated. + """ + if issubclass(self.clazz, EstimatorBase): + if "disable_profiler" in default_args and default_args["disable_profiler"] is True: + default_args["profiler_config"] = None + + def _get_target_functions(self) -> list: + """Fetch the target functions based on class + + Return: + list: The list of target functions is returned. + """ + if issubclass(self.clazz, Processor): + if issubclass(self.clazz, SageMakerClarifyProcessor): + return [ + self.clazz.run_pre_training_bias, + self.clazz.run_post_training_bias, + self.clazz.run_bias, + self.clazz.run_explainability, + ] + return [self.clazz.run] + if issubclass(self.clazz, EstimatorBase): + return [self.clazz.fit] + if issubclass(self.clazz, Transformer): + return [self.clazz.transform] + if issubclass(self.clazz, HyperparameterTuner): + return [self.clazz.fit] + if issubclass(self.clazz, (Model, PipelineModel)): + return [self.clazz.register, self.clazz.create] + raise TypeError(f"Unable to get target function for class {self.clazz}") diff --git a/tests/unit/sagemaker/workflow/test_mechanism/test_code/utilities.py b/tests/unit/sagemaker/workflow/test_mechanism/test_code/utilities.py new file mode 100644 index 0000000000..4c464a3985 --- /dev/null +++ b/tests/unit/sagemaker/workflow/test_mechanism/test_code/utilities.py @@ -0,0 +1,267 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import os +from inspect import signature + +from sagemaker import Model +from sagemaker.estimator import EstimatorBase +from sagemaker.fw_utils import UploadedCode +from sagemaker.workflow import is_pipeline_variable +from sagemaker.workflow.entities import PipelineVariable +from typing import Union +from typing_extensions import get_args, get_origin +from tests.unit.sagemaker.workflow.test_mechanism.test_code import ( + PIPELINE_VARIABLES, + REQUIRED, + OPTIONAL, + DEFAULT_VALUE, + IMAGE_URI, +) + + +def support_pipeline_variable(t: type) -> bool: + """Check if pipeline variable is supported by a parameter according to its type + + Args: + t (type): The type to be checked + + Return: + bool: True if it supports. False otherwise. + """ + return "PipelineVariable" in str(t) + + +def get_param_dict(func, clazz=None) -> dict: + """Get a parameter dict of a given function + + The parameter dict indicates if a parameter is required or not, as well as its type. + + Arg: + func (function): A class constructor method or other class methods. + clazz (type): The corresponding class whose method is passed in. + + Return: + dict: A parameter dict is returned. + """ + params = list() + params.append(signature(func)) + if func.__name__ == "__init__" and issubclass(clazz, (EstimatorBase, Model)): + # Go through all parent classes constructor function to get the entire parameters since + # estimator and model classes use **kwargs for parameters defined in parent classes + # The leaf class's parameters should be on top of the params list and have high priority + _get_params_from_parent_class_constructors(clazz, params) + + params_dict = dict( + required=dict(), + optional=dict(), + ) + for param in params: + for param_val in param.parameters.values(): + if param_val.annotation is param_val.empty: + continue + val = dict(type=param_val.annotation, default_value=None) + if param_val.name == "sagemaker_session": + # Treat sagemaker_session as required as it must be a PipelineSession obj + if not _is_in_params_dict(param_val.name, params_dict): + params_dict[REQUIRED][param_val.name] = val + elif param_val.default is param_val.empty or param_val.default is not None: + if not _is_in_params_dict(param_val.name, params_dict): + # Some parameters e.g. entry_point in TensorFlow appears as both required (in Framework) + # and optional (in EstimatorBase) parameter. The annotation defined in the + # class node (i.e. Framework) which is closer to the leaf class (TensorFlow) should win. + if param_val.default is not param_val.empty: + val[DEFAULT_VALUE] = param_val.default + params_dict[REQUIRED][param_val.name] = val + else: + if not _is_in_params_dict(param_val.name, params_dict): + params_dict[OPTIONAL][param_val.name] = val + return params_dict + + +def _is_in_params_dict(param_name: str, params_dict: dict): + """To check if the parameter is in the parameter dict + + Args: + param_name (str): The name of the parameter to be checked + params_dict (dict): The parameter dict among which to check if the param_name exists + """ + return param_name in params_dict[REQUIRED] or param_name in params_dict[OPTIONAL] + + +def _get_params_from_parent_class_constructors(clazz: type, params: list): + """Get constructor parameters from parent class + + Args: + clazz (type): The downstream class to collect parameters from all its parent constructors + params (list): The list to collect all parameters + """ + while clazz.__name__ not in {"EstimatorBase", "Model"}: + parent_class = clazz.__base__ + params.append(signature(parent_class.__init__)) + clazz = parent_class + + +def generate_pipeline_vars_per_type( + param_name: str, + param_type: type, +) -> list: + """Provide a list of possible PipelineVariable objects. + + For example, if type_hint is Union[str, PipelineVariable], + return [ParameterString, Properties, JsonGet, Join, ExecutionVariable] + + Args: + param_name (str): The name of the parameter to generate the pipeline variable list. + param_type (type): The type of the parameter to generate the pipeline variable list. + + Return: + list: A list of possible PipelineVariable objects are returned. + """ + # verify if params allow pipeline variables + if "PipelineVariable" not in str(param_type): + raise TypeError(("The type: %s does not support PipelineVariable.", param_type)) + + types = get_args(param_type) + # e.g. Union[str, PipelineVariable] or Union[str, PipelineVariable, NoneType] + if PipelineVariable in types: + # PipelineVariable corresponds to Python Primitive types + # i.e. str, int, float, bool + ppl_var = _get_pipeline_var(types=types) + return ppl_var + + # e.g. Union[List[...], NoneType] or Union[Dict[...], NoneType] etc. + clean_type = clean_up_types(param_type) + origin_type = get_origin(clean_type) + if origin_type not in [list, dict, set, tuple]: + raise TypeError(f"Unsupported type: {param_type} for param: {param_name}") + sub_types = get_args(clean_type) + + # e.g. List[...], Tuple[...], Set[...] + if origin_type in [list, tuple, set]: + ppl_var_list = generate_pipeline_vars_per_type(param_name, sub_types[0]) + return [ + ( + origin_type([var]), + dict( + origin=origin_type([expected["origin"]]), + to_string=origin_type([expected["to_string"]]), + ), + ) + for var, expected in ppl_var_list + ] + + # e.g. Dict[...] + if origin_type is dict: + key_type = sub_types[0] + if key_type is not str: + raise TypeError( + f"Unsupported type: {key_type} for dict key in {param_name} of {param_type} type" + ) + ppl_var_list = generate_pipeline_vars_per_type(param_name, sub_types[1]) + return [ + ( + dict(MyKey=var), + dict( + origin=dict(MyKey=expected["origin"]), + to_string=dict(MyKey=expected["to_string"]), + ), + ) + for var, expected in ppl_var_list + ] + return list() + + +def clean_up_types(t: type) -> type: + """Clean up the Union type and return the first subtype (not a NoneType) of it + + For example for Union[str, int, NoneType], it will return str + + Args: + t (type): The type of a parameter to be cleaned up. + + Return: + type: The cleaned up type is returned. + """ + if get_origin(t) == Union: + types = get_args(t) + return list(filter(lambda t: "NoneType" not in str(t), types))[0] + return t + + +def _get_pipeline_var(types: tuple) -> list: + """Get a Pipeline variable based on one kind of the parameter types. + + Args: + types (tuple): The possible types of a parameter. + + Return: + list: a list of possible PipelineVariable objects are returned + """ + if str in types: + return PIPELINE_VARIABLES["str"] + if int in types: + return PIPELINE_VARIABLES["int"] + if float in types: + return PIPELINE_VARIABLES["float"] + if bool in types: + return PIPELINE_VARIABLES["bool"] + raise TypeError(f"Unable to parse types: {types}.") + + +def mock_tar_and_upload_dir( + session, + bucket, + s3_key_prefix, + script, + directory=None, + dependencies=None, + kms_key=None, + s3_resource=None, + settings=None, +): + """Briefly mock the behavior of tar_and_upload_dir""" + if directory and (is_pipeline_variable(directory) or directory.lower().startswith("s3://")): + return UploadedCode(s3_prefix=directory, script_name=script) + script_name = script if directory else os.path.basename(script) + key = "%s/sourcedir.tar.gz" % s3_key_prefix + return UploadedCode(s3_prefix="s3://%s/%s" % (bucket, key), script_name=script_name) + + +def mock_image_uris_retrieve( + framework, + region, + version=None, + py_version=None, + instance_type=None, + accelerator_type=None, + image_scope=None, + container_version=None, + distribution=None, + base_framework_version=None, + training_compiler_config=None, + model_id=None, + model_version=None, + tolerate_vulnerable_model=False, + tolerate_deprecated_model=False, + sdk_version=None, + inference_tool=None, + serverless_inference_config=None, +) -> str: + """Briefly mock the behavior of image_uris.retrieve""" + args = dict(locals()) + for name, val in args.items(): + if is_pipeline_variable(val): + raise ValueError("%s should not be a pipeline variable (%s)" % (name, type(val))) + return IMAGE_URI diff --git a/tests/unit/sagemaker/workflow/test_mechanism/test_entries/test_pipeline_var_compatibility_with_estimators.py b/tests/unit/sagemaker/workflow/test_mechanism/test_entries/test_pipeline_var_compatibility_with_estimators.py new file mode 100644 index 0000000000..e4e3a74b98 --- /dev/null +++ b/tests/unit/sagemaker/workflow/test_mechanism/test_entries/test_pipeline_var_compatibility_with_estimators.py @@ -0,0 +1,440 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from random import getrandbits + +from unittest.mock import patch, MagicMock + +from sagemaker.estimator import Estimator +from sagemaker.huggingface import TrainingCompilerConfig as HF_TrainingCompilerConfig +from sagemaker.tensorflow import TensorFlow +from sagemaker.amazon.lda import LDA +from sagemaker.xgboost.estimator import XGBoost +from sagemaker.rl.estimator import RLEstimator +from sagemaker.amazon.pca import PCA +from sagemaker.mxnet.estimator import MXNet +from sagemaker.amazon.randomcutforest import RandomCutForest +from sagemaker.amazon.factorization_machines import FactorizationMachines +from sagemaker.algorithm import AlgorithmEstimator +from sagemaker.sklearn.estimator import SKLearn +from sagemaker.amazon.ipinsights import IPInsights +from sagemaker.huggingface.estimator import HuggingFace +from sagemaker.tensorflow.estimator import TrainingCompilerConfig as TF_TrainingCompilerConfig +from sagemaker.amazon.ntm import NTM +from sagemaker.pytorch import PyTorch +from sagemaker.chainer import Chainer +from sagemaker.amazon.linear_learner import LinearLearner +from sagemaker.amazon.knn import KNN +from tests.unit.sagemaker.workflow.test_mechanism.test_code.test_pipeline_var_compatibility_template import ( + PipelineVarCompatiTestTemplate, +) +from sagemaker.amazon.object2vec import Object2Vec +from sagemaker.amazon.kmeans import KMeans +from tests.unit.sagemaker.workflow.test_mechanism.test_code import IMAGE_URI, MockProperties +from tests.unit.sagemaker.workflow.test_mechanism.test_code.utilities import ( + mock_tar_and_upload_dir, + mock_image_uris_retrieve, +) + +_IS_TRUE = bool(getrandbits(1)) + + +# These tests provide the incomplete default arg dict +# within which some class or target func parameters are missing. +# The test template will fill in those missing args +# Note: the default args should not include PipelineVariable objects +@patch( + "sagemaker.workflow.steps.Properties", + MagicMock(return_value=MockProperties(step_name="MyStep")), +) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +def test_estimator_compatibility(): + default_args = dict( + clazz_args=dict(), + func_args=dict(), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=Estimator, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch( + "sagemaker.workflow.steps.Properties", + MagicMock(return_value=MockProperties(step_name="MyStep")), +) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +def test_tensorflow_estimator_compatibility(): + default_args = dict( + clazz_args=dict( + compiler_config=TF_TrainingCompilerConfig() if _IS_TRUE else None, + image_uri=IMAGE_URI if not _IS_TRUE else None, + instance_type="ml.p3.2xlarge", + framework_version="2.9", + py_version="py39", + ), + func_args=dict(), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=TensorFlow, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch( + "sagemaker.workflow.steps.Properties", + MagicMock(return_value=MockProperties(step_name="MyStep")), +) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +def test_lda_estimator_compatibility(): + default_args = dict( + clazz_args=dict( + instance_count=1, + ), + func_args=dict( + mini_batch_size=128, + ), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=LDA, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch( + "sagemaker.workflow.steps.Properties", + MagicMock(return_value=MockProperties(step_name="MyStep")), +) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +def test_pca_estimator_compatibility(): + default_args = dict( + clazz_args=dict(), + func_args=dict(), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=PCA, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch( + "sagemaker.workflow.steps.Properties", + MagicMock(return_value=MockProperties(step_name="MyStep")), +) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +def test_mxnet_estimator_compatibility(): + default_args = dict( + clazz_args=dict(framework_version="1.4.0"), + func_args=dict(), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=MXNet, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch( + "sagemaker.workflow.steps.Properties", + MagicMock(return_value=MockProperties(step_name="MyStep")), +) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +def test_rl_estimator_compatibility(): + default_args = dict( + clazz_args=dict(), + func_args=dict(), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=RLEstimator, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch( + "sagemaker.workflow.steps.Properties", + MagicMock(return_value=MockProperties(step_name="MyStep")), +) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +def test_ntm_estimator_compatibility(): + default_args = dict( + clazz_args=dict(clip_gradient=1e-3), + func_args=dict(), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=NTM, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch( + "sagemaker.workflow.steps.Properties", + MagicMock(return_value=MockProperties(step_name="MyStep")), +) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +def test_rcf_estimator_compatibility(): + default_args = dict( + clazz_args=dict(), + func_args=dict(), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=RandomCutForest, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch( + "sagemaker.workflow.steps.Properties", + MagicMock(return_value=MockProperties(step_name="MyStep")), +) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +def test_xgboost_estimator_compatibility(): + default_args = dict( + clazz_args=dict(framework_version="1.2-1"), + func_args=dict(), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=XGBoost, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch( + "sagemaker.workflow.steps.Properties", + MagicMock(return_value=MockProperties(step_name="MyStep")), +) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +def test_object2vec_estimator_compatibility(): + default_args = dict( + clazz_args=dict(), + func_args=dict(), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=Object2Vec, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch( + "sagemaker.workflow.steps.Properties", + MagicMock(return_value=MockProperties(step_name="MyStep")), +) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +def test_fm_estimator_compatibility(): + default_args = dict( + clazz_args=dict(), + func_args=dict(), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=FactorizationMachines, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch( + "sagemaker.workflow.steps.Properties", + MagicMock(return_value=MockProperties(step_name="MyStep")), +) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +def test_ae_estimator_compatibility(): + default_args = dict( + clazz_args=dict(), + func_args=dict(), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=AlgorithmEstimator, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch( + "sagemaker.workflow.steps.Properties", + MagicMock(return_value=MockProperties(step_name="MyStep")), +) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +def test_sklearn_estimator_compatibility(): + default_args = dict( + clazz_args=dict( + py_version="py3", + instance_count=1, + instance_type="ml.m5.xlarge", + framework_version="0.20.0", + ), + func_args=dict(), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=SKLearn, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch( + "sagemaker.workflow.steps.Properties", + MagicMock(return_value=MockProperties(step_name="MyStep")), +) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +def test_ipinsights_estimator_compatibility(): + default_args = dict( + clazz_args=dict(), + func_args=dict(), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=IPInsights, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch( + "sagemaker.workflow.steps.Properties", + MagicMock(return_value=MockProperties(step_name="MyStep")), +) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +def test_huggingface_estimator_compatibility(): + default_args = dict( + clazz_args=dict( + instance_type="ml.p3.2xlarge", + transformers_version="4.11", + tensorflow_version="2.5" if _IS_TRUE else None, + pytorch_version="1.9" if not _IS_TRUE else None, + compiler_config=HF_TrainingCompilerConfig() if _IS_TRUE else None, + image_uri=IMAGE_URI if not _IS_TRUE else None, + py_version="py37" if _IS_TRUE else "py38", + ), + func_args=dict(), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=HuggingFace, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch( + "sagemaker.workflow.steps.Properties", + MagicMock(return_value=MockProperties(step_name="MyStep")), +) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +def test_pytorch_estimator_compatibility(): + default_args = dict( + clazz_args=dict(framework_version="1.8.0", py_version="py3"), + func_args=dict(), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=PyTorch, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch( + "sagemaker.workflow.steps.Properties", + MagicMock(return_value=MockProperties(step_name="MyStep")), +) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +def test_kmeans_estimator_compatibility(): + default_args = dict( + clazz_args=dict(), + func_args=dict(), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=KMeans, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch( + "sagemaker.workflow.steps.Properties", + MagicMock(return_value=MockProperties(step_name="MyStep")), +) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +def test_knn_estimator_compatibility(): + default_args = dict( + clazz_args=dict( + dimension_reduction_target=6, + dimension_reduction_type="sign", + ), + func_args=dict(), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=KNN, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch( + "sagemaker.workflow.steps.Properties", + MagicMock(return_value=MockProperties(step_name="MyStep")), +) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +def test_chainer_estimator_compatibility(): + default_args = dict( + clazz_args=dict(framework_version="4.0"), + func_args=dict(), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=Chainer, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch( + "sagemaker.workflow.steps.Properties", + MagicMock(return_value=MockProperties(step_name="MyStep")), +) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +def test_ll_estimator_compatibility(): + default_args = dict( + clazz_args=dict(init_method="normal"), + func_args=dict(), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=LinearLearner, + default_args=default_args, + ) + test_template.check_compatibility() diff --git a/tests/unit/sagemaker/workflow/test_mechanism/test_entries/test_pipeline_var_compatibility_with_models.py b/tests/unit/sagemaker/workflow/test_mechanism/test_entries/test_pipeline_var_compatibility_with_models.py new file mode 100644 index 0000000000..611e415f2e --- /dev/null +++ b/tests/unit/sagemaker/workflow/test_mechanism/test_entries/test_pipeline_var_compatibility_with_models.py @@ -0,0 +1,550 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from random import getrandbits +from unittest.mock import patch, MagicMock + +from sagemaker import ( + Model, + KNNModel, + KMeansModel, + PCAModel, + LDAModel, + NTMModel, + Object2VecModel, + FactorizationMachinesModel, + IPInsightsModel, + RandomCutForestModel, + LinearLearnerModel, + PipelineModel, +) +from sagemaker.chainer import ChainerModel +from sagemaker.huggingface import HuggingFaceModel +from sagemaker.model import FrameworkModel +from sagemaker.multidatamodel import MultiDataModel +from sagemaker.mxnet import MXNetModel +from sagemaker.pytorch import PyTorchModel +from sagemaker.sklearn import SKLearnModel +from sagemaker.sparkml import SparkMLModel +from sagemaker.tensorflow import TensorFlowModel +from sagemaker.workflow._utils import _RepackModelStep +from sagemaker.xgboost import XGBoostModel +from tests.unit.sagemaker.workflow.test_mechanism.test_code.test_pipeline_var_compatibility_template import ( + PipelineVarCompatiTestTemplate, +) +from tests.unit.sagemaker.workflow.test_mechanism.test_code import BUCKET, MockProperties +from tests.unit.sagemaker.workflow.test_mechanism.test_code.utilities import ( + mock_image_uris_retrieve, + mock_tar_and_upload_dir, +) + +_IS_TRUE = bool(getrandbits(1)) +_mock_properties = MockProperties(step_name="MyStep") +_mock_properties.__dict__["ModelArtifacts"] = MockProperties( + step_name="MyStep", path="ModelArtifacts" +) +_mock_properties.ModelArtifacts.__dict__["S3ModelArtifacts"] = MockProperties( + step_name="MyStep", path="ModelArtifacts.S3ModelArtifacts" +) + + +# These tests provide the incomplete default arg dict +# within which some class or target func parameters are missing. +# The test template will fill in those missing args +# Note: the default args should not include PipelineVariable objects +@patch("sagemaker.workflow.steps.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.workflow._utils.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +@patch("sagemaker.utils.repack_model", MagicMock()) +@patch.object(_RepackModelStep, "_inject_repack_script", MagicMock()) +def test_model_compatibility(): + default_args = dict( + clazz_args=dict(), + func_args=dict( + register=dict( + inference_instances=["ml.t2.medium", "ml.m5.xlarge"], + transform_instances=["ml.t2.medium", "ml.m5.xlarge"], + model_package_group_name="my-model-pkg-group" if not _IS_TRUE else None, + model_package_name="my-model-pkg" if _IS_TRUE else None, + ), + create=dict( + instance_type="ml.t2.medium", + ), + ), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=Model, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch("sagemaker.workflow.steps.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.workflow._utils.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +@patch("sagemaker.utils.repack_model", MagicMock()) +@patch.object(_RepackModelStep, "_inject_repack_script", MagicMock()) +def test_framework_model_compatibility(): + default_args = dict( + clazz_args=dict(), + func_args=dict( + register=dict(), + create=dict(), + ), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=FrameworkModel, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch("sagemaker.workflow.steps.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.workflow._utils.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +@patch("sagemaker.utils.repack_model", MagicMock()) +@patch.object(_RepackModelStep, "_inject_repack_script", MagicMock()) +# Skip validate source dir because we skip _inject_repack_script +# Thus, repack script is not in the dir and the validation fails +@patch("sagemaker.estimator.validate_source_dir", MagicMock(return_value=True)) +def test_tensorflow_model_compatibility(): + default_args = dict( + clazz_args=dict(), + func_args=dict( + register=dict(), + create=dict(), + ), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=TensorFlowModel, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch("sagemaker.workflow.steps.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.workflow._utils.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +@patch("sagemaker.utils.repack_model", MagicMock()) +@patch.object(_RepackModelStep, "_inject_repack_script", MagicMock()) +def test_knn_model_compatibility(): + default_args = dict( + clazz_args=dict(), + func_args=dict( + register=dict(), + create=dict(), + ), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=KNNModel, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch("sagemaker.workflow.steps.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.workflow._utils.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +@patch("sagemaker.utils.repack_model", MagicMock()) +@patch.object(_RepackModelStep, "_inject_repack_script", MagicMock()) +def test_sparkml_model_compatibility(): + default_args = dict( + clazz_args=dict(), + func_args=dict( + register=dict(), + create=dict(), + ), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=SparkMLModel, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch("sagemaker.workflow.steps.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.workflow._utils.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +@patch("sagemaker.utils.repack_model", MagicMock()) +@patch.object(_RepackModelStep, "_inject_repack_script", MagicMock()) +def test_kmeans_model_compatibility(): + default_args = dict( + clazz_args=dict(), + func_args=dict( + register=dict(), + create=dict(), + ), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=KMeansModel, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch("sagemaker.workflow.steps.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.workflow._utils.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +@patch("sagemaker.utils.repack_model", MagicMock()) +@patch.object(_RepackModelStep, "_inject_repack_script", MagicMock()) +def test_pca_model_compatibility(): + default_args = dict( + clazz_args=dict(), + func_args=dict( + register=dict(), + create=dict(), + ), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=PCAModel, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch("sagemaker.workflow.steps.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.workflow._utils.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +@patch("sagemaker.utils.repack_model", MagicMock()) +@patch.object(_RepackModelStep, "_inject_repack_script", MagicMock()) +def test_lda_model_compatibility(): + default_args = dict( + clazz_args=dict(), + func_args=dict( + register=dict(), + create=dict(), + ), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=LDAModel, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch("sagemaker.workflow.steps.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.workflow._utils.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +@patch("sagemaker.utils.repack_model", MagicMock()) +@patch.object(_RepackModelStep, "_inject_repack_script", MagicMock()) +def test_ntm_model_compatibility(): + default_args = dict( + clazz_args=dict(), + func_args=dict( + register=dict(), + create=dict(), + ), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=NTMModel, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch("sagemaker.workflow.steps.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.workflow._utils.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +@patch("sagemaker.utils.repack_model", MagicMock()) +@patch.object(_RepackModelStep, "_inject_repack_script", MagicMock()) +def test_object2vec_model_compatibility(): + default_args = dict( + clazz_args=dict(), + func_args=dict( + register=dict(), + create=dict(), + ), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=Object2VecModel, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch("sagemaker.workflow.steps.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.workflow._utils.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +@patch("sagemaker.utils.repack_model", MagicMock()) +@patch.object(_RepackModelStep, "_inject_repack_script", MagicMock()) +def test_factorizationmachines_model_compatibility(): + default_args = dict( + clazz_args=dict(), + func_args=dict( + register=dict(), + create=dict(), + ), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=FactorizationMachinesModel, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch("sagemaker.workflow.steps.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.workflow._utils.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +@patch("sagemaker.utils.repack_model", MagicMock()) +@patch.object(_RepackModelStep, "_inject_repack_script", MagicMock()) +def test_ipinsights_model_compatibility(): + default_args = dict( + clazz_args=dict(), + func_args=dict( + register=dict(), + create=dict(), + ), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=IPInsightsModel, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch("sagemaker.workflow.steps.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.workflow._utils.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +@patch("sagemaker.utils.repack_model", MagicMock()) +@patch.object(_RepackModelStep, "_inject_repack_script", MagicMock()) +def test_randomcutforest_model_compatibility(): + default_args = dict( + clazz_args=dict(), + func_args=dict( + register=dict(), + create=dict(), + ), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=RandomCutForestModel, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch("sagemaker.workflow.steps.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.workflow._utils.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +@patch("sagemaker.utils.repack_model", MagicMock()) +@patch.object(_RepackModelStep, "_inject_repack_script", MagicMock()) +def test_linearlearner_model_compatibility(): + default_args = dict( + clazz_args=dict(), + func_args=dict( + register=dict(), + create=dict(), + ), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=LinearLearnerModel, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch("sagemaker.workflow.steps.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.workflow._utils.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +@patch("sagemaker.utils.repack_model", MagicMock()) +@patch.object(_RepackModelStep, "_inject_repack_script", MagicMock()) +@patch("sagemaker.estimator.validate_source_dir", MagicMock(return_value=True)) +def test_sklearn_model_compatibility(): + default_args = dict( + clazz_args=dict(framework_version="0.20.0"), + func_args=dict( + register=dict(), + create=dict(accelerator_type=None), + ), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=SKLearnModel, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch("sagemaker.workflow.steps.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.workflow._utils.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +@patch("sagemaker.utils.repack_model", MagicMock()) +@patch.object(_RepackModelStep, "_inject_repack_script", MagicMock()) +@patch("sagemaker.estimator.validate_source_dir", MagicMock(return_value=True)) +def test_pytorch_model_compatibility(): + default_args = dict( + clazz_args=dict(), + func_args=dict( + register=dict(), + create=dict(), + ), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=PyTorchModel, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch("sagemaker.workflow.steps.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.workflow._utils.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +@patch("sagemaker.utils.repack_model", MagicMock()) +@patch.object(_RepackModelStep, "_inject_repack_script", MagicMock()) +@patch("sagemaker.estimator.validate_source_dir", MagicMock(return_value=True)) +def test_xgboost_model_compatibility(): + default_args = dict( + clazz_args=dict( + framework_version="1", + ), + func_args=dict( + register=dict(), + create=dict(), + ), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=XGBoostModel, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch("sagemaker.workflow.steps.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.workflow._utils.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +@patch("sagemaker.utils.repack_model", MagicMock()) +@patch.object(_RepackModelStep, "_inject_repack_script", MagicMock()) +def test_chainer_model_compatibility(): + default_args = dict( + clazz_args=dict(framework_version="4.0.0"), + func_args=dict( + register=dict(), + create=dict(accelerator_type=None), + ), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=ChainerModel, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch("sagemaker.workflow.steps.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.workflow._utils.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +@patch("sagemaker.utils.repack_model", MagicMock()) +@patch.object(_RepackModelStep, "_inject_repack_script", MagicMock()) +@patch("sagemaker.estimator.validate_source_dir", MagicMock(return_value=True)) +def test_huggingface_model_compatibility(): + default_args = dict( + clazz_args=dict( + tensorflow_version="2.4.1" if _IS_TRUE else None, + pytorch_version="1.7.1" if not _IS_TRUE else None, + transformers_version="4.6.1", + py_version="py37" if _IS_TRUE else "py36", + ), + func_args=dict( + register=dict(), + create=dict(accelerator_type=None), + ), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=HuggingFaceModel, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch("sagemaker.workflow.steps.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.workflow._utils.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +@patch("sagemaker.utils.repack_model", MagicMock()) +@patch.object(_RepackModelStep, "_inject_repack_script", MagicMock()) +@patch("sagemaker.estimator.validate_source_dir", MagicMock(return_value=True)) +def test_mxnet_model_compatibility(): + default_args = dict( + clazz_args=dict(), + func_args=dict( + register=dict(), + create=dict(), + ), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=MXNetModel, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch("sagemaker.workflow.steps.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.workflow._utils.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +@patch("sagemaker.utils.repack_model", MagicMock()) +@patch.object(_RepackModelStep, "_inject_repack_script", MagicMock()) +def test_multidata_model_compatibility(): + default_args = dict( + clazz_args=dict( + model_data_prefix=f"s3://{BUCKET}", + model=None, + ), + func_args=dict( + register=dict(), + create=dict(), + ), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=MultiDataModel, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch("sagemaker.workflow.steps.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.workflow._utils.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +@patch("sagemaker.utils.repack_model", MagicMock()) +@patch.object(_RepackModelStep, "_inject_repack_script", MagicMock()) +def test_pipelinemodel_compatibility(): + default_args = dict( + clazz_args=dict(), + func_args=dict( + register=dict(), + create=dict(), + ), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=PipelineModel, + default_args=default_args, + ) + test_template.check_compatibility() diff --git a/tests/unit/sagemaker/workflow/test_mechanism/test_entries/test_pipeline_var_compatibility_with_processors.py b/tests/unit/sagemaker/workflow/test_mechanism/test_entries/test_pipeline_var_compatibility_with_processors.py new file mode 100644 index 0000000000..b677b909be --- /dev/null +++ b/tests/unit/sagemaker/workflow/test_mechanism/test_entries/test_pipeline_var_compatibility_with_processors.py @@ -0,0 +1,363 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from random import getrandbits +from unittest.mock import MagicMock, patch + +from sagemaker.processing import FrameworkProcessor, ScriptProcessor, Processor +from sagemaker.pytorch.processing import PyTorchProcessor +from sagemaker.clarify import ( + SageMakerClarifyProcessor, + BiasConfig, + ModelConfig, + ModelPredictedLabelConfig, + PDPConfig, +) +from sagemaker.tensorflow.processing import TensorFlowProcessor +from sagemaker.xgboost.processing import XGBoostProcessor +from sagemaker.mxnet.processing import MXNetProcessor +from sagemaker.sklearn.processing import SKLearnProcessor +from sagemaker.huggingface.processing import HuggingFaceProcessor +from tests.unit.sagemaker.workflow.test_mechanism.test_code.test_pipeline_var_compatibility_template import ( + PipelineVarCompatiTestTemplate, +) +from tests.unit.sagemaker.workflow.test_mechanism.test_code import ( + ROLE, + DUMMY_S3_SCRIPT_PATH, + PIPELINE_SESSION, + MockProperties, +) + +from tests.unit.sagemaker.workflow.test_mechanism.test_code.utilities import ( + mock_image_uris_retrieve, +) + +_IS_TRUE = bool(getrandbits(1)) + + +# These tests provide the incomplete default arg dict +# within which some class or target func parameters are missing. +# The test template will fill in those missing args +# Note: the default args should not include PipelineVariable objects +@patch( + "sagemaker.workflow.steps.Properties", + MagicMock(return_value=MockProperties(step_name="MyStep")), +) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +def test_processor_compatibility(): + default_args = dict( + clazz_args=dict( + role=ROLE, + volume_size_in_gb=None, + sagemaker_session=PIPELINE_SESSION, + ), + func_args=dict(), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=Processor, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch( + "sagemaker.workflow.steps.Properties", + MagicMock(return_value=MockProperties(step_name="MyStep")), +) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +def test_script_processor_compatibility(): + default_args = dict( + clazz_args=dict( + role=ROLE, + volume_size_in_gb=None, + sagemaker_session=PIPELINE_SESSION, + ), + func_args=dict( + code=DUMMY_S3_SCRIPT_PATH, + ), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=ScriptProcessor, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch( + "sagemaker.workflow.steps.Properties", + MagicMock(return_value=MockProperties(step_name="MyStep")), +) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +def test_framework_processor_compatibility(): + default_args = dict( + clazz_args=dict( + role=ROLE, + py_version="py3", + volume_size_in_gb=None, + sagemaker_session=PIPELINE_SESSION, + ), + func_args=dict( + code=DUMMY_S3_SCRIPT_PATH, + ), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=FrameworkProcessor, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch( + "sagemaker.workflow.steps.Properties", + MagicMock(return_value=MockProperties(step_name="MyStep")), +) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +def test_pytorch_processor_compatibility(): + default_args = dict( + clazz_args=dict( + framework_version="1.8.1", + role=ROLE, + py_version="py3", + volume_size_in_gb=None, + sagemaker_session=PIPELINE_SESSION, + ), + func_args=dict( + code=DUMMY_S3_SCRIPT_PATH, + ), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=PyTorchProcessor, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch( + "sagemaker.workflow.steps.Properties", + MagicMock(return_value=MockProperties(step_name="MyStep")), +) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +def test_sagemaker_clarify_processor(): + bias_config = BiasConfig( + facet_values_or_threshold=0.6, + facet_name="facet_name", + label_values_or_threshold=0.6, + ) + model_config = ModelConfig( + model_name="my-model", + instance_count=1, + instance_type="ml.m5.xlarge", + ) + model_pred_config = ModelPredictedLabelConfig(label="pred", probability_threshold=0.6) + + default_args = dict( + clazz_args=dict( + role=ROLE, + sagemaker_session=PIPELINE_SESSION, + ), + func_args=dict( + run_pre_training_bias=dict( + data_bias_config=bias_config, + ), + run_post_training_bias=dict( + data_bias_config=bias_config, + model_config=model_config, + model_predicted_label_config=model_pred_config, + ), + run_bias=dict( + bias_config=bias_config, + model_config=model_config, + model_predicted_label_config=model_pred_config, + ), + run_explainability=dict( + model_config=model_config, + model_scores=ModelPredictedLabelConfig(label="pred", probability_threshold=0.6), + explainability_config=PDPConfig(features=["f1", "f2", "f3", "f4"]), + ), + ), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=SageMakerClarifyProcessor, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch( + "sagemaker.workflow.steps.Properties", + MagicMock(return_value=MockProperties(step_name="MyStep")), +) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +def test_tensorflow_processor(): + default_args = dict( + clazz_args=dict( + framework_version="2.8", + role=ROLE, + py_version="py39", + sagemaker_session=PIPELINE_SESSION, + ), + func_args=dict( + code=DUMMY_S3_SCRIPT_PATH, + ), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=TensorFlowProcessor, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch( + "sagemaker.workflow.steps.Properties", + MagicMock(return_value=MockProperties(step_name="MyStep")), +) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +def test_xgboost_processor(): + default_args = dict( + clazz_args=dict( + role=ROLE, + framework_version="1.2-1", + py_version="py3", + sagemaker_session=PIPELINE_SESSION, + ), + func_args=dict( + code=DUMMY_S3_SCRIPT_PATH, + ), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=XGBoostProcessor, + default_args=default_args, + ) + test_template.check_compatibility() + + +# TODO: need to merge a fix from Jerry from latest sdk master branch to unblock +# @patch("sagemaker.workflow.steps.Properties", MagicMock(return_value=MockProperties(step_name="MyStep"))) +# @patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +# def test_spark_jar_processor(): +# # takes a really long time, since the .run has many args +# default_args = dict( +# clazz_args=dict( +# role=ROLE, +# framework_version="2.4", +# py_version="py37", +# sagemaker_session=PIPELINE_SESSION, +# ), +# func_args=dict( +# submit_app=DUMMY_S3_SCRIPT_PATH, +# ), +# ) +# test_template = PipelineVarCompatiTestTemplate( +# clazz=SparkJarProcessor, +# default_args=default_args, +# ) +# test_template.check_compatibility() + + +# @patch("sagemaker.workflow.steps.Properties", MagicMock(return_value=MockProperties(step_name="MyStep"))) +# @patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +# def test_py_spark_processor(): +# # takes a really long time, since the .run has many args +# default_args = dict( +# clazz_args=dict( +# role=ROLE, +# framework_version="2.4", +# py_version="py37", +# sagemaker_session=PIPELINE_SESSION, +# ), +# func_args=dict( +# submit_app=DUMMY_S3_SCRIPT_PATH, +# ), +# ) +# test_template = PipelineVarCompatiTestTemplate( +# clazz=PySparkProcessor, +# default_args=default_args, +# ) +# test_template.check_compatibility() + + +@patch( + "sagemaker.workflow.steps.Properties", + MagicMock(return_value=MockProperties(step_name="MyStep")), +) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +def test_mxnet_processor(): + # takes a really long time, since the .run has many args + default_args = dict( + clazz_args=dict( + role=ROLE, + framework_version="1.6", + py_version="py3", + sagemaker_session=PIPELINE_SESSION, + ), + func_args=dict( + code=DUMMY_S3_SCRIPT_PATH, + ), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=MXNetProcessor, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch( + "sagemaker.workflow.steps.Properties", + MagicMock(return_value=MockProperties(step_name="MyStep")), +) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +def test_sklearn_processor(): + default_args = dict( + clazz_args=dict( + role=ROLE, + framework_version="0.23-1", + sagemaker_session=PIPELINE_SESSION, + instance_type="ml.m5.xlarge", + ), + func_args=dict( + code=DUMMY_S3_SCRIPT_PATH, + ), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=SKLearnProcessor, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch( + "sagemaker.workflow.steps.Properties", + MagicMock(return_value=MockProperties(step_name="MyStep")), +) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +def test_hugging_face_processor(): + default_args = dict( + clazz_args=dict( + role=ROLE, + sagemaker_session=PIPELINE_SESSION, + transformers_version="4.6", + tensorflow_version="2.4" if _IS_TRUE else None, + pytorch_version="1.8" if not _IS_TRUE else None, + py_version="py37" if _IS_TRUE else "py36", + instance_type="ml.p3.xlarge", + ), + func_args=dict( + code=DUMMY_S3_SCRIPT_PATH, + ), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=HuggingFaceProcessor, + default_args=default_args, + ) + test_template.check_compatibility() diff --git a/tests/unit/sagemaker/workflow/test_mechanism/test_entries/test_pipeline_var_compatibility_with_transformer.py b/tests/unit/sagemaker/workflow/test_mechanism/test_entries/test_pipeline_var_compatibility_with_transformer.py new file mode 100644 index 0000000000..51c9f2799e --- /dev/null +++ b/tests/unit/sagemaker/workflow/test_mechanism/test_entries/test_pipeline_var_compatibility_with_transformer.py @@ -0,0 +1,41 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from unittest.mock import MagicMock, patch + +from sagemaker.transformer import Transformer +from tests.unit.sagemaker.workflow.test_mechanism.test_code.test_pipeline_var_compatibility_template import ( + PipelineVarCompatiTestTemplate, +) +from tests.unit.sagemaker.workflow.test_mechanism.test_code import MockProperties + + +# These tests provide the incomplete default arg dict +# within which some class or target func parameters are missing. +# The test template will fill in those missing args +# Note: the default args should not include PipelineVariable objects +@patch( + "sagemaker.workflow.steps.Properties", + MagicMock(return_value=MockProperties(step_name="MyStep")), +) +def test_transformer_compatibility(): + default_args = dict( + clazz_args=dict(), + func_args=dict(), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=Transformer, + default_args=default_args, + ) + test_template.check_compatibility() diff --git a/tests/unit/sagemaker/workflow/test_mechanism/test_entries/test_pipeline_var_compatibility_with_tuner.py b/tests/unit/sagemaker/workflow/test_mechanism/test_entries/test_pipeline_var_compatibility_with_tuner.py new file mode 100644 index 0000000000..868cfa7344 --- /dev/null +++ b/tests/unit/sagemaker/workflow/test_mechanism/test_entries/test_pipeline_var_compatibility_with_tuner.py @@ -0,0 +1,47 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from unittest.mock import patch, MagicMock + +from sagemaker.tuner import HyperparameterTuner +from tests.unit.sagemaker.workflow.test_mechanism.test_code.test_pipeline_var_compatibility_template import ( + PipelineVarCompatiTestTemplate, +) +from tests.unit.sagemaker.workflow.test_mechanism.test_code import MockProperties +from tests.unit.sagemaker.workflow.test_mechanism.test_code.utilities import ( + mock_image_uris_retrieve, + mock_tar_and_upload_dir, +) + + +# These tests provide the incomplete default arg dict +# within which some class or target func parameters are missing. +# The test template will fill in those missing args +# Note: the default args should not include PipelineVariable objects +@patch( + "sagemaker.workflow.steps.Properties", + MagicMock(return_value=MockProperties(step_name="MyStep")), +) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +def test_hyperparametertuner_compatibility(): + default_args = dict( + clazz_args=dict(), + func_args=dict(), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=HyperparameterTuner, + default_args=default_args, + ) + test_template.check_compatibility() diff --git a/tests/unit/sagemaker/workflow/test_training_step.py b/tests/unit/sagemaker/workflow/test_training_step.py index 4fe36dec00..502a7674f8 100644 --- a/tests/unit/sagemaker/workflow/test_training_step.py +++ b/tests/unit/sagemaker/workflow/test_training_step.py @@ -274,7 +274,9 @@ def test_training_step_with_estimator(pipeline_session, training_input, hyperpar "DependsOn": ["TestStep"], "Arguments": step_args.args, } - assert step_train.properties.TrainingJobName.expr == {"Get": "Steps.MyTrainingStep.TrainingJobName"} + assert step_train.properties.TrainingJobName.expr == { + "Get": "Steps.MyTrainingStep.TrainingJobName" + } adjacency_list = PipelineGraph.from_pipeline(pipeline).adjacency_list assert ordered(adjacency_list) == ordered( {"MyTrainingStep": [], "SecondTestStep": ["MyTrainingStep"], "TestStep": ["MyTrainingStep"]} @@ -352,9 +354,6 @@ def test_training_step_with_framework_estimator( estimator.sagemaker_session = pipeline_session step_args = estimator.fit(inputs=TrainingInput(s3_data=training_input)) - from sagemaker.workflow.retry import SageMakerJobStepRetryPolicy, SageMakerJobExceptionTypeEnum - from sagemaker.workflow.parameters import ParameterInteger - step = TrainingStep( name="MyTrainingStep", step_args=step_args, From dcdf66f9cab43a6fee561411efbadec2db61a4df Mon Sep 17 00:00:00 2001 From: Dewen Qi Date: Thu, 4 Aug 2022 19:08:23 -0700 Subject: [PATCH 4/6] go with model base and tf --- src/sagemaker/estimator.py | 7 +++++-- src/sagemaker/fw_utils.py | 8 ++++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index fe87186646..a4b769a306 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -50,6 +50,7 @@ validate_source_code_input_against_pipeline_variables, ) from sagemaker.inputs import TrainingInput, FileSystemInput +from sagemaker.instance_group import InstanceGroup from sagemaker.job import _Job from sagemaker.jumpstart.utils import ( add_jumpstart_tags, @@ -149,7 +150,7 @@ def __init__( code_location: Optional[str] = None, entry_point: Optional[Union[str, PipelineVariable]] = None, dependencies: Optional[List[Union[str]]] = None, - instance_groups: Optional[Dict[str, Union[str, int]]] = None, + instance_groups: Optional[List[InstanceGroup]] = None, **kwargs, ): """Initialize an ``EstimatorBase`` instance. @@ -1580,6 +1581,8 @@ def _get_instance_type(self): for instance_group in self.instance_groups: instance_type = instance_group.instance_type + if is_pipeline_variable(instance_type): + continue match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type) if match: @@ -2179,7 +2182,7 @@ def __init__( code_location: Optional[str] = None, entry_point: Optional[Union[str, PipelineVariable]] = None, dependencies: Optional[List[str]] = None, - instance_groups: Optional[Dict[str, Union[str, int]]] = None, + instance_groups: Optional[List[InstanceGroup]] = None, **kwargs, ): """Initialize an ``Estimator`` instance. diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index 66d1ed80f3..47af026842 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -871,6 +871,14 @@ def validate_distribution_instance(sagemaker_session, distribution, instance_typ # Strategy modelparallel is not enabled return + if is_pipeline_variable(instance_type): + logger.warning( + "instance_type is a pipeline variable, which is only interpreted in " + "pipeline execution time. As modelparallel only runs on GPU-enabled " + "instances, in execution time, the specified instance type has to support GPU." + ) + return + instance_desc = sagemaker_session.boto_session.client("ec2").describe_instance_types( InstanceTypes=[f"{instance_type}"] ) From bcb4e4c828ab37a2409b757562cf3b2891fe2666 Mon Sep 17 00:00:00 2001 From: Dewen Qi Date: Thu, 4 Aug 2022 19:09:26 -0700 Subject: [PATCH 5/6] go with estimator subclasses --- src/sagemaker/pytorch/estimator.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/pytorch/estimator.py b/src/sagemaker/pytorch/estimator.py index 55a4fff89d..c904bea44d 100644 --- a/src/sagemaker/pytorch/estimator.py +++ b/src/sagemaker/pytorch/estimator.py @@ -31,6 +31,7 @@ from sagemaker.pytorch import defaults from sagemaker.pytorch.model import PyTorchModel from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT +from sagemaker.workflow import is_pipeline_variable from sagemaker.workflow.entities import PipelineVariable logger = logging.getLogger("sagemaker") @@ -51,7 +52,7 @@ def __init__( source_dir: Optional[Union[str, PipelineVariable]] = None, hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, image_uri: Optional[Union[str, PipelineVariable]] = None, - distribution: Dict = None, + distribution: Optional[Dict] = None, **kwargs ): """This ``Estimator`` executes a PyTorch script in a managed PyTorch execution environment. @@ -224,7 +225,7 @@ def __init__( if distribution is not None: instance_type = self._get_instance_type() # remove "ml." prefix - if instance_type[:3] == "ml.": + if not is_pipeline_variable(instance_type) and instance_type[:3] == "ml.": instance_type = instance_type[3:] validate_distribution_instance(self.sagemaker_session, distribution, instance_type) From 46bba6883a6b4993e9d8dd28a9257e06c10d62c4 Mon Sep 17 00:00:00 2001 From: Dewen Qi Date: Thu, 4 Aug 2022 19:57:46 -0700 Subject: [PATCH 6/6] update TM as per latest estimator changes --- src/sagemaker/instance_group.py | 10 ++-- .../test_mechanism/test_code/__init__.py | 53 ++++++++++++++++--- ...est_pipeline_var_compatibility_template.py | 13 +++-- ...eline_var_compatibility_with_estimators.py | 1 - 4 files changed, 64 insertions(+), 13 deletions(-) diff --git a/src/sagemaker/instance_group.py b/src/sagemaker/instance_group.py index 5042787be5..8669e8b706 100644 --- a/src/sagemaker/instance_group.py +++ b/src/sagemaker/instance_group.py @@ -13,15 +13,19 @@ """Defines the InstanceGroup class that configures a heterogeneous cluster.""" from __future__ import absolute_import +from typing import Optional, Union + +from sagemaker.workflow.entities import PipelineVariable + class InstanceGroup(object): """The class to create instance groups for a heterogeneous cluster.""" def __init__( self, - instance_group_name=None, - instance_type=None, - instance_count=None, + instance_group_name: Optional[Union[str, PipelineVariable]] = None, + instance_type: Optional[Union[str, PipelineVariable]] = None, + instance_count: Optional[Union[int, PipelineVariable]] = None, ): """It initializes an ``InstanceGroup`` instance. diff --git a/tests/unit/sagemaker/workflow/test_mechanism/test_code/__init__.py b/tests/unit/sagemaker/workflow/test_mechanism/test_code/__init__.py index 4a9b2a37b0..0ad89c662c 100644 --- a/tests/unit/sagemaker/workflow/test_mechanism/test_code/__init__.py +++ b/tests/unit/sagemaker/workflow/test_mechanism/test_code/__init__.py @@ -19,6 +19,7 @@ from sagemaker import ModelMetrics, MetricsSource, FileSource, Predictor from sagemaker.drift_check_baselines import DriftCheckBaselines +from sagemaker.instance_group import InstanceGroup from sagemaker.metadata_properties import MetadataProperties from sagemaker.model import FrameworkModel from sagemaker.parameter import IntegerParameter @@ -233,6 +234,10 @@ def _generate_all_pipeline_vars() -> dict: ) +# TODO: we should remove the _IS_TRUE_TMP and replace its usages with IS_TRUE +# As currently the `instance_groups` does not work well with some estimator subclasses, +# we temporarily hard code it to False which disables the instance_groups +_IS_TRUE_TMP = False IS_TRUE = bool(getrandbits(1)) PIPELINE_SESSION = _generate_mock_pipeline_session() PIPELINE_VARIABLES = _generate_all_pipeline_vars() @@ -240,7 +245,6 @@ def _generate_all_pipeline_vars() -> dict: # TODO: need to recursively assign with Pipeline Variable in later changes FIXED_ARGUMENTS = dict( common=dict( - instance_type=INSTANCE_TYPE, role=ROLE, sagemaker_session=PIPELINE_SESSION, source_dir=f"s3://{BUCKET}/source", @@ -281,6 +285,7 @@ def _generate_all_pipeline_vars() -> dict: response_types=["application/json"], ), processor=dict( + instance_type=INSTANCE_TYPE, estimator_cls=PyTorch, code=f"s3://{BUCKET}/code", spark_event_logs_s3_uri=f"s3://{BUCKET}/my-spark-output-path", @@ -438,13 +443,33 @@ def _generate_all_pipeline_vars() -> dict: input_mode=ParameterString(name="train_inputs_input_mode"), attribute_names=[ParameterString(name="train_inputs_attribute_name")], target_attribute_name=ParameterString(name="train_inputs_target_attr_name"), + instance_groups=[ParameterString(name="train_inputs_instance_groups")], ), }, + instance_groups=[ + InstanceGroup( + instance_group_name=ParameterString(name="instance_group_name"), + # hard code the instance_type here because InstanceGroup.instance_type + # would be used to retrieve image_uri if image_uri is not presented + # and currently the test mechanism does not support skip the test case + # relating to bonded parameters in composite variables (i.e. the InstanceGroup) + # TODO: we should support skip testing on bonded parameters in composite vars + instance_type="ml.m5.xlarge", + instance_count=ParameterString(name="instance_group_instance_count"), + ), + ] + if _IS_TRUE_TMP + else None, + instance_type="ml.m5.xlarge" if not _IS_TRUE_TMP else None, + instance_count=1 if not _IS_TRUE_TMP else None, + distribution={} if not _IS_TRUE_TMP else None, ), transformer=dict( + instance_type=INSTANCE_TYPE, data=f"s3://{BUCKET}/data", ), tuner=dict( + instance_type=INSTANCE_TYPE, estimator=TensorFlow( entry_point=TENSORFLOW_ENTRY_POINT, role=ROLE, @@ -475,12 +500,14 @@ def _generate_all_pipeline_vars() -> dict: include_cls_metadata={"estimator-1": IS_TRUE}, ), model=dict( + instance_type=INSTANCE_TYPE, serverless_inference_config=ServerlessInferenceConfig(), framework_version="1.11.0", py_version="py3", accelerator_type="ml.eia2.xlarge", ), pipelinemodel=dict( + instance_type=INSTANCE_TYPE, models=[ SparkMLModel( name="MySparkMLModel", @@ -577,12 +604,17 @@ def _generate_all_pipeline_vars() -> dict: }, ), ) -# A dict to keep the optional arguments which should not be None according to the logic -# specific to the subclass. +# A dict to keep the optional arguments which should not be set to None +# in the test iteration according to the logic specific to the subclass. PARAMS_SHOULD_NOT_BE_NONE = dict( estimator=dict( init=dict( - common={"instance_count", "instance_type"}, + # TODO: we should remove the three instance_ parameters here + # For mutually exclusive parameters: instance group + # vs instance count/instance type, if any side is set to None during iteration, + # the other side should get a not None value, instead of listing them here + # and force them to be not None + common={"instance_count", "instance_type", "instance_groups"}, LDA={"mini_batch_size"}, ) ), @@ -692,7 +724,10 @@ def _generate_all_pipeline_vars() -> dict: ), estimator=dict( init=dict( - common=dict(), + common=dict( + entry_point={"enable_network_isolation"}, + source_dir={"enable_network_isolation"}, + ), TensorFlow=dict( image_uri={"compiler_config"}, compiler_config={"image_uri"}, @@ -701,7 +736,13 @@ def _generate_all_pipeline_vars() -> dict: image_uri={"compiler_config"}, compiler_config={"image_uri"}, ), - ) + ), + fit=dict( + common=dict( + instance_count={"instance_groups"}, + instance_type={"instance_groups"}, + ), + ), ), ) diff --git a/tests/unit/sagemaker/workflow/test_mechanism/test_code/test_pipeline_var_compatibility_template.py b/tests/unit/sagemaker/workflow/test_mechanism/test_code/test_pipeline_var_compatibility_template.py index c9867a3020..ca64ad871a 100644 --- a/tests/unit/sagemaker/workflow/test_mechanism/test_code/test_pipeline_var_compatibility_template.py +++ b/tests/unit/sagemaker/workflow/test_mechanism/test_code/test_pipeline_var_compatibility_template.py @@ -15,7 +15,7 @@ import json from random import getrandbits -from typing import Optional +from typing import Optional, List from typing_extensions import get_origin from sagemaker import Model, PipelineModel, AlgorithmEstimator @@ -368,14 +368,14 @@ def _verify_composite_object_against_pipeline_var( self, param_with_none: str, step_dsl: str, - step_dsl_obj: object, + step_dsl_obj: List[dict], ): """verify pipeline definition regarding composite objects against pipeline variables Args: param_with_none (str): The name of the parameter with None value. step_dsl (str): The step definition retrieved from the pipeline definition DSL. - step_dsl_obj (objet): The json load object of the step definition. + step_dsl_obj (List[dict]): The json load object of the step definition. """ # TODO: remove the following hard code assertion once recursive assignment is added if issubclass(self.clazz, Processor): @@ -398,6 +398,12 @@ def _verify_composite_object_against_pipeline_var( assert '{"Get": "Parameters.proc_input_s3_data_type"}' in step_dsl assert '{"Get": "Parameters.proc_input_app_managed"}' in step_dsl elif issubclass(self.clazz, EstimatorBase): + if ( + param_with_none != "instance_groups" + and self.default_args[CLAZZ_ARGS]["instance_groups"] + ): + assert '{"Get": "Parameters.instance_group_name"}' in step_dsl + assert '{"Get": "Parameters.instance_group_instance_count"}' in step_dsl if issubclass(self.clazz, AmazonAlgorithmEstimatorBase): # AmazonAlgorithmEstimatorBase's input is records if param_with_none != "records": @@ -415,6 +421,7 @@ def _verify_composite_object_against_pipeline_var( assert '{"Get": "Parameters.train_inputs_input_mode"}' in step_dsl assert '{"Get": "Parameters.train_inputs_attribute_name"}' in step_dsl assert '{"Get": "Parameters.train_inputs_target_attr_name"}' in step_dsl + assert '{"Get": "Parameters.train_inputs_instance_groups"}' in step_dsl if not issubclass(self.clazz, (TensorFlow, MXNet, PyTorch, AlgorithmEstimator)): # debugger_hook_config may be disabled for these first 3 frameworks # AlgorithmEstimator ignores the kwargs diff --git a/tests/unit/sagemaker/workflow/test_mechanism/test_entries/test_pipeline_var_compatibility_with_estimators.py b/tests/unit/sagemaker/workflow/test_mechanism/test_entries/test_pipeline_var_compatibility_with_estimators.py index e4e3a74b98..f749118b6a 100644 --- a/tests/unit/sagemaker/workflow/test_mechanism/test_entries/test_pipeline_var_compatibility_with_estimators.py +++ b/tests/unit/sagemaker/workflow/test_mechanism/test_entries/test_pipeline_var_compatibility_with_estimators.py @@ -291,7 +291,6 @@ def test_sklearn_estimator_compatibility(): clazz_args=dict( py_version="py3", instance_count=1, - instance_type="ml.m5.xlarge", framework_version="0.20.0", ), func_args=dict(),