From 38b911407edbfe9f6dd686c2d78d22064b08800e Mon Sep 17 00:00:00 2001 From: Bobby Lindsey Date: Thu, 19 May 2022 22:38:54 -0700 Subject: [PATCH] fix: fix missing register method params for framework models --- src/sagemaker/huggingface/model.py | 4 ++++ src/sagemaker/sklearn/model.py | 7 +++++++ 2 files changed, 11 insertions(+) diff --git a/src/sagemaker/huggingface/model.py b/src/sagemaker/huggingface/model.py index d1e876fa21..9aca7b62d0 100644 --- a/src/sagemaker/huggingface/model.py +++ b/src/sagemaker/huggingface/model.py @@ -304,6 +304,7 @@ def register( approval_status=None, description=None, drift_check_baselines=None, + customer_metadata_properties=None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -331,6 +332,8 @@ def register( or "PendingManualApproval". Defaults to ``PendingManualApproval``. description (str): Model Package description. Defaults to ``None``. drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None). + customer_metadata_properties (dict[str, str]): A dictionary of key-value paired + metadata properties (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -359,6 +362,7 @@ def register( approval_status, description, drift_check_baselines=drift_check_baselines, + customer_metadata_properties=customer_metadata_properties, ) def prepare_container_def( diff --git a/src/sagemaker/sklearn/model.py b/src/sagemaker/sklearn/model.py index 4cdfb99af7..27df03d3dd 100644 --- a/src/sagemaker/sklearn/model.py +++ b/src/sagemaker/sklearn/model.py @@ -151,6 +151,8 @@ def register( marketplace_cert=False, approval_status=None, description=None, + drift_check_baselines=None, + customer_metadata_properties=None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -175,6 +177,9 @@ def register( approval_status (str): Model Approval Status, values can be "Approved", "Rejected", or "PendingManualApproval" (default: "PendingManualApproval"). description (str): Model Package description (default: None). + drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None). + customer_metadata_properties (dict[str, str]): A dictionary of key-value paired + metadata properties (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -202,6 +207,8 @@ def register( marketplace_cert, approval_status, description, + drift_check_baselines=drift_check_baselines, + customer_metadata_properties=customer_metadata_properties, ) def prepare_container_def(