diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 6431ca8afc..fd74633584 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -1263,6 +1263,7 @@ def register( compile_model_family=None, model_name=None, drift_check_baselines=None, + customer_metadata_properties=None, **kwargs, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -1292,6 +1293,8 @@ def register( model will be used (default: None). model_name (str): User defined model name (default: None). drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None). + customer_metadata_properties (dict[str, str]): A dictionary of key-value paired + metadata properties (default: None). **kwargs: Passed to invocation of ``create_model()``. Implementations may customize ``create_model()`` to accept ``**kwargs`` to customize model creation during deploy. For more, see the implementation docs. @@ -1322,6 +1325,7 @@ def register( approval_status, description, drift_check_baselines=drift_check_baselines, + customer_metadata_properties=customer_metadata_properties, ) @property diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 00a04a3199..2d01bb4c0f 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -303,6 +303,7 @@ def register( approval_status=None, description=None, drift_check_baselines=None, + customer_metadata_properties=None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -328,6 +329,8 @@ def register( or "PendingManualApproval" (default: "PendingManualApproval"). description (str): Model Package description (default: None). drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None). + customer_metadata_properties (dict[str, str]): A dictionary of key-value paired + metadata properties (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -355,6 +358,7 @@ def register( description=description, container_def_list=[container_def], drift_check_baselines=drift_check_baselines, + customer_metadata_properties=customer_metadata_properties, ) model_package = self.sagemaker_session.create_model_package_from_containers( **model_pkg_args diff --git a/src/sagemaker/mxnet/model.py b/src/sagemaker/mxnet/model.py index df0dd31a28..0a10cbf3c1 100644 --- a/src/sagemaker/mxnet/model.py +++ b/src/sagemaker/mxnet/model.py @@ -158,6 +158,7 @@ def register( approval_status=None, description=None, drift_check_baselines=None, + customer_metadata_properties=None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -183,6 +184,8 @@ def register( or "PendingManualApproval" (default: "PendingManualApproval"). description (str): Model Package description (default: None). drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None). + customer_metadata_properties (dict[str, str]): A dictionary of key-value paired + metadata properties (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -211,6 +214,7 @@ def register( approval_status, description, drift_check_baselines=drift_check_baselines, + customer_metadata_properties=customer_metadata_properties, ) def prepare_container_def(self, instance_type=None, accelerator_type=None): diff --git a/src/sagemaker/pytorch/model.py b/src/sagemaker/pytorch/model.py index 3a0c3a283c..0f51788626 100644 --- a/src/sagemaker/pytorch/model.py +++ b/src/sagemaker/pytorch/model.py @@ -157,6 +157,7 @@ def register( approval_status=None, description=None, drift_check_baselines=None, + customer_metadata_properties=None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -182,6 +183,8 @@ def register( or "PendingManualApproval" (default: "PendingManualApproval"). description (str): Model Package description (default: None). drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None). + customer_metadata_properties (dict[str, str]): A dictionary of key-value paired + metadata properties (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -210,6 +213,7 @@ def register( approval_status, description, drift_check_baselines=drift_check_baselines, + customer_metadata_properties=customer_metadata_properties, ) def prepare_container_def(self, instance_type=None, accelerator_type=None): diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 91b89ea4c9..c50a22d3f8 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -2778,6 +2778,7 @@ def create_model_package_from_containers( approval_status="PendingManualApproval", description=None, drift_check_baselines=None, + customer_metadata_properties=None, ): """Get request dictionary for CreateModelPackage API. @@ -2803,6 +2804,9 @@ def create_model_package_from_containers( or "PendingManualApproval" (default: "PendingManualApproval"). description (str): Model Package description (default: None). drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None). + customer_metadata_properties (dict[str, str]): A dictionary of key-value paired + metadata properties (default: None). + """ request = get_create_model_package_request( @@ -2819,7 +2823,17 @@ def create_model_package_from_containers( approval_status, description, drift_check_baselines=drift_check_baselines, + customer_metadata_properties=customer_metadata_properties, ) + if model_package_group_name is not None: + try: + self.sagemaker_client.describe_model_package_group( + ModelPackageGroupName=request["ModelPackageGroupName"] + ) + except ClientError: + self.sagemaker_client.create_model_package_group( + ModelPackageGroupName=request["ModelPackageGroupName"] + ) return self.sagemaker_client.create_model_package(**request) def wait_for_model_package(self, model_package_name, poll=5): @@ -4120,6 +4134,7 @@ def get_model_package_args( tags=None, container_def_list=None, drift_check_baselines=None, + customer_metadata_properties=None, ): """Get arguments for create_model_package method. @@ -4148,6 +4163,8 @@ def get_model_package_args( (default: None). container_def_list (list): A list of container defintiions (default: None). drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None). + customer_metadata_properties (dict[str, str]): A dictionary of key-value paired + metadata properties (default: None). Returns: dict: A dictionary of method argument names and values. """ @@ -4185,6 +4202,8 @@ def get_model_package_args( model_package_args["description"] = description if tags is not None: model_package_args["tags"] = tags + if customer_metadata_properties is not None: + model_package_args["customer_metadata_properties"] = customer_metadata_properties return model_package_args @@ -4203,6 +4222,7 @@ def get_create_model_package_request( description=None, tags=None, drift_check_baselines=None, + customer_metadata_properties=None, ): """Get request dictionary for CreateModelPackage API. @@ -4229,6 +4249,8 @@ def get_create_model_package_request( tags (List[dict[str, str]]): A list of dictionaries containing key-value pairs (default: None). drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None). + customer_metadata_properties (dict[str, str]): A dictionary of key-value paired + metadata properties (default: None). """ if all([model_package_name, model_package_group_name]): @@ -4250,6 +4272,8 @@ def get_create_model_package_request( request_dict["DriftCheckBaselines"] = drift_check_baselines if metadata_properties: request_dict["MetadataProperties"] = metadata_properties + if customer_metadata_properties is not None: + request_dict["CustomerMetadataProperties"] = customer_metadata_properties if containers is not None: if not all([content_types, response_types, inference_instances, transform_instances]): raise ValueError( diff --git a/src/sagemaker/tensorflow/model.py b/src/sagemaker/tensorflow/model.py index 0b8d2f7235..9f6a7841d5 100644 --- a/src/sagemaker/tensorflow/model.py +++ b/src/sagemaker/tensorflow/model.py @@ -201,6 +201,7 @@ def register( approval_status=None, description=None, drift_check_baselines=None, + customer_metadata_properties=None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -226,6 +227,9 @@ def register( or "PendingManualApproval" (default: "PendingManualApproval"). description (str): Model Package description (default: None). drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None). + customer_metadata_properties (dict[str, str]): A dictionary of key-value paired + metadata properties (default: None). + Returns: A `sagemaker.model.ModelPackage` instance. @@ -254,6 +258,7 @@ def register( approval_status, description, drift_check_baselines=drift_check_baselines, + customer_metadata_properties=customer_metadata_properties, ) def deploy( diff --git a/src/sagemaker/workflow/_utils.py b/src/sagemaker/workflow/_utils.py index ca078fe7ea..d341af211d 100644 --- a/src/sagemaker/workflow/_utils.py +++ b/src/sagemaker/workflow/_utils.py @@ -310,6 +310,7 @@ def __init__( tags=None, container_def_list=None, drift_check_baselines=None, + customer_metadata_properties=None, **kwargs, ): """Constructor of a register model step. @@ -347,6 +348,8 @@ def __init__( this step depends on retry_policies (List[RetryPolicy]): The list of retry policies for the current step drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None). + customer_metadata_properties (dict[str, str]): A dictionary of key-value paired + metadata properties (default: None). **kwargs: additional arguments to `create_model`. """ super(_RegisterModelStep, self).__init__( @@ -362,6 +365,7 @@ def __init__( self.tags = tags self.model_metrics = model_metrics self.drift_check_baselines = drift_check_baselines + self.customer_metadata_properties = customer_metadata_properties self.metadata_properties = metadata_properties self.approval_status = approval_status self.image_uri = image_uri @@ -435,6 +439,7 @@ def arguments(self) -> RequestType: description=self.description, tags=self.tags, container_def_list=self.container_def_list, + customer_metadata_properties=self.customer_metadata_properties, ) request_dict = get_create_model_package_request(**model_package_args) diff --git a/src/sagemaker/workflow/step_collections.py b/src/sagemaker/workflow/step_collections.py index f4606488b2..27060d928e 100644 --- a/src/sagemaker/workflow/step_collections.py +++ b/src/sagemaker/workflow/step_collections.py @@ -75,6 +75,7 @@ def __init__( tags=None, model: Union[Model, PipelineModel] = None, drift_check_baselines=None, + customer_metadata_properties=None, **kwargs, ): """Construct steps `_RepackModelStep` and `_RegisterModelStep` based on the estimator. @@ -95,7 +96,7 @@ def __init__( for the repack model step register_model_step_retry_policies (List[RetryPolicy]): The list of retry policies for register model step - model_package_group_name (str): The Model Package Group name, exclusive to + model_package_group_name (str): The Model Package Group name or Arn, exclusive to `model_package_name`, using `model_package_group_name` makes the Model Package versioned (default: None). model_metrics (ModelMetrics): ModelMetrics object (default: None). @@ -113,6 +114,9 @@ def __init__( model (object or Model): A PipelineModel object that comprises a list of models which gets executed as a serial inference pipeline or a Model object. drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None). + customer_metadata_properties (dict[str, str]): A dictionary of key-value paired + metadata properties (default: None). + **kwargs: additional arguments to `create_model`. """ steps: List[Step] = [] @@ -229,6 +233,7 @@ def __init__( tags=tags, container_def_list=self.container_def_list, retry_policies=register_model_step_retry_policies, + customer_metadata_properties=customer_metadata_properties, **kwargs, ) if not repack_model: diff --git a/tests/integ/test_workflow.py b/tests/integ/test_workflow.py index 160f9f934b..14c2cf54b3 100644 --- a/tests/integ/test_workflow.py +++ b/tests/integ/test_workflow.py @@ -1952,6 +1952,7 @@ def test_model_registration_with_drift_check_baselines( content_type="application/json", ), ) + customer_metadata_properties = {"key1": "value1"} estimator = XGBoost( entry_point="training.py", source_dir=os.path.join(DATA_DIR, "sip"), @@ -1973,6 +1974,7 @@ def test_model_registration_with_drift_check_baselines( model_package_group_name="testModelPackageGroup", model_metrics=model_metrics, drift_check_baselines=drift_check_baselines, + customer_metadata_properties=customer_metadata_properties, ) pipeline = Pipeline( @@ -2043,6 +2045,7 @@ def test_model_registration_with_drift_check_baselines( response["DriftCheckBaselines"]["ModelDataQuality"]["Statistics"]["ContentType"] == "application/json" ) + assert response["CustomerMetadataProperties"] == customer_metadata_properties break finally: try: diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 8604835890..4523253a7f 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -2385,6 +2385,7 @@ def test_create_model_package_from_containers_all_args(sagemaker_session): marketplace_cert = (True,) approval_status = ("Approved",) description = "description" + customer_metadata_properties = {"key1": "value1"} sagemaker_session.create_model_package_from_containers( containers=containers, content_types=content_types, @@ -2398,6 +2399,7 @@ def test_create_model_package_from_containers_all_args(sagemaker_session): approval_status=approval_status, description=description, drift_check_baselines=drift_check_baselines, + customer_metadata_properties=customer_metadata_properties, ) expected_args = { "ModelPackageName": model_package_name, @@ -2414,6 +2416,7 @@ def test_create_model_package_from_containers_all_args(sagemaker_session): "CertifyForMarketplace": marketplace_cert, "ModelApprovalStatus": approval_status, "DriftCheckBaselines": drift_check_baselines, + "CustomerMetadataProperties": customer_metadata_properties, } sagemaker_session.sagemaker_client.create_model_package.assert_called_with(**expected_args)