diff --git a/src/sagemaker/algorithm.py b/src/sagemaker/algorithm.py index 1ab5ee3bcf..0ddf9fae78 100644 --- a/src/sagemaker/algorithm.py +++ b/src/sagemaker/algorithm.py @@ -13,15 +13,22 @@ """Test docstring""" from __future__ import absolute_import +from typing import Optional, Union, Dict, List + import sagemaker import sagemaker.parameter from sagemaker import vpc_utils from sagemaker.deserializers import BytesDeserializer from sagemaker.deprecations import removed_kwargs from sagemaker.estimator import EstimatorBase +from sagemaker.inputs import TrainingInput, FileSystemInput from sagemaker.serializers import IdentitySerializer from sagemaker.transformer import Transformer from sagemaker.predictor import Predictor +from sagemaker.session import Session +from sagemaker.workflow.entities import PipelineVariable + +from sagemaker.workflow import is_pipeline_variable class AlgorithmEstimator(EstimatorBase): @@ -37,28 +44,28 @@ class AlgorithmEstimator(EstimatorBase): def __init__( self, - algorithm_arn, - role, - instance_count, - instance_type, - volume_size=30, - volume_kms_key=None, - max_run=24 * 60 * 60, - input_mode="File", - output_path=None, - output_kms_key=None, - base_job_name=None, - sagemaker_session=None, - hyperparameters=None, - tags=None, - subnets=None, - security_group_ids=None, - model_uri=None, - model_channel_name="model", - metric_definitions=None, - encrypt_inter_container_traffic=False, - use_spot_instances=False, - max_wait=None, + algorithm_arn: str, + role: str, + instance_count: Optional[Union[int, PipelineVariable]] = None, + instance_type: Optional[Union[str, PipelineVariable]] = None, + volume_size: Union[int, PipelineVariable] = 30, + volume_kms_key: Optional[Union[str, PipelineVariable]] = None, + max_run: Union[int, PipelineVariable] = 24 * 60 * 60, + input_mode: Union[str, PipelineVariable] = "File", + output_path: Optional[Union[str, PipelineVariable]] = None, + output_kms_key: Optional[Union[str, PipelineVariable]] = None, + base_job_name: Optional[str] = None, + sagemaker_session: Optional[Session] = None, + hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, + subnets: Optional[List[Union[str, PipelineVariable]]] = None, + security_group_ids: Optional[List[Union[str, PipelineVariable]]] = None, + model_uri: Optional[str] = None, + model_channel_name: Union[str, PipelineVariable] = "model", + metric_definitions: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, + encrypt_inter_container_traffic: Union[bool, PipelineVariable] = False, + use_spot_instances: Union[bool, PipelineVariable] = False, + max_wait: Union[int, PipelineVariable] = None, **kwargs # pylint: disable=W0613 ): """Initialize an ``AlgorithmEstimator`` instance. @@ -186,7 +193,7 @@ def validate_train_spec(self): # Check that the input mode provided is compatible with the training input modes for the # algorithm. input_modes = self._algorithm_training_input_modes(train_spec["TrainingChannels"]) - if self.input_mode not in input_modes: + if not is_pipeline_variable(self.input_mode) and self.input_mode not in input_modes: raise ValueError( "Invalid input mode: %s. %s only supports: %s" % (self.input_mode, algorithm_name, input_modes) @@ -194,14 +201,17 @@ def validate_train_spec(self): # Check that the training instance type is compatible with the algorithm. supported_instances = train_spec["SupportedTrainingInstanceTypes"] - if self.instance_type not in supported_instances: + if ( + not is_pipeline_variable(self.instance_type) + and self.instance_type not in supported_instances + ): raise ValueError( "Invalid instance_type: %s. %s supports the following instance types: %s" % (self.instance_type, algorithm_name, supported_instances) ) # Verify if distributed training is supported by the algorithm - if ( + if not is_pipeline_variable(self.instance_count) and ( self.instance_count > 1 and "SupportsDistributedTraining" in train_spec and not train_spec["SupportsDistributedTraining"] @@ -414,12 +424,18 @@ def _prepare_for_training(self, job_name=None): super(AlgorithmEstimator, self)._prepare_for_training(job_name) - def fit(self, inputs=None, wait=True, logs=True, job_name=None): + def fit( + self, + inputs: Optional[Union[str, Dict, TrainingInput, FileSystemInput]] = None, + wait: bool = True, + logs: bool = True, + job_name: Optional[str] = None, + ): """Placeholder docstring""" if inputs: self._validate_input_channels(inputs) - super(AlgorithmEstimator, self).fit(inputs, wait, logs, job_name) + return super(AlgorithmEstimator, self).fit(inputs, wait, logs, job_name) def _validate_input_channels(self, channels): """Placeholder docstring""" diff --git a/src/sagemaker/amazon/amazon_estimator.py b/src/sagemaker/amazon/amazon_estimator.py index 09e77d612a..dad5d54dcd 100644 --- a/src/sagemaker/amazon/amazon_estimator.py +++ b/src/sagemaker/amazon/amazon_estimator.py @@ -13,6 +13,8 @@ """Placeholder docstring""" from __future__ import absolute_import +from typing import Optional, Union, Dict + import json import logging import tempfile @@ -27,7 +29,10 @@ from sagemaker.estimator import EstimatorBase, _TrainingJob from sagemaker.inputs import FileSystemInput, TrainingInput from sagemaker.utils import sagemaker_timestamp +from sagemaker.workflow.entities import PipelineVariable from sagemaker.workflow.pipeline_context import runnable_by_pipeline +from sagemaker.workflow.parameters import ParameterBoolean +from sagemaker.workflow import is_pipeline_variable logger = logging.getLogger(__name__) @@ -40,16 +45,16 @@ class AmazonAlgorithmEstimatorBase(EstimatorBase): feature_dim = hp("feature_dim", validation.gt(0), data_type=int) mini_batch_size = hp("mini_batch_size", validation.gt(0), data_type=int) - repo_name = None - repo_version = None + repo_name: Optional[str] = None + repo_version: Optional[str] = None def __init__( self, - role, - instance_count=None, - instance_type=None, - data_location=None, - enable_network_isolation=False, + role: str, + instance_count: Optional[Union[int]] = None, + instance_type: Optional[Union[str, PipelineVariable]] = None, + data_location: Optional[str] = None, + enable_network_isolation: Union[bool, ParameterBoolean] = False, **kwargs ): """Initialize an AmazonAlgorithmEstimatorBase. @@ -113,6 +118,11 @@ def data_location(self): @data_location.setter def data_location(self, data_location): """Placeholder docstring""" + if is_pipeline_variable(data_location): + raise ValueError( + "data_location argument has to be an integer " + "rather than a pipeline variable" + ) + if not data_location.startswith("s3://"): raise ValueError( 'Expecting an S3 URL beginning with "s3://". Got "{}"'.format(data_location) @@ -196,12 +206,12 @@ def _prepare_for_training(self, records, mini_batch_size=None, job_name=None): @runnable_by_pipeline def fit( self, - records, - mini_batch_size=None, - wait=True, - logs=True, - job_name=None, - experiment_config=None, + records: "RecordSet", + mini_batch_size: Optional[int] = None, + wait: bool = True, + logs: bool = True, + job_name: Optional[str] = None, + experiment_config: Optional[Dict[str, str]] = None, ): """Fit this Estimator on serialized Record objects, stored in S3. @@ -304,7 +314,12 @@ class RecordSet(object): """Placeholder docstring""" def __init__( - self, s3_data, num_records, feature_dim, s3_data_type="ManifestFile", channel="train" + self, + s3_data: Union[str, PipelineVariable], + num_records: int, + feature_dim: int, + s3_data_type: Union[str, PipelineVariable] = "ManifestFile", + channel: Union[str, PipelineVariable] = "train", ): """A collection of Amazon :class:~`Record` objects serialized and stored in S3. diff --git a/src/sagemaker/amazon/factorization_machines.py b/src/sagemaker/amazon/factorization_machines.py index 5e9c2098b9..0d3f570e60 100644 --- a/src/sagemaker/amazon/factorization_machines.py +++ b/src/sagemaker/amazon/factorization_machines.py @@ -37,83 +37,83 @@ class FactorizationMachines(AmazonAlgorithmEstimatorBase): sparse datasets economically. """ - repo_name = "factorization-machines" - repo_version = 1 + repo_name: str = "factorization-machines" + repo_version: int = 1 - num_factors = hp("num_factors", gt(0), "An integer greater than zero", int) - predictor_type = hp( + num_factors: hp = hp("num_factors", gt(0), "An integer greater than zero", int) + predictor_type: hp = hp( "predictor_type", isin("binary_classifier", "regressor"), 'Value "binary_classifier" or "regressor"', str, ) - epochs = hp("epochs", gt(0), "An integer greater than 0", int) - clip_gradient = hp("clip_gradient", (), "A float value", float) - eps = hp("eps", (), "A float value", float) - rescale_grad = hp("rescale_grad", (), "A float value", float) - bias_lr = hp("bias_lr", ge(0), "A non-negative float", float) - linear_lr = hp("linear_lr", ge(0), "A non-negative float", float) - factors_lr = hp("factors_lr", ge(0), "A non-negative float", float) - bias_wd = hp("bias_wd", ge(0), "A non-negative float", float) - linear_wd = hp("linear_wd", ge(0), "A non-negative float", float) - factors_wd = hp("factors_wd", ge(0), "A non-negative float", float) - bias_init_method = hp( + epochs: hp = hp("epochs", gt(0), "An integer greater than 0", int) + clip_gradient: hp = hp("clip_gradient", (), "A float value", float) + eps: hp = hp("eps", (), "A float value", float) + rescale_grad: hp = hp("rescale_grad", (), "A float value", float) + bias_lr: hp = hp("bias_lr", ge(0), "A non-negative float", float) + linear_lr: hp = hp("linear_lr", ge(0), "A non-negative float", float) + factors_lr: hp = hp("factors_lr", ge(0), "A non-negative float", float) + bias_wd: hp = hp("bias_wd", ge(0), "A non-negative float", float) + linear_wd: hp = hp("linear_wd", ge(0), "A non-negative float", float) + factors_wd: hp = hp("factors_wd", ge(0), "A non-negative float", float) + bias_init_method: hp = hp( "bias_init_method", isin("normal", "uniform", "constant"), 'Value "normal", "uniform" or "constant"', str, ) - bias_init_scale = hp("bias_init_scale", ge(0), "A non-negative float", float) - bias_init_sigma = hp("bias_init_sigma", ge(0), "A non-negative float", float) - bias_init_value = hp("bias_init_value", (), "A float value", float) - linear_init_method = hp( + bias_init_scale: hp = hp("bias_init_scale", ge(0), "A non-negative float", float) + bias_init_sigma: hp = hp("bias_init_sigma", ge(0), "A non-negative float", float) + bias_init_value: hp = hp("bias_init_value", (), "A float value", float) + linear_init_method: hp = hp( "linear_init_method", isin("normal", "uniform", "constant"), 'Value "normal", "uniform" or "constant"', str, ) - linear_init_scale = hp("linear_init_scale", ge(0), "A non-negative float", float) - linear_init_sigma = hp("linear_init_sigma", ge(0), "A non-negative float", float) - linear_init_value = hp("linear_init_value", (), "A float value", float) - factors_init_method = hp( + linear_init_scale: hp = hp("linear_init_scale", ge(0), "A non-negative float", float) + linear_init_sigma: hp = hp("linear_init_sigma", ge(0), "A non-negative float", float) + linear_init_value: hp = hp("linear_init_value", (), "A float value", float) + factors_init_method: hp = hp( "factors_init_method", isin("normal", "uniform", "constant"), 'Value "normal", "uniform" or "constant"', str, ) - factors_init_scale = hp("factors_init_scale", ge(0), "A non-negative float", float) - factors_init_sigma = hp("factors_init_sigma", ge(0), "A non-negative float", float) - factors_init_value = hp("factors_init_value", (), "A float value", float) + factors_init_scale: hp = hp("factors_init_scale", ge(0), "A non-negative float", float) + factors_init_sigma: hp = hp("factors_init_sigma", ge(0), "A non-negative float", float) + factors_init_value: hp = hp("factors_init_value", (), "A float value", float) def __init__( self, - role, - instance_count=None, - instance_type=None, - num_factors=None, - predictor_type=None, - epochs=None, - clip_gradient=None, - eps=None, - rescale_grad=None, - bias_lr=None, - linear_lr=None, - factors_lr=None, - bias_wd=None, - linear_wd=None, - factors_wd=None, - bias_init_method=None, - bias_init_scale=None, - bias_init_sigma=None, - bias_init_value=None, - linear_init_method=None, - linear_init_scale=None, - linear_init_sigma=None, - linear_init_value=None, - factors_init_method=None, - factors_init_scale=None, - factors_init_sigma=None, - factors_init_value=None, + role: str, + instance_count: Optional[Union[int, PipelineVariable]] = None, + instance_type: Optional[Union[str, PipelineVariable]] = None, + num_factors: Optional[int] = None, + predictor_type: Optional[str] = None, + epochs: Optional[int] = None, + clip_gradient: Optional[float] = None, + eps: Optional[float] = None, + rescale_grad: Optional[float] = None, + bias_lr: Optional[float] = None, + linear_lr: Optional[float] = None, + factors_lr: Optional[float] = None, + bias_wd: Optional[float] = None, + linear_wd: Optional[float] = None, + factors_wd: Optional[float] = None, + bias_init_method: Optional[str] = None, + bias_init_scale: Optional[float] = None, + bias_init_sigma: Optional[float] = None, + bias_init_value: Optional[float] = None, + linear_init_method: Optional[str] = None, + linear_init_scale: Optional[float] = None, + linear_init_sigma: Optional[float] = None, + linear_init_value: Optional[float] = None, + factors_init_method: Optional[str] = None, + factors_init_scale: Optional[float] = None, + factors_init_sigma: Optional[float] = None, + factors_init_value: Optional[float] = None, **kwargs ): """Factorization Machines is :class:`Estimator` for general-purpose supervised learning. diff --git a/src/sagemaker/amazon/hyperparameter.py b/src/sagemaker/amazon/hyperparameter.py index 856927cb13..973668ed56 100644 --- a/src/sagemaker/amazon/hyperparameter.py +++ b/src/sagemaker/amazon/hyperparameter.py @@ -14,7 +14,6 @@ from __future__ import absolute_import import json - from sagemaker.workflow import is_pipeline_variable diff --git a/src/sagemaker/amazon/ipinsights.py b/src/sagemaker/amazon/ipinsights.py index 097f6b45dc..50f6f03566 100644 --- a/src/sagemaker/amazon/ipinsights.py +++ b/src/sagemaker/amazon/ipinsights.py @@ -36,45 +36,45 @@ class IPInsights(AmazonAlgorithmEstimatorBase): as user IDs or account numbers. """ - repo_name = "ipinsights" - repo_version = 1 - MINI_BATCH_SIZE = 10000 + repo_name: str = "ipinsights" + repo_version: int = 1 + MINI_BATCH_SIZE: int = 10000 - num_entity_vectors = hp( + num_entity_vectors: hp = hp( "num_entity_vectors", (ge(1), le(250000000)), "An integer in [1, 250000000]", int ) - vector_dim = hp("vector_dim", (ge(4), le(4096)), "An integer in [4, 4096]", int) + vector_dim: hp = hp("vector_dim", (ge(4), le(4096)), "An integer in [4, 4096]", int) - batch_metrics_publish_interval = hp( + batch_metrics_publish_interval: hp = hp( "batch_metrics_publish_interval", (ge(1)), "An integer greater than 0", int ) - epochs = hp("epochs", (ge(1)), "An integer greater than 0", int) - learning_rate = hp("learning_rate", (ge(1e-6), le(10.0)), "A float in [1e-6, 10.0]", float) - num_ip_encoder_layers = hp( + epochs: hp = hp("epochs", (ge(1)), "An integer greater than 0", int) + learning_rate: hp = hp("learning_rate", (ge(1e-6), le(10.0)), "A float in [1e-6, 10.0]", float) + num_ip_encoder_layers: hp = hp( "num_ip_encoder_layers", (ge(0), le(100)), "An integer in [0, 100]", int ) - random_negative_sampling_rate = hp( + random_negative_sampling_rate: hp = hp( "random_negative_sampling_rate", (ge(0), le(500)), "An integer in [0, 500]", int ) - shuffled_negative_sampling_rate = hp( + shuffled_negative_sampling_rate: hp = hp( "shuffled_negative_sampling_rate", (ge(0), le(500)), "An integer in [0, 500]", int ) - weight_decay = hp("weight_decay", (ge(0.0), le(10.0)), "A float in [0.0, 10.0]", float) + weight_decay: hp = hp("weight_decay", (ge(0.0), le(10.0)), "A float in [0.0, 10.0]", float) def __init__( self, - role, - instance_count=None, - instance_type=None, - num_entity_vectors=None, - vector_dim=None, - batch_metrics_publish_interval=None, - epochs=None, - learning_rate=None, - num_ip_encoder_layers=None, - random_negative_sampling_rate=None, - shuffled_negative_sampling_rate=None, - weight_decay=None, + role: str, + instance_count: Optional[Union[int, PipelineVariable]] = None, + instance_type: Optional[Union[str, PipelineVariable]] = None, + num_entity_vectors: Optional[int] = None, + vector_dim: Optional[int] = None, + batch_metrics_publish_interval: Optional[int] = None, + epochs: Optional[int] = None, + learning_rate: Optional[float] = None, + num_ip_encoder_layers: Optional[int] = None, + random_negative_sampling_rate: Optional[int] = None, + shuffled_negative_sampling_rate: Optional[int] = None, + weight_decay: Optional[float] = None, **kwargs ): """This estimator is for IP Insights. diff --git a/src/sagemaker/amazon/kmeans.py b/src/sagemaker/amazon/kmeans.py index 581e93e02a..1b925af6e4 100644 --- a/src/sagemaker/amazon/kmeans.py +++ b/src/sagemaker/amazon/kmeans.py @@ -13,7 +13,7 @@ """Placeholder docstring""" from __future__ import absolute_import -from typing import Union, Optional +from typing import Union, Optional, List from sagemaker import image_uris from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase @@ -36,23 +36,25 @@ class KMeans(AmazonAlgorithmEstimatorBase): the algorithm to use to determine similarity. """ - repo_name = "kmeans" - repo_version = 1 + repo_name: str = "kmeans" + repo_version: str = 1 - k = hp("k", gt(1), "An integer greater-than 1", int) - init_method = hp("init_method", isin("random", "kmeans++"), 'One of "random", "kmeans++"', str) - max_iterations = hp("local_lloyd_max_iter", gt(0), "An integer greater-than 0", int) - tol = hp("local_lloyd_tol", (ge(0), le(1)), "An float in [0, 1]", float) - num_trials = hp("local_lloyd_num_trials", gt(0), "An integer greater-than 0", int) - local_init_method = hp( + k: hp = hp("k", gt(1), "An integer greater-than 1", int) + init_method: hp = hp( + "init_method", isin("random", "kmeans++"), 'One of "random", "kmeans++"', str + ) + max_iterations: hp = hp("local_lloyd_max_iter", gt(0), "An integer greater-than 0", int) + tol: hp = hp("local_lloyd_tol", (ge(0), le(1)), "An float in [0, 1]", float) + num_trials: hp = hp("local_lloyd_num_trials", gt(0), "An integer greater-than 0", int) + local_init_method: hp = hp( "local_lloyd_init_method", isin("random", "kmeans++"), 'One of "random", "kmeans++"', str ) - half_life_time_size = hp( + half_life_time_size: hp = hp( "half_life_time_size", ge(0), "An integer greater-than-or-equal-to 0", int ) - epochs = hp("epochs", gt(0), "An integer greater-than 0", int) - center_factor = hp("extra_center_factor", gt(0), "An integer greater-than 0", int) - eval_metrics = hp( + epochs: hp = hp("epochs", gt(0), "An integer greater-than 0", int) + center_factor: hp = hp("extra_center_factor", gt(0), "An integer greater-than 0", int) + eval_metrics: hp = hp( name="eval_metrics", validation_message='A comma separated list of "msd" or "ssd"', data_type=list, @@ -60,19 +62,19 @@ class KMeans(AmazonAlgorithmEstimatorBase): def __init__( self, - role, - instance_count=None, - instance_type=None, - k=None, - init_method=None, - max_iterations=None, - tol=None, - num_trials=None, - local_init_method=None, - half_life_time_size=None, - epochs=None, - center_factor=None, - eval_metrics=None, + role: str, + instance_count: Optional[Union[int, PipelineVariable]] = None, + instance_type: Optional[Union[str, PipelineVariable]] = None, + k: Optional[int] = None, + init_method: Optional[str] = None, + max_iterations: Optional[int] = None, + tol: Optional[float] = None, + num_trials: Optional[int] = None, + local_init_method: Optional[str] = None, + half_life_time_size: Optional[int] = None, + epochs: Optional[int] = None, + center_factor: Optional[int] = None, + eval_metrics: Optional[List[str]] = None, **kwargs ): """A k-means clustering class :class:`~sagemaker.amazon.AmazonAlgorithmEstimatorBase`. diff --git a/src/sagemaker/amazon/knn.py b/src/sagemaker/amazon/knn.py index 14ba404ebf..2054c36483 100644 --- a/src/sagemaker/amazon/knn.py +++ b/src/sagemaker/amazon/knn.py @@ -37,54 +37,54 @@ class KNN(AmazonAlgorithmEstimatorBase): the average of their feature values as the predicted value. """ - repo_name = "knn" - repo_version = 1 + repo_name: str = "knn" + repo_version: int = 1 - k = hp("k", (ge(1)), "An integer greater than 0", int) - sample_size = hp("sample_size", (ge(1)), "An integer greater than 0", int) - predictor_type = hp( + k: hp = hp("k", (ge(1)), "An integer greater than 0", int) + sample_size: hp = hp("sample_size", (ge(1)), "An integer greater than 0", int) + predictor_type: hp = hp( "predictor_type", isin("classifier", "regressor"), 'One of "classifier" or "regressor"', str ) - dimension_reduction_target = hp( + dimension_reduction_target: hp = hp( "dimension_reduction_target", (ge(1)), "An integer greater than 0 and less than feature_dim", int, ) - dimension_reduction_type = hp( + dimension_reduction_type: hp = hp( "dimension_reduction_type", isin("sign", "fjlt"), 'One of "sign" or "fjlt"', str ) - index_metric = hp( + index_metric: hp = hp( "index_metric", isin("COSINE", "INNER_PRODUCT", "L2"), 'One of "COSINE", "INNER_PRODUCT", "L2"', str, ) - index_type = hp( + index_type: hp = hp( "index_type", isin("faiss.Flat", "faiss.IVFFlat", "faiss.IVFPQ"), 'One of "faiss.Flat", "faiss.IVFFlat", "faiss.IVFPQ"', str, ) - faiss_index_ivf_nlists = hp( + faiss_index_ivf_nlists: hp = hp( "faiss_index_ivf_nlists", (), '"auto" or an integer greater than 0', str ) - faiss_index_pq_m = hp("faiss_index_pq_m", (ge(1)), "An integer greater than 0", int) + faiss_index_pq_m: hp = hp("faiss_index_pq_m", (ge(1)), "An integer greater than 0", int) def __init__( self, - role, - instance_count=None, - instance_type=None, - k=None, - sample_size=None, - predictor_type=None, - dimension_reduction_type=None, - dimension_reduction_target=None, - index_type=None, - index_metric=None, - faiss_index_ivf_nlists=None, - faiss_index_pq_m=None, + role: str, + instance_count: Optional[Union[int, PipelineVariable]] = None, + instance_type: Optional[Union[str, PipelineVariable]] = None, + k: Optional[int] = None, + sample_size: Optional[int] = None, + predictor_type: Optional[str] = None, + dimension_reduction_type: Optional[str] = None, + dimension_reduction_target: Optional[int] = None, + index_type: Optional[str] = None, + index_metric: Optional[str] = None, + faiss_index_ivf_nlists: Optional[str] = None, + faiss_index_pq_m: Optional[int] = None, **kwargs ): """k-nearest neighbors (KNN) is :class:`Estimator` used for classification and regression. @@ -158,6 +158,7 @@ def __init__( self.index_metric = index_metric self.faiss_index_ivf_nlists = faiss_index_ivf_nlists self.faiss_index_pq_m = faiss_index_pq_m + if dimension_reduction_type and not dimension_reduction_target: raise ValueError( '"dimension_reduction_target" is required when "dimension_reduction_type" is set.' diff --git a/src/sagemaker/amazon/lda.py b/src/sagemaker/amazon/lda.py index 4158b6cc27..56b621dcb0 100644 --- a/src/sagemaker/amazon/lda.py +++ b/src/sagemaker/amazon/lda.py @@ -26,6 +26,7 @@ from sagemaker.utils import pop_out_unused_kwarg from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT from sagemaker.workflow.entities import PipelineVariable +from sagemaker.workflow import is_pipeline_variable class LDA(AmazonAlgorithmEstimatorBase): @@ -37,24 +38,24 @@ class LDA(AmazonAlgorithmEstimatorBase): word, and the categories are the topics. """ - repo_name = "lda" - repo_version = 1 + repo_name: str = "lda" + repo_version: int = 1 - num_topics = hp("num_topics", gt(0), "An integer greater than zero", int) - alpha0 = hp("alpha0", gt(0), "A positive float", float) - max_restarts = hp("max_restarts", gt(0), "An integer greater than zero", int) - max_iterations = hp("max_iterations", gt(0), "An integer greater than zero", int) - tol = hp("tol", gt(0), "A positive float", float) + num_topics: hp = hp("num_topics", gt(0), "An integer greater than zero", int) + alpha0: hp = hp("alpha0", gt(0), "A positive float", float) + max_restarts: hp = hp("max_restarts", gt(0), "An integer greater than zero", int) + max_iterations: hp = hp("max_iterations", gt(0), "An integer greater than zero", int) + tol: hp = hp("tol", gt(0), "A positive float", float) def __init__( self, - role, - instance_type=None, - num_topics=None, - alpha0=None, - max_restarts=None, - max_iterations=None, - tol=None, + role: str, + instance_type: Optional[Union[str, PipelineVariable]] = None, + num_topics: Optional[int] = None, + alpha0: Optional[float] = None, + max_restarts: Optional[int] = None, + max_iterations: Optional[int] = None, + tol: Optional[float] = None, **kwargs ): """Latent Dirichlet Allocation (LDA) is :class:`Estimator` used for unsupervised learning. @@ -124,7 +125,8 @@ def __init__( :class:`~sagemaker.estimator.EstimatorBase`. """ # this algorithm only supports single instance training - if kwargs.pop("instance_count", 1) != 1: + instance_count = kwargs.pop("instance_count", 1) + if is_pipeline_variable(instance_count) or instance_count != 1: print( "LDA only supports single instance training. Defaulting to 1 {}.".format( instance_type diff --git a/src/sagemaker/amazon/linear_learner.py b/src/sagemaker/amazon/linear_learner.py index d02ed2875f..50d2c33d03 100644 --- a/src/sagemaker/amazon/linear_learner.py +++ b/src/sagemaker/amazon/linear_learner.py @@ -13,6 +13,7 @@ """Placeholder docstring""" from __future__ import absolute_import +import logging from typing import Union, Optional from sagemaker import image_uris @@ -26,6 +27,9 @@ from sagemaker.utils import pop_out_unused_kwarg from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT from sagemaker.workflow.entities import PipelineVariable +from sagemaker.workflow import is_pipeline_variable + +logger = logging.getLogger(__name__) class LinearLearner(AmazonAlgorithmEstimatorBase): @@ -39,12 +43,12 @@ class LinearLearner(AmazonAlgorithmEstimatorBase): of the label y. """ - repo_name = "linear-learner" - repo_version = 1 + repo_name: str = "linear-learner" + repo_version: int = 1 - DEFAULT_MINI_BATCH_SIZE = 1000 + DEFAULT_MINI_BATCH_SIZE: int = 1000 - binary_classifier_model_selection_criteria = hp( + binary_classifier_model_selection_criteria: hp = hp( "binary_classifier_model_selection_criteria", isin( "accuracy", @@ -57,32 +61,36 @@ class LinearLearner(AmazonAlgorithmEstimatorBase): ), data_type=str, ) - target_recall = hp("target_recall", (gt(0), lt(1)), "A float in (0,1)", float) - target_precision = hp("target_precision", (gt(0), lt(1)), "A float in (0,1)", float) - positive_example_weight_mult = hp( + target_recall: hp = hp("target_recall", (gt(0), lt(1)), "A float in (0,1)", float) + target_precision: hp = hp("target_precision", (gt(0), lt(1)), "A float in (0,1)", float) + positive_example_weight_mult: hp = hp( "positive_example_weight_mult", (), "A float greater than 0 or 'auto' or 'balanced'", str ) - epochs = hp("epochs", gt(0), "An integer greater-than 0", int) - predictor_type = hp( + epochs: hp = hp("epochs", gt(0), "An integer greater-than 0", int) + predictor_type: hp = hp( "predictor_type", isin("binary_classifier", "regressor", "multiclass_classifier"), 'One of "binary_classifier" or "multiclass_classifier" or "regressor"', str, ) - use_bias = hp("use_bias", (), "Either True or False", bool) - num_models = hp("num_models", gt(0), "An integer greater-than 0", int) - num_calibration_samples = hp("num_calibration_samples", gt(0), "An integer greater-than 0", int) - init_method = hp("init_method", isin("uniform", "normal"), 'One of "uniform" or "normal"', str) - init_scale = hp("init_scale", gt(0), "A float greater-than 0", float) - init_sigma = hp("init_sigma", gt(0), "A float greater-than 0", float) - init_bias = hp("init_bias", (), "A number", float) - optimizer = hp( + use_bias: hp = hp("use_bias", (), "Either True or False", bool) + num_models: hp = hp("num_models", gt(0), "An integer greater-than 0", int) + num_calibration_samples: hp = hp( + "num_calibration_samples", gt(0), "An integer greater-than 0", int + ) + init_method: hp = hp( + "init_method", isin("uniform", "normal"), 'One of "uniform" or "normal"', str + ) + init_scale: hp = hp("init_scale", gt(0), "A float greater-than 0", float) + init_sigma: hp = hp("init_sigma", gt(0), "A float greater-than 0", float) + init_bias: hp = hp("init_bias", (), "A number", float) + optimizer: hp = hp( "optimizer", isin("sgd", "adam", "rmsprop", "auto"), 'One of "sgd", "adam", "rmsprop" or "auto', str, ) - loss = hp( + loss: hp = hp( "loss", isin( "logistic", @@ -100,83 +108,89 @@ class LinearLearner(AmazonAlgorithmEstimatorBase): ' "eps_insensitive_absolute_loss", "quantile_loss", "huber_loss", "softmax_loss" or "auto"', str, ) - wd = hp("wd", ge(0), "A float greater-than or equal to 0", float) - l1 = hp("l1", ge(0), "A float greater-than or equal to 0", float) - momentum = hp("momentum", (ge(0), lt(1)), "A float in [0,1)", float) - learning_rate = hp("learning_rate", gt(0), "A float greater-than 0", float) - beta_1 = hp("beta_1", (ge(0), lt(1)), "A float in [0,1)", float) - beta_2 = hp("beta_2", (ge(0), lt(1)), "A float in [0,1)", float) - bias_lr_mult = hp("bias_lr_mult", gt(0), "A float greater-than 0", float) - bias_wd_mult = hp("bias_wd_mult", ge(0), "A float greater-than or equal to 0", float) - use_lr_scheduler = hp("use_lr_scheduler", (), "A boolean", bool) - lr_scheduler_step = hp("lr_scheduler_step", gt(0), "An integer greater-than 0", int) - lr_scheduler_factor = hp("lr_scheduler_factor", (gt(0), lt(1)), "A float in (0,1)", float) - lr_scheduler_minimum_lr = hp("lr_scheduler_minimum_lr", gt(0), "A float greater-than 0", float) - normalize_data = hp("normalize_data", (), "A boolean", bool) - normalize_label = hp("normalize_label", (), "A boolean", bool) - unbias_data = hp("unbias_data", (), "A boolean", bool) - unbias_label = hp("unbias_label", (), "A boolean", bool) - num_point_for_scaler = hp("num_point_for_scaler", gt(0), "An integer greater-than 0", int) - margin = hp("margin", ge(0), "A float greater-than or equal to 0", float) - quantile = hp("quantile", (gt(0), lt(1)), "A float in (0,1)", float) - loss_insensitivity = hp("loss_insensitivity", gt(0), "A float greater-than 0", float) - huber_delta = hp("huber_delta", ge(0), "A float greater-than or equal to 0", float) - early_stopping_patience = hp("early_stopping_patience", gt(0), "An integer greater-than 0", int) - early_stopping_tolerance = hp( + wd: hp = hp("wd", ge(0), "A float greater-than or equal to 0", float) + l1: hp = hp("l1", ge(0), "A float greater-than or equal to 0", float) + momentum: hp = hp("momentum", (ge(0), lt(1)), "A float in [0,1)", float) + learning_rate: hp = hp("learning_rate", gt(0), "A float greater-than 0", float) + beta_1: hp = hp("beta_1", (ge(0), lt(1)), "A float in [0,1)", float) + beta_2: hp = hp("beta_2", (ge(0), lt(1)), "A float in [0,1)", float) + bias_lr_mult: hp = hp("bias_lr_mult", gt(0), "A float greater-than 0", float) + bias_wd_mult: hp = hp("bias_wd_mult", ge(0), "A float greater-than or equal to 0", float) + use_lr_scheduler: hp = hp("use_lr_scheduler", (), "A boolean", bool) + lr_scheduler_step: hp = hp("lr_scheduler_step", gt(0), "An integer greater-than 0", int) + lr_scheduler_factor: hp = hp("lr_scheduler_factor", (gt(0), lt(1)), "A float in (0,1)", float) + lr_scheduler_minimum_lr: hp = hp( + "lr_scheduler_minimum_lr", gt(0), "A float greater-than 0", float + ) + normalize_data: hp = hp("normalize_data", (), "A boolean", bool) + normalize_label: hp = hp("normalize_label", (), "A boolean", bool) + unbias_data: hp = hp("unbias_data", (), "A boolean", bool) + unbias_label: hp = hp("unbias_label", (), "A boolean", bool) + num_point_for_scaler: hp = hp("num_point_for_scaler", gt(0), "An integer greater-than 0", int) + margin: hp = hp("margin", ge(0), "A float greater-than or equal to 0", float) + quantile: hp = hp("quantile", (gt(0), lt(1)), "A float in (0,1)", float) + loss_insensitivity: hp = hp("loss_insensitivity", gt(0), "A float greater-than 0", float) + huber_delta: hp = hp("huber_delta", ge(0), "A float greater-than or equal to 0", float) + early_stopping_patience: hp = hp( + "early_stopping_patience", gt(0), "An integer greater-than 0", int + ) + early_stopping_tolerance: hp = hp( "early_stopping_tolerance", gt(0), "A float greater-than 0", float ) - num_classes = hp("num_classes", (gt(0), le(1000000)), "An integer in [1,1000000]", int) - accuracy_top_k = hp("accuracy_top_k", (gt(0), le(1000000)), "An integer in [1,1000000]", int) - f_beta = hp("f_beta", gt(0), "A float greater-than 0", float) - balance_multiclass_weights = hp("balance_multiclass_weights", (), "A boolean", bool) + num_classes: hp = hp("num_classes", (gt(0), le(1000000)), "An integer in [1,1000000]", int) + accuracy_top_k: hp = hp( + "accuracy_top_k", (gt(0), le(1000000)), "An integer in [1,1000000]", int + ) + f_beta: hp = hp("f_beta", gt(0), "A float greater-than 0", float) + balance_multiclass_weights: hp = hp("balance_multiclass_weights", (), "A boolean", bool) def __init__( self, - role, - instance_count=None, - instance_type=None, - predictor_type=None, - binary_classifier_model_selection_criteria=None, - target_recall=None, - target_precision=None, - positive_example_weight_mult=None, - epochs=None, - use_bias=None, - num_models=None, - num_calibration_samples=None, - init_method=None, - init_scale=None, - init_sigma=None, - init_bias=None, - optimizer=None, - loss=None, - wd=None, - l1=None, - momentum=None, - learning_rate=None, - beta_1=None, - beta_2=None, - bias_lr_mult=None, - bias_wd_mult=None, - use_lr_scheduler=None, - lr_scheduler_step=None, - lr_scheduler_factor=None, - lr_scheduler_minimum_lr=None, - normalize_data=None, - normalize_label=None, - unbias_data=None, - unbias_label=None, - num_point_for_scaler=None, - margin=None, - quantile=None, - loss_insensitivity=None, - huber_delta=None, - early_stopping_patience=None, - early_stopping_tolerance=None, - num_classes=None, - accuracy_top_k=None, - f_beta=None, - balance_multiclass_weights=None, + role: str, + instance_count: Optional[Union[int, PipelineVariable]] = None, + instance_type: Optional[Union[str, PipelineVariable]] = None, + predictor_type: Optional[str] = None, + binary_classifier_model_selection_criteria: Optional[str] = None, + target_recall: Optional[float] = None, + target_precision: Optional[float] = None, + positive_example_weight_mult: Optional[str] = None, + epochs: Optional[int] = None, + use_bias: Optional[bool] = None, + num_models: Optional[int] = None, + num_calibration_samples: Optional[int] = None, + init_method: Optional[str] = None, + init_scale: Optional[float] = None, + init_sigma: Optional[float] = None, + init_bias: Optional[float] = None, + optimizer: Optional[str] = None, + loss: Optional[str] = None, + wd: Optional[float] = None, + l1: Optional[float] = None, + momentum: Optional[float] = None, + learning_rate: Optional[float] = None, + beta_1: Optional[float] = None, + beta_2: Optional[float] = None, + bias_lr_mult: Optional[float] = None, + bias_wd_mult: Optional[float] = None, + use_lr_scheduler: Optional[bool] = None, + lr_scheduler_step: Optional[int] = None, + lr_scheduler_factor: Optional[float] = None, + lr_scheduler_minimum_lr: Optional[float] = None, + normalize_data: Optional[bool] = None, + normalize_label: Optional[bool] = None, + unbias_data: Optional[bool] = None, + unbias_label: Optional[bool] = None, + num_point_for_scaler: Optional[int] = None, + margin: Optional[float] = None, + quantile: Optional[float] = None, + loss_insensitivity: Optional[float] = None, + huber_delta: Optional[float] = None, + early_stopping_patience: Optional[int] = None, + early_stopping_tolerance: Optional[float] = None, + num_classes: Optional[int] = None, + accuracy_top_k: Optional[int] = None, + f_beta: Optional[float] = None, + balance_multiclass_weights: Optional[bool] = None, **kwargs ): """An :class:`Estimator` for binary classification and regression. @@ -424,10 +438,21 @@ def _prepare_for_training(self, records, mini_batch_size=None, job_name=None): num_records = records.num_records # mini_batch_size can't be greater than number of records or training job fails - default_mini_batch_size = min( - self.DEFAULT_MINI_BATCH_SIZE, max(1, int(num_records / self.instance_count)) - ) - mini_batch_size = mini_batch_size or default_mini_batch_size + if not mini_batch_size: + if is_pipeline_variable(self.instance_count): + logger.warning( + "mini_batch_size is not given in .fit() and instance_count is a " + "pipeline variable (%s) which is only parsed in execution time. " + "Thus setting mini_batch_size to 1, as it can't be greater than " + "number of records per instance_count, otherwise the training job fails.", + type(self.instance_count), + ) + mini_batch_size = 1 + else: + mini_batch_size = min( + self.DEFAULT_MINI_BATCH_SIZE, max(1, int(num_records / self.instance_count)) + ) + super(LinearLearner, self)._prepare_for_training( records, mini_batch_size=mini_batch_size, job_name=job_name ) diff --git a/src/sagemaker/amazon/ntm.py b/src/sagemaker/amazon/ntm.py index 83c2f97348..8dccb0d079 100644 --- a/src/sagemaker/amazon/ntm.py +++ b/src/sagemaker/amazon/ntm.py @@ -13,7 +13,7 @@ """Placeholder docstring""" from __future__ import absolute_import -from typing import Union, Optional +from typing import Optional, Union, List from sagemaker import image_uris from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase @@ -36,53 +36,59 @@ class NTM(AmazonAlgorithmEstimatorBase): "mileage", and "speed" are likely to share a topic on "transportation" for example. """ - repo_name = "ntm" - repo_version = 1 + repo_name: str = "ntm" + repo_version: int = 1 - num_topics = hp("num_topics", (ge(2), le(1000)), "An integer in [2, 1000]", int) - encoder_layers = hp( + num_topics: hp = hp("num_topics", (ge(2), le(1000)), "An integer in [2, 1000]", int) + encoder_layers: hp = hp( name="encoder_layers", validation_message="A comma separated list of " "positive integers", data_type=list, ) - epochs = hp("epochs", (ge(1), le(100)), "An integer in [1, 100]", int) - encoder_layers_activation = hp( + epochs: hp = hp("epochs", (ge(1), le(100)), "An integer in [1, 100]", int) + encoder_layers_activation: hp = hp( "encoder_layers_activation", isin("sigmoid", "tanh", "relu"), 'One of "sigmoid", "tanh" or "relu"', str, ) - optimizer = hp( + optimizer: hp = hp( "optimizer", isin("adagrad", "adam", "rmsprop", "sgd", "adadelta"), 'One of "adagrad", "adam", "rmsprop", "sgd" and "adadelta"', str, ) - tolerance = hp("tolerance", (ge(1e-6), le(0.1)), "A float in [1e-6, 0.1]", float) - num_patience_epochs = hp("num_patience_epochs", (ge(1), le(10)), "An integer in [1, 10]", int) - batch_norm = hp(name="batch_norm", validation_message="Value must be a boolean", data_type=bool) - rescale_gradient = hp("rescale_gradient", (ge(1e-3), le(1.0)), "A float in [1e-3, 1.0]", float) - clip_gradient = hp("clip_gradient", ge(1e-3), "A float greater equal to 1e-3", float) - weight_decay = hp("weight_decay", (ge(0.0), le(1.0)), "A float in [0.0, 1.0]", float) - learning_rate = hp("learning_rate", (ge(1e-6), le(1.0)), "A float in [1e-6, 1.0]", float) + tolerance: hp = hp("tolerance", (ge(1e-6), le(0.1)), "A float in [1e-6, 0.1]", float) + num_patience_epochs: hp = hp( + "num_patience_epochs", (ge(1), le(10)), "An integer in [1, 10]", int + ) + batch_norm: hp = hp( + name="batch_norm", validation_message="Value must be a boolean", data_type=bool + ) + rescale_gradient: hp = hp( + "rescale_gradient", (ge(1e-3), le(1.0)), "A float in [1e-3, 1.0]", float + ) + clip_gradient: hp = hp("clip_gradient", ge(1e-3), "A float greater equal to 1e-3", float) + weight_decay: hp = hp("weight_decay", (ge(0.0), le(1.0)), "A float in [0.0, 1.0]", float) + learning_rate: hp = hp("learning_rate", (ge(1e-6), le(1.0)), "A float in [1e-6, 1.0]", float) def __init__( self, - role, - instance_count=None, - instance_type=None, - num_topics=None, - encoder_layers=None, - epochs=None, - encoder_layers_activation=None, - optimizer=None, - tolerance=None, - num_patience_epochs=None, - batch_norm=None, - rescale_gradient=None, - clip_gradient=None, - weight_decay=None, - learning_rate=None, + role: str, + instance_count: Optional[Union[int, PipelineVariable]] = None, + instance_type: Optional[Union[str, PipelineVariable]] = None, + num_topics: Optional[int] = None, + encoder_layers: Optional[List] = None, + epochs: Optional[int] = None, + encoder_layers_activation: Optional[str] = None, + optimizer: Optional[str] = None, + tolerance: Optional[float] = None, + num_patience_epochs: Optional[int] = None, + batch_norm: Optional[bool] = None, + rescale_gradient: Optional[float] = None, + clip_gradient: Optional[float] = None, + weight_decay: Optional[float] = None, + learning_rate: Optional[float] = None, **kwargs ): """Neural Topic Model (NTM) is :class:`Estimator` used for unsupervised learning. diff --git a/src/sagemaker/amazon/object2vec.py b/src/sagemaker/amazon/object2vec.py index 1fbd846cbf..492f24f230 100644 --- a/src/sagemaker/amazon/object2vec.py +++ b/src/sagemaker/amazon/object2vec.py @@ -53,132 +53,142 @@ class Object2Vec(AmazonAlgorithmEstimatorBase): objects in the original space in the embedding space. """ - repo_name = "object2vec" - repo_version = 1 - MINI_BATCH_SIZE = 32 - - enc_dim = hp("enc_dim", (ge(4), le(10000)), "An integer in [4, 10000]", int) - mini_batch_size = hp("mini_batch_size", (ge(1), le(10000)), "An integer in [1, 10000]", int) - epochs = hp("epochs", (ge(1), le(100)), "An integer in [1, 100]", int) - early_stopping_patience = hp( + repo_name: str = "object2vec" + repo_version: int = 1 + MINI_BATCH_SIZE: int = 32 + + enc_dim: hp = hp("enc_dim", (ge(4), le(10000)), "An integer in [4, 10000]", int) + mini_batch_size: hp = hp("mini_batch_size", (ge(1), le(10000)), "An integer in [1, 10000]", int) + epochs: hp = hp("epochs", (ge(1), le(100)), "An integer in [1, 100]", int) + early_stopping_patience: hp = hp( "early_stopping_patience", (ge(1), le(5)), "An integer in [1, 5]", int ) - early_stopping_tolerance = hp( + early_stopping_tolerance: hp = hp( "early_stopping_tolerance", (ge(1e-06), le(0.1)), "A float in [1e-06, 0.1]", float ) - dropout = hp("dropout", (ge(0.0), le(1.0)), "A float in [0.0, 1.0]", float) - weight_decay = hp("weight_decay", (ge(0.0), le(10000.0)), "A float in [0.0, 10000.0]", float) - bucket_width = hp("bucket_width", (ge(0), le(100)), "An integer in [0, 100]", int) - num_classes = hp("num_classes", (ge(2), le(30)), "An integer in [2, 30]", int) - mlp_layers = hp("mlp_layers", (ge(1), le(10)), "An integer in [1, 10]", int) - mlp_dim = hp("mlp_dim", (ge(2), le(10000)), "An integer in [2, 10000]", int) - mlp_activation = hp( + dropout: hp = hp("dropout", (ge(0.0), le(1.0)), "A float in [0.0, 1.0]", float) + weight_decay: hp = hp( + "weight_decay", (ge(0.0), le(10000.0)), "A float in [0.0, 10000.0]", float + ) + bucket_width: hp = hp("bucket_width", (ge(0), le(100)), "An integer in [0, 100]", int) + num_classes: hp = hp("num_classes", (ge(2), le(30)), "An integer in [2, 30]", int) + mlp_layers: hp = hp("mlp_layers", (ge(1), le(10)), "An integer in [1, 10]", int) + mlp_dim: hp = hp("mlp_dim", (ge(2), le(10000)), "An integer in [2, 10000]", int) + mlp_activation: hp = hp( "mlp_activation", isin("tanh", "relu", "linear"), 'One of "tanh", "relu", "linear"', str ) - output_layer = hp( + output_layer: hp = hp( "output_layer", isin("softmax", "mean_squared_error"), 'One of "softmax", "mean_squared_error"', str, ) - optimizer = hp( + optimizer: hp = hp( "optimizer", isin("adagrad", "adam", "rmsprop", "sgd", "adadelta"), 'One of "adagrad", "adam", "rmsprop", "sgd", "adadelta"', str, ) - learning_rate = hp("learning_rate", (ge(1e-06), le(1.0)), "A float in [1e-06, 1.0]", float) + learning_rate: hp = hp("learning_rate", (ge(1e-06), le(1.0)), "A float in [1e-06, 1.0]", float) - negative_sampling_rate = hp( + negative_sampling_rate: hp = hp( "negative_sampling_rate", (ge(0), le(100)), "An integer in [0, 100]", int ) - comparator_list = hp( + comparator_list: hp = hp( "comparator_list", _list_check_subset(["hadamard", "concat", "abs_diff"]), 'Comma-separated of hadamard, concat, abs_diff. E.g. "hadamard,abs_diff"', str, ) - tied_token_embedding_weight = hp( + tied_token_embedding_weight: hp = hp( "tied_token_embedding_weight", (), "Either True or False", bool ) - token_embedding_storage_type = hp( + token_embedding_storage_type: hp = hp( "token_embedding_storage_type", isin("dense", "row_sparse"), 'One of "dense", "row_sparse"', str, ) - enc0_network = hp( + enc0_network: hp = hp( "enc0_network", isin("hcnn", "bilstm", "pooled_embedding"), 'One of "hcnn", "bilstm", "pooled_embedding"', str, ) - enc1_network = hp( + enc1_network: hp = hp( "enc1_network", isin("hcnn", "bilstm", "pooled_embedding", "enc0"), 'One of "hcnn", "bilstm", "pooled_embedding", "enc0"', str, ) - enc0_cnn_filter_width = hp("enc0_cnn_filter_width", (ge(1), le(9)), "An integer in [1, 9]", int) - enc1_cnn_filter_width = hp("enc1_cnn_filter_width", (ge(1), le(9)), "An integer in [1, 9]", int) - enc0_max_seq_len = hp("enc0_max_seq_len", (ge(1), le(5000)), "An integer in [1, 5000]", int) - enc1_max_seq_len = hp("enc1_max_seq_len", (ge(1), le(5000)), "An integer in [1, 5000]", int) - enc0_token_embedding_dim = hp( + enc0_cnn_filter_width: hp = hp( + "enc0_cnn_filter_width", (ge(1), le(9)), "An integer in [1, 9]", int + ) + enc1_cnn_filter_width: hp = hp( + "enc1_cnn_filter_width", (ge(1), le(9)), "An integer in [1, 9]", int + ) + enc0_max_seq_len: hp = hp("enc0_max_seq_len", (ge(1), le(5000)), "An integer in [1, 5000]", int) + enc1_max_seq_len: hp = hp("enc1_max_seq_len", (ge(1), le(5000)), "An integer in [1, 5000]", int) + enc0_token_embedding_dim: hp = hp( "enc0_token_embedding_dim", (ge(2), le(1000)), "An integer in [2, 1000]", int ) - enc1_token_embedding_dim = hp( + enc1_token_embedding_dim: hp = hp( "enc1_token_embedding_dim", (ge(2), le(1000)), "An integer in [2, 1000]", int ) - enc0_vocab_size = hp("enc0_vocab_size", (ge(2), le(3000000)), "An integer in [2, 3000000]", int) - enc1_vocab_size = hp("enc1_vocab_size", (ge(2), le(3000000)), "An integer in [2, 3000000]", int) - enc0_layers = hp("enc0_layers", (ge(1), le(4)), "An integer in [1, 4]", int) - enc1_layers = hp("enc1_layers", (ge(1), le(4)), "An integer in [1, 4]", int) - enc0_freeze_pretrained_embedding = hp( + enc0_vocab_size: hp = hp( + "enc0_vocab_size", (ge(2), le(3000000)), "An integer in [2, 3000000]", int + ) + enc1_vocab_size: hp = hp( + "enc1_vocab_size", (ge(2), le(3000000)), "An integer in [2, 3000000]", int + ) + enc0_layers: hp = hp("enc0_layers", (ge(1), le(4)), "An integer in [1, 4]", int) + enc1_layers: hp = hp("enc1_layers", (ge(1), le(4)), "An integer in [1, 4]", int) + enc0_freeze_pretrained_embedding: hp = hp( "enc0_freeze_pretrained_embedding", (), "Either True or False", bool ) - enc1_freeze_pretrained_embedding = hp( + enc1_freeze_pretrained_embedding: hp = hp( "enc1_freeze_pretrained_embedding", (), "Either True or False", bool ) def __init__( self, - role, - instance_count=None, - instance_type=None, - epochs=None, - enc0_max_seq_len=None, - enc0_vocab_size=None, - enc_dim=None, - mini_batch_size=None, - early_stopping_patience=None, - early_stopping_tolerance=None, - dropout=None, - weight_decay=None, - bucket_width=None, - num_classes=None, - mlp_layers=None, - mlp_dim=None, - mlp_activation=None, - output_layer=None, - optimizer=None, - learning_rate=None, - negative_sampling_rate=None, - comparator_list=None, - tied_token_embedding_weight=None, - token_embedding_storage_type=None, - enc0_network=None, - enc1_network=None, - enc0_cnn_filter_width=None, - enc1_cnn_filter_width=None, - enc1_max_seq_len=None, - enc0_token_embedding_dim=None, - enc1_token_embedding_dim=None, - enc1_vocab_size=None, - enc0_layers=None, - enc1_layers=None, - enc0_freeze_pretrained_embedding=None, - enc1_freeze_pretrained_embedding=None, + role: str, + instance_count: Optional[Union[int, PipelineVariable]] = None, + instance_type: Optional[Union[str, PipelineVariable]] = None, + epochs: Optional[int] = None, + enc0_max_seq_len: Optional[int] = None, + enc0_vocab_size: Optional[int] = None, + enc_dim: Optional[int] = None, + mini_batch_size: Optional[int] = None, + early_stopping_patience: Optional[int] = None, + early_stopping_tolerance: Optional[float] = None, + dropout: Optional[float] = None, + weight_decay: Optional[float] = None, + bucket_width: Optional[int] = None, + num_classes: Optional[int] = None, + mlp_layers: Optional[int] = None, + mlp_dim: Optional[int] = None, + mlp_activation: Optional[str] = None, + output_layer: Optional[str] = None, + optimizer: Optional[str] = None, + learning_rate: Optional[float] = None, + negative_sampling_rate: Optional[float] = None, + comparator_list: Optional[str] = None, + tied_token_embedding_weight: Optional[float] = None, + token_embedding_storage_type: Optional[str] = None, + enc0_network: Optional[str] = None, + enc1_network: Optional[str] = None, + enc0_cnn_filter_width: Optional[int] = None, + enc1_cnn_filter_width: Optional[int] = None, + enc1_max_seq_len: Optional[int] = None, + enc0_token_embedding_dim: Optional[int] = None, + enc1_token_embedding_dim: Optional[int] = None, + enc1_vocab_size: Optional[int] = None, + enc0_layers: Optional[int] = None, + enc1_layers: Optional[int] = None, + enc0_freeze_pretrained_embedding: Optional[bool] = None, + enc1_freeze_pretrained_embedding: Optional[bool] = None, **kwargs ): """Object2Vec is :class:`Estimator` used for anomaly detection. diff --git a/src/sagemaker/amazon/pca.py b/src/sagemaker/amazon/pca.py index e3127fd7a1..13440533d5 100644 --- a/src/sagemaker/amazon/pca.py +++ b/src/sagemaker/amazon/pca.py @@ -35,22 +35,24 @@ class PCA(AmazonAlgorithmEstimatorBase): retain as much information as possible. """ - repo_name = "pca" - repo_version = 1 + repo_name: str = "pca" + repo_version: int = 1 - DEFAULT_MINI_BATCH_SIZE = 500 + DEFAULT_MINI_BATCH_SIZE: int = 500 - num_components = hp("num_components", gt(0), "Value must be an integer greater than zero", int) - algorithm_mode = hp( + num_components: hp = hp( + "num_components", gt(0), "Value must be an integer greater than zero", int + ) + algorithm_mode: hp = hp( "algorithm_mode", isin("regular", "randomized"), 'Value must be one of "regular" and "randomized"', str, ) - subtract_mean = hp( + subtract_mean: hp = hp( name="subtract_mean", validation_message="Value must be a boolean", data_type=bool ) - extra_components = hp( + extra_components: hp = hp( name="extra_components", validation_message="Value must be an integer greater than or equal to 0, or -1.", data_type=int, @@ -58,13 +60,13 @@ class PCA(AmazonAlgorithmEstimatorBase): def __init__( self, - role, - instance_count=None, - instance_type=None, - num_components=None, - algorithm_mode=None, - subtract_mean=None, - extra_components=None, + role: str, + instance_count: Optional[int] = None, + instance_type: Optional[Union[str, PipelineVariable]] = None, + num_components: Optional[int] = None, + algorithm_mode: Optional[str] = None, + subtract_mean: Optional[bool] = None, + extra_components: Optional[int] = None, **kwargs ): """A Principal Components Analysis (PCA) diff --git a/src/sagemaker/amazon/randomcutforest.py b/src/sagemaker/amazon/randomcutforest.py index c38d75e3e4..16e4a09955 100644 --- a/src/sagemaker/amazon/randomcutforest.py +++ b/src/sagemaker/amazon/randomcutforest.py @@ -13,7 +13,7 @@ """Placeholder docstring""" from __future__ import absolute_import -from typing import Optional, Union +from typing import Optional, Union, List from sagemaker import image_uris from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase @@ -36,30 +36,30 @@ class RandomCutForest(AmazonAlgorithmEstimatorBase): or unclassifiable data points. """ - repo_name = "randomcutforest" - repo_version = 1 - MINI_BATCH_SIZE = 1000 + repo_name: str = "randomcutforest" + repo_version: int = 1 + MINI_BATCH_SIZE: int = 1000 - eval_metrics = hp( + eval_metrics: hp = hp( name="eval_metrics", validation_message='A comma separated list of "accuracy" or "precision_recall_fscore"', data_type=list, ) - num_trees = hp("num_trees", (ge(50), le(1000)), "An integer in [50, 1000]", int) - num_samples_per_tree = hp( + num_trees: hp = hp("num_trees", (ge(50), le(1000)), "An integer in [50, 1000]", int) + num_samples_per_tree: hp = hp( "num_samples_per_tree", (ge(1), le(2048)), "An integer in [1, 2048]", int ) - feature_dim = hp("feature_dim", (ge(1), le(10000)), "An integer in [1, 10000]", int) + feature_dim: hp = hp("feature_dim", (ge(1), le(10000)), "An integer in [1, 10000]", int) def __init__( self, - role, - instance_count=None, - instance_type=None, - num_samples_per_tree=None, - num_trees=None, - eval_metrics=None, + role: str, + instance_count: Optional[Union[int, PipelineVariable]] = None, + instance_type: Optional[Union[str, PipelineVariable]] = None, + num_samples_per_tree: Optional[int] = None, + num_trees: Optional[int] = None, + eval_metrics: Optional[List] = None, **kwargs ): """An `Estimator` class implementing a Random Cut Forest. diff --git a/src/sagemaker/chainer/estimator.py b/src/sagemaker/chainer/estimator.py index 12c22eae91..112c67a4d8 100644 --- a/src/sagemaker/chainer/estimator.py +++ b/src/sagemaker/chainer/estimator.py @@ -13,8 +13,9 @@ """Placeholder docstring""" from __future__ import absolute_import +from typing import Optional, Union, Dict + import logging -from typing import Union, Optional from sagemaker.estimator import Framework, EstimatorBase from sagemaker.fw_utils import ( @@ -34,26 +35,26 @@ class Chainer(Framework): """Handle end-to-end training and deployment of custom Chainer code.""" - _framework_name = "chainer" + _framework_name: str = "chainer" # Hyperparameters - _use_mpi = "sagemaker_use_mpi" - _num_processes = "sagemaker_num_processes" - _process_slots_per_host = "sagemaker_process_slots_per_host" - _additional_mpi_options = "sagemaker_additional_mpi_options" + _use_mpi: str = "sagemaker_use_mpi" + _num_processes: str = "sagemaker_num_processes" + _process_slots_per_host: str = "sagemaker_process_slots_per_host" + _additional_mpi_options: str = "sagemaker_additional_mpi_options" def __init__( self, entry_point: Union[str, PipelineVariable], - use_mpi=None, - num_processes=None, - process_slots_per_host=None, - additional_mpi_options=None, + use_mpi: Optional[Union[bool, PipelineVariable]] = None, + num_processes: Optional[Union[int, PipelineVariable]] = None, + process_slots_per_host: Optional[Union[int, PipelineVariable]] = None, + additional_mpi_options: Optional[Union[str, PipelineVariable]] = None, source_dir: Optional[Union[str, PipelineVariable]] = None, - hyperparameters=None, - framework_version=None, - py_version=None, - image_uri=None, + hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + framework_version: Optional[str] = None, + py_version: Optional[str] = None, + image_uri: Optional[Union[str, PipelineVariable]] = None, **kwargs ): """This ``Estimator`` executes an Chainer script in a managed execution environment. diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index 6590d30514..be16ccfb9e 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -17,6 +17,8 @@ """ from __future__ import absolute_import, print_function +from typing import Union, List, Optional, Dict, Any + import copy import json import logging @@ -26,25 +28,29 @@ import tempfile from abc import ABC, abstractmethod from sagemaker import image_uris, s3, utils +from sagemaker.session import Session +from sagemaker.network import NetworkConfig from sagemaker.processing import ProcessingInput, ProcessingOutput, Processor +from sagemaker.workflow.entities import PipelineVariable +from sagemaker.workflow import is_pipeline_variable logger = logging.getLogger(__name__) -class DataConfig: +class DataConfig: # TODO: add PipelineVariable to rest of fields """Config object related to configurations of the input and output dataset.""" def __init__( self, - s3_data_input_path, - s3_output_path, - s3_analysis_config_output_path=None, - label=None, - headers=None, - features=None, - dataset_type="text/csv", - s3_compression_type="None", - joinsource=None, + s3_data_input_path: Union[str, PipelineVariable], + s3_output_path: Union[str, PipelineVariable], + s3_analysis_config_output_path: Optional[str] = None, + label: Optional[str] = None, + headers: Optional[List[str]] = None, + features: Optional[List[str]] = None, + dataset_type: str = "text/csv", + s3_compression_type: str = "None", + joinsource: str = None, facet_dataset_uri=None, facet_headers=None, predicted_label_dataset_uri=None, @@ -186,10 +192,10 @@ class BiasConfig: def __init__( self, - label_values_or_threshold, - facet_name, - facet_values_or_threshold=None, - group_name=None, + label_values_or_threshold: Union[int, float, str], + facet_name: Union[str, int, List[str], List[int]], + facet_values_or_threshold: Optional[Union[int, float, str]] = None, + group_name: Optional[str] = None, ): """Initializes a configuration of the sensitive groups in the dataset. @@ -265,21 +271,21 @@ def get_config(self): return copy.deepcopy(self.analysis_config) -class ModelConfig: +class ModelConfig: # TODO add pipeline annotation """Config object related to a model and its endpoint to be created.""" def __init__( self, - model_name, - instance_count, - instance_type, - accept_type=None, - content_type=None, - content_template=None, - custom_attributes=None, - accelerator_type=None, - endpoint_name_prefix=None, - target_model=None, + model_name: str, + instance_count: int, + instance_type: str, + accept_type: Optional[str] = None, + content_type: Optional[str] = None, + content_template: Optional[str] = None, + custom_attributes: Optional[str] = None, + accelerator_type: Optional[str] = None, + endpoint_name_prefix: Optional[str] = None, + target_model: Optional[str] = None, ): r"""Initializes a configuration of a model and the endpoint to be created for it. @@ -378,10 +384,10 @@ class ModelPredictedLabelConfig: def __init__( self, - label=None, - probability=None, - probability_threshold=None, - label_headers=None, + label: Optional[Union[str, int]] = None, + probability: Optional[Union[str, int]] = None, + probability_threshold: Optional[float] = None, + label_headers: Optional[List[str]] = None, ): """Initializes a model output config to extract the predicted label or predicted score(s). @@ -473,7 +479,9 @@ class PDPConfig(ExplainabilityConfig): and the corresponding values are included in the analysis output. """ # noqa E501 - def __init__(self, features=None, grid_resolution=15, top_k_features=10): + def __init__( + self, features: Optional[List] = None, grid_resolution: int = 15, top_k_features: int = 10 + ): """Initializes PDP config. Args: @@ -641,8 +649,8 @@ class TextConfig: def __init__( self, - granularity, - language, + granularity: str, + language: str, ): """Initializes a text configuration. @@ -697,13 +705,13 @@ class ImageConfig: def __init__( self, - model_type, - num_segments=None, - feature_extraction_method=None, - segment_compactness=None, - max_objects=None, - iou_threshold=None, - context=None, + model_type: str, + num_segments: Optional[int] = None, + feature_extraction_method: Optional[str] = None, + segment_compactness: Optional[float] = None, + max_objects: Optional[int] = None, + iou_threshold: Optional[float] = None, + context: Optional[float] = None, ): """Initializes a config object for Computer Vision (CV) Image explainability. @@ -778,15 +786,15 @@ class SHAPConfig(ExplainabilityConfig): def __init__( self, - baseline=None, - num_samples=None, - agg_method=None, - use_logit=False, - save_local_shap_values=True, - seed=None, - num_clusters=None, - text_config=None, - image_config=None, + baseline: Optional[Union[str, List]] = None, + num_samples: Optional[int] = None, + agg_method: Optional[str] = None, + use_logit: Optional[bool] = None, + save_local_shap_values: Optional[bool] = None, + seed: Optional[int] = None, + num_clusters: Optional[int] = None, + text_config: Optional[TextConfig] = None, + image_config: Optional[ImageConfig] = None, ): """Initializes config for SHAP analysis. @@ -866,19 +874,19 @@ class SageMakerClarifyProcessor(Processor): def __init__( self, - role, - instance_count, - instance_type, - volume_size_in_gb=30, - volume_kms_key=None, - output_kms_key=None, - max_runtime_in_seconds=None, - sagemaker_session=None, - env=None, - tags=None, - network_config=None, - job_name_prefix=None, - version=None, + role: str, + instance_count: Union[int, PipelineVariable], + instance_type: Union[str, PipelineVariable], + volume_size_in_gb: Union[int, PipelineVariable] = 30, + volume_kms_key: Optional[Union[str, PipelineVariable]] = None, + output_kms_key: Optional[Union[str, PipelineVariable]] = None, + max_runtime_in_seconds: Optional[Union[int, PipelineVariable]] = None, + sagemaker_session: Optional[Session] = None, + env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, + network_config: Optional[NetworkConfig] = None, + job_name_prefix: Optional[str] = None, + version: Optional[str] = None, ): """Initializes a SageMakerClarifyProcessor to compute bias metrics and model explanations. @@ -949,13 +957,13 @@ def run(self, **_): def _run( self, - data_config, - analysis_config, - wait, - logs, - job_name, - kms_key, - experiment_config, + data_config: DataConfig, + analysis_config: Dict[str, Any], + wait: bool, + logs: bool, + job_name: str, + kms_key: str, + experiment_config: Dict[str, str], ): """Runs a :class:`~sagemaker.processing.ProcessingJob` with the SageMaker Clarify container @@ -991,6 +999,15 @@ def _run( analysis_config_file = os.path.join(tmpdirname, "analysis_config.json") with open(analysis_config_file, "w") as f: json.dump(analysis_config, f) + + if ( + is_pipeline_variable(data_config.s3_data_input_path) + and not data_config.s3_analysis_config_output_path + ): + raise ValueError( + "If s3_data_input_path for DataConfig is a pipeline variable, " + "s3_analysis_config_output_path must not be null" + ) s3_analysis_config_file = _upload_analysis_config( analysis_config_file, data_config.s3_analysis_config_output_path or data_config.s3_output_path, @@ -1033,14 +1050,14 @@ def _run( def run_pre_training_bias( self, - data_config, - data_bias_config, - methods="all", - wait=True, - logs=True, - job_name=None, - kms_key=None, - experiment_config=None, + data_config: DataConfig, + data_bias_config: BiasConfig, + methods: str = "all", + wait: bool = True, + logs: bool = True, + job_name: str = None, + kms_key: str = None, + experiment_config: Dict[str, str] = None, ): """Runs a :class:`~sagemaker.processing.ProcessingJob` to compute pre-training bias methods @@ -1103,16 +1120,16 @@ def run_pre_training_bias( def run_post_training_bias( self, - data_config, - data_bias_config, - model_config, - model_predicted_label_config, - methods="all", - wait=True, - logs=True, - job_name=None, - kms_key=None, - experiment_config=None, + data_config: DataConfig, + data_bias_config: BiasConfig, + model_config: ModelConfig, + model_predicted_label_config: ModelPredictedLabelConfig, + methods: str = "all", + wait: bool = True, + logs: bool = True, + job_name: str = None, + kms_key: str = None, + experiment_config: Dict[str, str] = None, ): """Runs a :class:`~sagemaker.processing.ProcessingJob` to compute posttraining bias @@ -1192,17 +1209,17 @@ def run_post_training_bias( def run_bias( self, - data_config, - bias_config, - model_config, - model_predicted_label_config=None, - pre_training_methods="all", - post_training_methods="all", - wait=True, - logs=True, - job_name=None, - kms_key=None, - experiment_config=None, + data_config: DataConfig, + bias_config: BiasConfig, + model_config: ModelConfig, + model_predicted_label_config: ModelPredictedLabelConfig = None, + pre_training_methods: str = "all", + post_training_methods: str = "all", + wait: bool = True, + logs: bool = True, + job_name: str = None, + kms_key: str = None, + experiment_config: Dict[str, str] = None, ): """Runs a :class:`~sagemaker.processing.ProcessingJob` to compute the requested bias methods @@ -1298,15 +1315,15 @@ def run_bias( def run_explainability( self, - data_config, - model_config, - explainability_config, - model_scores=None, - wait=True, - logs=True, - job_name=None, - kms_key=None, - experiment_config=None, + data_config: DataConfig, + model_config: ModelConfig, + explainability_config: Union[ExplainabilityConfig, List], + model_scores: Union[int, ModelPredictedLabelConfig] = None, + wait: bool = True, + logs: bool = True, + job_name: str = None, + kms_key: str = None, + experiment_config: Dict[str, str] = None, ): """Runs a :class:`~sagemaker.processing.ProcessingJob` computing feature attributions. diff --git a/src/sagemaker/debugger/debugger.py b/src/sagemaker/debugger/debugger.py index d2d53547f1..23f7b651a3 100644 --- a/src/sagemaker/debugger/debugger.py +++ b/src/sagemaker/debugger/debugger.py @@ -24,12 +24,15 @@ from abc import ABC +from typing import Union, Optional, List, Dict + import attr import smdebug_rulesconfig as rule_configs from sagemaker import image_uris from sagemaker.utils import build_dict +from sagemaker.workflow.entities import PipelineVariable framework_name = "debugger" DEBUGGER_FLAG = "USE_SMDEBUG" @@ -311,17 +314,17 @@ def sagemaker( @classmethod def custom( cls, - name, - image_uri, - instance_type, - volume_size_in_gb, - source=None, - rule_to_invoke=None, - container_local_output_path=None, - s3_output_path=None, - other_trials_s3_input_paths=None, - rule_parameters=None, - collections_to_save=None, + name: str, + image_uri: Union[str, PipelineVariable], + instance_type: Union[str, PipelineVariable], + volume_size_in_gb: Union[int, PipelineVariable], + source: Optional[str] = None, + rule_to_invoke: Optional[Union[str, PipelineVariable]] = None, + container_local_output_path: Optional[Union[str, PipelineVariable]] = None, + s3_output_path: Optional[Union[str, PipelineVariable]] = None, + other_trials_s3_input_paths: Optional[List[Union[str, PipelineVariable]]] = None, + rule_parameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + collections_to_save: Optional[List["CollectionConfig"]] = None, actions=None, ): """Initialize a ``Rule`` object for a *custom* debugging rule. @@ -610,10 +613,10 @@ class DebuggerHookConfig(object): def __init__( self, - s3_output_path=None, - container_local_output_path=None, - hook_parameters=None, - collection_configs=None, + s3_output_path: Optional[Union[str, PipelineVariable]] = None, + container_local_output_path: Optional[Union[str, PipelineVariable]] = None, + hook_parameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + collection_configs: Optional[List["CollectionConfig"]] = None, ): """Initialize the DebuggerHookConfig instance. @@ -679,7 +682,11 @@ def _to_request_dict(self): class TensorBoardOutputConfig(object): """Create a tensor ouput configuration object for debugging visualizations on TensorBoard.""" - def __init__(self, s3_output_path, container_local_output_path=None): + def __init__( + self, + s3_output_path: Union[str, PipelineVariable], + container_local_output_path: Optional[Union[str, PipelineVariable]] = None, + ): """Initialize the TensorBoardOutputConfig instance. Args: @@ -708,7 +715,11 @@ def _to_request_dict(self): class CollectionConfig(object): """Creates tensor collections for SageMaker Debugger.""" - def __init__(self, name, parameters=None): + def __init__( + self, + name: Union[str, PipelineVariable], + parameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + ): """Constructor for collection configuration. Args: diff --git a/src/sagemaker/debugger/profiler_config.py b/src/sagemaker/debugger/profiler_config.py index 371d161bbe..807ba91e79 100644 --- a/src/sagemaker/debugger/profiler_config.py +++ b/src/sagemaker/debugger/profiler_config.py @@ -13,7 +13,10 @@ """Configuration for collecting system and framework metrics in SageMaker training jobs.""" from __future__ import absolute_import +from typing import Optional, Union + from sagemaker.debugger.framework_profile import FrameworkProfile +from sagemaker.workflow.entities import PipelineVariable class ProfilerConfig(object): @@ -26,9 +29,9 @@ class ProfilerConfig(object): def __init__( self, - s3_output_path=None, - system_monitor_interval_millis=None, - framework_profile_params=None, + s3_output_path: Optional[Union[str, PipelineVariable]] = None, + system_monitor_interval_millis: Optional[Union[int, PipelineVariable]] = None, + framework_profile_params: Optional[FrameworkProfile] = None, ): """Initialize a ``ProfilerConfig`` instance. diff --git a/src/sagemaker/drift_check_baselines.py b/src/sagemaker/drift_check_baselines.py index 24aa4787d0..9c3b8dbd57 100644 --- a/src/sagemaker/drift_check_baselines.py +++ b/src/sagemaker/drift_check_baselines.py @@ -13,21 +13,25 @@ """This file contains code related to drift check baselines""" from __future__ import absolute_import +from typing import Optional + +from sagemaker.model_metrics import MetricsSource, FileSource + class DriftCheckBaselines(object): """Accepts drift check baselines parameters for conversion to request dict.""" def __init__( self, - model_statistics=None, - model_constraints=None, - model_data_statistics=None, - model_data_constraints=None, - bias_config_file=None, - bias_pre_training_constraints=None, - bias_post_training_constraints=None, - explainability_constraints=None, - explainability_config_file=None, + model_statistics: Optional[MetricsSource] = None, + model_constraints: Optional[MetricsSource] = None, + model_data_statistics: Optional[MetricsSource] = None, + model_data_constraints: Optional[MetricsSource] = None, + bias_config_file: Optional[FileSource] = None, + bias_pre_training_constraints: Optional[MetricsSource] = None, + bias_post_training_constraints: Optional[MetricsSource] = None, + explainability_constraints: Optional[MetricsSource] = None, + explainability_config_file: Optional[FileSource] = None, ): """Initialize a ``DriftCheckBaselines`` instance and turn parameters into dict. diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 9d0c30ff27..a4b769a306 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -50,6 +50,7 @@ validate_source_code_input_against_pipeline_variables, ) from sagemaker.inputs import TrainingInput, FileSystemInput +from sagemaker.instance_group import InstanceGroup from sagemaker.job import _Job from sagemaker.jumpstart.utils import ( add_jumpstart_tags, @@ -149,7 +150,7 @@ def __init__( code_location: Optional[str] = None, entry_point: Optional[Union[str, PipelineVariable]] = None, dependencies: Optional[List[Union[str]]] = None, - instance_groups: Optional[Dict[str, Union[str, int]]] = None, + instance_groups: Optional[List[InstanceGroup]] = None, **kwargs, ): """Initialize an ``EstimatorBase`` instance. @@ -1580,6 +1581,8 @@ def _get_instance_type(self): for instance_group in self.instance_groups: instance_type = instance_group.instance_type + if is_pipeline_variable(instance_type): + continue match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type) if match: @@ -2179,7 +2182,7 @@ def __init__( code_location: Optional[str] = None, entry_point: Optional[Union[str, PipelineVariable]] = None, dependencies: Optional[List[str]] = None, - instance_groups: Optional[Dict[str, Union[str, int]]] = None, + instance_groups: Optional[List[InstanceGroup]] = None, **kwargs, ): """Initialize an ``Estimator`` instance. @@ -2874,7 +2877,15 @@ def _validate_and_set_debugger_configs(self): # Disable debugger if checkpointing is enabled by the customer if self.checkpoint_s3_uri and self.checkpoint_local_path and self.debugger_hook_config: if self._framework_name in {"mxnet", "pytorch", "tensorflow"}: - if self.instance_count > 1 or ( + if is_pipeline_variable(self.instance_count): + logger.warning( + "SMDebug does not currently support distributed training jobs " + "with checkpointing enabled. Therefore, to allow parameterized " + "instance_count and allow to change it to any values in execution time, " + "the debugger_hook_config is disabled." + ) + self.debugger_hook_config = False + elif self.instance_count > 1 or ( hasattr(self, "distribution") and self.distribution is not None # pylint: disable=no-member ): diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index 613bbd3742..47af026842 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -555,7 +555,7 @@ def validate_smdistributed( if "smdistributed" not in distribution: # Distribution strategy other than smdistributed is selected return - if is_pipeline_variable(instance_type): + if is_pipeline_variable(instance_type) or (image_uri and is_pipeline_variable(image_uri)): # The instance_type is not available in compile time. # Rather, it's given in Pipeline execution time return @@ -871,6 +871,14 @@ def validate_distribution_instance(sagemaker_session, distribution, instance_typ # Strategy modelparallel is not enabled return + if is_pipeline_variable(instance_type): + logger.warning( + "instance_type is a pipeline variable, which is only interpreted in " + "pipeline execution time. As modelparallel only runs on GPU-enabled " + "instances, in execution time, the specified instance type has to support GPU." + ) + return + instance_desc = sagemaker_session.boto_session.client("ec2").describe_instance_types( InstanceTypes=[f"{instance_type}"] ) diff --git a/src/sagemaker/huggingface/estimator.py b/src/sagemaker/huggingface/estimator.py index 628c14dc8e..4d8b409eb4 100644 --- a/src/sagemaker/huggingface/estimator.py +++ b/src/sagemaker/huggingface/estimator.py @@ -13,9 +13,10 @@ """Placeholder docstring""" from __future__ import absolute_import +from typing import Optional, Union, Dict + import logging import re -from typing import Optional, Union, Dict from sagemaker.deprecations import renamed_kwargs from sagemaker.estimator import Framework, EstimatorBase diff --git a/src/sagemaker/huggingface/processing.py b/src/sagemaker/huggingface/processing.py index 3f3813b778..63810b0eb9 100644 --- a/src/sagemaker/huggingface/processing.py +++ b/src/sagemaker/huggingface/processing.py @@ -17,9 +17,15 @@ """ from __future__ import absolute_import +from typing import Union, Optional, List, Dict + +from sagemaker.session import Session +from sagemaker.network import NetworkConfig from sagemaker.processing import FrameworkProcessor from sagemaker.huggingface.estimator import HuggingFace +from sagemaker.workflow.entities import PipelineVariable + class HuggingFaceProcessor(FrameworkProcessor): """Handles Amazon SageMaker processing tasks for jobs using HuggingFace containers.""" @@ -28,25 +34,25 @@ class HuggingFaceProcessor(FrameworkProcessor): def __init__( self, - role, - instance_count, - instance_type, - transformers_version=None, - tensorflow_version=None, - pytorch_version=None, - py_version="py36", - image_uri=None, - command=None, - volume_size_in_gb=30, - volume_kms_key=None, - output_kms_key=None, - code_location=None, - max_runtime_in_seconds=None, - base_job_name=None, - sagemaker_session=None, - env=None, - tags=None, - network_config=None, + role: str, + instance_count: Union[int, PipelineVariable], + instance_type: Union[str, PipelineVariable], + transformers_version: Optional[str] = None, + tensorflow_version: Optional[str] = None, + pytorch_version: Optional[str] = None, + py_version: str = "py36", + image_uri: Optional[Union[str, PipelineVariable]] = None, + command: Optional[List[str]] = None, + volume_size_in_gb: Union[int, PipelineVariable] = 30, + volume_kms_key: Optional[Union[str, PipelineVariable]] = None, + output_kms_key: Optional[Union[str, PipelineVariable]] = None, + code_location: Optional[str] = None, + max_runtime_in_seconds: Optional[Union[int, PipelineVariable]] = None, + base_job_name: Optional[str] = None, + sagemaker_session: Optional[Session] = None, + env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, + network_config: Optional[NetworkConfig] = None, ): """This processor executes a Python script in a HuggingFace execution environment. diff --git a/src/sagemaker/huggingface/training_compiler/config.py b/src/sagemaker/huggingface/training_compiler/config.py index 07a3bcf9b7..b19fb2be2b 100644 --- a/src/sagemaker/huggingface/training_compiler/config.py +++ b/src/sagemaker/huggingface/training_compiler/config.py @@ -13,8 +13,10 @@ """Configuration for the SageMaker Training Compiler.""" from __future__ import absolute_import import logging +from typing import Union from sagemaker.training_compiler.config import TrainingCompilerConfig as BaseConfig +from sagemaker.workflow.entities import PipelineVariable logger = logging.getLogger(__name__) @@ -26,8 +28,8 @@ class TrainingCompilerConfig(BaseConfig): def __init__( self, - enabled=True, - debug=False, + enabled: Union[bool, PipelineVariable] = True, + debug: Union[bool, PipelineVariable] = False, ): """This class initializes a ``TrainingCompilerConfig`` instance. diff --git a/src/sagemaker/inputs.py b/src/sagemaker/inputs.py index 3481c138bd..5bfcb0d672 100644 --- a/src/sagemaker/inputs.py +++ b/src/sagemaker/inputs.py @@ -13,8 +13,12 @@ """Amazon SageMaker channel configurations for S3 data sources and file system data sources""" from __future__ import absolute_import, print_function +from typing import Union, Optional, List + import attr +from sagemaker.workflow.entities import PipelineVariable + FILE_SYSTEM_TYPES = ["FSxLustre", "EFS"] FILE_SYSTEM_ACCESS_MODES = ["ro", "rw"] @@ -29,17 +33,17 @@ class TrainingInput(object): def __init__( self, - s3_data, - distribution=None, - compression=None, - content_type=None, - record_wrapping=None, - s3_data_type="S3Prefix", - instance_groups=None, - input_mode=None, - attribute_names=None, - target_attribute_name=None, - shuffle_config=None, + s3_data: Union[str, PipelineVariable], + distribution: Optional[Union[str, PipelineVariable]] = None, + compression: Optional[Union[str, PipelineVariable]] = None, + content_type: Optional[Union[str, PipelineVariable]] = None, + record_wrapping: Optional[Union[str, PipelineVariable]] = None, + s3_data_type: Union[str, PipelineVariable] = "S3Prefix", + instance_groups: Optional[List[Union[str, PipelineVariable]]] = None, + input_mode: Optional[Union[str, PipelineVariable]] = None, + attribute_names: Optional[List[Union[str, PipelineVariable]]] = None, + target_attribute_name: Optional[Union[str, PipelineVariable]] = None, + shuffle_config: Optional["ShuffleConfig"] = None, ): r"""Create a definition for input data used by an SageMaker training job. diff --git a/src/sagemaker/instance_group.py b/src/sagemaker/instance_group.py index 5042787be5..8669e8b706 100644 --- a/src/sagemaker/instance_group.py +++ b/src/sagemaker/instance_group.py @@ -13,15 +13,19 @@ """Defines the InstanceGroup class that configures a heterogeneous cluster.""" from __future__ import absolute_import +from typing import Optional, Union + +from sagemaker.workflow.entities import PipelineVariable + class InstanceGroup(object): """The class to create instance groups for a heterogeneous cluster.""" def __init__( self, - instance_group_name=None, - instance_type=None, - instance_count=None, + instance_group_name: Optional[Union[str, PipelineVariable]] = None, + instance_type: Optional[Union[str, PipelineVariable]] = None, + instance_count: Optional[Union[int, PipelineVariable]] = None, ): """It initializes an ``InstanceGroup`` instance. diff --git a/src/sagemaker/metadata_properties.py b/src/sagemaker/metadata_properties.py index 4bc77ed0ee..b25aff9168 100644 --- a/src/sagemaker/metadata_properties.py +++ b/src/sagemaker/metadata_properties.py @@ -13,16 +13,20 @@ """This file contains code related to metadata properties.""" from __future__ import absolute_import +from typing import Optional, Union + +from sagemaker.workflow.entities import PipelineVariable + class MetadataProperties(object): """Accepts metadata properties parameters for conversion to request dict.""" def __init__( self, - commit_id=None, - repository=None, - generated_by=None, - project_id=None, + commit_id: Optional[Union[str, PipelineVariable]] = None, + repository: Optional[Union[str, PipelineVariable]] = None, + generated_by: Optional[Union[str, PipelineVariable]] = None, + project_id: Optional[Union[str, PipelineVariable]] = None, ): """Initialize a ``MetadataProperties`` instance and turn parameters into dict. diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index a2c6da4bb7..8772fa724f 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -18,7 +18,7 @@ import logging import os import copy -from typing import List, Dict +from typing import List, Dict, Optional, Union import sagemaker from sagemaker import ( @@ -29,7 +29,11 @@ utils, git_utils, ) +from sagemaker.session import Session +from sagemaker.model_metrics import ModelMetrics from sagemaker.deprecations import removed_kwargs +from sagemaker.drift_check_baselines import DriftCheckBaselines +from sagemaker.metadata_properties import MetadataProperties from sagemaker.predictor import PredictorBase from sagemaker.serverless import ServerlessInferenceConfig from sagemaker.transformer import Transformer @@ -37,10 +41,12 @@ from sagemaker.utils import ( unique_name_from_base, update_container_with_inference_params, + to_string, ) from sagemaker.async_inference import AsyncInferenceConfig from sagemaker.predictor_async import AsyncPredictor from sagemaker.workflow import is_pipeline_variable +from sagemaker.workflow.entities import PipelineVariable from sagemaker.workflow.pipeline_context import runnable_by_pipeline, PipelineSession LOGGER = logging.getLogger("sagemaker") @@ -82,23 +88,23 @@ class Model(ModelBase): def __init__( self, - image_uri, - model_data=None, - role=None, - predictor_cls=None, - env=None, - name=None, - vpc_config=None, - sagemaker_session=None, - enable_network_isolation=False, - model_kms_key=None, - image_config=None, - source_dir=None, - code_location=None, - entry_point=None, - container_log_level=logging.INFO, - dependencies=None, - git_config=None, + image_uri: Union[str, PipelineVariable], + model_data: Optional[Union[str, PipelineVariable]] = None, + role: Optional[str] = None, + predictor_cls: Optional[callable] = None, + env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + name: Optional[str] = None, + vpc_config: Optional[Dict[str, List[Union[str, PipelineVariable]]]] = None, + sagemaker_session: Optional[Session] = None, + enable_network_isolation: Union[bool, PipelineVariable] = False, + model_kms_key: Optional[str] = None, + image_config: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + source_dir: Optional[str] = None, + code_location: Optional[str] = None, + entry_point: Optional[str] = None, + container_log_level: Union[int, PipelineVariable] = logging.INFO, + dependencies: Optional[List[str]] = None, + git_config: Optional[Dict[str, str]] = None, ): """Initialize an SageMaker ``Model``. @@ -298,28 +304,28 @@ def __init__( @runnable_by_pipeline def register( self, - content_types, - response_types, - inference_instances=None, - transform_instances=None, - model_package_name=None, - model_package_group_name=None, - image_uri=None, - model_metrics=None, - metadata_properties=None, - marketplace_cert=False, - approval_status=None, - description=None, - drift_check_baselines=None, - customer_metadata_properties=None, - validation_specification=None, - domain=None, - task=None, - sample_payload_url=None, - framework=None, - framework_version=None, - nearest_model_name=None, - data_input_configuration=None, + content_types: List[Union[str, PipelineVariable]], + response_types: List[Union[str, PipelineVariable]], + inference_instances: Optional[List[Union[str, PipelineVariable]]] = None, + transform_instances: Optional[List[Union[str, PipelineVariable]]] = None, + model_package_name: Optional[Union[str, PipelineVariable]] = None, + model_package_group_name: Optional[Union[str, PipelineVariable]] = None, + image_uri: Optional[Union[str, PipelineVariable]] = None, + model_metrics: Optional[ModelMetrics] = None, + metadata_properties: Optional[MetadataProperties] = None, + marketplace_cert: bool = False, + approval_status: Optional[Union[str, PipelineVariable]] = None, + description: Optional[str] = None, + drift_check_baselines: Optional[DriftCheckBaselines] = None, + customer_metadata_properties: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + validation_specification: Optional[Union[str, PipelineVariable]] = None, + domain: Optional[Union[str, PipelineVariable]] = None, + task: Optional[Union[str, PipelineVariable]] = None, + sample_payload_url: Optional[Union[str, PipelineVariable]] = None, + framework: Optional[Union[str, PipelineVariable]] = None, + framework_version: Optional[Union[str, PipelineVariable]] = None, + nearest_model_name: Optional[Union[str, PipelineVariable]] = None, + data_input_configuration: Optional[Union[str, PipelineVariable]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -349,11 +355,11 @@ def register( metadata properties (default: None). domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING", "MACHINE_LEARNING" (default: None). - sample_payload_url (str): The S3 path where the sample payload is stored - (default: None). task (str): Task values which are supported by Inference Recommender are "FILL_MASK", "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION", "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). + sample_payload_url (str): The S3 path where the sample payload is stored + (default: None). framework (str): Machine learning framework of the model package container image (default: None). framework_version (str): Framework version of the Model Package Container Image @@ -421,10 +427,10 @@ def register( @runnable_by_pipeline def create( self, - instance_type: str = None, - accelerator_type: str = None, - serverless_inference_config: ServerlessInferenceConfig = None, - tags: List[Dict[str, str]] = None, + instance_type: Optional[str] = None, + accelerator_type: Optional[str] = None, + serverless_inference_config: Optional[ServerlessInferenceConfig] = None, + tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, ): """Create a SageMaker Model Entity @@ -608,7 +614,7 @@ def _script_mode_env_vars(self): return { SCRIPT_PARAM_NAME.upper(): script_name or str(), DIR_PARAM_NAME.upper(): dir_name or str(), - CONTAINER_LOG_LEVEL_PARAM_NAME.upper(): str(self.container_log_level), + CONTAINER_LOG_LEVEL_PARAM_NAME.upper(): to_string(self.container_log_level), SAGEMAKER_REGION_PARAM_NAME.upper(): self.sagemaker_session.boto_region_name, } @@ -1286,19 +1292,19 @@ class FrameworkModel(Model): def __init__( self, - model_data, - image_uri, - role, - entry_point, - source_dir=None, - predictor_cls=None, - env=None, - name=None, - container_log_level=logging.INFO, - code_location=None, - sagemaker_session=None, - dependencies=None, - git_config=None, + model_data: Union[str, PipelineVariable], + image_uri: Union[str, PipelineVariable], + role: str, + entry_point: str, + source_dir: Optional[str] = None, + predictor_cls: Optional[callable] = None, + env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + name: Optional[str] = None, + container_log_level: Union[int, PipelineVariable] = logging.INFO, + code_location: Optional[str] = None, + sagemaker_session: Optional[Session] = None, + dependencies: Optional[List[str]] = None, + git_config: Optional[Dict[str, str]] = None, **kwargs, ): """Initialize a ``FrameworkModel``. diff --git a/src/sagemaker/model_metrics.py b/src/sagemaker/model_metrics.py index acce4e13c9..83a43d3f18 100644 --- a/src/sagemaker/model_metrics.py +++ b/src/sagemaker/model_metrics.py @@ -13,20 +13,24 @@ """This file contains code related to model metrics, including metric source and file source.""" from __future__ import absolute_import +from typing import Optional, Union + +from sagemaker.workflow.entities import PipelineVariable + class ModelMetrics(object): """Accepts model metrics parameters for conversion to request dict.""" def __init__( self, - model_statistics=None, - model_constraints=None, - model_data_statistics=None, - model_data_constraints=None, - bias=None, - explainability=None, - bias_pre_training=None, - bias_post_training=None, + model_statistics: Optional["MetricsSource"] = None, + model_constraints: Optional["MetricsSource"] = None, + model_data_statistics: Optional["MetricsSource"] = None, + model_data_constraints: Optional["MetricsSource"] = None, + bias: Optional["MetricsSource"] = None, + explainability: Optional["MetricsSource"] = None, + bias_pre_training: Optional["MetricsSource"] = None, + bias_post_training: Optional["MetricsSource"] = None, ): """Initialize a ``ModelMetrics`` instance and turn parameters into dict. @@ -99,9 +103,9 @@ class MetricsSource(object): def __init__( self, - content_type, - s3_uri, - content_digest=None, + content_type: Union[str, PipelineVariable], + s3_uri: Union[str, PipelineVariable], + content_digest: Optional[Union[str, PipelineVariable]] = None, ): """Initialize a ``MetricsSource`` instance and turn parameters into dict. @@ -127,9 +131,9 @@ class FileSource(object): def __init__( self, - s3_uri, - content_digest=None, - content_type=None, + s3_uri: Union[str, PipelineVariable], + content_digest: Optional[Union[str, PipelineVariable]] = None, + content_type: Optional[Union[str, PipelineVariable]] = None, ): """Initialize a ``FileSource`` instance and turn parameters into dict. diff --git a/src/sagemaker/mxnet/estimator.py b/src/sagemaker/mxnet/estimator.py index 3f0c054929..48974a3413 100644 --- a/src/sagemaker/mxnet/estimator.py +++ b/src/sagemaker/mxnet/estimator.py @@ -13,8 +13,9 @@ """Placeholder docstring""" from __future__ import absolute_import +from typing import Optional, Dict, Union + import logging -from typing import Union, Optional, Dict from packaging.version import Version diff --git a/src/sagemaker/mxnet/processing.py b/src/sagemaker/mxnet/processing.py index 663b08be4b..71bce7cdff 100644 --- a/src/sagemaker/mxnet/processing.py +++ b/src/sagemaker/mxnet/processing.py @@ -17,8 +17,13 @@ """ from __future__ import absolute_import +from typing import Union, Optional, List, Dict + +from sagemaker.session import Session +from sagemaker.network import NetworkConfig from sagemaker.mxnet.estimator import MXNet from sagemaker.processing import FrameworkProcessor +from sagemaker.workflow.entities import PipelineVariable class MXNetProcessor(FrameworkProcessor): @@ -28,23 +33,23 @@ class MXNetProcessor(FrameworkProcessor): def __init__( self, - framework_version, # New arg - role, - instance_count, - instance_type, - py_version="py3", # New kwarg - image_uri=None, - command=None, - volume_size_in_gb=30, - volume_kms_key=None, - output_kms_key=None, - code_location=None, # New arg - max_runtime_in_seconds=None, - base_job_name=None, - sagemaker_session=None, - env=None, - tags=None, - network_config=None, + framework_version: str, # New arg + role: str, + instance_count: Union[int, PipelineVariable], + instance_type: Union[str, PipelineVariable], + py_version: str = "py3", # New kwarg + image_uri: Optional[Union[str, PipelineVariable]] = None, + command: Optional[List[str]] = None, + volume_size_in_gb: Union[int, PipelineVariable] = 30, + volume_kms_key: Optional[Union[str, PipelineVariable]] = None, + output_kms_key: Optional[Union[str, PipelineVariable]] = None, + code_location: Optional[str] = None, # New arg + max_runtime_in_seconds: Optional[Union[int, PipelineVariable]] = None, + base_job_name: Optional[str] = None, + sagemaker_session: Optional[Session] = None, + env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, + network_config: Optional[NetworkConfig] = None, ): """This processor executes a Python script in a managed MXNet execution environment. diff --git a/src/sagemaker/processing.py b/src/sagemaker/processing.py index 8da6e04768..11272ccb63 100644 --- a/src/sagemaker/processing.py +++ b/src/sagemaker/processing.py @@ -41,6 +41,7 @@ from sagemaker.apiutils._base_types import ApiObject from sagemaker.s3 import S3Uploader + logger = logging.getLogger(__name__) @@ -773,10 +774,10 @@ def start_new(cls, processor, inputs, outputs, experiment_config): process_args = cls._get_process_args(processor, inputs, outputs, experiment_config) # Print the job name and the user's inputs and outputs as lists of dictionaries. - print() - print("Job Name: ", process_args["job_name"]) - print("Inputs: ", process_args["inputs"]) - print("Outputs: ", process_args["output_config"]["Outputs"]) + # print() + # print("Job Name: ", process_args["job_name"]) + # print("Inputs: ", process_args["inputs"]) + # print("Outputs: ", process_args["output_config"]["Outputs"]) # Call sagemaker_session.process using the arguments dictionary. processor.sagemaker_session.process(**process_args) diff --git a/src/sagemaker/pytorch/estimator.py b/src/sagemaker/pytorch/estimator.py index 622e79084c..c904bea44d 100644 --- a/src/sagemaker/pytorch/estimator.py +++ b/src/sagemaker/pytorch/estimator.py @@ -13,8 +13,9 @@ """Placeholder docstring""" from __future__ import absolute_import +from typing import Optional, Union, Dict + import logging -from typing import Union, Optional from packaging.version import Version @@ -30,6 +31,7 @@ from sagemaker.pytorch import defaults from sagemaker.pytorch.model import PyTorchModel from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT +from sagemaker.workflow import is_pipeline_variable from sagemaker.workflow.entities import PipelineVariable logger = logging.getLogger("sagemaker") @@ -45,12 +47,12 @@ class PyTorch(Framework): def __init__( self, entry_point: Union[str, PipelineVariable], - framework_version=None, - py_version=None, + framework_version: Optional[str] = None, + py_version: Optional[str] = None, source_dir: Optional[Union[str, PipelineVariable]] = None, - hyperparameters=None, - image_uri=None, - distribution=None, + hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + image_uri: Optional[Union[str, PipelineVariable]] = None, + distribution: Optional[Dict] = None, **kwargs ): """This ``Estimator`` executes a PyTorch script in a managed PyTorch execution environment. @@ -223,7 +225,7 @@ def __init__( if distribution is not None: instance_type = self._get_instance_type() # remove "ml." prefix - if instance_type[:3] == "ml.": + if not is_pipeline_variable(instance_type) and instance_type[:3] == "ml.": instance_type = instance_type[3:] validate_distribution_instance(self.sagemaker_session, distribution, instance_type) diff --git a/src/sagemaker/pytorch/processing.py b/src/sagemaker/pytorch/processing.py index a6581efac6..3717fa9ccd 100644 --- a/src/sagemaker/pytorch/processing.py +++ b/src/sagemaker/pytorch/processing.py @@ -17,8 +17,13 @@ """ from __future__ import absolute_import +from typing import Union, Optional, List, Dict + +from sagemaker.session import Session +from sagemaker.network import NetworkConfig from sagemaker.processing import FrameworkProcessor from sagemaker.pytorch.estimator import PyTorch +from sagemaker.workflow.entities import PipelineVariable class PyTorchProcessor(FrameworkProcessor): @@ -28,23 +33,23 @@ class PyTorchProcessor(FrameworkProcessor): def __init__( self, - framework_version, # New arg - role, - instance_count, - instance_type, - py_version="py3", # New kwarg - image_uri=None, - command=None, - volume_size_in_gb=30, - volume_kms_key=None, - output_kms_key=None, - code_location=None, # New arg - max_runtime_in_seconds=None, - base_job_name=None, - sagemaker_session=None, - env=None, - tags=None, - network_config=None, + framework_version: str, # New arg + role: str, + instance_count: Union[int, PipelineVariable], + instance_type: Union[str, PipelineVariable], + py_version: str = "py3", # New kwarg + image_uri: Optional[Union[str, PipelineVariable]] = None, + command: Optional[List[str]] = None, + volume_size_in_gb: Union[int, PipelineVariable] = 30, + volume_kms_key: Optional[Union[str, PipelineVariable]] = None, + output_kms_key: Optional[Union[str, PipelineVariable]] = None, + code_location: Optional[str] = None, + max_runtime_in_seconds: Optional[Union[int, PipelineVariable]] = None, + base_job_name: Optional[str] = None, + sagemaker_session: Optional[Session] = None, + env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, + network_config: Optional[NetworkConfig] = None, ): """This processor executes a Python script in a PyTorch execution environment. diff --git a/src/sagemaker/rl/estimator.py b/src/sagemaker/rl/estimator.py index b004dd87b8..b95f192ea8 100644 --- a/src/sagemaker/rl/estimator.py +++ b/src/sagemaker/rl/estimator.py @@ -13,10 +13,11 @@ """Placeholder docstring""" from __future__ import absolute_import +from typing import Optional, Union, List, Dict + import enum import logging import re -from typing import Union, Optional from sagemaker import image_uris, fw_utils from sagemaker.estimator import Framework, EstimatorBase @@ -77,13 +78,13 @@ class RLEstimator(Framework): def __init__( self, entry_point: Union[str, PipelineVariable], - toolkit=None, - toolkit_version=None, - framework=None, + toolkit: Optional[RLToolkit] = None, + toolkit_version: Optional[str] = None, + framework: Optional[Framework] = None, source_dir: Optional[Union[str, PipelineVariable]] = None, - hyperparameters=None, - image_uri=None, - metric_definitions=None, + hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + image_uri: Optional[Union[str, PipelineVariable]] = None, + metric_definitions: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, **kwargs ): """Creates an RLEstimator for managed Reinforcement Learning (RL). diff --git a/src/sagemaker/serverless/serverless_inference_config.py b/src/sagemaker/serverless/serverless_inference_config.py index 39950f4f84..adc98a319a 100644 --- a/src/sagemaker/serverless/serverless_inference_config.py +++ b/src/sagemaker/serverless/serverless_inference_config.py @@ -27,8 +27,8 @@ class ServerlessInferenceConfig(object): def __init__( self, - memory_size_in_mb=2048, - max_concurrency=5, + memory_size_in_mb: int = 2048, + max_concurrency: int = 5, ): """Initialize a ServerlessInferenceConfig object for serverless inference configuration. diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 145bf41cbe..221434d7db 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -2633,7 +2633,9 @@ def _create_model_request( request["VpcConfig"] = vpc_config if enable_network_isolation: - request["EnableNetworkIsolation"] = True + # enable_network_isolation may be a pipeline variable which is + # parsed in execution time + request["EnableNetworkIsolation"] = enable_network_isolation return request diff --git a/src/sagemaker/sklearn/estimator.py b/src/sagemaker/sklearn/estimator.py index e13fbb764c..72372f602c 100644 --- a/src/sagemaker/sklearn/estimator.py +++ b/src/sagemaker/sklearn/estimator.py @@ -13,8 +13,9 @@ """Placeholder docstring""" from __future__ import absolute_import +from typing import Optional, Union, Dict + import logging -from typing import Union, Optional from sagemaker import image_uris from sagemaker.deprecations import renamed_kwargs @@ -28,6 +29,7 @@ from sagemaker.sklearn.model import SKLearnModel from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT from sagemaker.workflow.entities import PipelineVariable +from sagemaker.workflow import is_pipeline_variable logger = logging.getLogger("sagemaker") @@ -40,12 +42,12 @@ class SKLearn(Framework): def __init__( self, entry_point: Union[str, PipelineVariable], - framework_version=None, - py_version="py3", + framework_version: Optional[str] = None, + py_version: str = "py3", source_dir: Optional[Union[str, PipelineVariable]] = None, - hyperparameters=None, - image_uri=None, - image_uri_region=None, + hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + image_uri: Optional[Union[str, PipelineVariable]] = None, + image_uri_region: Optional[str] = None, **kwargs ): """Creates a SKLearn Estimator for Scikit-learn environment. @@ -128,11 +130,17 @@ def __init__( self.framework_version = framework_version self.py_version = py_version + if instance_type: + if is_pipeline_variable(instance_type): + raise ValueError("instance_count argument cannot be a pipeline variable") # SciKit-Learn does not support distributed training or training on GPU instance types. # Fail fast. _validate_not_gpu_instance_type(instance_type) if instance_count: + if is_pipeline_variable(instance_count): + raise ValueError("instance_count argument cannot be a pipeline variable") + if instance_count != 1: raise AttributeError( "Scikit-Learn does not support distributed training. Please remove the " diff --git a/src/sagemaker/sklearn/processing.py b/src/sagemaker/sklearn/processing.py index c5445e31f4..7a4a953a3d 100644 --- a/src/sagemaker/sklearn/processing.py +++ b/src/sagemaker/sklearn/processing.py @@ -17,9 +17,13 @@ """ from __future__ import absolute_import +from typing import Union, List, Dict, Optional + +from sagemaker.network import NetworkConfig from sagemaker import image_uris, Session from sagemaker.processing import ScriptProcessor from sagemaker.sklearn import defaults +from sagemaker.workflow.entities import PipelineVariable class SKLearnProcessor(ScriptProcessor): @@ -27,20 +31,20 @@ class SKLearnProcessor(ScriptProcessor): def __init__( self, - framework_version, - role, - instance_type, - instance_count, - command=None, - volume_size_in_gb=30, - volume_kms_key=None, - output_kms_key=None, - max_runtime_in_seconds=None, - base_job_name=None, - sagemaker_session=None, - env=None, - tags=None, - network_config=None, + framework_version: str, # New arg + role: str, + instance_count: Union[int, PipelineVariable], + instance_type: str, + command: Optional[List[str]] = None, + volume_size_in_gb: Union[int, PipelineVariable] = 30, + volume_kms_key: Optional[Union[str, PipelineVariable]] = None, + output_kms_key: Optional[Union[str, PipelineVariable]] = None, + max_runtime_in_seconds: Optional[Union[int, PipelineVariable]] = None, + base_job_name: Optional[str] = None, + sagemaker_session: Optional[Session] = None, + env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, + network_config: Optional[NetworkConfig] = None, ): """Initialize an ``SKLearnProcessor`` instance. diff --git a/src/sagemaker/tensorflow/estimator.py b/src/sagemaker/tensorflow/estimator.py index 4db647e140..9533f475a1 100644 --- a/src/sagemaker/tensorflow/estimator.py +++ b/src/sagemaker/tensorflow/estimator.py @@ -14,6 +14,7 @@ from __future__ import absolute_import import logging +from typing import Optional, Union, Dict from packaging import version @@ -27,6 +28,7 @@ from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT from sagemaker.workflow import is_pipeline_variable from sagemaker.tensorflow.training_compiler.config import TrainingCompilerConfig +from sagemaker.workflow.entities import PipelineVariable logger = logging.getLogger("sagemaker") @@ -41,12 +43,12 @@ class TensorFlow(Framework): def __init__( self, - py_version=None, - framework_version=None, - model_dir=None, - image_uri=None, - distribution=None, - compiler_config=None, + py_version: Optional[str] = None, + framework_version: Optional[str] = None, + model_dir: Optional[Union[str, PipelineVariable]] = None, + image_uri: Optional[Union[str, PipelineVariable]] = None, + distribution: Optional[Dict[str, str]] = None, + compiler_config: Optional[TrainingCompilerConfig] = None, **kwargs, ): """Initialize a ``TensorFlow`` estimator. @@ -251,6 +253,8 @@ def _only_legacy_mode_supported(self): def _only_python_3_supported(self): """Placeholder docstring""" + if not self.framework_version: + return False return version.Version(self.framework_version) > self._HIGHEST_PYTHON_2_VERSION @classmethod diff --git a/src/sagemaker/tensorflow/processing.py b/src/sagemaker/tensorflow/processing.py index 38bf784b7c..aa489e6b8c 100644 --- a/src/sagemaker/tensorflow/processing.py +++ b/src/sagemaker/tensorflow/processing.py @@ -17,8 +17,13 @@ """ from __future__ import absolute_import +from typing import Union, List, Dict, Optional + +from sagemaker.session import Session +from sagemaker.network import NetworkConfig from sagemaker.processing import FrameworkProcessor from sagemaker.tensorflow.estimator import TensorFlow +from sagemaker.workflow.entities import PipelineVariable class TensorFlowProcessor(FrameworkProcessor): @@ -28,23 +33,23 @@ class TensorFlowProcessor(FrameworkProcessor): def __init__( self, - framework_version, # New arg - role, - instance_count, - instance_type, - py_version="py3", # New kwarg - image_uri=None, - command=None, - volume_size_in_gb=30, - volume_kms_key=None, - output_kms_key=None, - code_location=None, # New arg - max_runtime_in_seconds=None, - base_job_name=None, - sagemaker_session=None, - env=None, - tags=None, - network_config=None, + framework_version: str, # New arg + role: str, + instance_count: Union[int, PipelineVariable], + instance_type: Union[str, PipelineVariable], + py_version: str = "py3", # New kwarg + image_uri: Optional[Union[str, PipelineVariable]] = None, + command: Optional[List[str]] = None, + volume_size_in_gb: Union[int, PipelineVariable] = 30, + volume_kms_key: Optional[Union[str, PipelineVariable]] = None, + output_kms_key: Optional[Union[str, PipelineVariable]] = None, + code_location: Optional[str] = None, # New arg + max_runtime_in_seconds: Optional[Union[int, PipelineVariable]] = None, + base_job_name: Optional[str] = None, + sagemaker_session: Session = None, + env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, + network_config: Optional[NetworkConfig] = None, ): """This processor executes a Python script in a TensorFlow execution environment. diff --git a/src/sagemaker/tensorflow/training_compiler/config.py b/src/sagemaker/tensorflow/training_compiler/config.py index d14cc3359b..35516477a1 100644 --- a/src/sagemaker/tensorflow/training_compiler/config.py +++ b/src/sagemaker/tensorflow/training_compiler/config.py @@ -13,10 +13,13 @@ """Configuration for the SageMaker Training Compiler.""" from __future__ import absolute_import import logging +from typing import Union + from packaging.specifiers import SpecifierSet from packaging.version import Version from sagemaker.training_compiler.config import TrainingCompilerConfig as BaseConfig +from sagemaker.workflow.entities import PipelineVariable logger = logging.getLogger(__name__) @@ -29,8 +32,8 @@ class TrainingCompilerConfig(BaseConfig): def __init__( self, - enabled=True, - debug=False, + enabled: Union[bool, PipelineVariable] = True, + debug: Union[bool, PipelineVariable] = False, ): """This class initializes a ``TrainingCompilerConfig`` instance. diff --git a/src/sagemaker/training_compiler/config.py b/src/sagemaker/training_compiler/config.py index cd1fbd5957..5e0c312242 100644 --- a/src/sagemaker/training_compiler/config.py +++ b/src/sagemaker/training_compiler/config.py @@ -14,6 +14,8 @@ from __future__ import absolute_import import logging +from sagemaker.workflow import is_pipeline_variable + logger = logging.getLogger(__name__) @@ -132,6 +134,19 @@ def validate( ValueError: Raised if the requested configuration is not compatible with SageMaker Training Compiler. """ + if estimator.image_uri: + error_helper_string = ( + "Overriding the image URI is currently not supported " + "for SageMaker Training Compiler." + "Specify the following parameters to run the Hugging Face training job " + "with SageMaker Training Compiler enabled: " + "transformer_version, tensorflow_version or pytorch_version, and compiler_config." + ) + raise ValueError(error_helper_string) + + if is_pipeline_variable(estimator.instance_type): + # skip the validation if either instance type is a pipeline variable + return if "local" not in estimator.instance_type: requested_instance_class = estimator.instance_type.split(".")[ diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index d71b8e1433..a7f07963fc 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -36,7 +36,6 @@ from sagemaker.session_settings import SessionSettings from sagemaker.workflow import is_pipeline_variable, is_pipeline_parameter_string - ECR_URI_PATTERN = r"^(\d+)(\.)dkr(\.)ecr(\.)(.+)(\.)(.*)(/)(.*:.*)$" MAX_BUCKET_PATHS_COUNT = 5 S3_PREFIX = "s3://" diff --git a/src/sagemaker/xgboost/estimator.py b/src/sagemaker/xgboost/estimator.py index f6f0005f1f..498d009dd0 100644 --- a/src/sagemaker/xgboost/estimator.py +++ b/src/sagemaker/xgboost/estimator.py @@ -13,8 +13,9 @@ """Placeholder docstring""" from __future__ import absolute_import +from typing import Optional, Union, Dict + import logging -from typing import Union, Optional from sagemaker import image_uris from sagemaker.deprecations import renamed_kwargs @@ -26,11 +27,12 @@ ) from sagemaker.session import Session from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT -from sagemaker.workflow.entities import PipelineVariable from sagemaker.xgboost import defaults from sagemaker.xgboost.model import XGBoostModel from sagemaker.xgboost.utils import validate_py_version, validate_framework_version +from sagemaker.workflow.entities import PipelineVariable + logger = logging.getLogger("sagemaker") @@ -45,12 +47,12 @@ class XGBoost(Framework): def __init__( self, entry_point: Union[str, PipelineVariable], - framework_version, + framework_version: str, source_dir: Optional[Union[str, PipelineVariable]] = None, - hyperparameters=None, - py_version="py3", - image_uri=None, - image_uri_region=None, + hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + py_version: str = "py3", + image_uri: Optional[Union[str, PipelineVariable]] = None, + image_uri_region: Optional[str] = None, **kwargs ): """An estimator that executes an XGBoost-based SageMaker Training Job. diff --git a/src/sagemaker/xgboost/processing.py b/src/sagemaker/xgboost/processing.py index 3ca4b2361f..f3e9c08f02 100644 --- a/src/sagemaker/xgboost/processing.py +++ b/src/sagemaker/xgboost/processing.py @@ -17,8 +17,13 @@ """ from __future__ import absolute_import +from typing import Union, List, Dict, Optional + +from sagemaker.session import Session +from sagemaker.network import NetworkConfig from sagemaker.processing import FrameworkProcessor from sagemaker.xgboost.estimator import XGBoost +from sagemaker.workflow.entities import PipelineVariable class XGBoostProcessor(FrameworkProcessor): @@ -28,23 +33,23 @@ class XGBoostProcessor(FrameworkProcessor): def __init__( self, - framework_version, # New arg - role, - instance_count, - instance_type, - py_version="py3", # New kwarg - image_uri=None, - command=None, - volume_size_in_gb=30, - volume_kms_key=None, - output_kms_key=None, - code_location=None, # New arg - max_runtime_in_seconds=None, - base_job_name=None, - sagemaker_session=None, - env=None, - tags=None, - network_config=None, + framework_version: str, + role: str, + instance_count: Union[int, PipelineVariable], + instance_type: Union[str, PipelineVariable], + py_version: str = "py3", + image_uri: Optional[Union[str, PipelineVariable]] = None, + command: Optional[List[str]] = None, + volume_size_in_gb: Union[int, PipelineVariable] = 30, + volume_kms_key: Optional[Union[str, PipelineVariable]] = None, + output_kms_key: Optional[Union[str, PipelineVariable]] = None, + code_location: Optional[str] = None, + max_runtime_in_seconds: Optional[Union[int, PipelineVariable]] = None, + base_job_name: Optional[str] = None, + sagemaker_session: Optional[Session] = None, + env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, + network_config: Optional[NetworkConfig] = None, ): """This processor executes a Python script in an XGBoost execution environment. diff --git a/tests/unit/sagemaker/workflow/test_mechanism/test_code/__init__.py b/tests/unit/sagemaker/workflow/test_mechanism/test_code/__init__.py new file mode 100644 index 0000000000..0ad89c662c --- /dev/null +++ b/tests/unit/sagemaker/workflow/test_mechanism/test_code/__init__.py @@ -0,0 +1,796 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import json +import os +from random import getrandbits +from typing import List + +from sagemaker import ModelMetrics, MetricsSource, FileSource, Predictor +from sagemaker.drift_check_baselines import DriftCheckBaselines +from sagemaker.instance_group import InstanceGroup +from sagemaker.metadata_properties import MetadataProperties +from sagemaker.model import FrameworkModel +from sagemaker.parameter import IntegerParameter +from sagemaker.serverless import ServerlessInferenceConfig +from sagemaker.sparkml import SparkMLModel +from sagemaker.tensorflow import TensorFlow +from sagemaker.tuner import WarmStartConfig, WarmStartTypes +from mock import Mock, PropertyMock +from sagemaker import TrainingInput +from sagemaker.debugger import ( + Rule, + DebuggerHookConfig, + TensorBoardOutputConfig, + ProfilerConfig, + CollectionConfig, +) +from sagemaker.clarify import DataConfig +from sagemaker.network import NetworkConfig +from sagemaker.processing import ProcessingInput, ProcessingOutput +from sagemaker.amazon.amazon_estimator import RecordSet +from sagemaker.rl.estimator import RLToolkit, RLFramework +from sagemaker.pytorch import PyTorch +from sagemaker.workflow.entities import PipelineVariable +from sagemaker.workflow.execution_variables import ExecutionVariables +from sagemaker.workflow.functions import Join, JsonGet +from sagemaker.workflow.model_step import ModelStep +from sagemaker.workflow.parameters import ( + ParameterString, + ParameterInteger, + ParameterFloat, + ParameterBoolean, +) +from sagemaker.workflow.pipeline_context import PipelineSession +from sagemaker.workflow.properties import PropertyFile +from sagemaker.workflow.steps import ProcessingStep, TrainingStep, TuningStep, TransformStep +from tests.unit import DATA_DIR + +STR_VAL = "MyString" +ROLE = "DummyRole" +INSTANCE_TYPE = "ml.m5.xlarge" +BUCKET = "my-bucket" +REGION = "us-west-2" +IMAGE_URI = "fakeimage" +DUMMY_S3_SCRIPT_PATH = "s3://dummy-s3/dummy_script.py" +TENSORFLOW_PATH = os.path.join(DATA_DIR, "tfs/tfs-test-entrypoint-and-dependencies") +TENSORFLOW_ENTRY_POINT = os.path.join(TENSORFLOW_PATH, "inference.py") +GIT_REPO = "https://github.com/aws/sagemaker-python-sdk.git" +BRANCH = "test-branch-git-config" +COMMIT = "ae15c9d7d5b97ea95ea451e4662ee43da3401d73" + +DEFAULT_VALUE = "default_value" +CLAZZ_ARGS = "clazz_args" +FUNC_ARGS = "func_args" +REQUIRED = "required" +OPTIONAL = "optional" +COMMON = "common" +INIT = "init" +TYPE = "type" + + +class PipelineVariableEncoder(json.JSONEncoder): + """The Json Encoder for PipelineVariable""" + + def default(self, obj): + """To return a serializable object for the input object if it's a PipelineVariable + + or call the base implementation if it's not + + Args: + obj (object): The input object to be handled. + """ + if isinstance(obj, PipelineVariable): + return obj.expr + return json.JSONEncoder.default(self, obj) + + +class MockProperties(PipelineVariable): + """A mock object or Pipeline Properties""" + + def __init__( + self, + step_name: str, + path: str = None, + shape_name: str = None, + shape_names: List[str] = None, + service_name: str = "sagemaker", + ): + """Initialize a MockProperties object""" + self.step_name = step_name + self.path = path + + @property + def expr(self): + """The 'Get' expression dict for a `Properties`.""" + return {"Get": f"Steps.{self.step_name}.Outcome"} + + @property + def _referenced_steps(self) -> List[str]: + """List of step names that this function depends on.""" + return [self.step_name] + + +def _generate_mock_pipeline_session(): + """Generate mock pipeline session""" + client_mock = Mock() + client_mock._client_config.user_agent = ( + "Boto3/1.14.24 Python/3.8.5 Linux/5.4.0-42-generic Botocore/1.17.24 Resource" + ) + client_mock.describe_algorithm.return_value = { + "TrainingSpecification": { + "TrainingChannels": [ + { + "SupportedContentTypes": ["application/json"], + "SupportedInputModes": ["File"], + "Name": "train", + } + ], + "SupportedTrainingInstanceTypes": ["ml.m5.xlarge", "ml.m4.xlarge"], + "SupportedHyperParameters": [ + { + "Name": "MyKey", + "Type": "FreeText", + } + ], + }, + "AlgorithmName": "algo-name", + } + + role_mock = Mock() + type(role_mock).arn = PropertyMock(return_value=ROLE) + resource_mock = Mock() + resource_mock.Role.return_value = role_mock + session_mock = Mock(region_name=REGION) + session_mock.resource.return_value = resource_mock + session_mock.client.return_value = client_mock + + return PipelineSession( + boto_session=session_mock, + sagemaker_client=client_mock, + default_bucket=BUCKET, + ) + + +def _generate_all_pipeline_vars() -> dict: + """Generate a dic with all kinds of Pipeline variables""" + # Parameter + ppl_param_str = ParameterString(name="MyString") + ppl_param_int = ParameterInteger(name="MyInt") + ppl_param_float = ParameterFloat(name="MyFloat") + ppl_param_bool = ParameterBoolean(name="MyBool") + + # Function + ppl_join = Join(on=" ", values=[ppl_param_int, ppl_param_float, 1, "test"]) + property_file = PropertyFile( + name="name", + output_name="result", + path="output", + ) + ppl_json_get = JsonGet( + step_name="my-step", + property_file=property_file, + json_path="my-json-path", + ) + + # Properties + ppl_prop = MockProperties(step_name="MyPreStep") + + # Execution Variables + ppl_exe_var = ExecutionVariables.PIPELINE_NAME + + return dict( + str=[ + ( + ppl_param_str, + dict(origin=ppl_param_str.expr, to_string=ppl_param_str.to_string().expr), + ), + (ppl_join, dict(origin=ppl_join.expr, to_string=ppl_join.to_string().expr)), + (ppl_json_get, dict(origin=ppl_json_get.expr, to_string=ppl_json_get.to_string().expr)), + (ppl_prop, dict(origin=ppl_prop.expr, to_string=ppl_prop.to_string().expr)), + (ppl_exe_var, dict(origin=ppl_exe_var.expr, to_string=ppl_exe_var.to_string().expr)), + ], + int=[ + ( + ppl_param_int, + dict(origin=ppl_param_int.expr, to_string=ppl_param_int.to_string().expr), + ), + (ppl_json_get, dict(origin=ppl_json_get.expr, to_string=ppl_json_get.to_string().expr)), + (ppl_prop, dict(origin=ppl_prop.expr, to_string=ppl_prop.to_string().expr)), + ], + float=[ + ( + ppl_param_float, + dict(origin=ppl_param_float.expr, to_string=ppl_param_float.to_string().expr), + ), + (ppl_json_get, dict(origin=ppl_json_get.expr, to_string=ppl_json_get.to_string().expr)), + ( + ppl_prop, + dict(origin=ppl_prop.expr, to_string=ppl_prop.to_string().expr), + ), + ], + bool=[ + ( + ppl_param_bool, + dict(origin=ppl_param_bool.expr, to_string=ppl_param_bool.to_string().expr), + ), + (ppl_json_get, dict(origin=ppl_json_get.expr, to_string=ppl_json_get.to_string().expr)), + ( + ppl_prop, + dict(origin=ppl_prop.expr, to_string=ppl_prop.to_string().expr), + ), + ], + ) + + +# TODO: we should remove the _IS_TRUE_TMP and replace its usages with IS_TRUE +# As currently the `instance_groups` does not work well with some estimator subclasses, +# we temporarily hard code it to False which disables the instance_groups +_IS_TRUE_TMP = False +IS_TRUE = bool(getrandbits(1)) +PIPELINE_SESSION = _generate_mock_pipeline_session() +PIPELINE_VARIABLES = _generate_all_pipeline_vars() + +# TODO: need to recursively assign with Pipeline Variable in later changes +FIXED_ARGUMENTS = dict( + common=dict( + role=ROLE, + sagemaker_session=PIPELINE_SESSION, + source_dir=f"s3://{BUCKET}/source", + entry_point=TENSORFLOW_ENTRY_POINT, + dependencies=[os.path.join(TENSORFLOW_PATH, "dependency.py")], + code_location=f"s3://{BUCKET}/code", + predictor_cls=Predictor, + model_metrics=ModelMetrics( + model_statistics=MetricsSource( + content_type=ParameterString(name="model_statistics_content_type"), + s3_uri=ParameterString(name="model_statistics_s3_uri"), + content_digest=ParameterString(name="model_statistics_content_digest"), + ) + ), + metadata_properties=MetadataProperties( + commit_id=ParameterString(name="meta_properties_commit_id"), + repository=ParameterString(name="meta_properties_repository"), + generated_by=ParameterString(name="meta_properties_generated_by"), + project_id=ParameterString(name="meta_properties_project_id"), + ), + drift_check_baselines=DriftCheckBaselines( + model_constraints=MetricsSource( + content_type=ParameterString(name="drift_constraints_content_type"), + s3_uri=ParameterString(name="drift_constraints_s3_uri"), + content_digest=ParameterString(name="drift_constraints_content_digest"), + ), + bias_config_file=FileSource( + content_type=ParameterString(name="drift_bias_content_type"), + s3_uri=ParameterString(name="drift_bias_s3_uri"), + content_digest=ParameterString(name="drift_bias_content_digest"), + ), + ), + model_package_name="my-model-pkg" if IS_TRUE else None, + model_package_group_name="my-model-pkg-group" if not IS_TRUE else None, + inference_instances=["ml.t2.medium", "ml.m5.xlarge"], + transform_instances=["ml.t2.medium", "ml.m5.xlarge"], + content_types=["application/json"], + response_types=["application/json"], + ), + processor=dict( + instance_type=INSTANCE_TYPE, + estimator_cls=PyTorch, + code=f"s3://{BUCKET}/code", + spark_event_logs_s3_uri=f"s3://{BUCKET}/my-spark-output-path", + framework_version="1.8", + network_config=NetworkConfig( + subnets=[ParameterString(name="nw_cfg_subnets")], + security_group_ids=[ParameterString(name="nw_cfg_security_group_ids")], + enable_network_isolation=ParameterBoolean(name="nw_cfg_enable_nw_isolation"), + encrypt_inter_container_traffic=ParameterBoolean( + name="nw_cfg_encrypt_inter_container_traffic" + ), + ), + inputs=[ + ProcessingInput( + source=ParameterString(name="proc_input_source"), + destination=ParameterString(name="proc_input_dest"), + s3_data_type=ParameterString(name="proc_input_s3_data_type"), + app_managed=ParameterBoolean(name="proc_input_app_managed"), + ), + ], + outputs=[ + ProcessingOutput( + source=ParameterString(name="proc_output_source"), + destination=ParameterString(name="proc_output_dest"), + app_managed=ParameterBoolean(name="proc_output_app_managed"), + ), + ], + data_config=DataConfig( + s3_data_input_path=ParameterString(name="clarify_processor_input"), + s3_output_path=ParameterString(name="clarify_processor_output"), + s3_analysis_config_output_path="s3://analysis_config_output_path", + ), + data_bias_config=DataConfig( + s3_data_input_path=ParameterString(name="clarify_processor_input"), + s3_output_path=ParameterString(name="clarify_processor_output"), + s3_analysis_config_output_path="s3://analysis_config_output_path", + ), + ), + estimator=dict( + image_uri_region="us-west-2", + input_mode="File", + records=RecordSet( + s3_data=ParameterString(name="records_s3_data"), + num_records=1000, + feature_dim=128, + s3_data_type=ParameterString(name="records_s3_data_type"), + channel=ParameterString(name="records_channel"), + ), + disable_profiler=False, + vector_dim=128, + enc_dim=128, + momentum=1e-6, + beta_1=1e-4, + beta_2=1e-4, + mini_batch_size=1000, + dropout=0.25, + num_classes=10, + mlp_dim=512, + mlp_activation="relu", + output_layer="softmax", + comparator_list="hadamard,concat,abs_diff", + token_embedding_storage_type="dense", + enc0_network="bilstm", + enc1_network="bilstm", + enc0_token_embedding_dim=256, + enc1_token_embedding_dim=256, + enc0_vocab_size=512, + enc1_vocab_size=512, + bias_init_method="normal", + factors_init_method="normal", + predictor_type="regressor", + linear_init_method="uniform", + toolkit=RLToolkit.RAY, + toolkit_version="1.6.0", + framework=RLFramework.PYTORCH, + algorithm_mode="regular", + num_topics=6, + k=6, + init_method="kmeans++", + local_init_method="kmeans++", + eval_metrics="mds,ssd", + tol=1e-4, + dimension_reduction_type="sign", + index_type="faiss.Flat", + faiss_index_ivf_nlists="auto", + index_metric="COSINE", + binary_classifier_model_selection_criteria="f1", + positive_example_weight_mult="auto", + loss="logistic", + target_recall=0.1, + target_precision=0.8, + early_stopping_tolerance=1e-4, + encoder_layers_activation="relu", + optimizer="adam", + tolerance=1e-4, + rescale_gradient=1e-2, + weight_decay=1e-6, + learning_rate=1e-4, + num_trees=50, + source_dir=f"s3://{BUCKET}/source", + entry_point=os.path.join(TENSORFLOW_PATH, "inference.py"), + dependencies=[os.path.join(TENSORFLOW_PATH, "dependency.py")], + code_location=f"s3://{BUCKET}/code", + output_path=f"s3://{BUCKET}/output", + model_uri=f"s3://{BUCKET}/model", + py_version="py2", + framework_version="2.1.1", + rules=[ + Rule.custom( + name="CustomeRule", + image_uri=ParameterString(name="rules_image_uri"), + instance_type=ParameterString(name="rules_instance_type"), + volume_size_in_gb=ParameterInteger(name="rules_volume_size"), + source="path/to/my_custom_rule.py", + rule_to_invoke=ParameterString(name="rules_to_invoke"), + container_local_output_path=ParameterString(name="rules_local_output"), + s3_output_path=ParameterString(name="rules_to_s3_output_path"), + other_trials_s3_input_paths=[ParameterString(name="rules_other_s3_input")], + rule_parameters={"threshold": ParameterString(name="rules_param")}, + collections_to_save=[ + CollectionConfig( + name=ParameterString(name="rules_collections_name"), + parameters={"key1": ParameterString(name="rules_collections_param")}, + ) + ], + ) + ], + debugger_hook_config=DebuggerHookConfig( + s3_output_path=ParameterString(name="debugger_hook_s3_output"), + container_local_output_path=ParameterString(name="debugger_container_output"), + hook_parameters={"key1": ParameterString(name="debugger_hook_param")}, + collection_configs=[ + CollectionConfig( + name=ParameterString(name="debugger_collections_name"), + parameters={"key1": ParameterString(name="debugger_collections_param")}, + ) + ], + ), + tensorboard_output_config=TensorBoardOutputConfig( + s3_output_path=ParameterString(name="tensorboard_s3_output"), + container_local_output_path=ParameterString(name="tensorboard_container_output"), + ), + profiler_config=ProfilerConfig( + s3_output_path=ParameterString(name="profile_config_s3_output_path"), + system_monitor_interval_millis=ParameterInteger(name="profile_config_system_monitor"), + ), + inputs={ + "train": TrainingInput( + s3_data=ParameterString(name="train_inputs_s3_data"), + distribution=ParameterString(name="train_inputs_distribution"), + compression=ParameterString(name="train_inputs_compression"), + content_type=ParameterString(name="train_inputs_content_type"), + record_wrapping=ParameterString(name="train_inputs_record_wrapping"), + s3_data_type=ParameterString(name="train_inputs_s3_data_type"), + input_mode=ParameterString(name="train_inputs_input_mode"), + attribute_names=[ParameterString(name="train_inputs_attribute_name")], + target_attribute_name=ParameterString(name="train_inputs_target_attr_name"), + instance_groups=[ParameterString(name="train_inputs_instance_groups")], + ), + }, + instance_groups=[ + InstanceGroup( + instance_group_name=ParameterString(name="instance_group_name"), + # hard code the instance_type here because InstanceGroup.instance_type + # would be used to retrieve image_uri if image_uri is not presented + # and currently the test mechanism does not support skip the test case + # relating to bonded parameters in composite variables (i.e. the InstanceGroup) + # TODO: we should support skip testing on bonded parameters in composite vars + instance_type="ml.m5.xlarge", + instance_count=ParameterString(name="instance_group_instance_count"), + ), + ] + if _IS_TRUE_TMP + else None, + instance_type="ml.m5.xlarge" if not _IS_TRUE_TMP else None, + instance_count=1 if not _IS_TRUE_TMP else None, + distribution={} if not _IS_TRUE_TMP else None, + ), + transformer=dict( + instance_type=INSTANCE_TYPE, + data=f"s3://{BUCKET}/data", + ), + tuner=dict( + instance_type=INSTANCE_TYPE, + estimator=TensorFlow( + entry_point=TENSORFLOW_ENTRY_POINT, + role=ROLE, + framework_version="2.1.1", + py_version="py2", + instance_count=1, + instance_type="ml.m5.xlarge", + sagemaker_session=PIPELINE_SESSION, + enable_sagemaker_metrics=True, + max_retry_attempts=3, + hyperparameters={"static-hp": "hp1", "train_size": "1280"}, + ), + hyperparameter_ranges={ + "batch-size": IntegerParameter( + min_value=ParameterInteger(name="hyper_range_min_value"), + max_value=ParameterInteger(name="hyper_range_max_value"), + scaling_type=ParameterString(name="hyper_range_scaling_type"), + ), + }, + warm_start_config=WarmStartConfig( + warm_start_type=WarmStartTypes.IDENTICAL_DATA_AND_ALGORITHM, + parents={ParameterString(name="warm_start_cfg_parent")}, + ), + estimator_name="estimator-1", + inputs={ + "estimator-1": TrainingInput(s3_data=ParameterString(name="inputs_estimator_1")), + }, + include_cls_metadata={"estimator-1": IS_TRUE}, + ), + model=dict( + instance_type=INSTANCE_TYPE, + serverless_inference_config=ServerlessInferenceConfig(), + framework_version="1.11.0", + py_version="py3", + accelerator_type="ml.eia2.xlarge", + ), + pipelinemodel=dict( + instance_type=INSTANCE_TYPE, + models=[ + SparkMLModel( + name="MySparkMLModel", + model_data=f"s3://{BUCKET}", + role=ROLE, + sagemaker_session=PIPELINE_SESSION, + env={"SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT": "text/csv"}, + ), + FrameworkModel( + image_uri=IMAGE_URI, + model_data=f"s3://{BUCKET}/model.tar.gz", + role=ROLE, + sagemaker_session=PIPELINE_SESSION, + entry_point=f"{DATA_DIR}/dummy_script.py", + name="modelName", + vpc_config={"Subnets": ["abc", "def"], "SecurityGroupIds": ["123", "456"]}, + ), + ], + ), +) +STEP_CLASS = dict( + processor=ProcessingStep, + estimator=TrainingStep, + transformer=TransformStep, + tuner=TuningStep, + model=ModelStep, + pipelinemodel=ModelStep, +) + +# A dict for class __init__ parameters which are not used in +# request dict generated by a specific target function but are used in another +# For example, as for Model class constructor, the vpc_config is used only in +# model.create and is ignored in model.register +# Thus when testing on model.register we should skip replacing vpc_config +# with pipeline variables +CLASS_PARAMS_EXCLUDED_IN_TARGET_FUNC = dict( + model=dict( + register=dict( + common={"vpc_config", "enable_network_isolation"}, + ), + ), + pipelinemodel=dict( + register=dict( + common={"vpc_config", "enable_network_isolation"}, + ), + ), +) +# A dict for base class __init__ parameters which are not used in +# some specific subclasses. +# For example, TensorFlowModel uses **kwarg for duplicate parameters +# in base class (FrameworkModel/Model) but it ignores the "image_config" +# in target functions. +BASE_CLASS_PARAMS_EXCLUDED_IN_SUB_CLASS = dict( + model=dict( + TensorFlowModel={"image_config"}, + SKLearnModel={"image_config"}, + PyTorchModel={"image_config"}, + XGBoostModel={"image_config"}, + ChainerModel={"image_config"}, + HuggingFaceModel={"image_config"}, + MXNetModel={"image_config"}, + KNNModel={"image_uri"}, # it's overridden in __init__ with a fixed value + SparkMLModel={"image_uri"}, # it's overridden in __init__ with a fixed value + KMeansModel={"image_uri"}, # it's overridden in __init__ with a fixed value + PCAModel={"image_uri"}, # it's overridden in __init__ with a fixed value + LDAModel={"image_uri"}, # it's overridden in __init__ with a fixed value + NTMModel={"image_uri"}, # it's overridden in __init__ with a fixed value + Object2VecModel={"image_uri"}, # it's overridden in __init__ with a fixed value + FactorizationMachinesModel={"image_uri"}, # it's overridden in __init__ with a fixed value + IPInsightsModel={"image_uri"}, # it's overridden in __init__ with a fixed value + RandomCutForestModel={"image_uri"}, # it's overridden in __init__ with a fixed value + LinearLearnerModel={"image_uri"}, # it's overridden in __init__ with a fixed value + MultiDataModel={ + "model_data", # model_data is overridden in __init__ with model_data_prefix + "image_config", + "container_log_level", # it's simply ignored + }, + ), + estimator=dict( + AlgorithmEstimator={ # its kwargs is ignored so base class parameters are ignored + "enable_network_isolation", + "container_log_level", + "checkpoint_s3_uri", + "checkpoint_local_path", + "enable_sagemaker_metrics", + "environment", + "max_retry_attempts", + "source_dir", + "entry_point", + }, + SKLearn={ + "instance_count", + "instance_type", + }, + ), +) +# A dict to keep the optional arguments which should not be set to None +# in the test iteration according to the logic specific to the subclass. +PARAMS_SHOULD_NOT_BE_NONE = dict( + estimator=dict( + init=dict( + # TODO: we should remove the three instance_ parameters here + # For mutually exclusive parameters: instance group + # vs instance count/instance type, if any side is set to None during iteration, + # the other side should get a not None value, instead of listing them here + # and force them to be not None + common={"instance_count", "instance_type", "instance_groups"}, + LDA={"mini_batch_size"}, + ) + ), + model=dict( + register=dict( + common={}, + Model={"model_data"}, + HuggingFaceModel={"model_data"}, + ), + create=dict( + common={}, + Model={"role"}, + SparkMLModel={"role"}, + MultiDataModel={"role"}, + ), + ), +) +# A dict for parameters which should not be replaced with pipeline variables +# since they are bonded with other parameters with None value. For example: +# Case 1: if outputs (a parameter in FrameworkProcessor.run) is None, +# output_kms_key (a parameter in constructor) is omitted +# so don't need to replace it with pipeline variables +# Case 2: if image_uri is None, instance_type is not allowed to be pipeline variables, +# otherwise, the the class can fail to be initiated +UNSET_PARAM_BONDED_WITH_NONE = dict( + processor=dict( + init=dict( + common=dict(instance_type={"image_uri"}), + ), + run=dict( + common=dict(output_kms_key={"outputs"}), + ), + ), + estimator=dict( + init=dict( + common=dict( + # entry_point can only be parameterized when source_dir is given + # if source_dir is None, entry_point should be skipped to parameterize + entry_point={"source_dir"}, + instance_type={"image_uri"}, + ), + ), + fit=dict( + common=dict( + subnets={"security_group_ids"}, + security_group_ids={"subnets"}, + model_channel_name={"model_uri"}, + checkpoint_local_path={"checkpoint_s3_uri"}, + instance_type={"image_uri"}, + ), + ), + ), + model=dict( + register=dict( + common=dict( + env={"model_package_group_name"}, + image_config={"model_package_group_name"}, + model_server_workers={"model_package_group_name"}, + container_log_level={"model_package_group_name"}, + framework={"model_package_group_name"}, + framework_version={"model_package_group_name"}, + nearest_model_name={"model_package_group_name"}, + data_input_configuration={"model_package_group_name"}, + ) + ), + ), + pipelinemodel=dict( + register=dict( + common=dict( + image_uri={ + # model_package_name and model_package_group_name are mutual exclusive. + # If model_package_group_name is not None, image_uri will be ignored + "model_package_name" + }, + framework={"model_package_group_name"}, + framework_version={"model_package_group_name"}, + nearest_model_name={"model_package_group_name"}, + data_input_configuration={"model_package_group_name"}, + ), + ), + ), +) + +# A dict for parameters which should not be replaced with pipeline variables +# since they are bonded with other parameters with not None value. For example: +# 1. for any model subclass, if model_package_name is not None, model_package_group_name should be None +# and should skip to be replaced with a pipeline variable +# 2. for MultiDataModel, if if model is given, its kwargs including container_log_level will be ignored +# Note: for any mutual exclusive parameters (e.g. model_package_name, model_package_group_name), +# we can add an entry for each of them. +UNSET_PARAM_BONDED_WITH_NOT_NONE = dict( + model=dict( + register=dict( + common=dict( + model_package_name={"model_package_group_name"}, + model_package_group_name={"model_package_name"}, + ), + ), + ), + pipelinemodel=dict( + register=dict( + common=dict( + model_package_name={"model_package_group_name"}, + model_package_group_name={"model_package_name"}, + ), + ), + ), + estimator=dict( + init=dict( + common=dict( + entry_point={"enable_network_isolation"}, + source_dir={"enable_network_isolation"}, + ), + TensorFlow=dict( + image_uri={"compiler_config"}, + compiler_config={"image_uri"}, + ), + HuggingFace=dict( + image_uri={"compiler_config"}, + compiler_config={"image_uri"}, + ), + ), + fit=dict( + common=dict( + instance_count={"instance_groups"}, + instance_type={"instance_groups"}, + ), + ), + ), +) + + +# A dict for parameters that should not be set to None since they are bonded with +# other parameters with None value. For example: +# if image_uri is None in TensorFlow, py_version should not be None +# since it's used as substitute argument to retrieve image_uri. +SET_PARAM_BONDED_WITH_NONE = dict( + estimator=dict( + init=dict( + common=dict(), + TensorFlow=dict( + py_version={"image_uri"}, + framework_version={"image_uri"}, + ), + HuggingFace=dict( + transformers_version={"image_uri"}, + tensorflow_version={"pytorch_version"}, + ), + ) + ), + model=dict( + register=dict( + common=dict( + inference_instances={"model_package_group_name"}, + transform_instances={"model_package_group_name"}, + ), + ) + ), + pipelinemodel=dict( + register=dict( + common=dict( + inference_instances={"model_package_group_name"}, + transform_instances={"model_package_group_name"}, + ), + ) + ), +) + +# A dict for parameters that should not be set to None since they are bonded with +# other parameters with not None value. Thus we can skip it. For example: +# dimension_reduction_target should not be none when dimension_reduction_type is set +SET_PARAM_BONDED_WITH_NOT_NONE = dict( + estimator=dict( + init=dict( + common=dict(), + KNN=dict(dimension_reduction_target={"dimension_reduction_type"}), + ), + ), +) diff --git a/tests/unit/sagemaker/workflow/test_mechanism/test_code/parameter_skip_checker.py b/tests/unit/sagemaker/workflow/test_mechanism/test_code/parameter_skip_checker.py new file mode 100644 index 0000000000..c26761b384 --- /dev/null +++ b/tests/unit/sagemaker/workflow/test_mechanism/test_code/parameter_skip_checker.py @@ -0,0 +1,382 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from typing import TYPE_CHECKING + +from tests.unit.sagemaker.workflow.test_mechanism.test_code import ( + UNSET_PARAM_BONDED_WITH_NONE, + UNSET_PARAM_BONDED_WITH_NOT_NONE, + SET_PARAM_BONDED_WITH_NOT_NONE, + SET_PARAM_BONDED_WITH_NONE, + BASE_CLASS_PARAMS_EXCLUDED_IN_SUB_CLASS, + CLASS_PARAMS_EXCLUDED_IN_TARGET_FUNC, + PARAMS_SHOULD_NOT_BE_NONE, + CLAZZ_ARGS, + FUNC_ARGS, + INIT, + COMMON, +) + +if TYPE_CHECKING: + from tests.unit.sagemaker.workflow.test_mechanism.test_code.test_pipeline_var_compatibility_template import ( + PipelineVarCompatiTestTemplate, + ) + + +class ParameterSkipChecker: + """Check if the current parameter can skipped for test""" + + def __init__(self, template: "PipelineVarCompatiTestTemplate"): + """Initialize a `ParameterSkipChecker` instance. + + Args: + template (PipelineVarCompatiTestTemplate): The template object to check + the compatibility between Pipeline variables and the given class and target method. + """ + self._template = template + + def skip_setting_param_to_none( + self, + param_name: str, + target_func: str, + ) -> bool: + """Check if to skip setting the parameter to none. + + Args: + param_name (str): The name of the parameter, which is to be verified that + if we can skip it in the loop. + target_func (str): The target function impacted by the check. + If target_func is init, it means the parameter affects initiating + the class object. + """ + return ( + self._is_param_should_not_be_none(param_name, target_func) + or self._need_set_param_bonded_with_none(param_name, target_func) + or self._need_set_param_bonded_with_not_none(param_name, target_func) + ) + + def skip_setting_clz_param_to_ppl_var( + self, + clz_param: str, + target_func: str, + param_with_none: str = None, + ) -> bool: + """Check if to skip setting the class parameter to pipeline variable + + Args: + clz_param (str): The name of the class __init__ parameter, which is to + be verified that if we can skip it in the loop. + target_func (str): The target function impacted by the check. + If target_func is init, it means the parameter affects initiating + the class object. + param_with_none (str): The name of the parameter with None value. + """ + if target_func == INIT: + return ( + self._need_unset_param_bonded_with_none(clz_param, INIT) + or self._is_base_clz_param_excluded_in_subclz(clz_param) + or self._need_unset_param_bonded_with_not_none(clz_param, INIT) + ) + else: + return ( + self._need_unset_param_bonded_with_none(clz_param, target_func) + or self._need_unset_param_bonded_with_not_none(clz_param, target_func) + or self._is_param_should_not_be_none(param_with_none, target_func) + or self._is_overridden_class_param(clz_param, target_func) + or self._is_clz_param_excluded_in_func(clz_param, target_func) + ) + + def skip_setting_func_param_to_ppl_var( + self, + func_param: str, + target_func: str, + param_with_none: str = None, + ) -> bool: + """Check if to skip setting the function parameter to pipeline variable + + Args: + func_param (str): The name of the function parameter, which is to + be verified that if we can skip it in the loop. + target_func (str): The target function impacted by the check. + param_with_none (str): The name of the parameter with None value. + """ + return ( + self._is_param_should_not_be_none(param_with_none, target_func) + or self._need_unset_param_bonded_with_none(func_param, target_func) + or self._need_unset_param_bonded_with_not_none(func_param, target_func) + ) + + def _need_unset_param_bonded_with_none( + self, + param_name: str, + target_func: str, + ) -> bool: + """Check if to skip testing with pipeline variables due to the bond relationship. + + I.e. the parameter (param_name) does not present in the definition json + or it is not allowed to be a pipeline variable if its boned parameter is None. + Then we can skip replacing the param_name with pipeline variables + + Args: + param_name (str): The name of the parameter, which is to be verified that + if we can skip setting it to be pipeline variables. + target_func (str): The target function impacted by the check. + If target_func is init, it means the bonded parameters affect initiating + the class object + """ + return self._is_param_bonded_with_none( + param_map=UNSET_PARAM_BONDED_WITH_NONE, + param_name=param_name, + target_func=target_func, + ) + + def _need_unset_param_bonded_with_not_none(self, param_name: str, target_func: str) -> bool: + """Check if to skip testing with pipeline variables due to the bond relationship. + + I.e. the parameter (param_name) does not present in the definition json or it should + not be presented or it is not allowed to be a pipeline variable if its boned parameter + is not None. Then we can skip replacing the param_name with pipeline variables + + Args: + param_name (str): The name of the parameter, which is to be verified that + if we can skip replacing it with pipeline variables. + target_func (str): The target function impacted by the check. + """ + return self._is_param_bonded_with_not_none( + param_map=UNSET_PARAM_BONDED_WITH_NOT_NONE, + param_name=param_name, + target_func=target_func, + ) + + def _need_set_param_bonded_with_none(self, param_name: str, target_func: str) -> bool: + """Check if to skip testing with None value due to the bond relationship. + + I.e. if a parameter (another_param) is None, its substitute parameter (param_name) + should not be None. Thus we can skip the test round which sets the param_name + to None under the target function (target_func). + + Args: + param_name (str): The name of the parameter, which is to be verified regarding + None value. + target_func (str): The target function impacted by this check. + """ + return self._is_param_bonded_with_none( + param_map=SET_PARAM_BONDED_WITH_NONE, + param_name=param_name, + target_func=target_func, + ) + + def _need_set_param_bonded_with_not_none(self, param_name: str, target_func: str) -> bool: + """Check if to skip testing with None value due to the bond relationship. + + I.e. if the parameter (another_param) is not None, its bonded parameter (param_name) + should not be None. Thus we can skip the test round which sets the param_name + to None under the target function (target_func). + + Args: + param_name (str): The name of the parameter, which is to be verified + regarding None value. + target_func (str): The target function impacted by this check. + """ + return self._is_param_bonded_with_not_none( + param_map=SET_PARAM_BONDED_WITH_NOT_NONE, + param_name=param_name, + target_func=target_func, + ) + + def _is_param_bonded_with_not_none( + self, + param_map: dict, + param_name: str, + target_func: str, + ) -> bool: + """Check if the parameter is bonded one with not None value. + + Args: + param_map (dict): The parameter map storing the bond relationship. + param_name (str): The name of the parameter to be verified. + target_func (str): The target function impacted by this check. + """ + template = self._template + + def _not_none_checker(func: str, params_dict: dict): + for another_param in params_dict: + if template.default_args[CLAZZ_ARGS].get(another_param, None): + return True + if func == INIT: + continue + if template.default_args[FUNC_ARGS][func].get(another_param, None): + return True + return False + + return self._is_param_bonded( + param_map=param_map, + param_name=param_name, + target_func=target_func, + checker_func=_not_none_checker, + ) + + def _is_param_bonded_with_none( + self, + param_map: dict, + param_name: str, + target_func: str, + ) -> bool: + """Check if the parameter is bonded with another one with None value. + + Args: + param_map (dict): The parameter map storing the bond relationship. + param_name (str): The name of the parameter to be verified. + target_func (str): The target function impacted by this check. + """ + template = self._template + + def _none_checker(func: str, params_dict: dict): + for another_param in params_dict: + if template.default_args[CLAZZ_ARGS].get(another_param, "N/A") is None: + return True + if func == INIT: + continue + if template.default_args[FUNC_ARGS][func].get(another_param, "N/A") is None: + return True + return False + + return self._is_param_bonded( + param_map=param_map, + param_name=param_name, + target_func=target_func, + checker_func=_none_checker, + ) + + def _is_param_bonded( + self, + param_map: dict, + param_name: str, + target_func: str, + checker_func: callable, + ) -> bool: + """Check if the parameter has a specific bond relationship. + + Args: + param_map (dict): The parameter map storing the bond relationship. + param_name (str): The name of the parameter to be verified. + target_func (str): The target function impacted by this check. + checker_func (callable): The checker function to check the specific bond relationship. + """ + template = self._template + if template.clazz_type not in param_map: + return False + if target_func not in param_map[template.clazz_type]: + return False + params_dict = param_map[template.clazz_type][target_func][COMMON].get(param_name, {}) + if not params_dict: + if template.clazz.__name__ not in param_map[template.clazz_type][target_func]: + return False + params_dict = param_map[template.clazz_type][target_func][template.clazz.__name__].get( + param_name, {} + ) + return checker_func(target_func, params_dict) + + def _is_base_clz_param_excluded_in_subclz(self, clz_param_name: str) -> bool: + """Check if to skip testing with pipeline variables on class parameter due to exclusion. + + I.e. the base class parameter (clz_param_name) should not be replaced with pipeline variables, + as it's not used in the subclass. + + Args: + clz_param_name (str): The name of the class parameter, which is to be verified. + """ + template = self._template + if template.clazz_type not in BASE_CLASS_PARAMS_EXCLUDED_IN_SUB_CLASS: + return False + if ( + template.clazz.__name__ + not in BASE_CLASS_PARAMS_EXCLUDED_IN_SUB_CLASS[template.clazz_type] + ): + return False + return ( + clz_param_name + in BASE_CLASS_PARAMS_EXCLUDED_IN_SUB_CLASS[template.clazz_type][template.clazz.__name__] + ) + + def _is_overridden_class_param(self, clz_param_name: str, target_func: str) -> bool: + """Check if to skip testing with pipeline variables on class parameter due to override. + + I.e. the class parameter (clz_param_name) should not be replaced with pipeline variables + and tested on the target function (target_func) because it's overridden by a + function parameter with the same name. + e.g. image_uri in model.create can override that in model constructor. + + Args: + clz_param_name (str): The name of the class parameter, which is to be verified. + target_func (str): The target function impacted by the check. + """ + template = self._template + return template.default_args[FUNC_ARGS][target_func].get(clz_param_name, None) is not None + + def _is_clz_param_excluded_in_func(self, clz_param_name: str, target_func: str) -> bool: + """Check if to skip testing with pipeline variables on class parameter due to exclusion. + + I.e. the class parameter (clz_param_name) should not be replaced with pipeline variables + and tested on the target function (target_func), as it's not used there. + + Args: + clz_param_name (str): The name of the class parameter, which is to be verified. + target_func (str): The target function impacted by the check. + """ + return self._is_param_included( + param_map=CLASS_PARAMS_EXCLUDED_IN_TARGET_FUNC, + param_name=clz_param_name, + target_func=target_func, + ) + + def _is_param_should_not_be_none(self, param_name: str, target_func: str) -> bool: + """Check if to skip testing due to the parameter should not be None. + + I.e. the parameter (param_name) is set to None in this round but it is not allowed + according to the logic. Thus we can skip this round of test. + + Args: + param_name (str): The name of the parameter, which is to be verified regarding None value. + target_func (str): The target function impacted by this check. + """ + return self._is_param_included( + param_map=PARAMS_SHOULD_NOT_BE_NONE, + param_name=param_name, + target_func=target_func, + ) + + def _is_param_included( + self, + param_map: dict, + param_name: str, + target_func: str, + ) -> bool: + """Check if the parameter is included in a specific relationship. + + Args: + param_map (dict): The parameter map storing the specific relationship. + param_name (str): The name of the parameter to be verified. + target_func (str): The target function impacted by this check. + """ + template = self._template + if template.clazz_type not in param_map: + return False + if target_func not in param_map[template.clazz_type]: + return False + if param_name in param_map[template.clazz_type][target_func][COMMON]: + return True + if template.clazz.__name__ not in param_map[template.clazz_type][target_func]: + return False + return param_name in param_map[template.clazz_type][target_func][template.clazz.__name__] diff --git a/tests/unit/sagemaker/workflow/test_mechanism/test_code/test_pipeline_var_compatibility_template.py b/tests/unit/sagemaker/workflow/test_mechanism/test_code/test_pipeline_var_compatibility_template.py new file mode 100644 index 0000000000..ca64ad871a --- /dev/null +++ b/tests/unit/sagemaker/workflow/test_mechanism/test_code/test_pipeline_var_compatibility_template.py @@ -0,0 +1,593 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import json + +from random import getrandbits +from typing import Optional, List +from typing_extensions import get_origin + +from sagemaker import Model, PipelineModel, AlgorithmEstimator +from sagemaker.estimator import EstimatorBase +from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase +from sagemaker.mxnet import MXNet +from sagemaker.processing import Processor +from sagemaker.clarify import SageMakerClarifyProcessor +from sagemaker.pytorch import PyTorch +from sagemaker.tensorflow import TensorFlow +from sagemaker.transformer import Transformer +from sagemaker.tuner import HyperparameterTuner +from sagemaker.workflow.step_collections import StepCollection +from tests.unit.sagemaker.workflow.test_mechanism.test_code import ( + STEP_CLASS, + FIXED_ARGUMENTS, + STR_VAL, + CLAZZ_ARGS, + FUNC_ARGS, + REQUIRED, + OPTIONAL, + INIT, + TYPE, + DEFAULT_VALUE, + COMMON, + PipelineVariableEncoder, +) +from tests.unit.sagemaker.workflow.test_mechanism.test_code.utilities import ( + support_pipeline_variable, + get_param_dict, + generate_pipeline_vars_per_type, + clean_up_types, +) +from tests.unit.sagemaker.workflow.test_mechanism.test_code.parameter_skip_checker import ( + ParameterSkipChecker, +) + + +class PipelineVarCompatiTestTemplate: + """Check the compatibility between Pipeline variables and the given class, target method""" + + def __init__(self, clazz: type, default_args: dict): + """Initialize a `PipelineVarCompatiTestTemplate` instance. + + Args: + clazz (type): The class to test the compatibility. + default_args (dict): The given default arguments for the class and its target method. + """ + self.clazz = clazz + self.clazz_type = self._get_clazz_type() + self._target_funcs = self._get_target_functions() + self._clz_params = get_param_dict(clazz.__init__, clazz) + self._func_params = dict() + for func in self._target_funcs: + self._func_params[func.__name__] = get_param_dict(func) + self._set_and_restructure_default_args(default_args) + self._skip_param_checker = ParameterSkipChecker(self) + + def _set_and_restructure_default_args(self, default_args: dict): + """Set and restructure the default_args + + Restructure the default_args[FUNC_ARGS] if it's missing the layer of target function name + + Args: + default_args (dict): The given default arguments for the class and its target method. + """ + self.default_args = default_args + # restructure the default_args[FUNC_ARGS] if it's missing the layer of target function name + if len(self._target_funcs) == 1: + target_func_name = self._target_funcs[0].__name__ + if target_func_name not in default_args[FUNC_ARGS]: + args = self.default_args.pop(FUNC_ARGS) + self.default_args[FUNC_ARGS] = dict() + self.default_args[FUNC_ARGS][target_func_name] = args + + self._check_or_fill_in_args( + params={**self._clz_params[REQUIRED], **self._clz_params[OPTIONAL]}, + default_args=self.default_args[CLAZZ_ARGS], + ) + for func in self._target_funcs: + func_name = func.__name__ + self._check_or_fill_in_args( + params={ + **self._func_params[func_name][REQUIRED], + **self._func_params[func_name][OPTIONAL], + }, + default_args=self.default_args[FUNC_ARGS][func_name], + ) + + def _get_clazz_type(self) -> str: + """Get the type (in str) of the downstream class""" + if issubclass(self.clazz, Processor): + return "processor" + if issubclass(self.clazz, EstimatorBase): + return "estimator" + if issubclass(self.clazz, Transformer): + return "transformer" + if issubclass(self.clazz, HyperparameterTuner): + return "tuner" + if issubclass(self.clazz, Model): + return "model" + if issubclass(self.clazz, PipelineModel): + return "pipelinemodel" + raise TypeError(f"Unsupported downstream class: {self.clazz}") + + def check_compatibility(self): + """The entry to check the compatibility""" + print( + "Starting to check Pipeline variable compatibility for class (%s) and target methods (%s)\n" + % (self.clazz.__name__, [func.__name__ for func in self._target_funcs]) + ) + + # Check the case when all args are assigned not-None values + print("## Starting to check the compatibility when all optional args are not None ##") + self._iterate_params_to_check_compatibility() + + # Check the case when one of the optional arg is None + print( + "## Starting to check the compatibility when one of the optional arg is None in each round ##" + ) + self._iterate_optional_params_to_check_compatibility() + + def _iterate_params_to_check_compatibility( + self, + param_with_none: Optional[str] = None, + test_func_for_none: Optional[str] = None, + ): + """Iterate each parameter and assign a pipeline var to it to test compatibility + + Args: + param_with_none (str): The name of the parameter with None value. + test_func_for_none (str): The name of the function which is being tested by + replacing optional parameters to None. + """ + self._iterate_clz_params_to_check_compatibility(param_with_none, test_func_for_none) + self._iterate_func_params_to_check_compatibility(param_with_none, test_func_for_none) + + def _iterate_optional_params_to_check_compatibility(self): + """Iterate each optional parameter and set it to none to test compatibility""" + self._iterate_class_optional_params() + self._iterate_func_optional_params() + + def _iterate_class_optional_params(self): + """Iterate each optional parameter in class __init__ and check compatibility""" + print("### Starting to iterate optional parameters in class __init__") + self._iterate_optional_params( + optional_params=self._clz_params[OPTIONAL], + default_args=self.default_args[CLAZZ_ARGS], + ) + + def _iterate_func_optional_params(self): + """Iterate each function parameter and check compatibility""" + for func in self._target_funcs: + print(f"### Starting to iterate optional parameters in function {func.__name__}") + self._iterate_optional_params( + optional_params=self._func_params[func.__name__][OPTIONAL], + default_args=self.default_args[FUNC_ARGS][func.__name__], + test_func_for_none=func.__name__, + ) + + def _iterate_optional_params( + self, + optional_params: dict, + default_args: dict, + test_func_for_none: Optional[str] = None, + ): + """Iterate each optional parameter and check compatibility + Args: + optional_params (dict): The dict containing the optional parameters of a class or method. + default_args (dict): The dict containing the default arguments of a class or method. + test_func_for_none (str): The name of the function which is being tested by + replacing optional parameters to None. + """ + for param_name in optional_params.keys(): + if self._skip_param_checker.skip_setting_param_to_none(param_name, INIT): + continue + origin_val = default_args[param_name] + default_args[param_name] = None + print("=== Parameter (%s) is None in this round ===" % param_name) + self._iterate_params_to_check_compatibility(param_name, test_func_for_none) + default_args[param_name] = origin_val + + def _iterate_clz_params_to_check_compatibility( + self, + param_with_none: Optional[str] = None, + test_func_for_none: Optional[str] = None, + ): + """Iterate each class parameter and assign a pipeline var to it to test compatibility + + Args: + param_with_none (str): The name of the parameter with None value. + test_func_for_none (str): The name of the function which is being tested by + replacing optional parameters to None. + """ + print( + f"#### Iterating parameters (with PipelineVariable annotation) " + f"in class {self.clazz.__name__} __init__ function" + ) + clz_params = {**self._clz_params[REQUIRED], **self._clz_params[OPTIONAL]} + # Iterate through each default arg + for clz_param_name, clz_default_arg in self.default_args[CLAZZ_ARGS].items(): + if clz_param_name == param_with_none: + continue + clz_param_type = clz_params[clz_param_name][TYPE] + if not support_pipeline_variable(clz_param_type): + continue + if self._skip_param_checker.skip_setting_clz_param_to_ppl_var( + clz_param=clz_param_name, target_func=INIT + ): + continue + + # For each arg which supports pipeline variables, + # Replace it with each one of generated pipeline variables + ppl_vars = generate_pipeline_vars_per_type(clz_param_name, clz_param_type) + for clz_ppl_var, expected_clz_expr in ppl_vars: + self.default_args[CLAZZ_ARGS][clz_param_name] = clz_ppl_var + obj = self.clazz(**self.default_args[CLAZZ_ARGS]) + for func in self._target_funcs: + func_name = func.__name__ + if test_func_for_none and test_func_for_none != func_name: + # Iterating optional parameters of a specific target function + # (test_func_for_none), which does not impact other target functions, + # so we can skip them + continue + if self._skip_param_checker._need_set_param_bonded_with_none( + param_with_none, func_name + ): # TODO: add to a public method + continue + if self._skip_param_checker.skip_setting_clz_param_to_ppl_var( + clz_param=clz_param_name, + target_func=func_name, + param_with_none=param_with_none, + ): + continue + # print( + # "Replacing class init arg (%s) with pipeline variable which is expected " + # "to be (%s). Testing with target function (%s)" + # % (clz_param_name, expected_clz_expr, func_name) + # ) + self._generate_and_verify_step_definition( + target_func=getattr(obj, func_name), + expected_expr=expected_clz_expr, + param_with_none=param_with_none, + ) + + # print("============================\n") + self.default_args[CLAZZ_ARGS][clz_param_name] = clz_default_arg + + def _iterate_func_params_to_check_compatibility( + self, + param_with_none: Optional[str] = None, + test_func_for_none: Optional[str] = None, + ): + """Iterate each target func parameter and assign a pipeline var to it + + Args: + param_with_none (str): The name of the parameter with None value. + test_func_for_none (str): The name of the function which is being tested by + replacing optional parameters to None. + """ + obj = self.clazz(**self.default_args[CLAZZ_ARGS]) + for func in self._target_funcs: + func_name = func.__name__ + if test_func_for_none and test_func_for_none != func_name: + # Iterating optional parameters of a specific target function (test_func_for_none), + # which does not impact other target functions, so we can skip them + continue + if self._skip_param_checker._need_set_param_bonded_with_none( + param_with_none, func_name + ): # TODO: add to a public method + continue + print( + f"#### Iterating parameters (with PipelineVariable annotation) in target function: {func_name}" + ) + func_params = { + **self._func_params[func_name][REQUIRED], + **self._func_params[func_name][OPTIONAL], + } + for func_param_name, func_default_arg in self.default_args[FUNC_ARGS][ + func_name + ].items(): + if func_param_name == param_with_none: + continue + if not support_pipeline_variable(func_params[func_param_name][TYPE]): + continue + if self._skip_param_checker.skip_setting_func_param_to_ppl_var( + func_param=func_param_name, + target_func=func_name, + param_with_none=param_with_none, + ): + continue + # For each arg which supports pipeline variables, + # Replace it with each one of generated pipeline variables + ppl_vars = generate_pipeline_vars_per_type( + func_param_name, func_params[func_param_name][TYPE] + ) + for func_ppl_var, expected_func_expr in ppl_vars: + # print( + # "Replacing func arg (%s) with pipeline variable which is expected to be (%s)" + # % (func_param_name, expected_func_expr) + # ) + self.default_args[FUNC_ARGS][func_name][func_param_name] = func_ppl_var + self._generate_and_verify_step_definition( + target_func=getattr(obj, func_name), + expected_expr=expected_func_expr, + param_with_none=param_with_none, + ) + + self.default_args[FUNC_ARGS][func_name][func_param_name] = func_default_arg + # print("-------------------------\n") + + def _generate_and_verify_step_definition( + self, + target_func: callable, + expected_expr: dict, + param_with_none: str, + ): + """Generate a pipeline and verify the pipeline definition + + Args: + target_func (callable): The function to generate step_args. + expected_expr (dict): The expected json expression of a class or method argument. + param_with_none (str): The name of the parameter with None value. + """ + args = dict( + name="MyStep", + step_args=target_func(**self.default_args[FUNC_ARGS][target_func.__name__]), + ) + step = STEP_CLASS[self.clazz_type](**args) + if isinstance(step, StepCollection): + request_dicts = step.request_dicts() + else: + request_dicts = [step.to_request()] + + step_dsl = json.dumps(request_dicts, cls=PipelineVariableEncoder) + step_dsl_obj = json.loads(step_dsl) + exp_origin = json.dumps(expected_expr["origin"]) + exp_to_str = json.dumps(expected_expr["to_string"]) + # if the testing arg is a dict, we may need to remove the outer {} of its expected expr + # to compare, since for HyperParameters, some other arguments are auto inserted to the dict + assert ( + exp_origin in step_dsl + or exp_to_str in step_dsl + or exp_origin[1:-1] in step_dsl + or exp_to_str[1:-1] in step_dsl + ) + self._verify_composite_object_against_pipeline_var(param_with_none, step_dsl, step_dsl_obj) + + def _verify_composite_object_against_pipeline_var( + self, + param_with_none: str, + step_dsl: str, + step_dsl_obj: List[dict], + ): + """verify pipeline definition regarding composite objects against pipeline variables + + Args: + param_with_none (str): The name of the parameter with None value. + step_dsl (str): The step definition retrieved from the pipeline definition DSL. + step_dsl_obj (List[dict]): The json load object of the step definition. + """ + # TODO: remove the following hard code assertion once recursive assignment is added + if issubclass(self.clazz, Processor): + if param_with_none != "network_config": + assert '{"Get": "Parameters.nw_cfg_subnets"}' in step_dsl + assert '{"Get": "Parameters.nw_cfg_security_group_ids"}' in step_dsl + assert '{"Get": "Parameters.nw_cfg_enable_nw_isolation"}' in step_dsl + if issubclass(self.clazz, SageMakerClarifyProcessor): + if param_with_none != "data_config": + assert '{"Get": "Parameters.clarify_processor_input"}' in step_dsl + assert '{"Get": "Parameters.clarify_processor_output"}' in step_dsl + else: + if param_with_none != "outputs": + assert '{"Get": "Parameters.proc_output_source"}' in step_dsl + assert '{"Get": "Parameters.proc_output_dest"}' in step_dsl + assert '{"Get": "Parameters.proc_output_app_managed"}' in step_dsl + if param_with_none != "inputs": + assert '{"Get": "Parameters.proc_input_source"}' in step_dsl + assert '{"Get": "Parameters.proc_input_dest"}' in step_dsl + assert '{"Get": "Parameters.proc_input_s3_data_type"}' in step_dsl + assert '{"Get": "Parameters.proc_input_app_managed"}' in step_dsl + elif issubclass(self.clazz, EstimatorBase): + if ( + param_with_none != "instance_groups" + and self.default_args[CLAZZ_ARGS]["instance_groups"] + ): + assert '{"Get": "Parameters.instance_group_name"}' in step_dsl + assert '{"Get": "Parameters.instance_group_instance_count"}' in step_dsl + if issubclass(self.clazz, AmazonAlgorithmEstimatorBase): + # AmazonAlgorithmEstimatorBase's input is records + if param_with_none != "records": + assert '{"Get": "Parameters.records_s3_data"}' in step_dsl + assert '{"Get": "Parameters.records_s3_data_type"}' in step_dsl + assert '{"Get": "Parameters.records_channel"}' in step_dsl + else: + if param_with_none != "inputs": + assert '{"Get": "Parameters.train_inputs_s3_data"}' in step_dsl + assert '{"Get": "Parameters.train_inputs_distribution"}' in step_dsl + assert '{"Get": "Parameters.train_inputs_compression"}' in step_dsl + assert '{"Get": "Parameters.train_inputs_content_type"}' in step_dsl + assert '{"Get": "Parameters.train_inputs_record_wrapping"}' in step_dsl + assert '{"Get": "Parameters.train_inputs_s3_data_type"}' in step_dsl + assert '{"Get": "Parameters.train_inputs_input_mode"}' in step_dsl + assert '{"Get": "Parameters.train_inputs_attribute_name"}' in step_dsl + assert '{"Get": "Parameters.train_inputs_target_attr_name"}' in step_dsl + assert '{"Get": "Parameters.train_inputs_instance_groups"}' in step_dsl + if not issubclass(self.clazz, (TensorFlow, MXNet, PyTorch, AlgorithmEstimator)): + # debugger_hook_config may be disabled for these first 3 frameworks + # AlgorithmEstimator ignores the kwargs + if param_with_none != "debugger_hook_config": + assert '{"Get": "Parameters.debugger_hook_s3_output"}' in step_dsl + assert '{"Get": "Parameters.debugger_container_output"}' in step_dsl + assert '{"Get": "Parameters.debugger_hook_param"}' in step_dsl + assert '{"Get": "Parameters.debugger_collections_name"}' in step_dsl + assert '{"Get": "Parameters.debugger_collections_param"}' in step_dsl + if not issubclass(self.clazz, AlgorithmEstimator): + # AlgorithmEstimator ignores the kwargs + if param_with_none != "profiler_config": + assert '{"Get": "Parameters.profile_config_s3_output_path"}' in step_dsl + assert '{"Get": "Parameters.profile_config_system_monitor"}' in step_dsl + if param_with_none != "tensorboard_output_config": + assert '{"Get": "Parameters.tensorboard_s3_output"}' in step_dsl + assert '{"Get": "Parameters.tensorboard_container_output"}' in step_dsl + if param_with_none != "rules": + assert '{"Get": "Parameters.rules_image_uri"}' in step_dsl + assert '{"Get": "Parameters.rules_instance_type"}' in step_dsl + assert '{"Get": "Parameters.rules_volume_size"}' in step_dsl + assert '{"Get": "Parameters.rules_to_invoke"}' in step_dsl + assert '{"Get": "Parameters.rules_local_output"}' in step_dsl + assert '{"Get": "Parameters.rules_to_s3_output_path"}' in step_dsl + assert '{"Get": "Parameters.rules_other_s3_input"}' in step_dsl + assert '{"Get": "Parameters.rules_param"}' in step_dsl + if not issubclass(self.clazz, (TensorFlow, MXNet, PyTorch)): + # The collections_to_save is added to debugger rules, + # which may be disabled for some frameworks + assert '{"Get": "Parameters.rules_collections_name"}' in step_dsl + assert '{"Get": "Parameters.rules_collections_param"}' in step_dsl + elif issubclass(self.clazz, HyperparameterTuner): + if param_with_none != "inputs": + assert '{"Get": "Parameters.inputs_estimator_1"}' in step_dsl + if param_with_none != "warm_start_config": + assert '{"Get": "Parameters.warm_start_cfg_parent"}' in step_dsl + if param_with_none != "hyperparameter_ranges": + assert ( + json.dumps( + { + "Std:Join": { + "On": "", + "Values": [{"Get": "Parameters.hyper_range_min_value"}], + } + } + ) + in step_dsl + ) + assert ( + json.dumps( + { + "Std:Join": { + "On": "", + "Values": [{"Get": "Parameters.hyper_range_max_value"}], + } + } + ) + in step_dsl + ) + assert '{"Get": "Parameters.hyper_range_scaling_type"}' in step_dsl + elif issubclass(self.clazz, (Model, PipelineModel)): + if step_dsl_obj[-1]["Type"] == "Model": + return + if param_with_none != "model_metrics": + assert '{"Get": "Parameters.model_statistics_content_type"}' in step_dsl + assert '{"Get": "Parameters.model_statistics_s3_uri"}' in step_dsl + assert '{"Get": "Parameters.model_statistics_content_digest"}' in step_dsl + if param_with_none != "metadata_properties": + assert '{"Get": "Parameters.meta_properties_commit_id"}' in step_dsl + assert '{"Get": "Parameters.meta_properties_repository"}' in step_dsl + assert '{"Get": "Parameters.meta_properties_generated_by"}' in step_dsl + assert '{"Get": "Parameters.meta_properties_project_id"}' in step_dsl + if param_with_none != "drift_check_baselines": + assert '{"Get": "Parameters.drift_constraints_content_type"}' in step_dsl + assert '{"Get": "Parameters.drift_constraints_s3_uri"}' in step_dsl + assert '{"Get": "Parameters.drift_constraints_content_digest"}' in step_dsl + assert '{"Get": "Parameters.drift_bias_content_type"}' in step_dsl + assert '{"Get": "Parameters.drift_bias_s3_uri"}' in step_dsl + assert '{"Get": "Parameters.drift_bias_content_digest"}' in step_dsl + + def _get_non_pipeline_val(self, n: str, t: type) -> object: + """Get the value (not a Pipeline variable) based on parameter type and name + + Args: + n (str): The parameter name. If a parameter has a pre-defined value, + it will be returned directly. + t (type): The parameter type. If a parameter does not have a pre-defined value, + an arg will be auto-generated based on the type. + + Return: + object: A Python primitive value is returned. + """ + if n in FIXED_ARGUMENTS[COMMON]: + return FIXED_ARGUMENTS[COMMON][n] + if n in FIXED_ARGUMENTS[self.clazz_type]: + return FIXED_ARGUMENTS[self.clazz_type][n] + if t is str: + return STR_VAL + if t is int: + return 1 + if t is float: + return 1e-4 + if t is bool: + return bool(getrandbits(1)) + if t in [list, tuple, dict, set]: + return t() + + raise TypeError(f"Unable to parse type: {t}.") + + def _check_or_fill_in_args(self, params: dict, default_args: dict): + """Check if every args are provided and not None + + Otherwise fill in with some default values + + Args: + params (dict): The dict indicating the type of each parameter. + default_args (dict): The dict of args to be checked or filled in. + """ + for param_name, value in params.items(): + if param_name in default_args: + # User specified the default value + continue + if value[DEFAULT_VALUE]: + # The parameter has default value in method definition + default_args[param_name] = value[DEFAULT_VALUE] + continue + clean_type = clean_up_types(value[TYPE]) + origin_type = get_origin(clean_type) + if origin_type is None: + default_args[param_name] = self._get_non_pipeline_val(param_name, clean_type) + else: + default_args[param_name] = self._get_non_pipeline_val(param_name, origin_type) + + self._check_or_update_default_args(default_args) + + def _check_or_update_default_args(self, default_args: dict): + """To check if the default args are valid and update them if not + + Args: + default_args (dict): The dict of args to be checked or updated. + """ + if issubclass(self.clazz, EstimatorBase): + if "disable_profiler" in default_args and default_args["disable_profiler"] is True: + default_args["profiler_config"] = None + + def _get_target_functions(self) -> list: + """Fetch the target functions based on class + + Return: + list: The list of target functions is returned. + """ + if issubclass(self.clazz, Processor): + if issubclass(self.clazz, SageMakerClarifyProcessor): + return [ + self.clazz.run_pre_training_bias, + self.clazz.run_post_training_bias, + self.clazz.run_bias, + self.clazz.run_explainability, + ] + return [self.clazz.run] + if issubclass(self.clazz, EstimatorBase): + return [self.clazz.fit] + if issubclass(self.clazz, Transformer): + return [self.clazz.transform] + if issubclass(self.clazz, HyperparameterTuner): + return [self.clazz.fit] + if issubclass(self.clazz, (Model, PipelineModel)): + return [self.clazz.register, self.clazz.create] + raise TypeError(f"Unable to get target function for class {self.clazz}") diff --git a/tests/unit/sagemaker/workflow/test_mechanism/test_code/utilities.py b/tests/unit/sagemaker/workflow/test_mechanism/test_code/utilities.py new file mode 100644 index 0000000000..4c464a3985 --- /dev/null +++ b/tests/unit/sagemaker/workflow/test_mechanism/test_code/utilities.py @@ -0,0 +1,267 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import os +from inspect import signature + +from sagemaker import Model +from sagemaker.estimator import EstimatorBase +from sagemaker.fw_utils import UploadedCode +from sagemaker.workflow import is_pipeline_variable +from sagemaker.workflow.entities import PipelineVariable +from typing import Union +from typing_extensions import get_args, get_origin +from tests.unit.sagemaker.workflow.test_mechanism.test_code import ( + PIPELINE_VARIABLES, + REQUIRED, + OPTIONAL, + DEFAULT_VALUE, + IMAGE_URI, +) + + +def support_pipeline_variable(t: type) -> bool: + """Check if pipeline variable is supported by a parameter according to its type + + Args: + t (type): The type to be checked + + Return: + bool: True if it supports. False otherwise. + """ + return "PipelineVariable" in str(t) + + +def get_param_dict(func, clazz=None) -> dict: + """Get a parameter dict of a given function + + The parameter dict indicates if a parameter is required or not, as well as its type. + + Arg: + func (function): A class constructor method or other class methods. + clazz (type): The corresponding class whose method is passed in. + + Return: + dict: A parameter dict is returned. + """ + params = list() + params.append(signature(func)) + if func.__name__ == "__init__" and issubclass(clazz, (EstimatorBase, Model)): + # Go through all parent classes constructor function to get the entire parameters since + # estimator and model classes use **kwargs for parameters defined in parent classes + # The leaf class's parameters should be on top of the params list and have high priority + _get_params_from_parent_class_constructors(clazz, params) + + params_dict = dict( + required=dict(), + optional=dict(), + ) + for param in params: + for param_val in param.parameters.values(): + if param_val.annotation is param_val.empty: + continue + val = dict(type=param_val.annotation, default_value=None) + if param_val.name == "sagemaker_session": + # Treat sagemaker_session as required as it must be a PipelineSession obj + if not _is_in_params_dict(param_val.name, params_dict): + params_dict[REQUIRED][param_val.name] = val + elif param_val.default is param_val.empty or param_val.default is not None: + if not _is_in_params_dict(param_val.name, params_dict): + # Some parameters e.g. entry_point in TensorFlow appears as both required (in Framework) + # and optional (in EstimatorBase) parameter. The annotation defined in the + # class node (i.e. Framework) which is closer to the leaf class (TensorFlow) should win. + if param_val.default is not param_val.empty: + val[DEFAULT_VALUE] = param_val.default + params_dict[REQUIRED][param_val.name] = val + else: + if not _is_in_params_dict(param_val.name, params_dict): + params_dict[OPTIONAL][param_val.name] = val + return params_dict + + +def _is_in_params_dict(param_name: str, params_dict: dict): + """To check if the parameter is in the parameter dict + + Args: + param_name (str): The name of the parameter to be checked + params_dict (dict): The parameter dict among which to check if the param_name exists + """ + return param_name in params_dict[REQUIRED] or param_name in params_dict[OPTIONAL] + + +def _get_params_from_parent_class_constructors(clazz: type, params: list): + """Get constructor parameters from parent class + + Args: + clazz (type): The downstream class to collect parameters from all its parent constructors + params (list): The list to collect all parameters + """ + while clazz.__name__ not in {"EstimatorBase", "Model"}: + parent_class = clazz.__base__ + params.append(signature(parent_class.__init__)) + clazz = parent_class + + +def generate_pipeline_vars_per_type( + param_name: str, + param_type: type, +) -> list: + """Provide a list of possible PipelineVariable objects. + + For example, if type_hint is Union[str, PipelineVariable], + return [ParameterString, Properties, JsonGet, Join, ExecutionVariable] + + Args: + param_name (str): The name of the parameter to generate the pipeline variable list. + param_type (type): The type of the parameter to generate the pipeline variable list. + + Return: + list: A list of possible PipelineVariable objects are returned. + """ + # verify if params allow pipeline variables + if "PipelineVariable" not in str(param_type): + raise TypeError(("The type: %s does not support PipelineVariable.", param_type)) + + types = get_args(param_type) + # e.g. Union[str, PipelineVariable] or Union[str, PipelineVariable, NoneType] + if PipelineVariable in types: + # PipelineVariable corresponds to Python Primitive types + # i.e. str, int, float, bool + ppl_var = _get_pipeline_var(types=types) + return ppl_var + + # e.g. Union[List[...], NoneType] or Union[Dict[...], NoneType] etc. + clean_type = clean_up_types(param_type) + origin_type = get_origin(clean_type) + if origin_type not in [list, dict, set, tuple]: + raise TypeError(f"Unsupported type: {param_type} for param: {param_name}") + sub_types = get_args(clean_type) + + # e.g. List[...], Tuple[...], Set[...] + if origin_type in [list, tuple, set]: + ppl_var_list = generate_pipeline_vars_per_type(param_name, sub_types[0]) + return [ + ( + origin_type([var]), + dict( + origin=origin_type([expected["origin"]]), + to_string=origin_type([expected["to_string"]]), + ), + ) + for var, expected in ppl_var_list + ] + + # e.g. Dict[...] + if origin_type is dict: + key_type = sub_types[0] + if key_type is not str: + raise TypeError( + f"Unsupported type: {key_type} for dict key in {param_name} of {param_type} type" + ) + ppl_var_list = generate_pipeline_vars_per_type(param_name, sub_types[1]) + return [ + ( + dict(MyKey=var), + dict( + origin=dict(MyKey=expected["origin"]), + to_string=dict(MyKey=expected["to_string"]), + ), + ) + for var, expected in ppl_var_list + ] + return list() + + +def clean_up_types(t: type) -> type: + """Clean up the Union type and return the first subtype (not a NoneType) of it + + For example for Union[str, int, NoneType], it will return str + + Args: + t (type): The type of a parameter to be cleaned up. + + Return: + type: The cleaned up type is returned. + """ + if get_origin(t) == Union: + types = get_args(t) + return list(filter(lambda t: "NoneType" not in str(t), types))[0] + return t + + +def _get_pipeline_var(types: tuple) -> list: + """Get a Pipeline variable based on one kind of the parameter types. + + Args: + types (tuple): The possible types of a parameter. + + Return: + list: a list of possible PipelineVariable objects are returned + """ + if str in types: + return PIPELINE_VARIABLES["str"] + if int in types: + return PIPELINE_VARIABLES["int"] + if float in types: + return PIPELINE_VARIABLES["float"] + if bool in types: + return PIPELINE_VARIABLES["bool"] + raise TypeError(f"Unable to parse types: {types}.") + + +def mock_tar_and_upload_dir( + session, + bucket, + s3_key_prefix, + script, + directory=None, + dependencies=None, + kms_key=None, + s3_resource=None, + settings=None, +): + """Briefly mock the behavior of tar_and_upload_dir""" + if directory and (is_pipeline_variable(directory) or directory.lower().startswith("s3://")): + return UploadedCode(s3_prefix=directory, script_name=script) + script_name = script if directory else os.path.basename(script) + key = "%s/sourcedir.tar.gz" % s3_key_prefix + return UploadedCode(s3_prefix="s3://%s/%s" % (bucket, key), script_name=script_name) + + +def mock_image_uris_retrieve( + framework, + region, + version=None, + py_version=None, + instance_type=None, + accelerator_type=None, + image_scope=None, + container_version=None, + distribution=None, + base_framework_version=None, + training_compiler_config=None, + model_id=None, + model_version=None, + tolerate_vulnerable_model=False, + tolerate_deprecated_model=False, + sdk_version=None, + inference_tool=None, + serverless_inference_config=None, +) -> str: + """Briefly mock the behavior of image_uris.retrieve""" + args = dict(locals()) + for name, val in args.items(): + if is_pipeline_variable(val): + raise ValueError("%s should not be a pipeline variable (%s)" % (name, type(val))) + return IMAGE_URI diff --git a/tests/unit/sagemaker/workflow/test_mechanism/test_entries/test_pipeline_var_compatibility_with_estimators.py b/tests/unit/sagemaker/workflow/test_mechanism/test_entries/test_pipeline_var_compatibility_with_estimators.py new file mode 100644 index 0000000000..f749118b6a --- /dev/null +++ b/tests/unit/sagemaker/workflow/test_mechanism/test_entries/test_pipeline_var_compatibility_with_estimators.py @@ -0,0 +1,439 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from random import getrandbits + +from unittest.mock import patch, MagicMock + +from sagemaker.estimator import Estimator +from sagemaker.huggingface import TrainingCompilerConfig as HF_TrainingCompilerConfig +from sagemaker.tensorflow import TensorFlow +from sagemaker.amazon.lda import LDA +from sagemaker.xgboost.estimator import XGBoost +from sagemaker.rl.estimator import RLEstimator +from sagemaker.amazon.pca import PCA +from sagemaker.mxnet.estimator import MXNet +from sagemaker.amazon.randomcutforest import RandomCutForest +from sagemaker.amazon.factorization_machines import FactorizationMachines +from sagemaker.algorithm import AlgorithmEstimator +from sagemaker.sklearn.estimator import SKLearn +from sagemaker.amazon.ipinsights import IPInsights +from sagemaker.huggingface.estimator import HuggingFace +from sagemaker.tensorflow.estimator import TrainingCompilerConfig as TF_TrainingCompilerConfig +from sagemaker.amazon.ntm import NTM +from sagemaker.pytorch import PyTorch +from sagemaker.chainer import Chainer +from sagemaker.amazon.linear_learner import LinearLearner +from sagemaker.amazon.knn import KNN +from tests.unit.sagemaker.workflow.test_mechanism.test_code.test_pipeline_var_compatibility_template import ( + PipelineVarCompatiTestTemplate, +) +from sagemaker.amazon.object2vec import Object2Vec +from sagemaker.amazon.kmeans import KMeans +from tests.unit.sagemaker.workflow.test_mechanism.test_code import IMAGE_URI, MockProperties +from tests.unit.sagemaker.workflow.test_mechanism.test_code.utilities import ( + mock_tar_and_upload_dir, + mock_image_uris_retrieve, +) + +_IS_TRUE = bool(getrandbits(1)) + + +# These tests provide the incomplete default arg dict +# within which some class or target func parameters are missing. +# The test template will fill in those missing args +# Note: the default args should not include PipelineVariable objects +@patch( + "sagemaker.workflow.steps.Properties", + MagicMock(return_value=MockProperties(step_name="MyStep")), +) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +def test_estimator_compatibility(): + default_args = dict( + clazz_args=dict(), + func_args=dict(), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=Estimator, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch( + "sagemaker.workflow.steps.Properties", + MagicMock(return_value=MockProperties(step_name="MyStep")), +) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +def test_tensorflow_estimator_compatibility(): + default_args = dict( + clazz_args=dict( + compiler_config=TF_TrainingCompilerConfig() if _IS_TRUE else None, + image_uri=IMAGE_URI if not _IS_TRUE else None, + instance_type="ml.p3.2xlarge", + framework_version="2.9", + py_version="py39", + ), + func_args=dict(), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=TensorFlow, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch( + "sagemaker.workflow.steps.Properties", + MagicMock(return_value=MockProperties(step_name="MyStep")), +) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +def test_lda_estimator_compatibility(): + default_args = dict( + clazz_args=dict( + instance_count=1, + ), + func_args=dict( + mini_batch_size=128, + ), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=LDA, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch( + "sagemaker.workflow.steps.Properties", + MagicMock(return_value=MockProperties(step_name="MyStep")), +) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +def test_pca_estimator_compatibility(): + default_args = dict( + clazz_args=dict(), + func_args=dict(), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=PCA, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch( + "sagemaker.workflow.steps.Properties", + MagicMock(return_value=MockProperties(step_name="MyStep")), +) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +def test_mxnet_estimator_compatibility(): + default_args = dict( + clazz_args=dict(framework_version="1.4.0"), + func_args=dict(), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=MXNet, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch( + "sagemaker.workflow.steps.Properties", + MagicMock(return_value=MockProperties(step_name="MyStep")), +) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +def test_rl_estimator_compatibility(): + default_args = dict( + clazz_args=dict(), + func_args=dict(), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=RLEstimator, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch( + "sagemaker.workflow.steps.Properties", + MagicMock(return_value=MockProperties(step_name="MyStep")), +) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +def test_ntm_estimator_compatibility(): + default_args = dict( + clazz_args=dict(clip_gradient=1e-3), + func_args=dict(), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=NTM, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch( + "sagemaker.workflow.steps.Properties", + MagicMock(return_value=MockProperties(step_name="MyStep")), +) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +def test_rcf_estimator_compatibility(): + default_args = dict( + clazz_args=dict(), + func_args=dict(), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=RandomCutForest, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch( + "sagemaker.workflow.steps.Properties", + MagicMock(return_value=MockProperties(step_name="MyStep")), +) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +def test_xgboost_estimator_compatibility(): + default_args = dict( + clazz_args=dict(framework_version="1.2-1"), + func_args=dict(), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=XGBoost, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch( + "sagemaker.workflow.steps.Properties", + MagicMock(return_value=MockProperties(step_name="MyStep")), +) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +def test_object2vec_estimator_compatibility(): + default_args = dict( + clazz_args=dict(), + func_args=dict(), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=Object2Vec, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch( + "sagemaker.workflow.steps.Properties", + MagicMock(return_value=MockProperties(step_name="MyStep")), +) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +def test_fm_estimator_compatibility(): + default_args = dict( + clazz_args=dict(), + func_args=dict(), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=FactorizationMachines, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch( + "sagemaker.workflow.steps.Properties", + MagicMock(return_value=MockProperties(step_name="MyStep")), +) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +def test_ae_estimator_compatibility(): + default_args = dict( + clazz_args=dict(), + func_args=dict(), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=AlgorithmEstimator, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch( + "sagemaker.workflow.steps.Properties", + MagicMock(return_value=MockProperties(step_name="MyStep")), +) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +def test_sklearn_estimator_compatibility(): + default_args = dict( + clazz_args=dict( + py_version="py3", + instance_count=1, + framework_version="0.20.0", + ), + func_args=dict(), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=SKLearn, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch( + "sagemaker.workflow.steps.Properties", + MagicMock(return_value=MockProperties(step_name="MyStep")), +) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +def test_ipinsights_estimator_compatibility(): + default_args = dict( + clazz_args=dict(), + func_args=dict(), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=IPInsights, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch( + "sagemaker.workflow.steps.Properties", + MagicMock(return_value=MockProperties(step_name="MyStep")), +) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +def test_huggingface_estimator_compatibility(): + default_args = dict( + clazz_args=dict( + instance_type="ml.p3.2xlarge", + transformers_version="4.11", + tensorflow_version="2.5" if _IS_TRUE else None, + pytorch_version="1.9" if not _IS_TRUE else None, + compiler_config=HF_TrainingCompilerConfig() if _IS_TRUE else None, + image_uri=IMAGE_URI if not _IS_TRUE else None, + py_version="py37" if _IS_TRUE else "py38", + ), + func_args=dict(), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=HuggingFace, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch( + "sagemaker.workflow.steps.Properties", + MagicMock(return_value=MockProperties(step_name="MyStep")), +) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +def test_pytorch_estimator_compatibility(): + default_args = dict( + clazz_args=dict(framework_version="1.8.0", py_version="py3"), + func_args=dict(), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=PyTorch, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch( + "sagemaker.workflow.steps.Properties", + MagicMock(return_value=MockProperties(step_name="MyStep")), +) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +def test_kmeans_estimator_compatibility(): + default_args = dict( + clazz_args=dict(), + func_args=dict(), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=KMeans, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch( + "sagemaker.workflow.steps.Properties", + MagicMock(return_value=MockProperties(step_name="MyStep")), +) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +def test_knn_estimator_compatibility(): + default_args = dict( + clazz_args=dict( + dimension_reduction_target=6, + dimension_reduction_type="sign", + ), + func_args=dict(), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=KNN, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch( + "sagemaker.workflow.steps.Properties", + MagicMock(return_value=MockProperties(step_name="MyStep")), +) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +def test_chainer_estimator_compatibility(): + default_args = dict( + clazz_args=dict(framework_version="4.0"), + func_args=dict(), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=Chainer, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch( + "sagemaker.workflow.steps.Properties", + MagicMock(return_value=MockProperties(step_name="MyStep")), +) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +def test_ll_estimator_compatibility(): + default_args = dict( + clazz_args=dict(init_method="normal"), + func_args=dict(), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=LinearLearner, + default_args=default_args, + ) + test_template.check_compatibility() diff --git a/tests/unit/sagemaker/workflow/test_mechanism/test_entries/test_pipeline_var_compatibility_with_models.py b/tests/unit/sagemaker/workflow/test_mechanism/test_entries/test_pipeline_var_compatibility_with_models.py new file mode 100644 index 0000000000..611e415f2e --- /dev/null +++ b/tests/unit/sagemaker/workflow/test_mechanism/test_entries/test_pipeline_var_compatibility_with_models.py @@ -0,0 +1,550 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from random import getrandbits +from unittest.mock import patch, MagicMock + +from sagemaker import ( + Model, + KNNModel, + KMeansModel, + PCAModel, + LDAModel, + NTMModel, + Object2VecModel, + FactorizationMachinesModel, + IPInsightsModel, + RandomCutForestModel, + LinearLearnerModel, + PipelineModel, +) +from sagemaker.chainer import ChainerModel +from sagemaker.huggingface import HuggingFaceModel +from sagemaker.model import FrameworkModel +from sagemaker.multidatamodel import MultiDataModel +from sagemaker.mxnet import MXNetModel +from sagemaker.pytorch import PyTorchModel +from sagemaker.sklearn import SKLearnModel +from sagemaker.sparkml import SparkMLModel +from sagemaker.tensorflow import TensorFlowModel +from sagemaker.workflow._utils import _RepackModelStep +from sagemaker.xgboost import XGBoostModel +from tests.unit.sagemaker.workflow.test_mechanism.test_code.test_pipeline_var_compatibility_template import ( + PipelineVarCompatiTestTemplate, +) +from tests.unit.sagemaker.workflow.test_mechanism.test_code import BUCKET, MockProperties +from tests.unit.sagemaker.workflow.test_mechanism.test_code.utilities import ( + mock_image_uris_retrieve, + mock_tar_and_upload_dir, +) + +_IS_TRUE = bool(getrandbits(1)) +_mock_properties = MockProperties(step_name="MyStep") +_mock_properties.__dict__["ModelArtifacts"] = MockProperties( + step_name="MyStep", path="ModelArtifacts" +) +_mock_properties.ModelArtifacts.__dict__["S3ModelArtifacts"] = MockProperties( + step_name="MyStep", path="ModelArtifacts.S3ModelArtifacts" +) + + +# These tests provide the incomplete default arg dict +# within which some class or target func parameters are missing. +# The test template will fill in those missing args +# Note: the default args should not include PipelineVariable objects +@patch("sagemaker.workflow.steps.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.workflow._utils.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +@patch("sagemaker.utils.repack_model", MagicMock()) +@patch.object(_RepackModelStep, "_inject_repack_script", MagicMock()) +def test_model_compatibility(): + default_args = dict( + clazz_args=dict(), + func_args=dict( + register=dict( + inference_instances=["ml.t2.medium", "ml.m5.xlarge"], + transform_instances=["ml.t2.medium", "ml.m5.xlarge"], + model_package_group_name="my-model-pkg-group" if not _IS_TRUE else None, + model_package_name="my-model-pkg" if _IS_TRUE else None, + ), + create=dict( + instance_type="ml.t2.medium", + ), + ), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=Model, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch("sagemaker.workflow.steps.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.workflow._utils.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +@patch("sagemaker.utils.repack_model", MagicMock()) +@patch.object(_RepackModelStep, "_inject_repack_script", MagicMock()) +def test_framework_model_compatibility(): + default_args = dict( + clazz_args=dict(), + func_args=dict( + register=dict(), + create=dict(), + ), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=FrameworkModel, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch("sagemaker.workflow.steps.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.workflow._utils.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +@patch("sagemaker.utils.repack_model", MagicMock()) +@patch.object(_RepackModelStep, "_inject_repack_script", MagicMock()) +# Skip validate source dir because we skip _inject_repack_script +# Thus, repack script is not in the dir and the validation fails +@patch("sagemaker.estimator.validate_source_dir", MagicMock(return_value=True)) +def test_tensorflow_model_compatibility(): + default_args = dict( + clazz_args=dict(), + func_args=dict( + register=dict(), + create=dict(), + ), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=TensorFlowModel, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch("sagemaker.workflow.steps.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.workflow._utils.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +@patch("sagemaker.utils.repack_model", MagicMock()) +@patch.object(_RepackModelStep, "_inject_repack_script", MagicMock()) +def test_knn_model_compatibility(): + default_args = dict( + clazz_args=dict(), + func_args=dict( + register=dict(), + create=dict(), + ), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=KNNModel, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch("sagemaker.workflow.steps.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.workflow._utils.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +@patch("sagemaker.utils.repack_model", MagicMock()) +@patch.object(_RepackModelStep, "_inject_repack_script", MagicMock()) +def test_sparkml_model_compatibility(): + default_args = dict( + clazz_args=dict(), + func_args=dict( + register=dict(), + create=dict(), + ), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=SparkMLModel, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch("sagemaker.workflow.steps.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.workflow._utils.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +@patch("sagemaker.utils.repack_model", MagicMock()) +@patch.object(_RepackModelStep, "_inject_repack_script", MagicMock()) +def test_kmeans_model_compatibility(): + default_args = dict( + clazz_args=dict(), + func_args=dict( + register=dict(), + create=dict(), + ), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=KMeansModel, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch("sagemaker.workflow.steps.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.workflow._utils.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +@patch("sagemaker.utils.repack_model", MagicMock()) +@patch.object(_RepackModelStep, "_inject_repack_script", MagicMock()) +def test_pca_model_compatibility(): + default_args = dict( + clazz_args=dict(), + func_args=dict( + register=dict(), + create=dict(), + ), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=PCAModel, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch("sagemaker.workflow.steps.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.workflow._utils.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +@patch("sagemaker.utils.repack_model", MagicMock()) +@patch.object(_RepackModelStep, "_inject_repack_script", MagicMock()) +def test_lda_model_compatibility(): + default_args = dict( + clazz_args=dict(), + func_args=dict( + register=dict(), + create=dict(), + ), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=LDAModel, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch("sagemaker.workflow.steps.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.workflow._utils.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +@patch("sagemaker.utils.repack_model", MagicMock()) +@patch.object(_RepackModelStep, "_inject_repack_script", MagicMock()) +def test_ntm_model_compatibility(): + default_args = dict( + clazz_args=dict(), + func_args=dict( + register=dict(), + create=dict(), + ), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=NTMModel, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch("sagemaker.workflow.steps.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.workflow._utils.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +@patch("sagemaker.utils.repack_model", MagicMock()) +@patch.object(_RepackModelStep, "_inject_repack_script", MagicMock()) +def test_object2vec_model_compatibility(): + default_args = dict( + clazz_args=dict(), + func_args=dict( + register=dict(), + create=dict(), + ), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=Object2VecModel, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch("sagemaker.workflow.steps.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.workflow._utils.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +@patch("sagemaker.utils.repack_model", MagicMock()) +@patch.object(_RepackModelStep, "_inject_repack_script", MagicMock()) +def test_factorizationmachines_model_compatibility(): + default_args = dict( + clazz_args=dict(), + func_args=dict( + register=dict(), + create=dict(), + ), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=FactorizationMachinesModel, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch("sagemaker.workflow.steps.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.workflow._utils.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +@patch("sagemaker.utils.repack_model", MagicMock()) +@patch.object(_RepackModelStep, "_inject_repack_script", MagicMock()) +def test_ipinsights_model_compatibility(): + default_args = dict( + clazz_args=dict(), + func_args=dict( + register=dict(), + create=dict(), + ), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=IPInsightsModel, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch("sagemaker.workflow.steps.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.workflow._utils.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +@patch("sagemaker.utils.repack_model", MagicMock()) +@patch.object(_RepackModelStep, "_inject_repack_script", MagicMock()) +def test_randomcutforest_model_compatibility(): + default_args = dict( + clazz_args=dict(), + func_args=dict( + register=dict(), + create=dict(), + ), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=RandomCutForestModel, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch("sagemaker.workflow.steps.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.workflow._utils.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +@patch("sagemaker.utils.repack_model", MagicMock()) +@patch.object(_RepackModelStep, "_inject_repack_script", MagicMock()) +def test_linearlearner_model_compatibility(): + default_args = dict( + clazz_args=dict(), + func_args=dict( + register=dict(), + create=dict(), + ), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=LinearLearnerModel, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch("sagemaker.workflow.steps.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.workflow._utils.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +@patch("sagemaker.utils.repack_model", MagicMock()) +@patch.object(_RepackModelStep, "_inject_repack_script", MagicMock()) +@patch("sagemaker.estimator.validate_source_dir", MagicMock(return_value=True)) +def test_sklearn_model_compatibility(): + default_args = dict( + clazz_args=dict(framework_version="0.20.0"), + func_args=dict( + register=dict(), + create=dict(accelerator_type=None), + ), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=SKLearnModel, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch("sagemaker.workflow.steps.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.workflow._utils.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +@patch("sagemaker.utils.repack_model", MagicMock()) +@patch.object(_RepackModelStep, "_inject_repack_script", MagicMock()) +@patch("sagemaker.estimator.validate_source_dir", MagicMock(return_value=True)) +def test_pytorch_model_compatibility(): + default_args = dict( + clazz_args=dict(), + func_args=dict( + register=dict(), + create=dict(), + ), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=PyTorchModel, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch("sagemaker.workflow.steps.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.workflow._utils.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +@patch("sagemaker.utils.repack_model", MagicMock()) +@patch.object(_RepackModelStep, "_inject_repack_script", MagicMock()) +@patch("sagemaker.estimator.validate_source_dir", MagicMock(return_value=True)) +def test_xgboost_model_compatibility(): + default_args = dict( + clazz_args=dict( + framework_version="1", + ), + func_args=dict( + register=dict(), + create=dict(), + ), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=XGBoostModel, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch("sagemaker.workflow.steps.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.workflow._utils.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +@patch("sagemaker.utils.repack_model", MagicMock()) +@patch.object(_RepackModelStep, "_inject_repack_script", MagicMock()) +def test_chainer_model_compatibility(): + default_args = dict( + clazz_args=dict(framework_version="4.0.0"), + func_args=dict( + register=dict(), + create=dict(accelerator_type=None), + ), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=ChainerModel, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch("sagemaker.workflow.steps.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.workflow._utils.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +@patch("sagemaker.utils.repack_model", MagicMock()) +@patch.object(_RepackModelStep, "_inject_repack_script", MagicMock()) +@patch("sagemaker.estimator.validate_source_dir", MagicMock(return_value=True)) +def test_huggingface_model_compatibility(): + default_args = dict( + clazz_args=dict( + tensorflow_version="2.4.1" if _IS_TRUE else None, + pytorch_version="1.7.1" if not _IS_TRUE else None, + transformers_version="4.6.1", + py_version="py37" if _IS_TRUE else "py36", + ), + func_args=dict( + register=dict(), + create=dict(accelerator_type=None), + ), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=HuggingFaceModel, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch("sagemaker.workflow.steps.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.workflow._utils.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +@patch("sagemaker.utils.repack_model", MagicMock()) +@patch.object(_RepackModelStep, "_inject_repack_script", MagicMock()) +@patch("sagemaker.estimator.validate_source_dir", MagicMock(return_value=True)) +def test_mxnet_model_compatibility(): + default_args = dict( + clazz_args=dict(), + func_args=dict( + register=dict(), + create=dict(), + ), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=MXNetModel, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch("sagemaker.workflow.steps.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.workflow._utils.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +@patch("sagemaker.utils.repack_model", MagicMock()) +@patch.object(_RepackModelStep, "_inject_repack_script", MagicMock()) +def test_multidata_model_compatibility(): + default_args = dict( + clazz_args=dict( + model_data_prefix=f"s3://{BUCKET}", + model=None, + ), + func_args=dict( + register=dict(), + create=dict(), + ), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=MultiDataModel, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch("sagemaker.workflow.steps.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.workflow._utils.Properties", MagicMock(return_value=_mock_properties)) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +@patch("sagemaker.utils.repack_model", MagicMock()) +@patch.object(_RepackModelStep, "_inject_repack_script", MagicMock()) +def test_pipelinemodel_compatibility(): + default_args = dict( + clazz_args=dict(), + func_args=dict( + register=dict(), + create=dict(), + ), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=PipelineModel, + default_args=default_args, + ) + test_template.check_compatibility() diff --git a/tests/unit/sagemaker/workflow/test_mechanism/test_entries/test_pipeline_var_compatibility_with_processors.py b/tests/unit/sagemaker/workflow/test_mechanism/test_entries/test_pipeline_var_compatibility_with_processors.py new file mode 100644 index 0000000000..b677b909be --- /dev/null +++ b/tests/unit/sagemaker/workflow/test_mechanism/test_entries/test_pipeline_var_compatibility_with_processors.py @@ -0,0 +1,363 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from random import getrandbits +from unittest.mock import MagicMock, patch + +from sagemaker.processing import FrameworkProcessor, ScriptProcessor, Processor +from sagemaker.pytorch.processing import PyTorchProcessor +from sagemaker.clarify import ( + SageMakerClarifyProcessor, + BiasConfig, + ModelConfig, + ModelPredictedLabelConfig, + PDPConfig, +) +from sagemaker.tensorflow.processing import TensorFlowProcessor +from sagemaker.xgboost.processing import XGBoostProcessor +from sagemaker.mxnet.processing import MXNetProcessor +from sagemaker.sklearn.processing import SKLearnProcessor +from sagemaker.huggingface.processing import HuggingFaceProcessor +from tests.unit.sagemaker.workflow.test_mechanism.test_code.test_pipeline_var_compatibility_template import ( + PipelineVarCompatiTestTemplate, +) +from tests.unit.sagemaker.workflow.test_mechanism.test_code import ( + ROLE, + DUMMY_S3_SCRIPT_PATH, + PIPELINE_SESSION, + MockProperties, +) + +from tests.unit.sagemaker.workflow.test_mechanism.test_code.utilities import ( + mock_image_uris_retrieve, +) + +_IS_TRUE = bool(getrandbits(1)) + + +# These tests provide the incomplete default arg dict +# within which some class or target func parameters are missing. +# The test template will fill in those missing args +# Note: the default args should not include PipelineVariable objects +@patch( + "sagemaker.workflow.steps.Properties", + MagicMock(return_value=MockProperties(step_name="MyStep")), +) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +def test_processor_compatibility(): + default_args = dict( + clazz_args=dict( + role=ROLE, + volume_size_in_gb=None, + sagemaker_session=PIPELINE_SESSION, + ), + func_args=dict(), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=Processor, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch( + "sagemaker.workflow.steps.Properties", + MagicMock(return_value=MockProperties(step_name="MyStep")), +) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +def test_script_processor_compatibility(): + default_args = dict( + clazz_args=dict( + role=ROLE, + volume_size_in_gb=None, + sagemaker_session=PIPELINE_SESSION, + ), + func_args=dict( + code=DUMMY_S3_SCRIPT_PATH, + ), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=ScriptProcessor, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch( + "sagemaker.workflow.steps.Properties", + MagicMock(return_value=MockProperties(step_name="MyStep")), +) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +def test_framework_processor_compatibility(): + default_args = dict( + clazz_args=dict( + role=ROLE, + py_version="py3", + volume_size_in_gb=None, + sagemaker_session=PIPELINE_SESSION, + ), + func_args=dict( + code=DUMMY_S3_SCRIPT_PATH, + ), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=FrameworkProcessor, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch( + "sagemaker.workflow.steps.Properties", + MagicMock(return_value=MockProperties(step_name="MyStep")), +) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +def test_pytorch_processor_compatibility(): + default_args = dict( + clazz_args=dict( + framework_version="1.8.1", + role=ROLE, + py_version="py3", + volume_size_in_gb=None, + sagemaker_session=PIPELINE_SESSION, + ), + func_args=dict( + code=DUMMY_S3_SCRIPT_PATH, + ), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=PyTorchProcessor, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch( + "sagemaker.workflow.steps.Properties", + MagicMock(return_value=MockProperties(step_name="MyStep")), +) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +def test_sagemaker_clarify_processor(): + bias_config = BiasConfig( + facet_values_or_threshold=0.6, + facet_name="facet_name", + label_values_or_threshold=0.6, + ) + model_config = ModelConfig( + model_name="my-model", + instance_count=1, + instance_type="ml.m5.xlarge", + ) + model_pred_config = ModelPredictedLabelConfig(label="pred", probability_threshold=0.6) + + default_args = dict( + clazz_args=dict( + role=ROLE, + sagemaker_session=PIPELINE_SESSION, + ), + func_args=dict( + run_pre_training_bias=dict( + data_bias_config=bias_config, + ), + run_post_training_bias=dict( + data_bias_config=bias_config, + model_config=model_config, + model_predicted_label_config=model_pred_config, + ), + run_bias=dict( + bias_config=bias_config, + model_config=model_config, + model_predicted_label_config=model_pred_config, + ), + run_explainability=dict( + model_config=model_config, + model_scores=ModelPredictedLabelConfig(label="pred", probability_threshold=0.6), + explainability_config=PDPConfig(features=["f1", "f2", "f3", "f4"]), + ), + ), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=SageMakerClarifyProcessor, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch( + "sagemaker.workflow.steps.Properties", + MagicMock(return_value=MockProperties(step_name="MyStep")), +) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +def test_tensorflow_processor(): + default_args = dict( + clazz_args=dict( + framework_version="2.8", + role=ROLE, + py_version="py39", + sagemaker_session=PIPELINE_SESSION, + ), + func_args=dict( + code=DUMMY_S3_SCRIPT_PATH, + ), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=TensorFlowProcessor, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch( + "sagemaker.workflow.steps.Properties", + MagicMock(return_value=MockProperties(step_name="MyStep")), +) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +def test_xgboost_processor(): + default_args = dict( + clazz_args=dict( + role=ROLE, + framework_version="1.2-1", + py_version="py3", + sagemaker_session=PIPELINE_SESSION, + ), + func_args=dict( + code=DUMMY_S3_SCRIPT_PATH, + ), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=XGBoostProcessor, + default_args=default_args, + ) + test_template.check_compatibility() + + +# TODO: need to merge a fix from Jerry from latest sdk master branch to unblock +# @patch("sagemaker.workflow.steps.Properties", MagicMock(return_value=MockProperties(step_name="MyStep"))) +# @patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +# def test_spark_jar_processor(): +# # takes a really long time, since the .run has many args +# default_args = dict( +# clazz_args=dict( +# role=ROLE, +# framework_version="2.4", +# py_version="py37", +# sagemaker_session=PIPELINE_SESSION, +# ), +# func_args=dict( +# submit_app=DUMMY_S3_SCRIPT_PATH, +# ), +# ) +# test_template = PipelineVarCompatiTestTemplate( +# clazz=SparkJarProcessor, +# default_args=default_args, +# ) +# test_template.check_compatibility() + + +# @patch("sagemaker.workflow.steps.Properties", MagicMock(return_value=MockProperties(step_name="MyStep"))) +# @patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +# def test_py_spark_processor(): +# # takes a really long time, since the .run has many args +# default_args = dict( +# clazz_args=dict( +# role=ROLE, +# framework_version="2.4", +# py_version="py37", +# sagemaker_session=PIPELINE_SESSION, +# ), +# func_args=dict( +# submit_app=DUMMY_S3_SCRIPT_PATH, +# ), +# ) +# test_template = PipelineVarCompatiTestTemplate( +# clazz=PySparkProcessor, +# default_args=default_args, +# ) +# test_template.check_compatibility() + + +@patch( + "sagemaker.workflow.steps.Properties", + MagicMock(return_value=MockProperties(step_name="MyStep")), +) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +def test_mxnet_processor(): + # takes a really long time, since the .run has many args + default_args = dict( + clazz_args=dict( + role=ROLE, + framework_version="1.6", + py_version="py3", + sagemaker_session=PIPELINE_SESSION, + ), + func_args=dict( + code=DUMMY_S3_SCRIPT_PATH, + ), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=MXNetProcessor, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch( + "sagemaker.workflow.steps.Properties", + MagicMock(return_value=MockProperties(step_name="MyStep")), +) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +def test_sklearn_processor(): + default_args = dict( + clazz_args=dict( + role=ROLE, + framework_version="0.23-1", + sagemaker_session=PIPELINE_SESSION, + instance_type="ml.m5.xlarge", + ), + func_args=dict( + code=DUMMY_S3_SCRIPT_PATH, + ), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=SKLearnProcessor, + default_args=default_args, + ) + test_template.check_compatibility() + + +@patch( + "sagemaker.workflow.steps.Properties", + MagicMock(return_value=MockProperties(step_name="MyStep")), +) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +def test_hugging_face_processor(): + default_args = dict( + clazz_args=dict( + role=ROLE, + sagemaker_session=PIPELINE_SESSION, + transformers_version="4.6", + tensorflow_version="2.4" if _IS_TRUE else None, + pytorch_version="1.8" if not _IS_TRUE else None, + py_version="py37" if _IS_TRUE else "py36", + instance_type="ml.p3.xlarge", + ), + func_args=dict( + code=DUMMY_S3_SCRIPT_PATH, + ), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=HuggingFaceProcessor, + default_args=default_args, + ) + test_template.check_compatibility() diff --git a/tests/unit/sagemaker/workflow/test_mechanism/test_entries/test_pipeline_var_compatibility_with_transformer.py b/tests/unit/sagemaker/workflow/test_mechanism/test_entries/test_pipeline_var_compatibility_with_transformer.py new file mode 100644 index 0000000000..51c9f2799e --- /dev/null +++ b/tests/unit/sagemaker/workflow/test_mechanism/test_entries/test_pipeline_var_compatibility_with_transformer.py @@ -0,0 +1,41 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from unittest.mock import MagicMock, patch + +from sagemaker.transformer import Transformer +from tests.unit.sagemaker.workflow.test_mechanism.test_code.test_pipeline_var_compatibility_template import ( + PipelineVarCompatiTestTemplate, +) +from tests.unit.sagemaker.workflow.test_mechanism.test_code import MockProperties + + +# These tests provide the incomplete default arg dict +# within which some class or target func parameters are missing. +# The test template will fill in those missing args +# Note: the default args should not include PipelineVariable objects +@patch( + "sagemaker.workflow.steps.Properties", + MagicMock(return_value=MockProperties(step_name="MyStep")), +) +def test_transformer_compatibility(): + default_args = dict( + clazz_args=dict(), + func_args=dict(), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=Transformer, + default_args=default_args, + ) + test_template.check_compatibility() diff --git a/tests/unit/sagemaker/workflow/test_mechanism/test_entries/test_pipeline_var_compatibility_with_tuner.py b/tests/unit/sagemaker/workflow/test_mechanism/test_entries/test_pipeline_var_compatibility_with_tuner.py new file mode 100644 index 0000000000..868cfa7344 --- /dev/null +++ b/tests/unit/sagemaker/workflow/test_mechanism/test_entries/test_pipeline_var_compatibility_with_tuner.py @@ -0,0 +1,47 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from unittest.mock import patch, MagicMock + +from sagemaker.tuner import HyperparameterTuner +from tests.unit.sagemaker.workflow.test_mechanism.test_code.test_pipeline_var_compatibility_template import ( + PipelineVarCompatiTestTemplate, +) +from tests.unit.sagemaker.workflow.test_mechanism.test_code import MockProperties +from tests.unit.sagemaker.workflow.test_mechanism.test_code.utilities import ( + mock_image_uris_retrieve, + mock_tar_and_upload_dir, +) + + +# These tests provide the incomplete default arg dict +# within which some class or target func parameters are missing. +# The test template will fill in those missing args +# Note: the default args should not include PipelineVariable objects +@patch( + "sagemaker.workflow.steps.Properties", + MagicMock(return_value=MockProperties(step_name="MyStep")), +) +@patch("sagemaker.image_uris.retrieve", MagicMock(side_effect=mock_image_uris_retrieve)) +@patch("sagemaker.estimator.tar_and_upload_dir", MagicMock(side_effect=mock_tar_and_upload_dir)) +def test_hyperparametertuner_compatibility(): + default_args = dict( + clazz_args=dict(), + func_args=dict(), + ) + test_template = PipelineVarCompatiTestTemplate( + clazz=HyperparameterTuner, + default_args=default_args, + ) + test_template.check_compatibility() diff --git a/tests/unit/sagemaker/workflow/test_processing_step.py b/tests/unit/sagemaker/workflow/test_processing_step.py index 262d0eb558..e320ac7ebb 100644 --- a/tests/unit/sagemaker/workflow/test_processing_step.py +++ b/tests/unit/sagemaker/workflow/test_processing_step.py @@ -356,7 +356,6 @@ def test_processing_step_with_script_processor(pipeline_session, processing_inpu processor = ScriptProcessor( role=ROLE, image_uri=IMAGE_URI, - command=["python3"], instance_type=INSTANCE_TYPE, instance_count=1, volume_size_in_gb=100, @@ -430,6 +429,9 @@ def test_processing_step_with_framework_processor( == processing_output.destination ) + del step_args["AppSpecification"]["ContainerEntrypoint"] + del step_def["Arguments"]["AppSpecification"]["ContainerEntrypoint"] + del step_args["ProcessingInputs"][0]["S3Input"]["S3Uri"] del step_def["Arguments"]["ProcessingInputs"][0]["S3Input"]["S3Uri"] diff --git a/tests/unit/sagemaker/workflow/test_training_step.py b/tests/unit/sagemaker/workflow/test_training_step.py index f043048095..502a7674f8 100644 --- a/tests/unit/sagemaker/workflow/test_training_step.py +++ b/tests/unit/sagemaker/workflow/test_training_step.py @@ -56,6 +56,8 @@ from sagemaker.inputs import TrainingInput from tests.unit.sagemaker.workflow.helpers import CustomStep, ordered +from sagemaker.workflow.condition_step import ConditionStep +from sagemaker.workflow.conditions import ConditionGreaterThanOrEqualTo REGION = "us-west-2" BUCKET = "my-bucket" @@ -232,18 +234,28 @@ def test_training_step_with_estimator(pipeline_session, training_input, hyperpar assert "Running within a PipelineSession" in str(w[-1].message) with warnings.catch_warnings(record=True) as w: - step = TrainingStep( + step_train = TrainingStep( name="MyTrainingStep", step_args=step_args, description="TrainingStep description", display_name="MyTrainingStep", - depends_on=["TestStep", "SecondTestStep"], + depends_on=["TestStep"], ) assert len(w) == 0 + step_condition = ConditionStep( + name="MyConditionStep", + conditions=[ + ConditionGreaterThanOrEqualTo( + left=step_train.properties.FinalMetricDataList["val:acc"].Value, right=0.95 + ) + ], + if_steps=[custom_step2], + ) + pipeline = Pipeline( name="MyPipeline", - steps=[step, custom_step1, custom_step2], + steps=[step_train, step_condition, custom_step1, custom_step2], parameters=[enable_network_isolation, encrypt_container_traffic], sagemaker_session=pipeline_session, ) @@ -251,15 +263,20 @@ def test_training_step_with_estimator(pipeline_session, training_input, hyperpar "Get": "Parameters.encrypt_container_traffic" } step_args.args["EnableNetworkIsolation"] = {"Get": "Parameters.encrypt_container_traffic"} + assert step_condition.conditions[0].left.expr == { + "Get": "Steps.MyTrainingStep.FinalMetricDataList['val:acc'].Value" + } assert json.loads(pipeline.definition())["Steps"][0] == { "Name": "MyTrainingStep", "Description": "TrainingStep description", "DisplayName": "MyTrainingStep", "Type": "Training", - "DependsOn": ["TestStep", "SecondTestStep"], + "DependsOn": ["TestStep"], "Arguments": step_args.args, } - assert step.properties.TrainingJobName.expr == {"Get": "Steps.MyTrainingStep.TrainingJobName"} + assert step_train.properties.TrainingJobName.expr == { + "Get": "Steps.MyTrainingStep.TrainingJobName" + } adjacency_list = PipelineGraph.from_pipeline(pipeline).adjacency_list assert ordered(adjacency_list) == ordered( {"MyTrainingStep": [], "SecondTestStep": ["MyTrainingStep"], "TestStep": ["MyTrainingStep"]}