From 747c239d182ffc8df0cda3ab486144fd32bbdb3d Mon Sep 17 00:00:00 2001 From: Madhubalasri Date: Tue, 15 Feb 2022 08:36:23 +0530 Subject: [PATCH 1/8] feature: adding customer metadata support to registermodel step --- src/sagemaker/estimator.py | 4 +++ src/sagemaker/model.py | 4 +++ src/sagemaker/mxnet/model.py | 4 +++ src/sagemaker/pytorch/model.py | 4 +++ src/sagemaker/session.py | 21 +++++++++++ src/sagemaker/tensorflow/model.py | 5 +++ src/sagemaker/workflow/_utils.py | 5 +++ src/sagemaker/workflow/step_collections.py | 7 +++- tests/integ/test_mxnet.py | 41 ++++++++++++++++++++++ tests/integ/test_workflow.py | 3 ++ tests/unit/test_session.py | 3 ++ 11 files changed, 100 insertions(+), 1 deletion(-) diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 6431ca8afc..fd74633584 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -1263,6 +1263,7 @@ def register( compile_model_family=None, model_name=None, drift_check_baselines=None, + customer_metadata_properties=None, **kwargs, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -1292,6 +1293,8 @@ def register( model will be used (default: None). model_name (str): User defined model name (default: None). drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None). + customer_metadata_properties (dict[str, str]): A dictionary of key-value paired + metadata properties (default: None). **kwargs: Passed to invocation of ``create_model()``. Implementations may customize ``create_model()`` to accept ``**kwargs`` to customize model creation during deploy. For more, see the implementation docs. @@ -1322,6 +1325,7 @@ def register( approval_status, description, drift_check_baselines=drift_check_baselines, + customer_metadata_properties=customer_metadata_properties, ) @property diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index ede78c7cce..47b4538bcc 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -304,6 +304,7 @@ def register( approval_status=None, description=None, drift_check_baselines=None, + customer_metadata_properties=None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -329,6 +330,8 @@ def register( or "PendingManualApproval" (default: "PendingManualApproval"). description (str): Model Package description (default: None). drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None). + customer_metadata_properties (dict[str, str]): A dictionary of key-value paired + metadata properties (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -356,6 +359,7 @@ def register( description=description, container_def_list=[container_def], drift_check_baselines=drift_check_baselines, + customer_metadata_properties = customer_metadata_properties, ) model_package = self.sagemaker_session.create_model_package_from_containers( **model_pkg_args diff --git a/src/sagemaker/mxnet/model.py b/src/sagemaker/mxnet/model.py index df0dd31a28..0a10cbf3c1 100644 --- a/src/sagemaker/mxnet/model.py +++ b/src/sagemaker/mxnet/model.py @@ -158,6 +158,7 @@ def register( approval_status=None, description=None, drift_check_baselines=None, + customer_metadata_properties=None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -183,6 +184,8 @@ def register( or "PendingManualApproval" (default: "PendingManualApproval"). description (str): Model Package description (default: None). drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None). + customer_metadata_properties (dict[str, str]): A dictionary of key-value paired + metadata properties (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -211,6 +214,7 @@ def register( approval_status, description, drift_check_baselines=drift_check_baselines, + customer_metadata_properties=customer_metadata_properties, ) def prepare_container_def(self, instance_type=None, accelerator_type=None): diff --git a/src/sagemaker/pytorch/model.py b/src/sagemaker/pytorch/model.py index 3a0c3a283c..0f51788626 100644 --- a/src/sagemaker/pytorch/model.py +++ b/src/sagemaker/pytorch/model.py @@ -157,6 +157,7 @@ def register( approval_status=None, description=None, drift_check_baselines=None, + customer_metadata_properties=None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -182,6 +183,8 @@ def register( or "PendingManualApproval" (default: "PendingManualApproval"). description (str): Model Package description (default: None). drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None). + customer_metadata_properties (dict[str, str]): A dictionary of key-value paired + metadata properties (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -210,6 +213,7 @@ def register( approval_status, description, drift_check_baselines=drift_check_baselines, + customer_metadata_properties=customer_metadata_properties, ) def prepare_container_def(self, instance_type=None, accelerator_type=None): diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 91b89ea4c9..6825953ae3 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -2778,6 +2778,7 @@ def create_model_package_from_containers( approval_status="PendingManualApproval", description=None, drift_check_baselines=None, + customer_metadata_properties=customer_metadata_properties, ): """Get request dictionary for CreateModelPackage API. @@ -2803,6 +2804,9 @@ def create_model_package_from_containers( or "PendingManualApproval" (default: "PendingManualApproval"). description (str): Model Package description (default: None). drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None). + customer_metadata_properties (dict[str, str]): A dictionary of key-value paired + metadata properties (default: None). + """ request = get_create_model_package_request( @@ -2819,7 +2823,14 @@ def create_model_package_from_containers( approval_status, description, drift_check_baselines=drift_check_baselines, + customer_metadata_properties=customer_metadata_properties, ) + try: + self.sagemaker_client.describe_model_package_group( + ModelPackageGroupName=request["ModelPackageGroupName"]) + except ClientError as e: + self.sagemaker_client.create_model_package_group( + ModelPackageGroupName=request["ModelPackageGroupName"]) return self.sagemaker_client.create_model_package(**request) def wait_for_model_package(self, model_package_name, poll=5): @@ -4120,6 +4131,7 @@ def get_model_package_args( tags=None, container_def_list=None, drift_check_baselines=None, + customer_metadata_properties=None, ): """Get arguments for create_model_package method. @@ -4148,6 +4160,8 @@ def get_model_package_args( (default: None). container_def_list (list): A list of container defintiions (default: None). drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None). + customer_metadata_properties (dict[str, str]): A dictionary of key-value paired + metadata properties (default: None). Returns: dict: A dictionary of method argument names and values. """ @@ -4185,6 +4199,8 @@ def get_model_package_args( model_package_args["description"] = description if tags is not None: model_package_args["tags"] = tags + if customer_metadata_properties is not None: + model_package_args["customer_metadata_properties"] = customer_metadata_properties return model_package_args @@ -4203,6 +4219,7 @@ def get_create_model_package_request( description=None, tags=None, drift_check_baselines=None, + customer_metadata_properties=None, ): """Get request dictionary for CreateModelPackage API. @@ -4229,6 +4246,8 @@ def get_create_model_package_request( tags (List[dict[str, str]]): A list of dictionaries containing key-value pairs (default: None). drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None). + customer_metadata_properties (dict[str, str]): A dictionary of key-value paired + metadata properties (default: None). """ if all([model_package_name, model_package_group_name]): @@ -4250,6 +4269,8 @@ def get_create_model_package_request( request_dict["DriftCheckBaselines"] = drift_check_baselines if metadata_properties: request_dict["MetadataProperties"] = metadata_properties + if customer_metadata_properties is not None: + request_dict["CustomerMetadataProperties"] = customer_metadata_properties if containers is not None: if not all([content_types, response_types, inference_instances, transform_instances]): raise ValueError( diff --git a/src/sagemaker/tensorflow/model.py b/src/sagemaker/tensorflow/model.py index 0b8d2f7235..9f6a7841d5 100644 --- a/src/sagemaker/tensorflow/model.py +++ b/src/sagemaker/tensorflow/model.py @@ -201,6 +201,7 @@ def register( approval_status=None, description=None, drift_check_baselines=None, + customer_metadata_properties=None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -226,6 +227,9 @@ def register( or "PendingManualApproval" (default: "PendingManualApproval"). description (str): Model Package description (default: None). drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None). + customer_metadata_properties (dict[str, str]): A dictionary of key-value paired + metadata properties (default: None). + Returns: A `sagemaker.model.ModelPackage` instance. @@ -254,6 +258,7 @@ def register( approval_status, description, drift_check_baselines=drift_check_baselines, + customer_metadata_properties=customer_metadata_properties, ) def deploy( diff --git a/src/sagemaker/workflow/_utils.py b/src/sagemaker/workflow/_utils.py index ca078fe7ea..f413b990da 100644 --- a/src/sagemaker/workflow/_utils.py +++ b/src/sagemaker/workflow/_utils.py @@ -310,6 +310,7 @@ def __init__( tags=None, container_def_list=None, drift_check_baselines=None, + customer_metadata_properties=None, **kwargs, ): """Constructor of a register model step. @@ -347,6 +348,8 @@ def __init__( this step depends on retry_policies (List[RetryPolicy]): The list of retry policies for the current step drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None). + customer_metadata_properties (dict[str, str]): A dictionary of key-value paired + metadata properties (default: None). **kwargs: additional arguments to `create_model`. """ super(_RegisterModelStep, self).__init__( @@ -362,6 +365,7 @@ def __init__( self.tags = tags self.model_metrics = model_metrics self.drift_check_baselines = drift_check_baselines + self.customer_metadata_properties = customer_metadata_properties, self.metadata_properties = metadata_properties self.approval_status = approval_status self.image_uri = image_uri @@ -435,6 +439,7 @@ def arguments(self) -> RequestType: description=self.description, tags=self.tags, container_def_list=self.container_def_list, + customer_metadata_properties=self.customer_metadata_properties ) request_dict = get_create_model_package_request(**model_package_args) diff --git a/src/sagemaker/workflow/step_collections.py b/src/sagemaker/workflow/step_collections.py index f4606488b2..27060d928e 100644 --- a/src/sagemaker/workflow/step_collections.py +++ b/src/sagemaker/workflow/step_collections.py @@ -75,6 +75,7 @@ def __init__( tags=None, model: Union[Model, PipelineModel] = None, drift_check_baselines=None, + customer_metadata_properties=None, **kwargs, ): """Construct steps `_RepackModelStep` and `_RegisterModelStep` based on the estimator. @@ -95,7 +96,7 @@ def __init__( for the repack model step register_model_step_retry_policies (List[RetryPolicy]): The list of retry policies for register model step - model_package_group_name (str): The Model Package Group name, exclusive to + model_package_group_name (str): The Model Package Group name or Arn, exclusive to `model_package_name`, using `model_package_group_name` makes the Model Package versioned (default: None). model_metrics (ModelMetrics): ModelMetrics object (default: None). @@ -113,6 +114,9 @@ def __init__( model (object or Model): A PipelineModel object that comprises a list of models which gets executed as a serial inference pipeline or a Model object. drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None). + customer_metadata_properties (dict[str, str]): A dictionary of key-value paired + metadata properties (default: None). + **kwargs: additional arguments to `create_model`. """ steps: List[Step] = [] @@ -229,6 +233,7 @@ def __init__( tags=tags, container_def_list=self.container_def_list, retry_policies=register_model_step_retry_policies, + customer_metadata_properties=customer_metadata_properties, **kwargs, ) if not repack_model: diff --git a/tests/integ/test_mxnet.py b/tests/integ/test_mxnet.py index d13108d471..e286f1481c 100644 --- a/tests/integ/test_mxnet.py +++ b/tests/integ/test_mxnet.py @@ -230,6 +230,47 @@ def test_register_model_package( assert result is not None sagemaker_session.sagemaker_client.delete_model_package(ModelPackageName=model_package_name) +def test_register_model_package_via_group( + mxnet_training_job, + sagemaker_session, + mxnet_inference_latest_version, + mxnet_inference_latest_py_version, + cpu_instance_type, +): + endpoint_name = "test-mxnet-deploy-model-{}".format(sagemaker_timestamp()) + with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session): + desc = sagemaker_session.sagemaker_client.describe_training_job( + TrainingJobName=mxnet_training_job + ) + model_data = desc["ModelArtifacts"]["S3ModelArtifacts"] + script_path = os.path.join(DATA_DIR, "mxnet_mnist", "mnist.py") + model = MXNetModel( + model_data, + "SageMakerRole", + entry_point=script_path, + py_version=mxnet_inference_latest_py_version, + sagemaker_session=sagemaker_session, + framework_version=mxnet_inference_latest_version, + ) + model_package_group_name = "register-model-package-{}".format(sagemaker_timestamp()) + model_pkg = model.register( + content_types=["application/json"], + response_types=["application/json"], + inference_instances=["ml.m5.large"], + transform_instances=["ml.m5.large"], + model_package_group_name=model_package_group_name, + ) + assert isinstance(model_pkg, ModelPackage) + predictor = model.deploy(1, cpu_instance_type, endpoint_name=endpoint_name) + data = numpy.zeros(shape=(1, 1, 28, 28)) + result = predictor.predict(data) + assert result is not None + model_packages = \ + sagemaker_session.sagemaker_client.list_model_packages(ModelPackageGroupName=model_package_group_name)[ + 'ModelPackageSummaryList'] + for model_package in model_packages: + sagemaker_session.sagemaker_client.delete_model_package(ModelPackageName=model_package['ModelPackageArn']) + sagemaker_session.sagemaker_client.delete_model_package_group(ModelPackageGroupName=model_package_group_name) def test_register_model_package_versioned( mxnet_training_job, diff --git a/tests/integ/test_workflow.py b/tests/integ/test_workflow.py index d2c142ee38..4cf4ebb244 100644 --- a/tests/integ/test_workflow.py +++ b/tests/integ/test_workflow.py @@ -1951,6 +1951,7 @@ def test_model_registration_with_drift_check_baselines( content_type="application/json", ), ) + customer_metadata_properties = {"key1": "value1"} estimator = XGBoost( entry_point="training.py", source_dir=os.path.join(DATA_DIR, "sip"), @@ -1972,6 +1973,7 @@ def test_model_registration_with_drift_check_baselines( model_package_group_name="testModelPackageGroup", model_metrics=model_metrics, drift_check_baselines=drift_check_baselines, + customer_metadata_properties=customer_metadata_properties, ) pipeline = Pipeline( @@ -2042,6 +2044,7 @@ def test_model_registration_with_drift_check_baselines( response["DriftCheckBaselines"]["ModelDataQuality"]["Statistics"]["ContentType"] == "application/json" ) + assert response["CustomerMetadataProperties"] == customer_metadata_properties break finally: try: diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 8604835890..e84c0142a5 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -2385,6 +2385,7 @@ def test_create_model_package_from_containers_all_args(sagemaker_session): marketplace_cert = (True,) approval_status = ("Approved",) description = "description" + customer_metadata_properties = {"key1": "value1"} sagemaker_session.create_model_package_from_containers( containers=containers, content_types=content_types, @@ -2398,6 +2399,7 @@ def test_create_model_package_from_containers_all_args(sagemaker_session): approval_status=approval_status, description=description, drift_check_baselines=drift_check_baselines, + customer_metadata_properties=customer_metadata_properties, ) expected_args = { "ModelPackageName": model_package_name, @@ -2414,6 +2416,7 @@ def test_create_model_package_from_containers_all_args(sagemaker_session): "CertifyForMarketplace": marketplace_cert, "ModelApprovalStatus": approval_status, "DriftCheckBaselines": drift_check_baselines, + "CustomerMetadataProperties": customer_metadata_properties } sagemaker_session.sagemaker_client.create_model_package.assert_called_with(**expected_args) From 538dbdf9bdf9b8b6ff07550b1bd8afa79109de5e Mon Sep 17 00:00:00 2001 From: Madhubalasri Date: Tue, 15 Feb 2022 12:46:15 +0530 Subject: [PATCH 2/8] Modifying session object --- src/sagemaker/session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 6825953ae3..8be7a0707e 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -2778,7 +2778,7 @@ def create_model_package_from_containers( approval_status="PendingManualApproval", description=None, drift_check_baselines=None, - customer_metadata_properties=customer_metadata_properties, + customer_metadata_properties=None, ): """Get request dictionary for CreateModelPackage API. From 46dd80027d404fbb5bdaba6d0f383b13eb6b3b27 Mon Sep 17 00:00:00 2001 From: Madhubalasri Date: Tue, 15 Feb 2022 13:19:10 +0530 Subject: [PATCH 3/8] Fixing lint errors --- src/sagemaker/model.py | 2 +- src/sagemaker/session.py | 2 +- tests/integ/test_mxnet.py | 4 +++- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 47b4538bcc..fe264994e8 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -359,7 +359,7 @@ def register( description=description, container_def_list=[container_def], drift_check_baselines=drift_check_baselines, - customer_metadata_properties = customer_metadata_properties, + customer_metadata_properties=customer_metadata_properties, ) model_package = self.sagemaker_session.create_model_package_from_containers( **model_pkg_args diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 8be7a0707e..412aedfcb5 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -2828,7 +2828,7 @@ def create_model_package_from_containers( try: self.sagemaker_client.describe_model_package_group( ModelPackageGroupName=request["ModelPackageGroupName"]) - except ClientError as e: + except ClientError: self.sagemaker_client.create_model_package_group( ModelPackageGroupName=request["ModelPackageGroupName"]) return self.sagemaker_client.create_model_package(**request) diff --git a/tests/integ/test_mxnet.py b/tests/integ/test_mxnet.py index e286f1481c..688b08855b 100644 --- a/tests/integ/test_mxnet.py +++ b/tests/integ/test_mxnet.py @@ -230,6 +230,7 @@ def test_register_model_package( assert result is not None sagemaker_session.sagemaker_client.delete_model_package(ModelPackageName=model_package_name) + def test_register_model_package_via_group( mxnet_training_job, sagemaker_session, @@ -266,12 +267,13 @@ def test_register_model_package_via_group( result = predictor.predict(data) assert result is not None model_packages = \ - sagemaker_session.sagemaker_client.list_model_packages(ModelPackageGroupName=model_package_group_name)[ + sagemaker_session.sagemaker_client.list_model_packages(ModelPackageGroupName=model_package_group_name)[ 'ModelPackageSummaryList'] for model_package in model_packages: sagemaker_session.sagemaker_client.delete_model_package(ModelPackageName=model_package['ModelPackageArn']) sagemaker_session.sagemaker_client.delete_model_package_group(ModelPackageGroupName=model_package_group_name) + def test_register_model_package_versioned( mxnet_training_job, sagemaker_session, From d216a86da2d6b4f099643a0c0752885bc19aec19 Mon Sep 17 00:00:00 2001 From: Madhubalasri Date: Tue, 15 Feb 2022 13:22:30 +0530 Subject: [PATCH 4/8] fixing lint --- tests/integ/test_mxnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integ/test_mxnet.py b/tests/integ/test_mxnet.py index 688b08855b..63269188d4 100644 --- a/tests/integ/test_mxnet.py +++ b/tests/integ/test_mxnet.py @@ -268,7 +268,7 @@ def test_register_model_package_via_group( assert result is not None model_packages = \ sagemaker_session.sagemaker_client.list_model_packages(ModelPackageGroupName=model_package_group_name)[ - 'ModelPackageSummaryList'] + 'ModelPackageSummaryList'] for model_package in model_packages: sagemaker_session.sagemaker_client.delete_model_package(ModelPackageName=model_package['ModelPackageArn']) sagemaker_session.sagemaker_client.delete_model_package_group(ModelPackageGroupName=model_package_group_name) From 6a5f707ec7e209012640711edd49d07391b997b4 Mon Sep 17 00:00:00 2001 From: Madhubalasri Date: Tue, 15 Feb 2022 13:26:47 +0530 Subject: [PATCH 5/8] fixing lint errors --- src/sagemaker/workflow/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sagemaker/workflow/_utils.py b/src/sagemaker/workflow/_utils.py index f413b990da..d39429f0e6 100644 --- a/src/sagemaker/workflow/_utils.py +++ b/src/sagemaker/workflow/_utils.py @@ -365,7 +365,7 @@ def __init__( self.tags = tags self.model_metrics = model_metrics self.drift_check_baselines = drift_check_baselines - self.customer_metadata_properties = customer_metadata_properties, + self.customer_metadata_properties = customer_metadata_properties self.metadata_properties = metadata_properties self.approval_status = approval_status self.image_uri = image_uri From fbe676d0cd8a470ffdbf55a2d7e7811762a6fec9 Mon Sep 17 00:00:00 2001 From: Madhubalasri Date: Tue, 15 Feb 2022 14:34:23 +0530 Subject: [PATCH 6/8] Fixing black-check formats --- src/sagemaker/session.py | 6 ++++-- src/sagemaker/workflow/_utils.py | 2 +- tests/integ/test_mxnet.py | 18 ++++++++++++------ tests/unit/test_session.py | 2 +- 4 files changed, 18 insertions(+), 10 deletions(-) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 412aedfcb5..78650774b8 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -2827,10 +2827,12 @@ def create_model_package_from_containers( ) try: self.sagemaker_client.describe_model_package_group( - ModelPackageGroupName=request["ModelPackageGroupName"]) + ModelPackageGroupName=request["ModelPackageGroupName"] + ) except ClientError: self.sagemaker_client.create_model_package_group( - ModelPackageGroupName=request["ModelPackageGroupName"]) + ModelPackageGroupName=request["ModelPackageGroupName"] + ) return self.sagemaker_client.create_model_package(**request) def wait_for_model_package(self, model_package_name, poll=5): diff --git a/src/sagemaker/workflow/_utils.py b/src/sagemaker/workflow/_utils.py index d39429f0e6..d341af211d 100644 --- a/src/sagemaker/workflow/_utils.py +++ b/src/sagemaker/workflow/_utils.py @@ -439,7 +439,7 @@ def arguments(self) -> RequestType: description=self.description, tags=self.tags, container_def_list=self.container_def_list, - customer_metadata_properties=self.customer_metadata_properties + customer_metadata_properties=self.customer_metadata_properties, ) request_dict = get_create_model_package_request(**model_package_args) diff --git a/tests/integ/test_mxnet.py b/tests/integ/test_mxnet.py index 63269188d4..05c68e0ac1 100644 --- a/tests/integ/test_mxnet.py +++ b/tests/integ/test_mxnet.py @@ -253,7 +253,9 @@ def test_register_model_package_via_group( sagemaker_session=sagemaker_session, framework_version=mxnet_inference_latest_version, ) - model_package_group_name = "register-model-package-{}".format(sagemaker_timestamp()) + model_package_group_name = "register-model-package-{}".format( + sagemaker_timestamp() + ) model_pkg = model.register( content_types=["application/json"], response_types=["application/json"], @@ -266,12 +268,16 @@ def test_register_model_package_via_group( data = numpy.zeros(shape=(1, 1, 28, 28)) result = predictor.predict(data) assert result is not None - model_packages = \ - sagemaker_session.sagemaker_client.list_model_packages(ModelPackageGroupName=model_package_group_name)[ - 'ModelPackageSummaryList'] + model_packages = sagemaker_session.sagemaker_client.list_model_packages( + ModelPackageGroupName=model_package_group_name + )["ModelPackageSummaryList"] for model_package in model_packages: - sagemaker_session.sagemaker_client.delete_model_package(ModelPackageName=model_package['ModelPackageArn']) - sagemaker_session.sagemaker_client.delete_model_package_group(ModelPackageGroupName=model_package_group_name) + sagemaker_session.sagemaker_client.delete_model_package( + ModelPackageName=model_package["ModelPackageArn"] + ) + sagemaker_session.sagemaker_client.delete_model_package_group( + ModelPackageGroupName=model_package_group_name + ) def test_register_model_package_versioned( diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index e84c0142a5..4523253a7f 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -2416,7 +2416,7 @@ def test_create_model_package_from_containers_all_args(sagemaker_session): "CertifyForMarketplace": marketplace_cert, "ModelApprovalStatus": approval_status, "DriftCheckBaselines": drift_check_baselines, - "CustomerMetadataProperties": customer_metadata_properties + "CustomerMetadataProperties": customer_metadata_properties, } sagemaker_session.sagemaker_client.create_model_package.assert_called_with(**expected_args) From a1182f1e96da224d04948ed65c2f7ef5b0d2902a Mon Sep 17 00:00:00 2001 From: Madhubalasri Date: Tue, 15 Feb 2022 14:45:57 +0530 Subject: [PATCH 7/8] Removing redundant test --- tests/integ/test_mxnet.py | 49 --------------------------------------- 1 file changed, 49 deletions(-) diff --git a/tests/integ/test_mxnet.py b/tests/integ/test_mxnet.py index 05c68e0ac1..d13108d471 100644 --- a/tests/integ/test_mxnet.py +++ b/tests/integ/test_mxnet.py @@ -231,55 +231,6 @@ def test_register_model_package( sagemaker_session.sagemaker_client.delete_model_package(ModelPackageName=model_package_name) -def test_register_model_package_via_group( - mxnet_training_job, - sagemaker_session, - mxnet_inference_latest_version, - mxnet_inference_latest_py_version, - cpu_instance_type, -): - endpoint_name = "test-mxnet-deploy-model-{}".format(sagemaker_timestamp()) - with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session): - desc = sagemaker_session.sagemaker_client.describe_training_job( - TrainingJobName=mxnet_training_job - ) - model_data = desc["ModelArtifacts"]["S3ModelArtifacts"] - script_path = os.path.join(DATA_DIR, "mxnet_mnist", "mnist.py") - model = MXNetModel( - model_data, - "SageMakerRole", - entry_point=script_path, - py_version=mxnet_inference_latest_py_version, - sagemaker_session=sagemaker_session, - framework_version=mxnet_inference_latest_version, - ) - model_package_group_name = "register-model-package-{}".format( - sagemaker_timestamp() - ) - model_pkg = model.register( - content_types=["application/json"], - response_types=["application/json"], - inference_instances=["ml.m5.large"], - transform_instances=["ml.m5.large"], - model_package_group_name=model_package_group_name, - ) - assert isinstance(model_pkg, ModelPackage) - predictor = model.deploy(1, cpu_instance_type, endpoint_name=endpoint_name) - data = numpy.zeros(shape=(1, 1, 28, 28)) - result = predictor.predict(data) - assert result is not None - model_packages = sagemaker_session.sagemaker_client.list_model_packages( - ModelPackageGroupName=model_package_group_name - )["ModelPackageSummaryList"] - for model_package in model_packages: - sagemaker_session.sagemaker_client.delete_model_package( - ModelPackageName=model_package["ModelPackageArn"] - ) - sagemaker_session.sagemaker_client.delete_model_package_group( - ModelPackageGroupName=model_package_group_name - ) - - def test_register_model_package_versioned( mxnet_training_job, sagemaker_session, From cda6b3c3e40db713e8297bd683d2be6408c9b984 Mon Sep 17 00:00:00 2001 From: Madhubalasri Date: Tue, 15 Feb 2022 15:40:03 +0530 Subject: [PATCH 8/8] Fixing black-check errors --- src/sagemaker/session.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 78650774b8..c50a22d3f8 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -2825,14 +2825,15 @@ def create_model_package_from_containers( drift_check_baselines=drift_check_baselines, customer_metadata_properties=customer_metadata_properties, ) - try: - self.sagemaker_client.describe_model_package_group( - ModelPackageGroupName=request["ModelPackageGroupName"] - ) - except ClientError: - self.sagemaker_client.create_model_package_group( - ModelPackageGroupName=request["ModelPackageGroupName"] - ) + if model_package_group_name is not None: + try: + self.sagemaker_client.describe_model_package_group( + ModelPackageGroupName=request["ModelPackageGroupName"] + ) + except ClientError: + self.sagemaker_client.create_model_package_group( + ModelPackageGroupName=request["ModelPackageGroupName"] + ) return self.sagemaker_client.create_model_package(**request) def wait_for_model_package(self, model_package_name, poll=5):