From 4b8cb2953e1646cd91691386a045f2ec3270ac14 Mon Sep 17 00:00:00 2001 From: Qiushi Wuye Date: Fri, 14 Aug 2020 21:00:30 +0000 Subject: [PATCH 1/5] feat: support creating endpoints with model images from private registries --- src/sagemaker/image_config.py | 39 ++++++++++++++++++++++++ src/sagemaker/model.py | 14 ++++++++- src/sagemaker/session.py | 9 +++++- tests/unit/sagemaker/model/test_model.py | 13 ++++++++ tests/unit/test_image_config.py | 30 ++++++++++++++++++ 5 files changed, 103 insertions(+), 2 deletions(-) create mode 100644 src/sagemaker/image_config.py create mode 100644 tests/unit/test_image_config.py diff --git a/src/sagemaker/image_config.py b/src/sagemaker/image_config.py new file mode 100644 index 0000000000..3fe13970eb --- /dev/null +++ b/src/sagemaker/image_config.py @@ -0,0 +1,39 @@ +# Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module contains code to create and manage SageMaker ``ImageConfig``""" + + +class ImageConfig(object): + """Configuration of Docker image used in Model.""" + + def __init__( + self, repository_access_mode="Platform", + ): + """Initialize an ``ImageConfig``. + + Args: + repository_access_mode (str): Set this to one of the following values (default: "Platform"): + * Platform: The model image is hosted in Amazon ECR. + * Vpc: The model image is hosted in a private Docker registry in your VPC. + """ + self.repository_access_mode = repository_access_mode + + def _to_request_dict(self): + """Generates a request dictionary using the parameters provided to the class.""" + req = { + "RepositoryAccessMode": "Platform", + } + if self.repository_access_mode is not None: + req["RepositoryAccessMode"] = self.repository_access_mode + + return req diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 7730bb3ba1..bdd1197a52 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -43,6 +43,7 @@ def __init__( sagemaker_session=None, enable_network_isolation=False, model_kms_key=None, + image_config=None, ): """Initialize an SageMaker ``Model``. @@ -80,6 +81,10 @@ def __init__( or from the model container. model_kms_key (str): KMS key ARN used to encrypt the repacked model archive file if the model is repacked + image_config (sagemaker.ImageConfig): Specifies whether the image of + model container is pulled from ECR, or private registry in your + VPC. By default it is set to pull model container image from + ECR. (default: None). """ self.model_data = model_data self.image_uri = image_uri @@ -94,6 +99,7 @@ def __init__( self._is_compiled_model = False self._enable_network_isolation = enable_network_isolation self.model_kms_key = model_kms_key + self.image_config = image_config def _init_sagemaker_session_if_does_not_exist(self, instance_type): """Set ``self.sagemaker_session`` to be a ``LocalSession`` or @@ -127,7 +133,13 @@ def prepare_container_def( Returns: dict: A container definition object usable with the CreateModel API. """ - return sagemaker.container_def(self.image_uri, self.model_data, self.env) + image_config_dict = None + if self.image_config: + image_config_dict = self.image_config._to_request_dict() + + return sagemaker.container_def( + self.image_uri, self.model_data, self.env, image_config_dict=image_config_dict + ) def enable_network_isolation(self): """Whether to enable network isolation when creating this Model diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 7888a86268..5376c7d640 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -3254,7 +3254,9 @@ def logs_for_transform_job(self, job_name, wait=False, poll=10): print() -def container_def(image_uri, model_data_url=None, env=None, container_mode=None): +def container_def( + image_uri, model_data_url=None, env=None, container_mode=None, image_config_dict=None +): """Create a definition for executing a container as part of a SageMaker model. Args: @@ -3266,6 +3268,9 @@ def container_def(image_uri, model_data_url=None, env=None, container_mode=None) * MultiModel: Indicates that model container can support hosting multiple models * SingleModel: Indicates that model container can support hosting a single model This is the default model container mode when container_mode = None + image_config_dict (dict): Specifies whether the image of model container is pulled from ECR, + or private registry in your VPC. By default it is set to pull model container image + from ECR. (default: None). Returns: dict[str, str]: A complete container definition object usable with the CreateModel API if passed via `PrimaryContainers` field. @@ -3277,6 +3282,8 @@ def container_def(image_uri, model_data_url=None, env=None, container_mode=None) c_def["ModelDataUrl"] = model_data_url if container_mode: c_def["Mode"] = container_mode + if image_config_dict: + c_def["ImageConfig"] = image_config_dict return c_def diff --git a/tests/unit/sagemaker/model/test_model.py b/tests/unit/sagemaker/model/test_model.py index 284745d3e4..8f306312a1 100644 --- a/tests/unit/sagemaker/model/test_model.py +++ b/tests/unit/sagemaker/model/test_model.py @@ -17,6 +17,7 @@ import sagemaker from sagemaker.model import Model +from sagemaker.image_config import ImageConfig MODEL_DATA = "s3://bucket/model.tar.gz" MODEL_IMAGE = "mi" @@ -26,6 +27,7 @@ INSTANCE_COUNT = 2 INSTANCE_TYPE = "ml.c4.4xlarge" ROLE = "some-role" +REPOSITORY_ACCESS_MODE = "Vpc" @pytest.fixture @@ -54,6 +56,17 @@ def test_prepare_container_def_with_model_data_and_env(): assert expected == container_def +def test_prepare_container_def_with_image_config(): + image_config = ImageConfig(repository_access_mode=REPOSITORY_ACCESS_MODE) + model = Model(MODEL_IMAGE, image_config=image_config) + + expected_image_config_dict = {"RepositoryAccessMode": "Vpc"} + expected = {"Image": MODEL_IMAGE, "ImageConfig": expected_image_config_dict, "Environment": {}} + + container_def = model.prepare_container_def() + assert expected == container_def + + def test_model_enable_network_isolation(): model = Model(MODEL_IMAGE, MODEL_DATA) assert model.enable_network_isolation() is False diff --git a/tests/unit/test_image_config.py b/tests/unit/test_image_config.py new file mode 100644 index 0000000000..e80ed6246e --- /dev/null +++ b/tests/unit/test_image_config.py @@ -0,0 +1,30 @@ +# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from sagemaker.image_config import ImageConfig + +REPOSITORY_ACCESS_MODE_PLATFORM = "Platform" +REPOSITORY_ACCESS_MODE_VPC = "Vpc" + + +def test_init_with_defaults(): + image_config = ImageConfig() + + assert image_config.repository_access_mode == REPOSITORY_ACCESS_MODE_PLATFORM + + +def test_init_with_non_defaults(): + image_config = ImageConfig(repository_access_mode=REPOSITORY_ACCESS_MODE_VPC) + + assert image_config.repository_access_mode == REPOSITORY_ACCESS_MODE_VPC From dcea5097ce48cf60a700b7542224222f0ded45ea Mon Sep 17 00:00:00 2001 From: Qiushi Wuye Date: Mon, 17 Aug 2020 22:30:16 +0000 Subject: [PATCH 2/5] fix: fix pylint and flake8 errors --- src/sagemaker/image_config.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/sagemaker/image_config.py b/src/sagemaker/image_config.py index 3fe13970eb..d8efbd78f1 100644 --- a/src/sagemaker/image_config.py +++ b/src/sagemaker/image_config.py @@ -11,6 +11,7 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. """This module contains code to create and manage SageMaker ``ImageConfig``""" +from __future__ import absolute_import class ImageConfig(object): @@ -22,7 +23,8 @@ def __init__( """Initialize an ``ImageConfig``. Args: - repository_access_mode (str): Set this to one of the following values (default: "Platform"): + repository_access_mode (str): Set this to one of the following + values (default: "Platform"): * Platform: The model image is hosted in Amazon ECR. * Vpc: The model image is hosted in a private Docker registry in your VPC. """ From 17b14f898eeb326d46975b04c1bec93e93ba769d Mon Sep 17 00:00:00 2001 From: Qiushi Wuye Date: Wed, 26 Aug 2020 19:00:54 +0000 Subject: [PATCH 3/5] change: simplify modeling of ImageConfig `ImageConfig` class is overkill to model a configuration bag. --- src/sagemaker/amazon/amazon_estimator.py | 4 +- src/sagemaker/amazon/kmeans.py | 6 +- src/sagemaker/amazon/knn.py | 4 +- src/sagemaker/amazon/lda.py | 12 ++-- src/sagemaker/amazon/ntm.py | 4 +- src/sagemaker/amazon/pca.py | 4 +- src/sagemaker/analytics.py | 13 ++-- src/sagemaker/automl/automl.py | 3 +- src/sagemaker/cli/framework_upgrade.py | 10 ++- src/sagemaker/image_config.py | 41 ----------- src/sagemaker/local/local_session.py | 3 +- src/sagemaker/model.py | 8 +-- src/sagemaker/model_monitor/dataset_format.py | 3 +- .../model_monitor/model_monitoring.py | 3 +- .../model_monitor/monitoring_files.py | 12 ++-- src/sagemaker/parameter.py | 6 +- src/sagemaker/session.py | 14 ++-- src/sagemaker/tensorflow/model.py | 5 +- src/sagemaker/transformer.py | 3 +- src/sagemaker/tuner.py | 3 +- tests/data/sagemaker_rl/coach_launcher.py | 7 +- tests/data/sagemaker_rl/configuration_list.py | 17 ++--- tests/integ/test_airflow_config.py | 10 ++- tests/integ/test_multidatamodel.py | 3 +- tests/integ/test_mxnet.py | 5 +- tests/integ/test_processing.py | 5 +- tests/integ/test_pytorch.py | 5 +- tests/integ/test_sklearn.py | 30 ++++++-- tests/integ/test_tfs.py | 4 +- tests/integ/test_tuner.py | 8 +-- .../v2/modifiers/test_framework_version.py | 24 +++++-- .../compatibility/v2/modifiers/test_serde.py | 70 +++++++++++++++---- .../image_uris/test_dlc_frameworks.py | 7 +- .../unit/sagemaker/image_uris/test_sklearn.py | 5 +- .../unit/sagemaker/image_uris/test_xgboost.py | 5 +- tests/unit/sagemaker/model/test_deploy.py | 12 ++-- tests/unit/sagemaker/model/test_model.py | 30 +++++--- tests/unit/test_image_config.py | 30 -------- tests/unit/test_linear_learner.py | 6 +- tests/unit/test_ntm.py | 8 ++- tests/unit/test_pca.py | 6 +- tests/unit/test_predictor.py | 3 +- tests/unit/test_pytorch.py | 6 +- tests/unit/test_session.py | 5 +- tests/unit/test_tuner.py | 8 ++- tests/unit/tuner_test_utils.py | 6 +- 46 files changed, 283 insertions(+), 203 deletions(-) delete mode 100644 src/sagemaker/image_config.py delete mode 100644 tests/unit/test_image_config.py diff --git a/src/sagemaker/amazon/amazon_estimator.py b/src/sagemaker/amazon/amazon_estimator.py index 307ac97f6d..34b52ccf9f 100644 --- a/src/sagemaker/amazon/amazon_estimator.py +++ b/src/sagemaker/amazon/amazon_estimator.py @@ -94,7 +94,9 @@ def __init__( def training_image_uri(self): """Placeholder docstring""" return image_uris.retrieve( - self.repo_name, self.sagemaker_session.boto_region_name, version=self.repo_version, + self.repo_name, + self.sagemaker_session.boto_region_name, + version=self.repo_version, ) def hyperparameters(self): diff --git a/src/sagemaker/amazon/kmeans.py b/src/sagemaker/amazon/kmeans.py index b7556d671c..fa2b3d1455 100644 --- a/src/sagemaker/amazon/kmeans.py +++ b/src/sagemaker/amazon/kmeans.py @@ -29,7 +29,7 @@ class KMeans(AmazonAlgorithmEstimatorBase): As the result of KMeans, members of a group are as similar as possible to one another and as different as possible from members of other groups. You define the attributes that you want - the algorithm to use to determine similarity. """ + the algorithm to use to determine similarity.""" repo_name = "kmeans" repo_version = 1 @@ -257,7 +257,9 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs): """ sagemaker_session = sagemaker_session or Session() image_uri = image_uris.retrieve( - KMeans.repo_name, sagemaker_session.boto_region_name, version=KMeans.repo_version, + KMeans.repo_name, + sagemaker_session.boto_region_name, + version=KMeans.repo_version, ) super(KMeansModel, self).__init__( image_uri, diff --git a/src/sagemaker/amazon/knn.py b/src/sagemaker/amazon/knn.py index ab4a4d1495..54fc141047 100644 --- a/src/sagemaker/amazon/knn.py +++ b/src/sagemaker/amazon/knn.py @@ -246,7 +246,9 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs): """ sagemaker_session = sagemaker_session or Session() image_uri = image_uris.retrieve( - KNN.repo_name, sagemaker_session.boto_region_name, version=KNN.repo_version, + KNN.repo_name, + sagemaker_session.boto_region_name, + version=KNN.repo_version, ) super(KNNModel, self).__init__( image_uri, diff --git a/src/sagemaker/amazon/lda.py b/src/sagemaker/amazon/lda.py index 3e4ef16e87..063dac5978 100644 --- a/src/sagemaker/amazon/lda.py +++ b/src/sagemaker/amazon/lda.py @@ -27,10 +27,10 @@ class LDA(AmazonAlgorithmEstimatorBase): """An unsupervised learning algorithm attempting to describe data as distinct categories. - LDA is most commonly used to discover a - user-specified number of topics shared by documents within a text corpus. Here each - observation is a document, the features are the presence (or occurrence count) of each - word, and the categories are the topics.""" + LDA is most commonly used to discover a + user-specified number of topics shared by documents within a text corpus. Here each + observation is a document, the features are the presence (or occurrence count) of each + word, and the categories are the topics.""" repo_name = "lda" repo_version = 1 @@ -230,7 +230,9 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs): """ sagemaker_session = sagemaker_session or Session() image_uri = image_uris.retrieve( - LDA.repo_name, sagemaker_session.boto_region_name, version=LDA.repo_version, + LDA.repo_name, + sagemaker_session.boto_region_name, + version=LDA.repo_version, ) super(LDAModel, self).__init__( image_uri, diff --git a/src/sagemaker/amazon/ntm.py b/src/sagemaker/amazon/ntm.py index 0e1d8fc39b..acbb5e4014 100644 --- a/src/sagemaker/amazon/ntm.py +++ b/src/sagemaker/amazon/ntm.py @@ -259,7 +259,9 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs): """ sagemaker_session = sagemaker_session or Session() image_uri = image_uris.retrieve( - NTM.repo_name, sagemaker_session.boto_region_name, version=NTM.repo_version, + NTM.repo_name, + sagemaker_session.boto_region_name, + version=NTM.repo_version, ) super(NTMModel, self).__init__( image_uri, diff --git a/src/sagemaker/amazon/pca.py b/src/sagemaker/amazon/pca.py index f7d52b54e4..9ab5d6a447 100644 --- a/src/sagemaker/amazon/pca.py +++ b/src/sagemaker/amazon/pca.py @@ -240,7 +240,9 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs): """ sagemaker_session = sagemaker_session or Session() image_uri = image_uris.retrieve( - PCA.repo_name, sagemaker_session.boto_region_name, version=PCA.repo_version, + PCA.repo_name, + sagemaker_session.boto_region_name, + version=PCA.repo_version, ) super(PCAModel, self).__init__( image_uri, diff --git a/src/sagemaker/analytics.py b/src/sagemaker/analytics.py index d3abd0e543..b1238fdcbb 100644 --- a/src/sagemaker/analytics.py +++ b/src/sagemaker/analytics.py @@ -416,8 +416,7 @@ def _metric_names_for_training_job(self): class ExperimentAnalytics(AnalyticsMetricsBase): - """Fetch trial component data and make them accessible for analytics. - """ + """Fetch trial component data and make them accessible for analytics.""" MAX_TRIAL_COMPONENTS = 10000 @@ -477,16 +476,14 @@ def __init__( @property def name(self): - """Name of the Experiment being analyzed - """ + """Name of the Experiment being analyzed""" return self._experiment_name def __repr__(self): return "" % self.name def clear_cache(self): - """Clear the object of all local caches of API methods. - """ + """Clear the object of all local caches of API methods.""" super(ExperimentAnalytics, self).clear_cache() self._trial_components = None @@ -570,13 +567,13 @@ def _reshape(self, trial_component): def _fetch_dataframe(self): """Return a pandas dataframe with all the trial_components, - along with their parameters and metrics. + along with their parameters and metrics. """ df = pd.DataFrame([self._reshape(component) for component in self._get_trial_components()]) return df def _get_trial_components(self, force_refresh=False): - """ Get all trial components matching the given search query expression. + """Get all trial components matching the given search query expression. Args: force_refresh (bool): Set to True to fetch the latest data from SageMaker API. diff --git a/src/sagemaker/automl/automl.py b/src/sagemaker/automl/automl.py index ab381565e7..b23389d80b 100644 --- a/src/sagemaker/automl/automl.py +++ b/src/sagemaker/automl/automl.py @@ -26,8 +26,7 @@ class AutoML(object): - """A class for creating and interacting with SageMaker AutoML jobs - """ + """A class for creating and interacting with SageMaker AutoML jobs""" def __init__( self, diff --git a/src/sagemaker/cli/framework_upgrade.py b/src/sagemaker/cli/framework_upgrade.py index 81247c5325..36931ffb79 100644 --- a/src/sagemaker/cli/framework_upgrade.py +++ b/src/sagemaker/cli/framework_upgrade.py @@ -158,7 +158,13 @@ def add_region(existing_content, region, account): def add_version( - existing_content, short_version, full_version, scope, processors, py_versions, tag_prefix, + existing_content, + short_version, + full_version, + scope, + processors, + py_versions, + tag_prefix, ): """Read framework image uri information from json file to a dictionary, update it with new framework version information, then write the dictionary back to json file. @@ -172,7 +178,7 @@ def add_version( processors (str): Supported processors (e.g. "cpu,gpu"). py_versions (str): Supported Python versions (e.g. "py3,py37"). tag_prefix (str): Algorithm image's tag prefix. - """ + """ if py_versions: py_versions = py_versions.split(",") processors = processors.split(",") diff --git a/src/sagemaker/image_config.py b/src/sagemaker/image_config.py deleted file mode 100644 index d8efbd78f1..0000000000 --- a/src/sagemaker/image_config.py +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -"""This module contains code to create and manage SageMaker ``ImageConfig``""" -from __future__ import absolute_import - - -class ImageConfig(object): - """Configuration of Docker image used in Model.""" - - def __init__( - self, repository_access_mode="Platform", - ): - """Initialize an ``ImageConfig``. - - Args: - repository_access_mode (str): Set this to one of the following - values (default: "Platform"): - * Platform: The model image is hosted in Amazon ECR. - * Vpc: The model image is hosted in a private Docker registry in your VPC. - """ - self.repository_access_mode = repository_access_mode - - def _to_request_dict(self): - """Generates a request dictionary using the parameters provided to the class.""" - req = { - "RepositoryAccessMode": "Platform", - } - if self.repository_access_mode is not None: - req["RepositoryAccessMode"] = self.repository_access_mode - - return req diff --git a/src/sagemaker/local/local_session.py b/src/sagemaker/local/local_session.py index 388eeb7fc2..69943e8cae 100644 --- a/src/sagemaker/local/local_session.py +++ b/src/sagemaker/local/local_session.py @@ -448,8 +448,7 @@ class file_input(object): """Amazon SageMaker channel configuration for FILE data sources, used in local mode.""" def __init__(self, fileUri, content_type=None): - """Create a definition for input data used by an SageMaker training job in local mode. - """ + """Create a definition for input data used by an SageMaker training job in local mode.""" self.config = { "DataSource": { "FileDataSource": { diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index bdd1197a52..f57bb76f83 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -81,7 +81,7 @@ def __init__( or from the model container. model_kms_key (str): KMS key ARN used to encrypt the repacked model archive file if the model is repacked - image_config (sagemaker.ImageConfig): Specifies whether the image of + image_config (dict[str, object]): Specifies whether the image of model container is pulled from ECR, or private registry in your VPC. By default it is set to pull model container image from ECR. (default: None). @@ -133,12 +133,8 @@ def prepare_container_def( Returns: dict: A container definition object usable with the CreateModel API. """ - image_config_dict = None - if self.image_config: - image_config_dict = self.image_config._to_request_dict() - return sagemaker.container_def( - self.image_uri, self.model_data, self.env, image_config_dict=image_config_dict + self.image_uri, self.model_data, self.env, image_config=self.image_config ) def enable_network_isolation(self): diff --git a/src/sagemaker/model_monitor/dataset_format.py b/src/sagemaker/model_monitor/dataset_format.py index f4c9c0b967..8dcdb79476 100644 --- a/src/sagemaker/model_monitor/dataset_format.py +++ b/src/sagemaker/model_monitor/dataset_format.py @@ -18,8 +18,7 @@ class DatasetFormat(object): - """Represents a Dataset Format that is used when calling a DefaultModelMonitor. - """ + """Represents a Dataset Format that is used when calling a DefaultModelMonitor.""" @staticmethod def csv(header=True, output_columns_position="START"): diff --git a/src/sagemaker/model_monitor/model_monitoring.py b/src/sagemaker/model_monitor/model_monitoring.py index a3c001b70f..c01a974f9b 100644 --- a/src/sagemaker/model_monitor/model_monitoring.py +++ b/src/sagemaker/model_monitor/model_monitoring.py @@ -566,8 +566,7 @@ def latest_monitoring_constraint_violations( return latest_monitoring_execution.constraint_violations(file_name=file_name) def describe_latest_baselining_job(self): - """Describe the latest baselining job kicked off by the suggest workflow. - """ + """Describe the latest baselining job kicked off by the suggest workflow.""" if self.latest_baselining_job is None: raise ValueError("No suggestion jobs were kicked off.") return self.latest_baselining_job.describe() diff --git a/src/sagemaker/model_monitor/monitoring_files.py b/src/sagemaker/model_monitor/monitoring_files.py index 71d02b2149..ae9038cff0 100644 --- a/src/sagemaker/model_monitor/monitoring_files.py +++ b/src/sagemaker/model_monitor/monitoring_files.py @@ -29,8 +29,7 @@ class ModelMonitoringFile(object): - """Represents a file with a body and an S3 uri. - """ + """Represents a file with a body and an S3 uri.""" def __init__(self, body_dict, file_s3_uri, kms_key, sagemaker_session): """Initializes a file with a body and an S3 uri. @@ -76,8 +75,7 @@ def save(self, new_save_location_s3_uri=None): class Statistics(ModelMonitoringFile): - """Represents the statistics JSON file used in Amazon SageMaker Model Monitoring. - """ + """Represents the statistics JSON file used in Amazon SageMaker Model Monitoring.""" def __init__(self, body_dict, statistics_file_s3_uri, kms_key=None, sagemaker_session=None): """Initializes the Statistics object used in Amazon SageMaker Model Monitoring. @@ -202,8 +200,7 @@ def from_file_path(cls, statistics_file_path, kms_key=None, sagemaker_session=No class Constraints(ModelMonitoringFile): - """Represents the constraints JSON file used in Amazon SageMaker Model Monitoring. - """ + """Represents the constraints JSON file used in Amazon SageMaker Model Monitoring.""" def __init__(self, body_dict, constraints_file_s3_uri, kms_key=None, sagemaker_session=None): """Initializes the Constraints object used in Amazon SageMaker Model Monitoring. @@ -354,8 +351,7 @@ def set_monitoring(self, enable_monitoring, feature_name=None): class ConstraintViolations(ModelMonitoringFile): - """Represents the constraint violations JSON file used in Amazon SageMaker Model Monitoring. - """ + """Represents the constraint violations JSON file used in Amazon SageMaker Model Monitoring.""" def __init__( self, body_dict, constraint_violations_file_s3_uri, kms_key=None, sagemaker_session=None diff --git a/src/sagemaker/parameter.py b/src/sagemaker/parameter.py index 9ed7cd2731..ac7e059413 100644 --- a/src/sagemaker/parameter.py +++ b/src/sagemaker/parameter.py @@ -161,9 +161,9 @@ def cast_to_type(cls, value): class IntegerParameter(ParameterRange): """A class for representing hyperparameters that have an integer range of possible values. - Args: - min_value (int): The minimum value for the range. - max_value (int): The maximum value for the range. + Args: + min_value (int): The minimum value for the range. + max_value (int): The maximum value for the range. """ __name__ = "Integer" diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 05ece950b6..8a54881237 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -3030,7 +3030,9 @@ def endpoint_from_model_data( lambda: self.sagemaker_client.describe_model(ModelName=name) ): primary_container = container_def( - image_uri=image_uri, model_data_url=model_s3_location, env=model_environment_vars, + image_uri=image_uri, + model_data_url=model_s3_location, + env=model_environment_vars, ) self.create_model( name=name, role=role, container_defs=primary_container, vpc_config=model_vpc_config @@ -3441,9 +3443,7 @@ def logs_for_transform_job(self, job_name, wait=False, poll=10): print() -def container_def( - image_uri, model_data_url=None, env=None, container_mode=None, image_config_dict=None -): +def container_def(image_uri, model_data_url=None, env=None, container_mode=None, image_config=None): """Create a definition for executing a container as part of a SageMaker model. Args: @@ -3455,7 +3455,7 @@ def container_def( * MultiModel: Indicates that model container can support hosting multiple models * SingleModel: Indicates that model container can support hosting a single model This is the default model container mode when container_mode = None - image_config_dict (dict): Specifies whether the image of model container is pulled from ECR, + image_config (dict): Specifies whether the image of model container is pulled from ECR, or private registry in your VPC. By default it is set to pull model container image from ECR. (default: None). Returns: @@ -3469,8 +3469,8 @@ def container_def( c_def["ModelDataUrl"] = model_data_url if container_mode: c_def["Mode"] = container_mode - if image_config_dict: - c_def["ImageConfig"] = image_config_dict + if image_config: + c_def["ImageConfig"] = image_config return c_def diff --git a/src/sagemaker/tensorflow/model.py b/src/sagemaker/tensorflow/model.py index 751cdb6c84..f239df41f1 100644 --- a/src/sagemaker/tensorflow/model.py +++ b/src/sagemaker/tensorflow/model.py @@ -58,7 +58,10 @@ def __init__( version of the model will be used. """ super(TensorFlowPredictor, self).__init__( - endpoint_name, sagemaker_session, serializer, deserializer, + endpoint_name, + sagemaker_session, + serializer, + deserializer, ) attributes = [] diff --git a/src/sagemaker/transformer.py b/src/sagemaker/transformer.py index a52bc02167..d712437d01 100644 --- a/src/sagemaker/transformer.py +++ b/src/sagemaker/transformer.py @@ -262,8 +262,7 @@ def wait(self, logs=True): self.latest_transform_job.wait(logs=logs) def stop_transform_job(self, wait=True): - """Stop latest running batch transform job. - """ + """Stop latest running batch transform job.""" self._ensure_last_transform_job() self.latest_transform_job.stop() if wait: diff --git a/src/sagemaker/tuner.py b/src/sagemaker/tuner.py index c2864093da..6859f10680 100644 --- a/src/sagemaker/tuner.py +++ b/src/sagemaker/tuner.py @@ -1021,8 +1021,7 @@ def hyperparameter_ranges(self): ) def hyperparameter_ranges_dict(self): - """Return a dictionary of hyperparameter ranges for all estimators in ``estimator_dict`` - """ + """Return a dictionary of hyperparameter ranges for all estimators in ``estimator_dict``""" if self._hyperparameter_ranges_dict is None: return None diff --git a/tests/data/sagemaker_rl/coach_launcher.py b/tests/data/sagemaker_rl/coach_launcher.py index 4d69e582e1..2e3d624f7f 100644 --- a/tests/data/sagemaker_rl/coach_launcher.py +++ b/tests/data/sagemaker_rl/coach_launcher.py @@ -33,7 +33,7 @@ class CoachConfigurationList(ConfigurationList): - """Helper Object for converting CLI arguments (or SageMaker hyperparameters) + """Helper Object for converting CLI arguments (or SageMaker hyperparameters) into Coach configuration. """ @@ -270,7 +270,7 @@ def _save_onnx_model(self): @classmethod def train_main(cls): - """Entrypoint for training. + """Entrypoint for training. Parses command-line arguments and starts training. """ trainer = cls() @@ -303,8 +303,7 @@ def define_environment(self): ) def get_graph_manager_from_args(self, args): - """Returns the GraphManager object for coach to use to train by calling improve() - """ + """Returns the GraphManager object for coach to use to train by calling improve()""" # NOTE: TaskParameters are not configurable at this time. # Visualization diff --git a/tests/data/sagemaker_rl/configuration_list.py b/tests/data/sagemaker_rl/configuration_list.py index 4728ba7b60..e4e888de56 100644 --- a/tests/data/sagemaker_rl/configuration_list.py +++ b/tests/data/sagemaker_rl/configuration_list.py @@ -2,20 +2,19 @@ class ConfigurationList(object): - """Helper Object for converting CLI arguments (or SageMaker hyperparameters) + """Helper Object for converting CLI arguments (or SageMaker hyperparameters) into Coach configuration. """ def __init__(self): """Args: - - arg_list [list]: list of arguments on the command-line like [key1, value1, key2, value2, ...] - - prefix [str]: Prefix for every key that must be present, e.g. "--" for common command-line args + - arg_list [list]: list of arguments on the command-line like [key1, value1, key2, value2, ...] + - prefix [str]: Prefix for every key that must be present, e.g. "--" for common command-line args """ self.hp_dict = {} def store(self, name, value): - """Store a key/value hyperparameter combination - """ + """Store a key/value hyperparameter combination""" self.hp_dict[name] = value def apply_subset(self, config_object, prefix): @@ -41,8 +40,7 @@ def apply_subset(self, config_object, prefix): del self.hp_dict[key] def _set_rl_property_value(self, obj, key, val, path=""): - """Sets a property on obj to val, or to a sub-object within obj if key looks like "foo.bar" - """ + """Sets a property on obj to val, or to a sub-object within obj if key looks like "foo.bar" """ if key.find(".") >= 0: top_key, sub_keys = key_list = key.split(".", 1) if top_key.startswith("__"): @@ -63,8 +61,7 @@ def _set_rl_property_value(self, obj, key, val, path=""): obj.__dict__[key] = val def _autotype(self, val): - """Converts string to an int or float as possible. - """ + """Converts string to an int or float as possible.""" try: return int(val) except ValueError: @@ -83,7 +80,7 @@ def _parse_type(self, key, val): Automatically detects ints and floats when possible. If the key takes the form "foo:bar" then it looks in ALLOWED_TYPES for an entry of bar, and instantiates one of those objects, passing - val to the constructor. So if key="foo:EnvironmentSteps" then + val to the constructor. So if key="foo:EnvironmentSteps" then """ val = self._autotype(val) if key.find(":") > 0: diff --git a/tests/integ/test_airflow_config.py b/tests/integ/test_airflow_config.py index 64a4861a48..3a5761cdef 100644 --- a/tests/integ/test_airflow_config.py +++ b/tests/integ/test_airflow_config.py @@ -478,7 +478,10 @@ def test_mxnet_airflow_config_uploads_data_source_to_s3( @pytest.mark.canary_quick def test_sklearn_airflow_config_uploads_data_source_to_s3( - sagemaker_session, cpu_instance_type, sklearn_latest_version, sklearn_latest_py_version, + sagemaker_session, + cpu_instance_type, + sklearn_latest_version, + sklearn_latest_py_version, ): with timeout(seconds=AIRFLOW_CONFIG_TIMEOUT_IN_SECONDS): script_path = os.path.join(DATA_DIR, "sklearn_mnist", "mnist.py") @@ -515,7 +518,10 @@ def test_sklearn_airflow_config_uploads_data_source_to_s3( @pytest.mark.canary_quick def test_tf_airflow_config_uploads_data_source_to_s3( - sagemaker_session, cpu_instance_type, tf_full_version, tf_full_py_version, + sagemaker_session, + cpu_instance_type, + tf_full_version, + tf_full_py_version, ): with timeout(seconds=AIRFLOW_CONFIG_TIMEOUT_IN_SECONDS): tf = TensorFlow( diff --git a/tests/integ/test_multidatamodel.py b/tests/integ/test_multidatamodel.py index 75289f20d3..d5e53c408a 100644 --- a/tests/integ/test_multidatamodel.py +++ b/tests/integ/test_multidatamodel.py @@ -117,8 +117,7 @@ def _delete_repository(ecr_client, repository_name): def _ecr_login(ecr_client): - """ Get a login credentials for an ecr client. - """ + """Get a login credentials for an ecr client.""" login = ecr_client.get_authorization_token() b64token = login["authorizationData"][0]["authorizationToken"].encode("utf-8") username, password = base64.b64decode(b64token).decode("utf-8").split(":") diff --git a/tests/integ/test_mxnet.py b/tests/integ/test_mxnet.py index a90a183b49..6a143d0ebb 100644 --- a/tests/integ/test_mxnet.py +++ b/tests/integ/test_mxnet.py @@ -83,7 +83,10 @@ def test_attach_deploy(mxnet_training_job, sagemaker_session, cpu_instance_type) def test_deploy_estimator_with_different_instance_types( - mxnet_training_job, sagemaker_session, cpu_instance_type, alternative_cpu_instance_type, + mxnet_training_job, + sagemaker_session, + cpu_instance_type, + alternative_cpu_instance_type, ): def _deploy_estimator_and_assert_instance_type(estimator, instance_type): # don't use timeout_and_delete_endpoint_by_name because this tests if diff --git a/tests/integ/test_processing.py b/tests/integ/test_processing.py index eda8e3445d..246818bb85 100644 --- a/tests/integ/test_processing.py +++ b/tests/integ/test_processing.py @@ -59,7 +59,10 @@ def sagemaker_session_with_custom_bucket( @pytest.fixture(scope="module") def image_uri( - sklearn_latest_version, sklearn_latest_py_version, cpu_instance_type, sagemaker_session, + sklearn_latest_version, + sklearn_latest_py_version, + cpu_instance_type, + sagemaker_session, ): return image_uris.retrieve( "sklearn", diff --git a/tests/integ/test_pytorch.py b/tests/integ/test_pytorch.py index 0309d6eec6..ac468cdd6f 100644 --- a/tests/integ/test_pytorch.py +++ b/tests/integ/test_pytorch.py @@ -182,7 +182,10 @@ def test_deploy_packed_model_with_entry_point_name( test_region() not in EI_SUPPORTED_REGIONS, reason="EI isn't supported in that specific region." ) def test_deploy_model_with_accelerator( - sagemaker_session, cpu_instance_type, pytorch_eia_latest_version, pytorch_eia_latest_py_version, + sagemaker_session, + cpu_instance_type, + pytorch_eia_latest_version, + pytorch_eia_latest_py_version, ): endpoint_name = "test-pytorch-deploy-eia-{}".format(sagemaker_timestamp()) model_data = sagemaker_session.upload_data(path=EIA_MODEL) diff --git a/tests/integ/test_sklearn.py b/tests/integ/test_sklearn.py index a98201fd19..b5d992296b 100644 --- a/tests/integ/test_sklearn.py +++ b/tests/integ/test_sklearn.py @@ -31,16 +31,25 @@ "This test should be fixed. Details in https://github.com/aws/sagemaker-python-sdk/pull/968" ) def sklearn_training_job( - sagemaker_session, sklearn_latest_version, sklearn_latest_py_version, cpu_instance_type, + sagemaker_session, + sklearn_latest_version, + sklearn_latest_py_version, + cpu_instance_type, ): return _run_mnist_training_job( - sagemaker_session, cpu_instance_type, sklearn_latest_version, sklearn_latest_py_version, + sagemaker_session, + cpu_instance_type, + sklearn_latest_version, + sklearn_latest_py_version, ) sagemaker_session.boto_region_name def test_training_with_additional_hyperparameters( - sagemaker_session, sklearn_latest_version, sklearn_latest_py_version, cpu_instance_type, + sagemaker_session, + sklearn_latest_version, + sklearn_latest_py_version, + cpu_instance_type, ): with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): script_path = os.path.join(DATA_DIR, "sklearn_mnist", "mnist.py") @@ -68,7 +77,10 @@ def test_training_with_additional_hyperparameters( def test_training_with_network_isolation( - sagemaker_session, sklearn_latest_version, sklearn_latest_py_version, cpu_instance_type, + sagemaker_session, + sklearn_latest_version, + sklearn_latest_py_version, + cpu_instance_type, ): with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): script_path = os.path.join(DATA_DIR, "sklearn_mnist", "mnist.py") @@ -147,7 +159,10 @@ def test_deploy_model( "This test should be fixed. Details in https://github.com/aws/sagemaker-python-sdk/pull/968" ) def test_async_fit( - sagemaker_session, cpu_instance_type, sklearn_latest_version, sklearn_latest_py_version, + sagemaker_session, + cpu_instance_type, + sklearn_latest_version, + sklearn_latest_py_version, ): endpoint_name = "test-sklearn-attach-deploy-{}".format(sagemaker_timestamp()) @@ -172,7 +187,10 @@ def test_async_fit( def test_failed_training_job( - sagemaker_session, sklearn_latest_version, sklearn_latest_py_version, cpu_instance_type, + sagemaker_session, + sklearn_latest_version, + sklearn_latest_py_version, + cpu_instance_type, ): with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): script_path = os.path.join(DATA_DIR, "sklearn_mnist", "failure_script.py") diff --git a/tests/integ/test_tfs.py b/tests/integ/test_tfs.py index 242eb101ff..f76c734383 100644 --- a/tests/integ/test_tfs.py +++ b/tests/integ/test_tfs.py @@ -238,7 +238,9 @@ def test_predict_csv(tfs_predictor): expected_result = {"predictions": [[3.5, 4.0, 5.5], [3.5, 4.0, 5.5]]} predictor = TensorFlowPredictor( - tfs_predictor.endpoint_name, tfs_predictor.sagemaker_session, serializer=CSVSerializer(), + tfs_predictor.endpoint_name, + tfs_predictor.sagemaker_session, + serializer=CSVSerializer(), ) result = predictor.predict(input_data) diff --git a/tests/integ/test_tuner.py b/tests/integ/test_tuner.py index dc23a199fd..02dd131e92 100644 --- a/tests/integ/test_tuner.py +++ b/tests/integ/test_tuner.py @@ -213,7 +213,7 @@ def test_tuning_kmeans_identical_dataset_algorithm_tuner( sagemaker_session, kmeans_train_set, kmeans_estimator, hyperparameter_ranges ): """Tests Identical dataset and algorithm use case with one parent and child job launched with - .identical_dataset_and_algorithm_tuner() """ + .identical_dataset_and_algorithm_tuner()""" parent_tuning_job_name = unique_name_from_base("km-iden1-parent", max_length=32) child_tuning_job_name = unique_name_from_base("km-iden1-child", max_length=32) @@ -249,7 +249,7 @@ def test_create_tuning_kmeans_identical_dataset_algorithm_tuner( sagemaker_session, kmeans_train_set, kmeans_estimator, hyperparameter_ranges ): """Tests Identical dataset and algorithm use case with one parent and child job launched with - .create_identical_dataset_and_algorithm_tuner() """ + .create_identical_dataset_and_algorithm_tuner()""" parent_tuning_job_name = unique_name_from_base("km-iden2-parent", max_length=32) child_tuning_job_name = unique_name_from_base("km-iden2-child", max_length=32) @@ -290,7 +290,7 @@ def test_transfer_learning_tuner( sagemaker_session, kmeans_train_set, kmeans_estimator, hyperparameter_ranges ): """Tests Transfer learning use case with one parent and child job launched with - .transfer_learning_tuner() """ + .transfer_learning_tuner()""" parent_tuning_job_name = unique_name_from_base("km-tran1-parent", max_length=32) child_tuning_job_name = unique_name_from_base("km-tran1-child", max_length=32) @@ -328,7 +328,7 @@ def test_create_transfer_learning_tuner( sagemaker_session, kmeans_train_set, kmeans_estimator, hyperparameter_ranges ): """Tests Transfer learning use case with two parents and child job launched with - create_transfer_learning_tuner() """ + create_transfer_learning_tuner()""" parent_tuning_job_name_1 = unique_name_from_base("km-tran2-parent1", max_length=32) parent_tuning_job_name_2 = unique_name_from_base("km-tran2-parent2", max_length=32) child_tuning_job_name = unique_name_from_base("km-tran2-child", max_length=32) diff --git a/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_framework_version.py b/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_framework_version.py index ed74c00034..f35a50b929 100644 --- a/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_framework_version.py +++ b/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_framework_version.py @@ -28,7 +28,11 @@ class Template: """ def __init__( - self, framework, framework_version, py_version, py_version_for_model=True, + self, + framework, + framework_version, + py_version, + py_version_for_model=True, ): self.framework = framework self.framework_version = framework_version @@ -97,9 +101,21 @@ def _format_templates(keywords, templates): py_version="py2", py_version_for_model=False, ), - Template(framework="MXNet", framework_version="1.2.0", py_version="py2",), - Template(framework="Chainer", framework_version="4.1.0", py_version="py3",), - Template(framework="PyTorch", framework_version="0.4.0", py_version="py3",), + Template( + framework="MXNet", + framework_version="1.2.0", + py_version="py2", + ), + Template( + framework="Chainer", + framework_version="4.1.0", + py_version="py3", + ), + Template( + framework="PyTorch", + framework_version="0.4.0", + py_version="py3", + ), Template( framework="SKLearn", framework_version="0.20.0", diff --git a/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_serde.py b/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_serde.py index c71468bf41..4f52e0d1f1 100644 --- a/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_serde.py +++ b/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_serde.py @@ -64,8 +64,14 @@ def test_constructor_node_should_be_modified(src, expected): ("sagemaker.predictor._NpySerializer()", "serializers.NumpySerializer()"), ("sagemaker.predictor._CsvDeserializer()", "deserializers.CSVDeserializer()"), ("sagemaker.predictor.BytesDeserializer()", "deserializers.BytesDeserializer()"), - ("sagemaker.predictor.StringDeserializer()", "deserializers.StringDeserializer()",), - ("sagemaker.predictor.StreamDeserializer()", "deserializers.StreamDeserializer()",), + ( + "sagemaker.predictor.StringDeserializer()", + "deserializers.StringDeserializer()", + ), + ( + "sagemaker.predictor.StreamDeserializer()", + "deserializers.StreamDeserializer()", + ), ("sagemaker.predictor._NumpyDeserializer()", "deserializers.NumpyDeserializer()"), ("sagemaker.predictor._JsonDeserializer()", "deserializers.JSONDeserializer()"), ( @@ -100,18 +106,54 @@ def test_constructor_modify_node(src, expected): @pytest.mark.parametrize( "src, expected", [ - ("sagemaker.predictor.csv_serializer", True,), - ("sagemaker.predictor.json_serializer", True,), - ("sagemaker.predictor.npy_serializer", True,), - ("sagemaker.predictor.csv_deserializer", True,), - ("sagemaker.predictor.json_deserializer", True,), - ("sagemaker.predictor.numpy_deserializer", True,), - ("csv_serializer", True,), - ("json_serializer", True,), - ("npy_serializer", True,), - ("csv_deserializer", True,), - ("json_deserializer", True,), - ("numpy_deserializer", True,), + ( + "sagemaker.predictor.csv_serializer", + True, + ), + ( + "sagemaker.predictor.json_serializer", + True, + ), + ( + "sagemaker.predictor.npy_serializer", + True, + ), + ( + "sagemaker.predictor.csv_deserializer", + True, + ), + ( + "sagemaker.predictor.json_deserializer", + True, + ), + ( + "sagemaker.predictor.numpy_deserializer", + True, + ), + ( + "csv_serializer", + True, + ), + ( + "json_serializer", + True, + ), + ( + "npy_serializer", + True, + ), + ( + "csv_deserializer", + True, + ), + ( + "json_deserializer", + True, + ), + ( + "numpy_deserializer", + True, + ), ], ) def test_name_node_should_be_modified(src, expected): diff --git a/tests/unit/sagemaker/image_uris/test_dlc_frameworks.py b/tests/unit/sagemaker/image_uris/test_dlc_frameworks.py index 71f0822ef3..d5558c31aa 100644 --- a/tests/unit/sagemaker/image_uris/test_dlc_frameworks.py +++ b/tests/unit/sagemaker/image_uris/test_dlc_frameworks.py @@ -170,7 +170,12 @@ def _expected_tf_inference_uri(tf_inference_version, processor="cpu", region=REG account = _sagemaker_or_dlc_account(repo, region) return expected_uris.framework_uri( - repo, tf_inference_version, account, py_version, processor=processor, region=region, + repo, + tf_inference_version, + account, + py_version, + processor=processor, + region=region, ) diff --git a/tests/unit/sagemaker/image_uris/test_sklearn.py b/tests/unit/sagemaker/image_uris/test_sklearn.py index 97fb5b8a1e..28def49869 100644 --- a/tests/unit/sagemaker/image_uris/test_sklearn.py +++ b/tests/unit/sagemaker/image_uris/test_sklearn.py @@ -81,7 +81,10 @@ def test_py2_error(sklearn_version): def test_gpu_error(sklearn_version): with pytest.raises(ValueError) as e: image_uris.retrieve( - "sklearn", region="us-west-2", version=sklearn_version, instance_type="ml.p2.xlarge", + "sklearn", + region="us-west-2", + version=sklearn_version, + instance_type="ml.p2.xlarge", ) assert "Unsupported processor: gpu." in str(e.value) diff --git a/tests/unit/sagemaker/image_uris/test_xgboost.py b/tests/unit/sagemaker/image_uris/test_xgboost.py index 698e4931af..99d9dced21 100644 --- a/tests/unit/sagemaker/image_uris/test_xgboost.py +++ b/tests/unit/sagemaker/image_uris/test_xgboost.py @@ -78,7 +78,10 @@ def test_xgboost_framework(xgboost_framework_version): for region in regions.regions(): uri = image_uris.retrieve( - framework="xgboost", region=region, version=xgboost_framework_version, py_version="py3", + framework="xgboost", + region=region, + version=xgboost_framework_version, + py_version="py3", ) expected = expected_uris.framework_uri( diff --git a/tests/unit/sagemaker/model/test_deploy.py b/tests/unit/sagemaker/model/test_deploy.py index fb5aa1e750..717a697231 100644 --- a/tests/unit/sagemaker/model/test_deploy.py +++ b/tests/unit/sagemaker/model/test_deploy.py @@ -149,10 +149,12 @@ def test_deploy_generates_endpoint_name_each_time_from_model_name( ) model.deploy( - instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + initial_instance_count=INSTANCE_COUNT, ) model.deploy( - instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + initial_instance_count=INSTANCE_COUNT, ) base_from_name.assert_called_with(MODEL_NAME) @@ -173,10 +175,12 @@ def test_deploy_generates_endpoint_name_each_time_from_base_name( model._base_name = base_name model.deploy( - instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + initial_instance_count=INSTANCE_COUNT, ) model.deploy( - instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + initial_instance_count=INSTANCE_COUNT, ) base_from_name.assert_not_called() diff --git a/tests/unit/sagemaker/model/test_model.py b/tests/unit/sagemaker/model/test_model.py index 8f306312a1..98b5e1e35f 100644 --- a/tests/unit/sagemaker/model/test_model.py +++ b/tests/unit/sagemaker/model/test_model.py @@ -17,7 +17,6 @@ import sagemaker from sagemaker.model import Model -from sagemaker.image_config import ImageConfig MODEL_DATA = "s3://bucket/model.tar.gz" MODEL_IMAGE = "mi" @@ -27,7 +26,6 @@ INSTANCE_COUNT = 2 INSTANCE_TYPE = "ml.c4.4xlarge" ROLE = "some-role" -REPOSITORY_ACCESS_MODE = "Vpc" @pytest.fixture @@ -57,11 +55,14 @@ def test_prepare_container_def_with_model_data_and_env(): def test_prepare_container_def_with_image_config(): - image_config = ImageConfig(repository_access_mode=REPOSITORY_ACCESS_MODE) + image_config = {"RepositoryAccessMode": "Vpc"} model = Model(MODEL_IMAGE, image_config=image_config) - expected_image_config_dict = {"RepositoryAccessMode": "Vpc"} - expected = {"Image": MODEL_IMAGE, "ImageConfig": expected_image_config_dict, "Environment": {}} + expected = { + "Image": MODEL_IMAGE, + "ImageConfig": {"RepositoryAccessMode": "Vpc"}, + "Environment": {}, + } container_def = model.prepare_container_def() assert expected == container_def @@ -166,14 +167,23 @@ def test_create_sagemaker_model_generates_model_name( container_def = {"Image": MODEL_IMAGE, "Environment": {}, "ModelDataUrl": MODEL_DATA} prepare_container_def.return_value = container_def - model = Model(MODEL_IMAGE, MODEL_DATA, sagemaker_session=sagemaker_session,) + model = Model( + MODEL_IMAGE, + MODEL_DATA, + sagemaker_session=sagemaker_session, + ) model._create_sagemaker_model(INSTANCE_TYPE) base_name_from_image.assert_called_with(MODEL_IMAGE) name_from_base.assert_called_with(base_name_from_image.return_value) sagemaker_session.create_model.assert_called_with( - MODEL_NAME, None, container_def, vpc_config=None, enable_network_isolation=False, tags=None, + MODEL_NAME, + None, + container_def, + vpc_config=None, + enable_network_isolation=False, + tags=None, ) @@ -186,7 +196,11 @@ def test_create_sagemaker_model_generates_model_name_each_time( container_def = {"Image": MODEL_IMAGE, "Environment": {}, "ModelDataUrl": MODEL_DATA} prepare_container_def.return_value = container_def - model = Model(MODEL_IMAGE, MODEL_DATA, sagemaker_session=sagemaker_session,) + model = Model( + MODEL_IMAGE, + MODEL_DATA, + sagemaker_session=sagemaker_session, + ) model._create_sagemaker_model(INSTANCE_TYPE) model._create_sagemaker_model(INSTANCE_TYPE) diff --git a/tests/unit/test_image_config.py b/tests/unit/test_image_config.py deleted file mode 100644 index e80ed6246e..0000000000 --- a/tests/unit/test_image_config.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -from __future__ import absolute_import - -from sagemaker.image_config import ImageConfig - -REPOSITORY_ACCESS_MODE_PLATFORM = "Platform" -REPOSITORY_ACCESS_MODE_VPC = "Vpc" - - -def test_init_with_defaults(): - image_config = ImageConfig() - - assert image_config.repository_access_mode == REPOSITORY_ACCESS_MODE_PLATFORM - - -def test_init_with_non_defaults(): - image_config = ImageConfig(repository_access_mode=REPOSITORY_ACCESS_MODE_VPC) - - assert image_config.repository_access_mode == REPOSITORY_ACCESS_MODE_VPC diff --git a/tests/unit/test_linear_learner.py b/tests/unit/test_linear_learner.py index 4754f6c845..cdf6641758 100644 --- a/tests/unit/test_linear_learner.py +++ b/tests/unit/test_linear_learner.py @@ -67,7 +67,11 @@ def sagemaker_session(): def test_init_required_positional(sagemaker_session): lr = LinearLearner( - ROLE, INSTANCE_COUNT, INSTANCE_TYPE, PREDICTOR_TYPE, sagemaker_session=sagemaker_session, + ROLE, + INSTANCE_COUNT, + INSTANCE_TYPE, + PREDICTOR_TYPE, + sagemaker_session=sagemaker_session, ) assert lr.role == ROLE assert lr.instance_count == INSTANCE_COUNT diff --git a/tests/unit/test_ntm.py b/tests/unit/test_ntm.py index 6a3aeeabbe..3805d640a9 100644 --- a/tests/unit/test_ntm.py +++ b/tests/unit/test_ntm.py @@ -65,7 +65,13 @@ def sagemaker_session(): def test_init_required_positional(sagemaker_session): - ntm = NTM(ROLE, INSTANCE_COUNT, INSTANCE_TYPE, NUM_TOPICS, sagemaker_session=sagemaker_session,) + ntm = NTM( + ROLE, + INSTANCE_COUNT, + INSTANCE_TYPE, + NUM_TOPICS, + sagemaker_session=sagemaker_session, + ) assert ntm.role == ROLE assert ntm.instance_count == INSTANCE_COUNT assert ntm.instance_type == INSTANCE_TYPE diff --git a/tests/unit/test_pca.py b/tests/unit/test_pca.py index d6295b5554..4861b57c5a 100644 --- a/tests/unit/test_pca.py +++ b/tests/unit/test_pca.py @@ -66,7 +66,11 @@ def sagemaker_session(): def test_init_required_positional(sagemaker_session): pca = PCA( - ROLE, INSTANCE_COUNT, INSTANCE_TYPE, NUM_COMPONENTS, sagemaker_session=sagemaker_session, + ROLE, + INSTANCE_COUNT, + INSTANCE_TYPE, + NUM_COMPONENTS, + sagemaker_session=sagemaker_session, ) assert pca.role == ROLE assert pca.instance_count == INSTANCE_COUNT diff --git a/tests/unit/test_predictor.py b/tests/unit/test_predictor.py index 81762e1cb4..5667a3b454 100644 --- a/tests/unit/test_predictor.py +++ b/tests/unit/test_predictor.py @@ -323,7 +323,8 @@ def test_update_endpoint_instance_type_and_count(name_from_base, production_vari new_instance_type = "ml.c4.xlarge" predictor.update_endpoint( - initial_instance_count=new_instance_count, instance_type=new_instance_type, + initial_instance_count=new_instance_count, + instance_type=new_instance_type, ) assert [existing_model_name] == predictor._model_names diff --git a/tests/unit/test_pytorch.py b/tests/unit/test_pytorch.py index 25c6e0ea11..1d33fb25dc 100644 --- a/tests/unit/test_pytorch.py +++ b/tests/unit/test_pytorch.py @@ -410,7 +410,11 @@ def test_model_image_accelerator(sagemaker_session): def test_model_prepare_container_def_no_instance_type_or_image(): model = PyTorchModel( - MODEL_DATA, role=ROLE, entry_point=SCRIPT_PATH, framework_version="1.3.1", py_version="py3", + MODEL_DATA, + role=ROLE, + entry_point=SCRIPT_PATH, + framework_version="1.3.1", + py_version="py3", ) with pytest.raises(ValueError) as e: diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 1842f6a1aa..8eaccff988 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -1710,7 +1710,10 @@ def test_create_model_from_job_with_container_def(sagemaker_session): ims = sagemaker_session ims.sagemaker_client.describe_training_job.return_value = COMPLETED_DESCRIBE_JOB_RESULT ims.create_model_from_job( - JOB_NAME, image_uri="some-image", model_data_url="some-data", env={"a": "b"}, + JOB_NAME, + image_uri="some-image", + model_data_url="some-data", + env={"a": "b"}, ) [create_model_call] = ims.sagemaker_client.create_model.call_args_list c_def = create_model_call[1]["PrimaryContainer"] diff --git a/tests/unit/test_tuner.py b/tests/unit/test_tuner.py index 0e61ec70eb..5fcba01ffe 100644 --- a/tests/unit/test_tuner.py +++ b/tests/unit/test_tuner.py @@ -80,7 +80,11 @@ def test_prepare_for_training(tuner): def test_prepare_for_tuning_with_amazon_estimator(tuner, sagemaker_session): tuner.estimator = PCA( - ROLE, INSTANCE_COUNT, INSTANCE_TYPE, NUM_COMPONENTS, sagemaker_session=sagemaker_session, + ROLE, + INSTANCE_COUNT, + INSTANCE_TYPE, + NUM_COMPONENTS, + sagemaker_session=sagemaker_session, ) tuner._prepare_for_tuning() @@ -1434,7 +1438,7 @@ def test_describe(tuner): def _convert_tuning_job_details(job_details, estimator_name): """Convert a tuning job description using the 'TrainingJobDefinition' field into a new one using a single-item - 'TrainingJobDefinitions' field (list). + 'TrainingJobDefinitions' field (list). """ assert "TrainingJobDefinition" in job_details diff --git a/tests/unit/tuner_test_utils.py b/tests/unit/tuner_test_utils.py index f5dd138fd0..23a28e1063 100644 --- a/tests/unit/tuner_test_utils.py +++ b/tests/unit/tuner_test_utils.py @@ -80,7 +80,11 @@ sagemaker_session=SAGEMAKER_SESSION, ) ESTIMATOR_TWO = PCA( - ROLE, INSTANCE_COUNT, INSTANCE_TYPE, NUM_COMPONENTS, sagemaker_session=SAGEMAKER_SESSION, + ROLE, + INSTANCE_COUNT, + INSTANCE_TYPE, + NUM_COMPONENTS, + sagemaker_session=SAGEMAKER_SESSION, ) WARM_START_CONFIG = WarmStartConfig( From 84f073e83f7cdc16eb2bbc4c08ca68567eb7ec01 Mon Sep 17 00:00:00 2001 From: ChoiByungWook Date: Tue, 2 Mar 2021 19:32:23 -0800 Subject: [PATCH 4/5] add newline in between return and input doc --- src/sagemaker/session.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index d2d92c9ebc..835ede3777 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -4030,9 +4030,10 @@ def container_def(image_uri, model_data_url=None, env=None, container_mode=None, * MultiModel: Indicates that model container can support hosting multiple models * SingleModel: Indicates that model container can support hosting a single model This is the default model container mode when container_mode = None - image_config (dict): Specifies whether the image of model container is pulled from ECR, + image_config (dict[str, str]): Specifies whether the image of model container is pulled from ECR, or private registry in your VPC. By default it is set to pull model container image from ECR. (default: None). + Returns: dict[str, str]: A complete container definition object usable with the CreateModel API if passed via `PrimaryContainers` field. From 3e17050ffa904425f05168d676c8f351b4a6ec73 Mon Sep 17 00:00:00 2001 From: ChoiByungWook Date: Tue, 2 Mar 2021 19:50:58 -0800 Subject: [PATCH 5/5] split long doc string to newline --- src/sagemaker/model.py | 2 +- src/sagemaker/session.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index e59718a558..7e4f79a87e 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -91,7 +91,7 @@ def __init__( or from the model container. model_kms_key (str): KMS key ARN used to encrypt the repacked model archive file if the model is repacked - image_config (dict[str, object]): Specifies whether the image of + image_config (dict[str, str]): Specifies whether the image of model container is pulled from ECR, or private registry in your VPC. By default it is set to pull model container image from ECR. (default: None). diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 835ede3777..ddda685c99 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -4030,9 +4030,9 @@ def container_def(image_uri, model_data_url=None, env=None, container_mode=None, * MultiModel: Indicates that model container can support hosting multiple models * SingleModel: Indicates that model container can support hosting a single model This is the default model container mode when container_mode = None - image_config (dict[str, str]): Specifies whether the image of model container is pulled from ECR, - or private registry in your VPC. By default it is set to pull model container image - from ECR. (default: None). + image_config (dict[str, str]): Specifies whether the image of model container is pulled + from ECR, or private registry in your VPC. By default it is set to pull model + container image from ECR. (default: None). Returns: dict[str, str]: A complete container definition object usable with the CreateModel API if