From d23619443cbc274a5956dca0b5f18039772c886e Mon Sep 17 00:00:00 2001 From: hardianlawi Date: Tue, 7 Jun 2022 08:36:20 +0000 Subject: [PATCH] Add variant_name kwarg --- src/sagemaker/huggingface/model.py | 4 ++++ src/sagemaker/model.py | 4 ++++ src/sagemaker/multidatamodel.py | 9 ++++++++- 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/src/sagemaker/huggingface/model.py b/src/sagemaker/huggingface/model.py index 6b93470c3a..8d80730839 100644 --- a/src/sagemaker/huggingface/model.py +++ b/src/sagemaker/huggingface/model.py @@ -200,6 +200,7 @@ def deploy( deserializer=None, accelerator_type=None, endpoint_name=None, + variant_name="AllTraffic", tags=None, kms_key=None, wait=True, @@ -251,6 +252,8 @@ def deploy( https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html endpoint_name (str): The name of the endpoint to create (default: None). If not specified, a unique endpoint name will be created. + variant_name (string): The ``VariantName`` of this production variant + (default: 'AllTraffic'). tags (List[dict[str, str]]): The list of tags to attach to this specific endpoint. kms_key (str): The ARN of the KMS key that is used to encrypt the @@ -308,6 +311,7 @@ def deploy( deserializer, accelerator_type, endpoint_name, + variant_name, tags, kms_key, wait, diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 4fc0552d64..dfeb153572 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -1023,6 +1023,7 @@ def deploy( deserializer=None, accelerator_type=None, endpoint_name=None, + variant_name="AllTraffic", tags=None, kms_key=None, wait=True, @@ -1074,6 +1075,8 @@ def deploy( https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html endpoint_name (str): The name of the endpoint to create (default: None). If not specified, a unique endpoint name will be created. + variant_name (string): The ``VariantName`` of this production variant + (default: 'AllTraffic'). tags (List[dict[str, str]]): The list of tags to attach to this specific endpoint. kms_key (str): The ARN of the KMS key that is used to encrypt the @@ -1166,6 +1169,7 @@ def deploy( self.name, instance_type, initial_instance_count, + variant_name=variant_name, accelerator_type=accelerator_type, serverless_inference_config=serverless_inference_config_dict, volume_size=volume_size, diff --git a/src/sagemaker/multidatamodel.py b/src/sagemaker/multidatamodel.py index 2cb6674ffd..7f83353872 100644 --- a/src/sagemaker/multidatamodel.py +++ b/src/sagemaker/multidatamodel.py @@ -158,6 +158,7 @@ def deploy( deserializer=None, accelerator_type=None, endpoint_name=None, + variant_name="AllTraffic", tags=None, kms_key=None, wait=True, @@ -203,6 +204,8 @@ def deploy( https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html endpoint_name (str): The name of the endpoint to create (default: None). If not specified, a unique endpoint name will be created. + variant_name (string): The ``VariantName`` of this production variant + (default: 'AllTraffic'). tags (List[dict[str, str]]): The list of tags to attach to this specific endpoint. kms_key (str): The ARN of the KMS key that is used to encrypt the @@ -250,7 +253,11 @@ def deploy( ) production_variant = sagemaker.production_variant( - self.name, instance_type, initial_instance_count, accelerator_type=accelerator_type + self.name, + instance_type, + initial_instance_count, + variant_name=variant_name, + accelerator_type=accelerator_type, ) if endpoint_name: self.endpoint_name = endpoint_name