diff --git a/src/sagemaker/accept_types.py b/src/sagemaker/accept_types.py index 78aa655e04..0327ef3845 100644 --- a/src/sagemaker/accept_types.py +++ b/src/sagemaker/accept_types.py @@ -24,6 +24,7 @@ def retrieve_options( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -37,6 +38,8 @@ def retrieve_options( retrieve the supported accept types. (Default: None). model_version (str): The version of the model for which to retrieve the supported accept types. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -60,11 +63,12 @@ def retrieve_options( ) return artifacts._retrieve_supported_accept_types( - model_id, - model_version, - region, - tolerate_vulnerable_model, - tolerate_deprecated_model, + model_id=model_id, + model_version=model_version, + hub_arn=hub_arn, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, ) @@ -73,6 +77,7 @@ def retrieve_default( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -87,6 +92,8 @@ def retrieve_default( retrieve the default accept type. (Default: None). model_version (str): The version of the model for which to retrieve the default accept type. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -110,11 +117,12 @@ def retrieve_default( ) return artifacts._retrieve_default_accept_type( - model_id, - model_version, - region, - tolerate_vulnerable_model, - tolerate_deprecated_model, + model_id=model_id, + model_version=model_version, + hub_arn=hub_arn, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, ) diff --git a/src/sagemaker/chainer/model.py b/src/sagemaker/chainer/model.py index 59c8310587..963eaaa474 100644 --- a/src/sagemaker/chainer/model.py +++ b/src/sagemaker/chainer/model.py @@ -282,6 +282,7 @@ def prepare_container_def( accelerator_type=None, serverless_inference_config=None, accept_eula=None, + model_reference_arn=None, ): """Return a container definition with framework configuration set in model environment. @@ -333,6 +334,7 @@ def prepare_container_def( self.model_data, deploy_env, accept_eula=accept_eula, + model_reference_arn=model_reference_arn, ) def serving_image_uri( diff --git a/src/sagemaker/content_types.py b/src/sagemaker/content_types.py index 46d0361f67..3154c1e4fe 100644 --- a/src/sagemaker/content_types.py +++ b/src/sagemaker/content_types.py @@ -24,6 +24,7 @@ def retrieve_options( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -37,6 +38,8 @@ def retrieve_options( retrieve the supported content types. (Default: None). model_version (str): The version of the model for which to retrieve the supported content types. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -60,11 +63,12 @@ def retrieve_options( ) return artifacts._retrieve_supported_content_types( - model_id, - model_version, - region, - tolerate_vulnerable_model, - tolerate_deprecated_model, + model_id=model_id, + model_version=model_version, + hub_arn=hub_arn, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, ) @@ -73,6 +77,7 @@ def retrieve_default( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -87,6 +92,8 @@ def retrieve_default( retrieve the default content type. (Default: None). model_version (str): The version of the model for which to retrieve the default content type. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -110,11 +117,12 @@ def retrieve_default( ) return artifacts._retrieve_default_content_type( - model_id, - model_version, - region, - tolerate_vulnerable_model, - tolerate_deprecated_model, + model_id=model_id, + model_version=model_version, + hub_arn=hub_arn, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, ) diff --git a/src/sagemaker/deserializers.py b/src/sagemaker/deserializers.py index 1a4be43897..3081daea23 100644 --- a/src/sagemaker/deserializers.py +++ b/src/sagemaker/deserializers.py @@ -43,6 +43,7 @@ def retrieve_options( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -56,6 +57,8 @@ def retrieve_options( retrieve the supported deserializers. (Default: None). model_version (str): The version of the model for which to retrieve the supported deserializers. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -80,11 +83,12 @@ def retrieve_options( ) return artifacts._retrieve_deserializer_options( - model_id, - model_version, - region, - tolerate_vulnerable_model, - tolerate_deprecated_model, + model_id=model_id, + model_version=model_version, + hub_arn=hub_arn, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, ) @@ -93,6 +97,7 @@ def retrieve_default( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -107,6 +112,8 @@ def retrieve_default( retrieve the default deserializer. (Default: None). model_version (str): The version of the model for which to retrieve the default deserializer. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -131,11 +138,12 @@ def retrieve_default( ) return artifacts._retrieve_default_deserializer( - model_id, - model_version, - region, - tolerate_vulnerable_model, - tolerate_deprecated_model, + model_id=model_id, + model_version=model_version, + hub_arn=hub_arn, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, ) diff --git a/src/sagemaker/djl_inference/model.py b/src/sagemaker/djl_inference/model.py index efbb44460c..61db6759f8 100644 --- a/src/sagemaker/djl_inference/model.py +++ b/src/sagemaker/djl_inference/model.py @@ -732,6 +732,7 @@ def prepare_container_def( accelerator_type=None, serverless_inference_config=None, accept_eula=None, + model_reference_arn=None, ): # pylint: disable=unused-argument """A container definition with framework configuration set in model environment variables. diff --git a/src/sagemaker/environment_variables.py b/src/sagemaker/environment_variables.py index b67066fcde..57851d112a 100644 --- a/src/sagemaker/environment_variables.py +++ b/src/sagemaker/environment_variables.py @@ -30,6 +30,7 @@ def retrieve_default( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, include_aws_sdk_env_vars: bool = True, @@ -46,6 +47,8 @@ def retrieve_default( retrieve the default environment variables. (Default: None). model_version (str): Optional. The version of the model for which to retrieve the default environment variables. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to + retrieve model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -78,12 +81,13 @@ def retrieve_default( ) return artifacts._retrieve_default_environment_variables( - model_id, - model_version, - region, - tolerate_vulnerable_model, - tolerate_deprecated_model, - include_aws_sdk_env_vars, + model_id=model_id, + model_version=model_version, + hub_arn=hub_arn, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, + include_aws_sdk_env_vars=include_aws_sdk_env_vars, sagemaker_session=sagemaker_session, instance_type=instance_type, script=script, diff --git a/src/sagemaker/huggingface/model.py b/src/sagemaker/huggingface/model.py index 8c1978c156..533a427747 100644 --- a/src/sagemaker/huggingface/model.py +++ b/src/sagemaker/huggingface/model.py @@ -479,6 +479,7 @@ def prepare_container_def( serverless_inference_config=None, inference_tool=None, accept_eula=None, + model_reference_arn=None, ): """A container definition with framework configuration set in model environment variables. @@ -533,6 +534,7 @@ def prepare_container_def( self.repacked_model_data or self.model_data, deploy_env, accept_eula=accept_eula, + model_reference_arn=model_reference_arn, ) def serving_image_uri( diff --git a/src/sagemaker/hyperparameters.py b/src/sagemaker/hyperparameters.py index 5873e37b9f..49ced478dd 100644 --- a/src/sagemaker/hyperparameters.py +++ b/src/sagemaker/hyperparameters.py @@ -31,6 +31,7 @@ def retrieve_default( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, instance_type: Optional[str] = None, include_container_hyperparameters: bool = False, tolerate_vulnerable_model: bool = False, @@ -46,6 +47,8 @@ def retrieve_default( retrieve the default hyperparameters. (Default: None). model_version (str): The version of the model for which to retrieve the default hyperparameters. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (default: None). instance_type (str): An instance type to optionally supply in order to get hyperparameters specific for the instance type. include_container_hyperparameters (bool): ``True`` if the container hyperparameters @@ -80,6 +83,7 @@ def retrieve_default( return artifacts._retrieve_default_hyperparameters( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, instance_type=instance_type, region=region, include_container_hyperparameters=include_container_hyperparameters, @@ -92,6 +96,7 @@ def retrieve_default( def validate( region: Optional[str] = None, model_id: Optional[str] = None, + hub_arn: Optional[str] = None, model_version: Optional[str] = None, hyperparameters: Optional[dict] = None, validation_mode: HyperparameterValidationMode = HyperparameterValidationMode.VALIDATE_PROVIDED, @@ -107,6 +112,8 @@ def validate( (Default: None). model_version (str): The version of the model for which to validate hyperparameters. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (default: None). hyperparameters (dict): Hyperparameters to validate. (Default: None). validation_mode (HyperparameterValidationMode): Method of validation to use with @@ -148,6 +155,7 @@ def validate( return validate_hyperparameters( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, hyperparameters=hyperparameters, validation_mode=validation_mode, region=region, diff --git a/src/sagemaker/image_uris.py b/src/sagemaker/image_uris.py index 2bef305aeb..22328c4183 100644 --- a/src/sagemaker/image_uris.py +++ b/src/sagemaker/image_uris.py @@ -64,6 +64,7 @@ def retrieve( training_compiler_config=None, model_id=None, model_version=None, + hub_arn=None, tolerate_vulnerable_model=False, tolerate_deprecated_model=False, sdk_version=None, @@ -104,6 +105,8 @@ def retrieve( (default: None). model_version (str): The version of the JumpStart model for which to retrieve the image URI (default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): ``True`` if vulnerable versions of model specifications should be tolerated without an exception raised. If ``False``, raises an exception if the script used by this version of the model has dependencies with known security @@ -149,6 +152,7 @@ def retrieve( model_id, model_version, image_scope, + hub_arn, framework, region, version, diff --git a/src/sagemaker/instance_types.py b/src/sagemaker/instance_types.py index 48aaab0ac8..66e8e5127f 100644 --- a/src/sagemaker/instance_types.py +++ b/src/sagemaker/instance_types.py @@ -30,6 +30,7 @@ def retrieve_default( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, scope: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -46,6 +47,8 @@ def retrieve_default( retrieve the default instance type. (Default: None). model_version (str): The version of the model for which to retrieve the default instance type. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (default: None). scope (str): The model type, i.e. what it is used for. Valid values: "training" and "inference". tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -82,6 +85,7 @@ def retrieve_default( model_id, model_version, scope, + hub_arn, region, tolerate_vulnerable_model, tolerate_deprecated_model, @@ -95,6 +99,7 @@ def retrieve( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, scope: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -110,6 +115,8 @@ def retrieve( retrieve the supported instance types. (Default: None). model_version (str): The version of the model for which to retrieve the supported instance types. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -142,12 +149,13 @@ def retrieve( raise ValueError("Must specify scope for instance types.") return artifacts._retrieve_instance_types( - model_id, - model_version, - scope, - region, - tolerate_vulnerable_model, - tolerate_deprecated_model, + model_id=model_id, + model_version=model_version, + scope=scope, + hub_arn=hub_arn, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, training_instance_type=training_instance_type, ) diff --git a/src/sagemaker/jumpstart/accessors.py b/src/sagemaker/jumpstart/accessors.py index 35df030ddc..66003c9f03 100644 --- a/src/sagemaker/jumpstart/accessors.py +++ b/src/sagemaker/jumpstart/accessors.py @@ -10,6 +10,7 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. +# pylint: skip-file """This module contains accessors related to SageMaker JumpStart.""" from __future__ import absolute_import import functools @@ -17,10 +18,16 @@ import boto3 from sagemaker.deprecations import deprecated -from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartModelSpecs +from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartModelSpecs, HubContentType from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.jumpstart import cache +from sagemaker.jumpstart.hub.utils import ( + construct_hub_model_arn_from_inputs, + construct_hub_model_reference_arn_from_inputs, +) from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME +from sagemaker.session import Session +from sagemaker.jumpstart import constants class SageMakerSettings(object): @@ -253,8 +260,10 @@ def get_model_specs( region: str, model_id: str, version: str, + hub_arn: Optional[str] = None, s3_client: Optional[boto3.client] = None, model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> JumpStartModelSpecs: """Returns model specs from JumpStart models cache. @@ -270,10 +279,34 @@ def get_model_specs( if s3_client is not None: additional_kwargs.update({"s3_client": s3_client}) + if hub_arn: + additional_kwargs.update({"sagemaker_session": sagemaker_session}) + cache_kwargs = JumpStartModelsAccessor._validate_and_mutate_region_cache_kwargs( {**JumpStartModelsAccessor._cache_kwargs, **additional_kwargs} ) JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs) + + if hub_arn: + try: + hub_model_arn = construct_hub_model_arn_from_inputs( + hub_arn=hub_arn, model_name=model_id, version=version + ) + model_specs = JumpStartModelsAccessor._cache.get_hub_model( + hub_model_arn=hub_model_arn + ) + model_specs.set_hub_content_type(HubContentType.MODEL) + return model_specs + except: # noqa: E722 + hub_model_arn = construct_hub_model_reference_arn_from_inputs( + hub_arn=hub_arn, model_name=model_id, version=version + ) + model_specs = JumpStartModelsAccessor._cache.get_hub_model_reference( + hub_model_reference_arn=hub_model_arn + ) + model_specs.set_hub_content_type(HubContentType.MODEL_REFERENCE) + return model_specs + return JumpStartModelsAccessor._cache.get_specs( # type: ignore model_id=model_id, version_str=version, model_type=model_type ) diff --git a/src/sagemaker/jumpstart/artifacts/environment_variables.py b/src/sagemaker/jumpstart/artifacts/environment_variables.py index c28c27ed4e..5f00783ed3 100644 --- a/src/sagemaker/jumpstart/artifacts/environment_variables.py +++ b/src/sagemaker/jumpstart/artifacts/environment_variables.py @@ -32,6 +32,7 @@ def _retrieve_default_environment_variables( model_id: str, model_version: str, + hub_arn: Optional[str] = None, region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -47,6 +48,8 @@ def _retrieve_default_environment_variables( retrieve the default environment variables. model_version (str): Version of the JumpStart model for which to retrieve the default environment variables. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve default environment variables. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -79,6 +82,7 @@ def _retrieve_default_environment_variables( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=script, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -116,6 +120,7 @@ def _retrieve_default_environment_variables( lambda instance_type: _retrieve_gated_model_uri_env_var_value( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, @@ -162,6 +167,7 @@ def _retrieve_default_environment_variables( def _retrieve_gated_model_uri_env_var_value( model_id: str, model_version: str, + hub_arn: Optional[str] = None, region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -175,6 +181,8 @@ def _retrieve_gated_model_uri_env_var_value( retrieve the gated model env var URI. model_version (str): Version of the JumpStart model for which to retrieve the gated model env var URI. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve the gated model env var URI. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -206,6 +214,7 @@ def _retrieve_gated_model_uri_env_var_value( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.TRAINING, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -221,4 +230,7 @@ def _retrieve_gated_model_uri_env_var_value( if s3_key is None: return None + if hub_arn: + return s3_key + return f"s3://{get_jumpstart_gated_content_bucket(region)}/{s3_key}" diff --git a/src/sagemaker/jumpstart/artifacts/hyperparameters.py b/src/sagemaker/jumpstart/artifacts/hyperparameters.py index d19530ecfb..308c3a5386 100644 --- a/src/sagemaker/jumpstart/artifacts/hyperparameters.py +++ b/src/sagemaker/jumpstart/artifacts/hyperparameters.py @@ -30,6 +30,7 @@ def _retrieve_default_hyperparameters( model_id: str, model_version: str, + hub_arn: Optional[str] = None, region: Optional[str] = None, include_container_hyperparameters: bool = False, tolerate_vulnerable_model: bool = False, @@ -44,6 +45,8 @@ def _retrieve_default_hyperparameters( retrieve the default hyperparameters. model_version (str): Version of the JumpStart model for which to retrieve the default hyperparameters. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (str): Region for which to retrieve default hyperparameters. (Default: None). include_container_hyperparameters (bool): True if container hyperparameters @@ -77,6 +80,7 @@ def _retrieve_default_hyperparameters( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.TRAINING, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/artifacts/image_uris.py b/src/sagemaker/jumpstart/artifacts/image_uris.py index 9d19d5e069..4f34cdd1e2 100644 --- a/src/sagemaker/jumpstart/artifacts/image_uris.py +++ b/src/sagemaker/jumpstart/artifacts/image_uris.py @@ -33,6 +33,7 @@ def _retrieve_image_uri( model_id: str, model_version: str, image_scope: str, + hub_arn: Optional[str] = None, framework: Optional[str] = None, region: Optional[str] = None, version: Optional[str] = None, @@ -57,6 +58,8 @@ def _retrieve_image_uri( model_id (str): JumpStart model ID for which to retrieve image URI. model_version (str): Version of the JumpStart model for which to retrieve the image URI. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). image_scope (str): The image type, i.e. what it is used for. Valid values: "training", "inference", "eia". If ``accelerator_type`` is set, ``image_scope`` is ignored. @@ -111,6 +114,7 @@ def _retrieve_image_uri( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=image_scope, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -126,6 +130,10 @@ def _retrieve_image_uri( ) if image_uri is not None: return image_uri + if hub_arn: + ecr_uri = model_specs.hosting_ecr_uri + return ecr_uri + ecr_specs = model_specs.hosting_ecr_specs if ecr_specs is None: raise ValueError( @@ -141,6 +149,10 @@ def _retrieve_image_uri( ) if image_uri is not None: return image_uri + if hub_arn: + ecr_uri = model_specs.training_ecr_uri + return ecr_uri + ecr_specs = model_specs.training_ecr_specs if ecr_specs is None: raise ValueError( @@ -194,6 +206,7 @@ def _retrieve_image_uri( version=version_override or ecr_specs.framework_version, py_version=ecr_specs.py_version, instance_type=instance_type, + hub_arn=hub_arn, accelerator_type=accelerator_type, image_scope=image_scope, container_version=container_version, diff --git a/src/sagemaker/jumpstart/artifacts/incremental_training.py b/src/sagemaker/jumpstart/artifacts/incremental_training.py index 1b3c6f4b29..17328c44e0 100644 --- a/src/sagemaker/jumpstart/artifacts/incremental_training.py +++ b/src/sagemaker/jumpstart/artifacts/incremental_training.py @@ -30,6 +30,7 @@ def _model_supports_incremental_training( model_id: str, model_version: str, region: Optional[str], + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -43,6 +44,8 @@ def _model_supports_incremental_training( support status for incremental training. region (Optional[str]): Region for which to retrieve the support status for incremental training. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -65,6 +68,7 @@ def _model_supports_incremental_training( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.TRAINING, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/artifacts/instance_types.py b/src/sagemaker/jumpstart/artifacts/instance_types.py index e7c9c5911d..4c9e8075c5 100644 --- a/src/sagemaker/jumpstart/artifacts/instance_types.py +++ b/src/sagemaker/jumpstart/artifacts/instance_types.py @@ -34,6 +34,7 @@ def _retrieve_default_instance_type( model_id: str, model_version: str, scope: str, + hub_arn: Optional[str] = None, region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -50,6 +51,8 @@ def _retrieve_default_instance_type( default instance type. scope (str): The script type, i.e. what it is used for. Valid values: "training" and "inference". + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve default instance type. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -83,6 +86,7 @@ def _retrieve_default_instance_type( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=scope, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -123,6 +127,7 @@ def _retrieve_instance_types( model_id: str, model_version: str, scope: str, + hub_arn: Optional[str] = None, region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -138,6 +143,8 @@ def _retrieve_instance_types( supported instance types. scope (str): The script type, i.e. what it is used for. Valid values: "training" and "inference". + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve supported instance types. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -171,6 +178,7 @@ def _retrieve_instance_types( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=scope, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -196,7 +204,7 @@ def _retrieve_instance_types( elif scope == JumpStartScriptScope.TRAINING: if training_instance_type is not None: - raise ValueError("Cannot use `training_instance_type` argument " "with training scope.") + raise ValueError("Cannot use `training_instance_type` argument with training scope.") instance_types = model_specs.supported_training_instance_types else: raise NotImplementedError( diff --git a/src/sagemaker/jumpstart/artifacts/kwargs.py b/src/sagemaker/jumpstart/artifacts/kwargs.py index 9cd152b0bb..84c26bdda2 100644 --- a/src/sagemaker/jumpstart/artifacts/kwargs.py +++ b/src/sagemaker/jumpstart/artifacts/kwargs.py @@ -32,6 +32,7 @@ def _retrieve_model_init_kwargs( model_id: str, model_version: str, + hub_arn: Optional[str] = None, region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -45,6 +46,8 @@ def _retrieve_model_init_kwargs( retrieve the kwargs. model_version (str): Version of the JumpStart model for which to retrieve the kwargs. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve kwargs. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -69,6 +72,7 @@ def _retrieve_model_init_kwargs( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.INFERENCE, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -89,6 +93,7 @@ def _retrieve_model_deploy_kwargs( model_id: str, model_version: str, instance_type: str, + hub_arn: Optional[str] = None, region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -104,6 +109,8 @@ def _retrieve_model_deploy_kwargs( kwargs. instance_type (str): Instance type of the hosting endpoint, to determine if volume size is supported. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve kwargs. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -129,6 +136,7 @@ def _retrieve_model_deploy_kwargs( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.INFERENCE, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -147,6 +155,7 @@ def _retrieve_estimator_init_kwargs( model_id: str, model_version: str, instance_type: str, + hub_arn: Optional[str] = None, region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -161,6 +170,8 @@ def _retrieve_estimator_init_kwargs( kwargs. instance_type (str): Instance type of the training job, to determine if volume size is supported. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve kwargs. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -185,6 +196,7 @@ def _retrieve_estimator_init_kwargs( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.TRAINING, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -206,6 +218,7 @@ def _retrieve_estimator_init_kwargs( def _retrieve_estimator_fit_kwargs( model_id: str, model_version: str, + hub_arn: Optional[str] = None, region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -218,6 +231,8 @@ def _retrieve_estimator_fit_kwargs( retrieve the kwargs. model_version (str): Version of the JumpStart model for which to retrieve the kwargs. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve kwargs. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -243,6 +258,7 @@ def _retrieve_estimator_fit_kwargs( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.TRAINING, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/artifacts/metric_definitions.py b/src/sagemaker/jumpstart/artifacts/metric_definitions.py index 57f66155c7..901f5cc455 100644 --- a/src/sagemaker/jumpstart/artifacts/metric_definitions.py +++ b/src/sagemaker/jumpstart/artifacts/metric_definitions.py @@ -31,6 +31,7 @@ def _retrieve_default_training_metric_definitions( model_id: str, model_version: str, region: Optional[str], + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -45,6 +46,8 @@ def _retrieve_default_training_metric_definitions( default training metric definitions. region (Optional[str]): Region for which to retrieve default training metric definitions. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -69,6 +72,7 @@ def _retrieve_default_training_metric_definitions( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.TRAINING, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/artifacts/model_packages.py b/src/sagemaker/jumpstart/artifacts/model_packages.py index aa22351771..b1f931eac4 100644 --- a/src/sagemaker/jumpstart/artifacts/model_packages.py +++ b/src/sagemaker/jumpstart/artifacts/model_packages.py @@ -32,6 +32,7 @@ def _retrieve_model_package_arn( model_version: str, instance_type: Optional[str], region: Optional[str], + hub_arn: Optional[str] = None, scope: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -48,6 +49,8 @@ def _retrieve_model_package_arn( instance_type (Optional[str]): An instance type to optionally supply in order to get an arn specific for the instance type. region (Optional[str]): Region for which to retrieve the model package arn. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). scope (Optional[str]): Scope for which to retrieve the model package arn. tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an @@ -72,6 +75,7 @@ def _retrieve_model_package_arn( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=scope, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -114,6 +118,7 @@ def _retrieve_model_package_model_artifact_s3_uri( model_id: str, model_version: str, region: Optional[str], + hub_arn: Optional[str] = None, scope: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -128,6 +133,8 @@ def _retrieve_model_package_model_artifact_s3_uri( model package artifact. region (Optional[str]): Region for which to retrieve the model package artifact. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). scope (Optional[str]): Scope for which to retrieve the model package artifact. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -157,6 +164,7 @@ def _retrieve_model_package_model_artifact_s3_uri( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=scope, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/artifacts/model_uris.py b/src/sagemaker/jumpstart/artifacts/model_uris.py index 6bb2e576fc..2cebacb9c0 100644 --- a/src/sagemaker/jumpstart/artifacts/model_uris.py +++ b/src/sagemaker/jumpstart/artifacts/model_uris.py @@ -89,6 +89,7 @@ def _retrieve_training_artifact_key(model_specs: JumpStartModelSpecs, instance_t def _retrieve_model_uri( model_id: str, model_version: str, + hub_arn: Optional[str] = None, model_scope: Optional[str] = None, instance_type: Optional[str] = None, region: Optional[str] = None, @@ -105,6 +106,8 @@ def _retrieve_model_uri( the model artifact S3 URI. model_version (str): Version of the JumpStart model for which to retrieve the model artifact S3 URI. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). model_scope (str): The model type, i.e. what it is used for. Valid values: "training" and "inference". instance_type (str): The ML compute instance type for the specified scope. (Default: None). @@ -136,6 +139,7 @@ def _retrieve_model_uri( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=model_scope, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -149,6 +153,9 @@ def _retrieve_model_uri( is_prepacked = not model_specs.use_inference_script_uri() + if hub_arn: + model_artifact_uri = model_specs.hosting_artifact_uri + return model_artifact_uri model_artifact_key = ( _retrieve_hosting_prepacked_artifact_key(model_specs, instance_type) if is_prepacked @@ -179,6 +186,7 @@ def _model_supports_training_model_uri( model_id: str, model_version: str, region: Optional[str], + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -192,6 +200,8 @@ def _model_supports_training_model_uri( support status for model uri with training. region (Optional[str]): Region for which to retrieve the support status for model uri with training. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -214,6 +224,7 @@ def _model_supports_training_model_uri( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.TRAINING, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/artifacts/payloads.py b/src/sagemaker/jumpstart/artifacts/payloads.py index 3359e32732..41c9c93ad2 100644 --- a/src/sagemaker/jumpstart/artifacts/payloads.py +++ b/src/sagemaker/jumpstart/artifacts/payloads.py @@ -33,6 +33,7 @@ def _retrieve_example_payloads( model_id: str, model_version: str, region: Optional[str], + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -47,6 +48,8 @@ def _retrieve_example_payloads( example payloads. region (Optional[str]): Region for which to retrieve the example payloads. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -70,6 +73,7 @@ def _retrieve_example_payloads( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.INFERENCE, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/artifacts/predictors.py b/src/sagemaker/jumpstart/artifacts/predictors.py index 4f6dfe1fe3..96d1c1f7fb 100644 --- a/src/sagemaker/jumpstart/artifacts/predictors.py +++ b/src/sagemaker/jumpstart/artifacts/predictors.py @@ -73,6 +73,7 @@ def _retrieve_deserializer_from_accept_type( def _retrieve_default_deserializer( model_id: str, model_version: str, + hub_arn: Optional[str], region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -86,6 +87,8 @@ def _retrieve_default_deserializer( retrieve the default deserializer. model_version (str): Version of the JumpStart model for which to retrieve the default deserializer. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve default deserializer. tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an @@ -106,6 +109,7 @@ def _retrieve_default_deserializer( default_accept_type = _retrieve_default_accept_type( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, @@ -119,6 +123,7 @@ def _retrieve_default_deserializer( def _retrieve_default_serializer( model_id: str, model_version: str, + hub_arn: Optional[str], region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -132,6 +137,8 @@ def _retrieve_default_serializer( retrieve the default serializer. model_version (str): Version of the JumpStart model for which to retrieve the default serializer. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve default serializer. tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an @@ -151,6 +158,7 @@ def _retrieve_default_serializer( default_content_type = _retrieve_default_content_type( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, @@ -164,6 +172,7 @@ def _retrieve_default_serializer( def _retrieve_deserializer_options( model_id: str, model_version: str, + hub_arn: Optional[str], region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -177,6 +186,8 @@ def _retrieve_deserializer_options( retrieve the supported deserializers. model_version (str): Version of the JumpStart model for which to retrieve the supported deserializers. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve deserializer options. tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an @@ -196,6 +207,7 @@ def _retrieve_deserializer_options( supported_accept_types = _retrieve_supported_accept_types( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, @@ -223,6 +235,7 @@ def _retrieve_deserializer_options( def _retrieve_serializer_options( model_id: str, model_version: str, + hub_arn: Optional[str], region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -235,6 +248,8 @@ def _retrieve_serializer_options( retrieve the supported serializers. model_version (str): Version of the JumpStart model for which to retrieve the supported serializers. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve serializer options. tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an @@ -254,6 +269,7 @@ def _retrieve_serializer_options( supported_content_types = _retrieve_supported_content_types( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, @@ -280,6 +296,7 @@ def _retrieve_serializer_options( def _retrieve_default_content_type( model_id: str, model_version: str, + hub_arn: Optional[str], region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -293,6 +310,8 @@ def _retrieve_default_content_type( retrieve the default content type. model_version (str): Version of the JumpStart model for which to retrieve the default content type. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve default content type. tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an @@ -316,6 +335,7 @@ def _retrieve_default_content_type( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.INFERENCE, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -331,6 +351,7 @@ def _retrieve_default_content_type( def _retrieve_default_accept_type( model_id: str, model_version: str, + hub_arn: Optional[str], region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -344,6 +365,8 @@ def _retrieve_default_accept_type( retrieve the default accept type. model_version (str): Version of the JumpStart model for which to retrieve the default accept type. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve default accept type. tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an @@ -367,6 +390,7 @@ def _retrieve_default_accept_type( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.INFERENCE, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -383,6 +407,7 @@ def _retrieve_default_accept_type( def _retrieve_supported_accept_types( model_id: str, model_version: str, + hub_arn: Optional[str], region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -396,6 +421,8 @@ def _retrieve_supported_accept_types( retrieve the supported accept types. model_version (str): Version of the JumpStart model for which to retrieve the supported accept types. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve accept type options. tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an @@ -419,6 +446,7 @@ def _retrieve_supported_accept_types( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.INFERENCE, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -435,6 +463,7 @@ def _retrieve_supported_accept_types( def _retrieve_supported_content_types( model_id: str, model_version: str, + hub_arn: Optional[str], region: Optional[str], tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -448,6 +477,8 @@ def _retrieve_supported_content_types( retrieve the supported content types. model_version (str): Version of the JumpStart model for which to retrieve the supported content types. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve content type options. tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an @@ -471,6 +502,7 @@ def _retrieve_supported_content_types( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.INFERENCE, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/artifacts/resource_names.py b/src/sagemaker/jumpstart/artifacts/resource_names.py index cffd46d043..0b92d46a23 100644 --- a/src/sagemaker/jumpstart/artifacts/resource_names.py +++ b/src/sagemaker/jumpstart/artifacts/resource_names.py @@ -31,6 +31,7 @@ def _retrieve_resource_name_base( model_id: str, model_version: str, region: Optional[str], + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, @@ -45,6 +46,8 @@ def _retrieve_resource_name_base( default resource name. region (Optional[str]): Region for which to retrieve the default resource name. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -67,6 +70,7 @@ def _retrieve_resource_name_base( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.INFERENCE, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/artifacts/resource_requirements.py b/src/sagemaker/jumpstart/artifacts/resource_requirements.py index 369acac85f..8936a3f824 100644 --- a/src/sagemaker/jumpstart/artifacts/resource_requirements.py +++ b/src/sagemaker/jumpstart/artifacts/resource_requirements.py @@ -48,6 +48,7 @@ def _retrieve_default_resources( model_id: str, model_version: str, scope: str, + hub_arn: Optional[str] = None, region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -64,6 +65,8 @@ def _retrieve_default_resources( default resource requirements. scope (str): The script type, i.e. what it is used for. Valid values: "training" and "inference". + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve default resource requirements. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -96,6 +99,7 @@ def _retrieve_default_resources( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=scope, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/artifacts/script_uris.py b/src/sagemaker/jumpstart/artifacts/script_uris.py index f69732d2e0..3c79f93985 100644 --- a/src/sagemaker/jumpstart/artifacts/script_uris.py +++ b/src/sagemaker/jumpstart/artifacts/script_uris.py @@ -32,6 +32,7 @@ def _retrieve_script_uri( model_id: str, model_version: str, + hub_arn: Optional[str] = None, script_scope: Optional[str] = None, region: Optional[str] = None, tolerate_vulnerable_model: bool = False, @@ -47,6 +48,8 @@ def _retrieve_script_uri( retrieve the script S3 URI. model_version (str): Version of the JumpStart model for which to retrieve the model script S3 URI. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). script_scope (str): The script type, i.e. what it is used for. Valid values: "training" and "inference". region (str): Region for which to retrieve model script S3 URI. @@ -78,6 +81,7 @@ def _retrieve_script_uri( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=script_scope, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -104,7 +108,8 @@ def _retrieve_script_uri( def _model_supports_inference_script_uri( model_id: str, model_version: str, - region: Optional[str], + hub_arn: Optional[str] = None, + region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -116,6 +121,8 @@ def _model_supports_inference_script_uri( retrieve the support status for script uri with inference. model_version (str): Version of the JumpStart model for which to retrieve the support status for script uri with inference. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). region (Optional[str]): Region for which to retrieve the support status for script uri with inference. tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -140,6 +147,7 @@ def _model_supports_inference_script_uri( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.INFERENCE, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index e9a34a21a8..257a9e71af 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -15,13 +15,14 @@ import datetime from difflib import get_close_matches import os -from typing import List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import json import boto3 import botocore from packaging.version import Version from packaging.specifiers import SpecifierSet, InvalidSpecifier from sagemaker.jumpstart.constants import ( + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE, ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE, JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY, @@ -42,16 +43,23 @@ JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON, ) from sagemaker.jumpstart.types import ( - JumpStartCachedS3ContentKey, - JumpStartCachedS3ContentValue, + JumpStartCachedContentKey, + JumpStartCachedContentValue, JumpStartModelHeader, JumpStartModelSpecs, JumpStartS3FileType, JumpStartVersionedModelId, + HubContentType, +) +from sagemaker.jumpstart.hub import utils as hub_utils +from sagemaker.jumpstart.hub.interfaces import DescribeHubContentResponse +from sagemaker.jumpstart.hub.parsers import ( + make_model_specs_from_describe_hub_content_response, ) from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.jumpstart import utils from sagemaker.utilities.cache import LRUCache +from sagemaker.session import Session class JumpStartModelsCache: @@ -77,6 +85,7 @@ def __init__( s3_bucket_name: Optional[str] = None, s3_client_config: Optional[botocore.config.Config] = None, s3_client: Optional[boto3.client] = None, + sagemaker_session: Optional[Session] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -> None: """Initialize a ``JumpStartModelsCache`` instance. @@ -98,13 +107,15 @@ def __init__( s3_client_config (Optional[botocore.config.Config]): s3 client config to use for cache. Default: None (no config). s3_client (Optional[boto3.client]): s3 client to use. Default: None. + sagemaker_session: sagemaker session object to use. + Default: session object from default region us-west-2. """ self._region = region or utils.get_region_fallback( s3_bucket_name=s3_bucket_name, s3_client=s3_client ) - self._s3_cache = LRUCache[JumpStartCachedS3ContentKey, JumpStartCachedS3ContentValue]( + self._content_cache = LRUCache[JumpStartCachedContentKey, JumpStartCachedContentValue]( max_cache_items=max_s3_cache_items, expiration_horizon=s3_cache_expiration_horizon, retrieval_function=self._retrieval_function, @@ -139,6 +150,7 @@ def __init__( if s3_client_config else boto3.client("s3", region_name=self._region) ) + self._sagemaker_session = sagemaker_session def set_region(self, region: str) -> None: """Set region for cache. Clears cache after new region is set.""" @@ -230,8 +242,8 @@ def _model_id_retrieval_function( model_id, version = key.model_id, key.version sm_version = utils.get_sagemaker_version() - manifest = self._s3_cache.get( - JumpStartCachedS3ContentKey( + manifest = self._content_cache.get( + JumpStartCachedContentKey( MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type] ) )[0].formatted_content @@ -392,53 +404,96 @@ def _get_json_file_from_local_override( def _retrieval_function( self, - key: JumpStartCachedS3ContentKey, - value: Optional[JumpStartCachedS3ContentValue], - ) -> JumpStartCachedS3ContentValue: - """Return s3 content given a file type and s3_key in ``JumpStartCachedS3ContentKey``. + key: JumpStartCachedContentKey, + value: Optional[JumpStartCachedContentValue], + ) -> JumpStartCachedContentValue: + """Return s3 content given a file type and s3_key in ``JumpStartCachedContentKey``. If a manifest file is being fetched, we only download the object if the md5 hash in ``head_object`` does not match the current md5 hash for the stored value. This prevents unnecessarily downloading the full manifest when it hasn't changed. Args: - key (JumpStartCachedS3ContentKey): key for which to fetch s3 content. + key (JumpStartCachedContentKey): key for which to fetch s3 content. value (Optional[JumpStartVersionedModelId]): Current value of old cached s3 content. This is used for the manifest file, so that it is only downloaded when its content changes. """ - file_type, s3_key = key.file_type, key.s3_key - if file_type in { + data_type, id_info = key.data_type, key.id_info + + if data_type in { JumpStartS3FileType.OPEN_WEIGHT_MANIFEST, JumpStartS3FileType.PROPRIETARY_MANIFEST, }: if value is not None and not self._is_local_metadata_mode(): - etag = self._get_json_md5_hash(s3_key) + etag = self._get_json_md5_hash(id_info) if etag == value.md5_hash: return value - formatted_body, etag = self._get_json_file(s3_key, file_type) - return JumpStartCachedS3ContentValue( + formatted_body, etag = self._get_json_file(id_info, data_type) + return JumpStartCachedContentValue( formatted_content=utils.get_formatted_manifest(formatted_body), md5_hash=etag, ) - if file_type in { + if data_type in { JumpStartS3FileType.OPEN_WEIGHT_SPECS, JumpStartS3FileType.PROPRIETARY_SPECS, }: - formatted_body, _ = self._get_json_file(s3_key, file_type) + formatted_body, _ = self._get_json_file(id_info, data_type) model_specs = JumpStartModelSpecs(formatted_body) utils.emit_logs_based_on_model_specs(model_specs, self.get_region(), self._s3_client) - return JumpStartCachedS3ContentValue(formatted_content=model_specs) - raise ValueError(self._file_type_error_msg(file_type)) + return JumpStartCachedContentValue(formatted_content=model_specs) + + if data_type == HubContentType.NOTEBOOK: + hub_name, _, notebook_name, notebook_version = hub_utils.get_info_from_hub_resource_arn( + id_info + ) + response: Dict[str, Any] = self._sagemaker_session.describe_hub_content( + hub_name=hub_name, + hub_content_name=notebook_name, + hub_content_version=notebook_version, + hub_content_type=data_type, + ) + hub_notebook_description = DescribeHubContentResponse(response) + return JumpStartCachedContentValue(formatted_content=hub_notebook_description) + + if data_type in { + HubContentType.MODEL, + HubContentType.MODEL_REFERENCE, + }: + + hub_arn_extracted_info = hub_utils.get_info_from_hub_resource_arn(id_info) + + model_version: str = hub_utils.get_hub_model_version( + hub_model_name=hub_arn_extracted_info.hub_content_name, + hub_model_type=data_type.value, + hub_name=hub_arn_extracted_info.hub_name, + sagemaker_session=self._sagemaker_session, + hub_model_version=hub_arn_extracted_info.hub_content_version, + ) + + hub_model_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content( + hub_name=hub_arn_extracted_info.hub_name, + hub_content_name=hub_arn_extracted_info.hub_content_name, + hub_content_version=model_version, + hub_content_type=data_type.value, + ) + + model_specs = make_model_specs_from_describe_hub_content_response( + DescribeHubContentResponse(hub_model_description), + ) + + return JumpStartCachedContentValue(formatted_content=model_specs) + + raise ValueError(self._file_type_error_msg(data_type)) def get_manifest( self, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> List[JumpStartModelHeader]: """Return entire JumpStart models manifest.""" - manifest_dict = self._s3_cache.get( - JumpStartCachedS3ContentKey( + manifest_dict = self._content_cache.get( + JumpStartCachedContentKey( MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type] ) )[0].formatted_content @@ -525,8 +580,8 @@ def _get_header_impl( JumpStartVersionedModelId(model_id, semantic_version_str) )[0] - manifest = self._s3_cache.get( - JumpStartCachedS3ContentKey( + manifest = self._content_cache.get( + JumpStartCachedContentKey( MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type] ) )[0].formatted_content @@ -556,8 +611,8 @@ def get_specs( """ header = self.get_header(model_id, version_str, model_type) spec_key = header.spec_key - specs, cache_hit = self._s3_cache.get( - JumpStartCachedS3ContentKey(MODEL_TYPE_TO_SPECS_MAP[model_type], spec_key) + specs, cache_hit = self._content_cache.get( + JumpStartCachedContentKey(MODEL_TYPE_TO_SPECS_MAP[model_type], spec_key) ) if not cache_hit and "*" in version_str: @@ -566,8 +621,38 @@ def get_specs( ) return specs.formatted_content + def get_hub_model(self, hub_model_arn: str) -> JumpStartModelSpecs: + """Return JumpStart-compatible specs for a given Hub model + + Args: + hub_model_arn (str): Arn for the Hub model to get specs for + """ + + details, _ = self._content_cache.get( + JumpStartCachedContentKey( + HubContentType.MODEL, + hub_model_arn, + ) + ) + return details.formatted_content + + def get_hub_model_reference(self, hub_model_reference_arn: str) -> JumpStartModelSpecs: + """Return JumpStart-compatible specs for a given Hub model reference + + Args: + hub_model_arn (str): Arn for the Hub model to get specs for + """ + + details, _ = self._content_cache.get( + JumpStartCachedContentKey( + HubContentType.MODEL_REFERENCE, + hub_model_reference_arn, + ) + ) + return details.formatted_content + def clear(self) -> None: """Clears the model ID/version and s3 cache.""" - self._s3_cache.clear() + self._content_cache.clear() self._open_weight_model_id_manifest_key_cache.clear() self._proprietary_model_id_manifest_key_cache.clear() diff --git a/src/sagemaker/jumpstart/constants.py b/src/sagemaker/jumpstart/constants.py index 5b0f749c64..076e1d2fa1 100644 --- a/src/sagemaker/jumpstart/constants.py +++ b/src/sagemaker/jumpstart/constants.py @@ -185,9 +185,14 @@ JUMPSTART_DEFAULT_REGION_NAME = boto3.session.Session().region_name or "us-west-2" +JUMPSTART_MODEL_HUB_NAME = "SageMakerPublicHub" + JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json" JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY = "proprietary-sdk-manifest.json" +HUB_CONTENT_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub-content/(.*?)/(.*?)/(.*?)/(.*?)$" +HUB_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub/(.*?)$" + INFERENCE_ENTRY_POINT_SCRIPT_NAME = "inference.py" TRAINING_ENTRY_POINT_SCRIPT_NAME = "transfer_learning.py" diff --git a/src/sagemaker/jumpstart/enums.py b/src/sagemaker/jumpstart/enums.py index ca49fd41a3..6c77e72b9b 100644 --- a/src/sagemaker/jumpstart/enums.py +++ b/src/sagemaker/jumpstart/enums.py @@ -15,6 +15,7 @@ from __future__ import absolute_import from enum import Enum +from typing import List class ModelFramework(str, Enum): @@ -93,6 +94,8 @@ class JumpStartTag(str, Enum): MODEL_VERSION = "sagemaker-sdk:jumpstart-model-version" MODEL_TYPE = "sagemaker-sdk:jumpstart-model-type" + HUB_CONTENT_ARN = "sagemaker-sdk:hub-content-arn" + class SerializerType(str, Enum): """Enum class for serializers associated with JumpStart models.""" @@ -126,6 +129,28 @@ def from_suffixed_type(mime_type_with_suffix: str) -> "MIMEType": return MIMEType(base_type) +class NamingConventionType(str, Enum): + """Enum class for naming conventions.""" + + SNAKE_CASE = "snake_case" + UPPER_CAMEL_CASE = "upper_camel_case" + DEFAULT = UPPER_CAMEL_CASE + + +class ModelSpecKwargType(str, Enum): + """Enum class for types of kwargs for model hub content document and model specs.""" + + FIT = "fit_kwargs" + MODEL = "model_kwargs" + ESTIMATOR = "estimator_kwargs" + DEPLOY = "deploy_kwargs" + + @classmethod + def arg_keys(cls) -> List[str]: + """Returns a list of kwargs keys that each type can have""" + return [member.value for member in cls] + + class JumpStartConfigRankingName(str, Enum): """Enum class for ranking of JumpStart config.""" diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index f53d109dc8..7d600ddfbc 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -28,6 +28,7 @@ from sagemaker.instance_group import InstanceGroup from sagemaker.jumpstart.accessors import JumpStartModelsAccessor from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.jumpstart.hub.utils import generate_hub_arn_for_init_kwargs from sagemaker.jumpstart.enums import JumpStartScriptScope from sagemaker.jumpstart.exceptions import INVALID_MODEL_ID_ERROR_MSG @@ -58,6 +59,7 @@ def __init__( self, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_name: Optional[str] = None, tolerate_vulnerable_model: Optional[bool] = None, tolerate_deprecated_model: Optional[bool] = None, region: Optional[str] = None, @@ -124,6 +126,7 @@ def __init__( https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html for list of model IDs. model_version (Optional[str]): Version for JumpStart model to use (Default: None). + hub_name (Optional[str]): Hub name or arn where the model is stored (Default: None). tolerate_vulnerable_model (Optional[bool]): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies @@ -508,6 +511,12 @@ def __init__( ValueError: If the model ID is not recognized by JumpStart. """ + hub_arn = None + if hub_name: + hub_arn = generate_hub_arn_for_init_kwargs( + hub_name=hub_name, region=region, session=sagemaker_session + ) + def _validate_model_id_and_get_type_hook(): return validate_model_id_and_get_type( model_id=model_id, @@ -515,18 +524,20 @@ def _validate_model_id_and_get_type_hook(): region=region or getattr(sagemaker_session, "boto_region_name", None), script=JumpStartScriptScope.TRAINING, sagemaker_session=sagemaker_session, + hub_arn=hub_arn, ) self.model_type = _validate_model_id_and_get_type_hook() if not self.model_type: JumpStartModelsAccessor.reset_cache() self.model_type = _validate_model_id_and_get_type_hook() - if not self.model_type: + if not self.model_type and not hub_arn: raise ValueError(INVALID_MODEL_ID_ERROR_MSG.format(model_id=model_id)) estimator_init_kwargs = get_init_kwargs( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, model_type=self.model_type, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, @@ -584,6 +595,7 @@ def _validate_model_id_and_get_type_hook(): enable_session_tag_chaining=enable_session_tag_chaining, ) + self.hub_arn = estimator_init_kwargs.hub_arn self.model_id = estimator_init_kwargs.model_id self.model_version = estimator_init_kwargs.model_version self.instance_type = estimator_init_kwargs.instance_type @@ -660,6 +672,7 @@ def fit( estimator_fit_kwargs = get_fit_kwargs( model_id=self.model_id, model_version=self.model_version, + hub_arn=self.hub_arn, region=self.region, inputs=inputs, wait=wait, @@ -679,6 +692,7 @@ def attach( training_job_name: str, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, sagemaker_session: session.Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, model_channel_name: str = "model", ) -> "JumpStartEstimator": @@ -744,6 +758,7 @@ def attach( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, region=sagemaker_session.boto_region_name, scope=JumpStartScriptScope.TRAINING, tolerate_deprecated_model=True, # model is already trained, so tolerate if deprecated @@ -1047,6 +1062,7 @@ def deploy( estimator_deploy_kwargs = get_deploy_kwargs( model_id=self.model_id, model_version=self.model_version, + hub_arn=self.hub_arn, region=self.region, tolerate_vulnerable_model=self.tolerate_vulnerable_model, tolerate_deprecated_model=self.tolerate_deprecated_model, @@ -1097,6 +1113,7 @@ def deploy( predictor=predictor, model_id=self.model_id, model_version=self.model_version, + hub_arn=self.hub_arn, region=self.region, tolerate_deprecated_model=self.tolerate_deprecated_model, tolerate_vulnerable_model=self.tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index d6cea3cf09..d3e597c395 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -60,6 +60,7 @@ JumpStartModelInitKwargs, ) from sagemaker.jumpstart.utils import ( + add_hub_content_arn_tags, add_jumpstart_model_id_version_tags, get_eula_message, get_default_jumpstart_session_with_user_agent_suffix, @@ -78,6 +79,7 @@ def get_init_kwargs( model_id: str, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, tolerate_vulnerable_model: Optional[bool] = None, tolerate_deprecated_model: Optional[bool] = None, @@ -137,6 +139,7 @@ def get_init_kwargs( estimator_init_kwargs: JumpStartEstimatorInitKwargs = JumpStartEstimatorInitKwargs( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, model_type=model_type, role=role, region=region, @@ -214,6 +217,7 @@ def get_init_kwargs( def get_fit_kwargs( model_id: str, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, region: Optional[str] = None, inputs: Optional[Union[str, Dict, TrainingInput, FileSystemInput]] = None, wait: Optional[bool] = None, @@ -229,6 +233,7 @@ def get_fit_kwargs( estimator_fit_kwargs: JumpStartEstimatorFitKwargs = JumpStartEstimatorFitKwargs( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, inputs=inputs, wait=wait, @@ -251,6 +256,7 @@ def get_fit_kwargs( def get_deploy_kwargs( model_id: str, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, region: Optional[str] = None, initial_instance_count: Optional[int] = None, instance_type: Optional[str] = None, @@ -296,6 +302,7 @@ def get_deploy_kwargs( model_deploy_kwargs: JumpStartModelDeployKwargs = model.get_deploy_kwargs( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, initial_instance_count=initial_instance_count, instance_type=instance_type, @@ -324,6 +331,7 @@ def get_deploy_kwargs( model_id=model_id, model_from_estimator=True, model_version=model_version, + hub_arn=hub_arn, instance_type=( model_deploy_kwargs.instance_type if training_instance_type is None @@ -356,6 +364,7 @@ def get_deploy_kwargs( estimator_deploy_kwargs: JumpStartEstimatorDeployKwargs = JumpStartEstimatorDeployKwargs( model_id=model_init_kwargs.model_id, model_version=model_init_kwargs.model_version, + hub_arn=hub_arn, instance_type=model_init_kwargs.instance_type, initial_instance_count=model_deploy_kwargs.initial_instance_count, region=model_init_kwargs.region, @@ -424,6 +433,20 @@ def _add_model_version_to_kwargs(kwargs: JumpStartKwargs) -> JumpStartKwargs: kwargs.model_version = kwargs.model_version or "*" + if kwargs.hub_arn: + hub_content_version = verify_model_region_and_return_specs( + model_id=kwargs.model_id, + version=kwargs.model_version, + hub_arn=kwargs.hub_arn, + scope=JumpStartScriptScope.TRAINING, + region=kwargs.region, + tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + tolerate_deprecated_model=kwargs.tolerate_deprecated_model, + sagemaker_session=kwargs.sagemaker_session, + model_type=kwargs.model_type, + ).version + kwargs.model_version = hub_content_version + return kwargs @@ -451,6 +474,7 @@ def _add_instance_type_and_count_to_kwargs( region=kwargs.region, model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, scope=JumpStartScriptScope.TRAINING, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -473,6 +497,7 @@ def _add_tags_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstima full_model_version = verify_model_region_and_return_specs( model_id=kwargs.model_id, version=kwargs.model_version, + hub_arn=kwargs.hub_arn, scope=JumpStartScriptScope.TRAINING, region=kwargs.region, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -484,6 +509,10 @@ def _add_tags_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstima kwargs.tags = add_jumpstart_model_id_version_tags( kwargs.tags, kwargs.model_id, full_model_version ) + + if kwargs.hub_arn: + kwargs.tags = add_hub_content_arn_tags(kwargs.tags, kwargs.hub_arn) + return kwargs @@ -496,6 +525,7 @@ def _add_image_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE image_scope=JumpStartScriptScope.TRAINING, model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, instance_type=kwargs.instance_type, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -511,6 +541,7 @@ def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE if _model_supports_training_model_uri( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -533,6 +564,7 @@ def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE and not _model_supports_incremental_training( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -568,6 +600,7 @@ def _add_source_dir_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStart script_scope=JumpStartScriptScope.TRAINING, model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, region=kwargs.region, @@ -585,6 +618,7 @@ def _add_env_to_kwargs( extra_env_vars = environment_variables.retrieve_default( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, region=kwargs.region, include_aws_sdk_env_vars=False, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, @@ -597,6 +631,7 @@ def _add_env_to_kwargs( model_package_artifact_uri = _retrieve_model_package_model_artifact_s3_uri( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, region=kwargs.region, scope=JumpStartScriptScope.TRAINING, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, @@ -624,6 +659,7 @@ def _add_env_to_kwargs( model_specs = verify_model_region_and_return_specs( model_id=kwargs.model_id, version=kwargs.model_version, + hub_arn=kwargs.hub_arn, region=kwargs.region, scope=JumpStartScriptScope.TRAINING, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, @@ -657,6 +693,7 @@ def _add_training_job_name_to_kwargs( default_training_job_name = _retrieve_resource_name_base( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -683,6 +720,7 @@ def _add_hyperparameters_to_kwargs( region=kwargs.region, model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, @@ -716,6 +754,7 @@ def _add_metric_definitions_to_kwargs( region=kwargs.region, model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, @@ -744,6 +783,7 @@ def _add_estimator_extra_kwargs( estimator_kwargs_to_add = _retrieve_estimator_init_kwargs( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, instance_type=kwargs.instance_type, region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, @@ -769,6 +809,7 @@ def _add_fit_extra_kwargs(kwargs: JumpStartEstimatorFitKwargs) -> JumpStartEstim fit_kwargs_to_add = _retrieve_estimator_fit_kwargs( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 380afbb433..55dfa1394a 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -34,16 +34,19 @@ JUMPSTART_LOGGER, ) from sagemaker.model_card.model_card import ModelCard, ModelPackageModelCard +from sagemaker.jumpstart.hub.utils import construct_hub_model_reference_arn_from_inputs from sagemaker.model_metrics import ModelMetrics from sagemaker.metadata_properties import MetadataProperties from sagemaker.drift_check_baselines import DriftCheckBaselines from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType from sagemaker.jumpstart.types import ( + HubContentType, JumpStartModelDeployKwargs, JumpStartModelInitKwargs, JumpStartModelRegisterKwargs, ) from sagemaker.jumpstart.utils import ( + add_hub_content_arn_tags, add_jumpstart_model_id_version_tags, get_default_jumpstart_session_with_user_agent_suffix, update_dict_if_key_not_present, @@ -68,6 +71,7 @@ def get_default_predictor( predictor: Predictor, model_id: str, model_version: str, + hub_arn: Optional[str], region: str, tolerate_vulnerable_model: bool, tolerate_deprecated_model: bool, @@ -90,6 +94,7 @@ def get_default_predictor( predictor.serializer = serializers.retrieve_default( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -99,6 +104,7 @@ def get_default_predictor( predictor.deserializer = deserializers.retrieve_default( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -108,6 +114,7 @@ def get_default_predictor( predictor.accept = accept_types.retrieve_default( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -117,6 +124,7 @@ def get_default_predictor( predictor.content_type = content_types.retrieve_default( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -170,6 +178,20 @@ def _add_model_version_to_kwargs( kwargs.model_version = kwargs.model_version or "*" + if kwargs.hub_arn: + hub_content_version = verify_model_region_and_return_specs( + model_id=kwargs.model_id, + version=kwargs.model_version, + hub_arn=kwargs.hub_arn, + scope=JumpStartScriptScope.INFERENCE, + region=kwargs.region, + tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + tolerate_deprecated_model=kwargs.tolerate_deprecated_model, + sagemaker_session=kwargs.sagemaker_session, + model_type=kwargs.model_type, + ).version + kwargs.model_version = hub_content_version + return kwargs @@ -195,6 +217,7 @@ def _add_instance_type_to_kwargs( region=kwargs.region, model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, scope=JumpStartScriptScope.INFERENCE, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -228,6 +251,7 @@ def _add_image_uri_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel image_scope=JumpStartScriptScope.INFERENCE, model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, instance_type=kwargs.instance_type, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -237,6 +261,32 @@ def _add_image_uri_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel return kwargs +def _add_model_reference_arn_to_kwargs( + kwargs: JumpStartModelInitKwargs, +) -> JumpStartModelInitKwargs: + """Sets Model Reference ARN if the hub content type is Model Reference, returns full kwargs.""" + hub_content_type = verify_model_region_and_return_specs( + model_id=kwargs.model_id, + version=kwargs.model_version, + hub_arn=kwargs.hub_arn, + scope=JumpStartScriptScope.INFERENCE, + region=kwargs.region, + tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, + tolerate_deprecated_model=kwargs.tolerate_deprecated_model, + sagemaker_session=kwargs.sagemaker_session, + model_type=kwargs.model_type, + ).hub_content_type + kwargs.hub_content_type = hub_content_type if kwargs.hub_arn else None + + if hub_content_type == HubContentType.MODEL_REFERENCE: + kwargs.model_reference_arn = construct_hub_model_reference_arn_from_inputs( + hub_arn=kwargs.hub_arn, model_name=kwargs.model_id, version=kwargs.model_version + ) + else: + kwargs.model_reference_arn = None + return kwargs + + def _add_model_data_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs: """Sets model data based on default or override, returns full kwargs.""" @@ -248,6 +298,7 @@ def _add_model_data_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMode model_scope=JumpStartScriptScope.INFERENCE, model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -289,6 +340,7 @@ def _add_source_dir_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMode if _model_supports_inference_script_uri( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -298,6 +350,7 @@ def _add_source_dir_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMode script_scope=JumpStartScriptScope.INFERENCE, model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -321,6 +374,7 @@ def _add_entry_point_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartMod if _model_supports_inference_script_uri( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -349,6 +403,7 @@ def _add_env_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKw extra_env_vars = environment_variables.retrieve_default( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, region=kwargs.region, include_aws_sdk_env_vars=False, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, @@ -379,6 +434,7 @@ def _add_model_package_arn_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSt model_package_arn = kwargs.model_package_arn or _retrieve_model_package_arn( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, instance_type=kwargs.instance_type, scope=JumpStartScriptScope.INFERENCE, region=kwargs.region, @@ -398,6 +454,7 @@ def _add_extra_model_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelI model_kwargs_to_add = _retrieve_model_init_kwargs( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -434,6 +491,7 @@ def _add_endpoint_name_to_kwargs( default_endpoint_name = _retrieve_resource_name_base( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -456,6 +514,7 @@ def _add_model_name_to_kwargs( default_model_name = _retrieve_resource_name_base( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -476,6 +535,7 @@ def _add_tags_to_kwargs(kwargs: JumpStartModelDeployKwargs) -> Dict[str, Any]: full_model_version = verify_model_region_and_return_specs( model_id=kwargs.model_id, version=kwargs.model_version, + hub_arn=kwargs.hub_arn, scope=JumpStartScriptScope.INFERENCE, region=kwargs.region, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -489,6 +549,9 @@ def _add_tags_to_kwargs(kwargs: JumpStartModelDeployKwargs) -> Dict[str, Any]: kwargs.tags, kwargs.model_id, full_model_version, kwargs.model_type ) + if kwargs.hub_arn: + kwargs.tags = add_hub_content_arn_tags(kwargs.tags, kwargs.hub_arn) + return kwargs @@ -498,6 +561,7 @@ def _add_deploy_extra_kwargs(kwargs: JumpStartModelInitKwargs) -> Dict[str, Any] deploy_kwargs_to_add = _retrieve_model_deploy_kwargs( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, instance_type=kwargs.instance_type, region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, @@ -520,6 +584,7 @@ def _add_resources_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel region=kwargs.region, model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, scope=JumpStartScriptScope.INFERENCE, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -534,6 +599,7 @@ def _add_resources_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel def get_deploy_kwargs( model_id: str, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, region: Optional[str] = None, initial_instance_count: Optional[int] = None, @@ -558,6 +624,7 @@ def get_deploy_kwargs( tolerate_deprecated_model: Optional[bool] = None, sagemaker_session: Optional[Session] = None, accept_eula: Optional[bool] = None, + model_reference_arn: Optional[str] = None, endpoint_logging: Optional[bool] = None, resources: Optional[ResourceRequirements] = None, managed_instance_scaling: Optional[str] = None, @@ -569,6 +636,7 @@ def get_deploy_kwargs( deploy_kwargs: JumpStartModelDeployKwargs = JumpStartModelDeployKwargs( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, model_type=model_type, region=region, initial_instance_count=initial_instance_count, @@ -593,6 +661,7 @@ def get_deploy_kwargs( tolerate_vulnerable_model=tolerate_vulnerable_model, sagemaker_session=sagemaker_session, accept_eula=accept_eula, + model_reference_arn=model_reference_arn, endpoint_logging=endpoint_logging, resources=resources, routing_config=routing_config, @@ -623,6 +692,7 @@ def get_deploy_kwargs( def get_register_kwargs( model_id: str, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, region: Optional[str] = None, tolerate_deprecated_model: Optional[bool] = None, tolerate_vulnerable_model: Optional[bool] = None, @@ -656,6 +726,7 @@ def get_register_kwargs( register_kwargs = JumpStartModelRegisterKwargs( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -688,6 +759,7 @@ def get_register_kwargs( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, region=region, scope=JumpStartScriptScope.INFERENCE, sagemaker_session=sagemaker_session, @@ -709,6 +781,7 @@ def get_init_kwargs( model_id: str, model_from_estimator: bool = False, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, tolerate_vulnerable_model: Optional[bool] = None, tolerate_deprecated_model: Optional[bool] = None, @@ -741,6 +814,7 @@ def get_init_kwargs( model_init_kwargs: JumpStartModelInitKwargs = JumpStartModelInitKwargs( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, model_type=model_type, instance_type=instance_type, region=region, @@ -768,13 +842,12 @@ def get_init_kwargs( resources=resources, ) - model_init_kwargs = _add_model_version_to_kwargs(kwargs=model_init_kwargs) - model_init_kwargs = _add_vulnerable_and_deprecated_status_to_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_sagemaker_session_to_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_region_to_kwargs(kwargs=model_init_kwargs) + model_init_kwargs = _add_model_version_to_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_model_name_to_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_instance_type_to_kwargs( @@ -783,6 +856,12 @@ def get_init_kwargs( model_init_kwargs = _add_image_uri_to_kwargs(kwargs=model_init_kwargs) + if hub_arn: + model_init_kwargs = _add_model_reference_arn_to_kwargs(kwargs=model_init_kwargs) + else: + model_init_kwargs.model_reference_arn = None + model_init_kwargs.hub_content_type = None + # we use the model artifact from the training job output if not model_from_estimator: model_init_kwargs = _add_model_data_to_kwargs(kwargs=model_init_kwargs) diff --git a/src/sagemaker/jumpstart/hub/__init__.py b/src/sagemaker/jumpstart/hub/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/sagemaker/jumpstart/hub/constants.py b/src/sagemaker/jumpstart/hub/constants.py new file mode 100644 index 0000000000..e3a6b7752a --- /dev/null +++ b/src/sagemaker/jumpstart/hub/constants.py @@ -0,0 +1,16 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module stores constants related to SageMaker JumpStart Hub.""" +from __future__ import absolute_import + +LATEST_VERSION_WILDCARD = "*" diff --git a/src/sagemaker/jumpstart/hub/hub.py b/src/sagemaker/jumpstart/hub/hub.py new file mode 100644 index 0000000000..1545fe3a36 --- /dev/null +++ b/src/sagemaker/jumpstart/hub/hub.py @@ -0,0 +1,304 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# pylint: skip-file +"""This module provides the JumpStart Hub class.""" +from __future__ import absolute_import +from datetime import datetime +import logging +from typing import Optional, Dict, List, Any, Union +from botocore import exceptions + +from sagemaker.jumpstart.constants import JUMPSTART_MODEL_HUB_NAME +from sagemaker.jumpstart.enums import JumpStartScriptScope +from sagemaker.session import Session + +from sagemaker.jumpstart.constants import ( + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + JUMPSTART_LOGGER, +) +from sagemaker.jumpstart.types import ( + HubContentType, +) +from sagemaker.jumpstart.filters import Constant, Operator, BooleanValues +from sagemaker.jumpstart.hub.utils import ( + get_hub_model_version, + get_info_from_hub_resource_arn, + create_hub_bucket_if_it_does_not_exist, + generate_default_hub_bucket_name, + create_s3_object_reference_from_uri, + construct_hub_arn_from_name, +) + +from sagemaker.jumpstart.notebook_utils import ( + list_jumpstart_models, +) + +from sagemaker.jumpstart.hub.types import ( + S3ObjectLocation, +) +from sagemaker.jumpstart.hub.interfaces import ( + DescribeHubResponse, + DescribeHubContentResponse, +) +from sagemaker.jumpstart.hub.constants import ( + LATEST_VERSION_WILDCARD, +) +from sagemaker.jumpstart import utils + + +class Hub: + """Class for creating and managing a curated JumpStart hub""" + + # Setting LOGGER for backward compatibility, in case users import it... + logger = LOGGER = logging.getLogger("sagemaker") + + _list_hubs_cache: List[Dict[str, Any]] = [] + + def __init__( + self, + hub_name: str, + bucket_name: Optional[str] = None, + sagemaker_session: Optional[Session] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + ) -> None: + """Instantiates a SageMaker ``Hub``. + + Args: + hub_name (str): The name of the Hub to create. + sagemaker_session (sagemaker.session.Session): A SageMaker Session + object, used for SageMaker interactions. + """ + self.hub_name = hub_name + self.region = sagemaker_session.boto_region_name + self._sagemaker_session = sagemaker_session + self.hub_storage_location = self._generate_hub_storage_location(bucket_name) + + def _fetch_hub_bucket_name(self) -> str: + """Retrieves hub bucket name from Hub config if exists""" + try: + hub_response = self._sagemaker_session.describe_hub(hub_name=self.hub_name) + hub_output_location = hub_response["S3StorageConfig"].get("S3OutputPath") + if hub_output_location: + location = create_s3_object_reference_from_uri(hub_output_location) + return location.bucket + default_bucket_name = generate_default_hub_bucket_name(self._sagemaker_session) + JUMPSTART_LOGGER.warning( + "There is not a Hub bucket associated with %s. Using %s", + self.hub_name, + default_bucket_name, + ) + return default_bucket_name + except exceptions.ClientError: + hub_bucket_name = generate_default_hub_bucket_name(self._sagemaker_session) + JUMPSTART_LOGGER.warning( + "There is not a Hub bucket associated with %s. Using %s", + self.hub_name, + hub_bucket_name, + ) + return hub_bucket_name + + def _generate_hub_storage_location(self, bucket_name: Optional[str] = None) -> None: + """Generates an ``S3ObjectLocation`` given a Hub name.""" + hub_bucket_name = bucket_name or self._fetch_hub_bucket_name() + curr_timestamp = datetime.now().timestamp() + return S3ObjectLocation(bucket=hub_bucket_name, key=f"{self.hub_name}-{curr_timestamp}") + + def _get_latest_model_version(self, model_id: str) -> str: + """Populates the lastest version of a model from specs no matter what is passed. + + Returns model ({ model_id: str, version: str }) + """ + model_specs = utils.verify_model_region_and_return_specs( + model_id, LATEST_VERSION_WILDCARD, JumpStartScriptScope.INFERENCE, self.region + ) + return model_specs.version + + def create( + self, + description: str, + display_name: Optional[str] = None, + search_keywords: Optional[str] = None, + tags: Optional[str] = None, + ) -> Dict[str, str]: + """Creates a hub with the given description""" + + create_hub_bucket_if_it_does_not_exist( + self.hub_storage_location.bucket, self._sagemaker_session + ) + + return self._sagemaker_session.create_hub( + hub_name=self.hub_name, + hub_description=description, + hub_display_name=display_name, + hub_search_keywords=search_keywords, + s3_storage_config={"S3OutputPath": self.hub_storage_location.get_uri()}, + tags=tags, + ) + + def describe(self, hub_name: Optional[str] = None) -> DescribeHubResponse: + """Returns descriptive information about the Hub""" + + hub_description: DescribeHubResponse = self._sagemaker_session.describe_hub( + hub_name=self.hub_name if not hub_name else hub_name + ) + + return hub_description + + def _list_and_paginate_models(self, **kwargs) -> List[Dict[str, Any]]: + """List and paginate models from Hub.""" + next_token: Optional[str] = None + first_iteration: bool = True + hub_model_summaries: List[Dict[str, Any]] = [] + + while first_iteration or next_token: + first_iteration = False + list_hub_content_response = self._sagemaker_session.list_hub_contents(**kwargs) + hub_model_summaries.extend(list_hub_content_response.get("HubContentSummaries", [])) + next_token = list_hub_content_response.get("NextToken") + + return hub_model_summaries + + def list_models(self, clear_cache: bool = True, **kwargs) -> Dict[str, Any]: + """Lists the models and model references in this SageMaker Hub. + + This function caches the models in local memory + + **kwargs: Passed to invocation of ``Session:list_hub_contents``. + """ + response = {} + + if clear_cache: + self._list_hubs_cache = None + if self._list_hubs_cache is None: + + hub_model_reference_summaries = self._list_and_paginate_models( + **{ + "hub_name": self.hub_name, + "hub_content_type": HubContentType.MODEL_REFERENCE.value, + } + | kwargs + ) + + hub_model_summaries = self._list_and_paginate_models( + **{"hub_name": self.hub_name, "hub_content_type": HubContentType.MODEL.value} + | kwargs + ) + response["hub_content_summaries"] = hub_model_reference_summaries + hub_model_summaries + response["next_token"] = None # Temporary until pagination is implemented + return response + + def list_sagemaker_public_hub_models( + self, + filter: Union[Operator, str] = Constant(BooleanValues.TRUE), + next_token: Optional[str] = None, + ) -> Dict[str, Any]: + """Lists the models and model arns from AmazonSageMakerJumpStart Public Hub. + + Args: + filter (Union[Operator, str]): Optional. The filter to apply to list models. This can be + either an ``Operator`` type filter (e.g. ``And("task == ic", "framework == pytorch")``), + or simply a string filter which will get serialized into an Identity filter. + (e.g. ``"task == ic"``). If this argument is not supplied, all models will be listed. + (Default: Constant(BooleanValues.TRUE)). + next_token (str): Optional. A token to resume pagination of list_inference_components. + This is currently not implemented. + """ + + response = {} + + jumpstart_public_hub_arn = construct_hub_arn_from_name( + JUMPSTART_MODEL_HUB_NAME, self.region, self._sagemaker_session + ) + + hub_content_summaries = [] + models = list_jumpstart_models(filter=filter, list_versions=True) + for model in models: + if len(model) <= 63: + info = get_info_from_hub_resource_arn(jumpstart_public_hub_arn) + hub_model_arn = ( + f"arn:{info.partition}:" + f"sagemaker:{info.region}:" + f"aws:hub-content/{info.hub_name}/" + f"{HubContentType.MODEL}/{model[0]}" + ) + hub_content_summary = { + "hub_content_name": model[0], + "hub_content_arn": hub_model_arn, + } + hub_content_summaries.append(hub_content_summary) + response["hub_content_summaries"] = hub_content_summaries + + response["next_token"] = None # Temporary until pagination is implemented for this function + + return response + + def delete(self) -> None: + """Deletes this SageMaker Hub.""" + return self._sagemaker_session.delete_hub(self.hub_name) + + def create_model_reference( + self, model_arn: str, model_name: Optional[str] = None, min_version: Optional[str] = None + ): + """Adds model reference to this SageMaker Hub.""" + return self._sagemaker_session.create_hub_content_reference( + hub_name=self.hub_name, + source_hub_content_arn=model_arn, + hub_content_name=model_name, + min_version=min_version, + ) + + def delete_model_reference(self, model_name: str) -> None: + """Deletes model reference from this SageMaker Hub.""" + return self._sagemaker_session.delete_hub_content_reference( + hub_name=self.hub_name, + hub_content_type=HubContentType.MODEL_REFERENCE.value, + hub_content_name=model_name, + ) + + def describe_model( + self, model_name: str, hub_name: Optional[str] = None, model_version: Optional[str] = None + ) -> DescribeHubContentResponse: + """Describe model in the SageMaker Hub.""" + try: + model_version = get_hub_model_version( + hub_model_name=model_name, + hub_model_type=HubContentType.MODEL.value, + hub_name=self.hub_name, + sagemaker_session=self._sagemaker_session, + hub_model_version=model_version, + ) + + hub_content_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content( + hub_name=self.hub_name if not hub_name else hub_name, + hub_content_name=model_name, + hub_content_version=model_version, + hub_content_type=HubContentType.MODEL.value, + ) + + except Exception as ex: + logging.info("Recieved expection while calling APIs for ContentType Model: " + str(ex)) + model_version = get_hub_model_version( + hub_model_name=model_name, + hub_model_type=HubContentType.MODEL_REFERENCE.value, + hub_name=self.hub_name, + sagemaker_session=self._sagemaker_session, + hub_model_version=model_version, + ) + + hub_content_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content( + hub_name=self.hub_name, + hub_content_name=model_name, + hub_content_version=model_version, + hub_content_type=HubContentType.MODEL_REFERENCE.value, + ) + + return DescribeHubContentResponse(hub_content_description) diff --git a/src/sagemaker/jumpstart/hub/interfaces.py b/src/sagemaker/jumpstart/hub/interfaces.py new file mode 100644 index 0000000000..2748409927 --- /dev/null +++ b/src/sagemaker/jumpstart/hub/interfaces.py @@ -0,0 +1,831 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module stores types related to SageMaker JumpStart HubAPI requests and responses.""" +from __future__ import absolute_import + +import re +import json +import datetime + +from typing import Any, Dict, List, Union, Optional +from sagemaker.jumpstart.types import ( + HubContentType, + HubArnExtractedInfo, + JumpStartPredictorSpecs, + JumpStartHyperparameter, + JumpStartDataHolderType, + JumpStartEnvironmentVariable, + JumpStartSerializablePayload, + JumpStartInstanceTypeVariants, +) +from sagemaker.jumpstart.hub.parser_utils import ( + snake_to_upper_camel, + walk_and_apply_json, +) + + +class HubDataHolderType(JumpStartDataHolderType): + """Base class for many Hub API interfaces.""" + + def to_json(self) -> Dict[str, Any]: + """Returns json representation of object.""" + json_obj = {} + for att in self.__slots__: + if att in self._non_serializable_slots: + continue + if hasattr(self, att): + cur_val = getattr(self, att) + # Do not serialize null values. + if cur_val is None: + continue + if issubclass(type(cur_val), JumpStartDataHolderType): + json_obj[att] = cur_val.to_json() + elif isinstance(cur_val, list): + json_obj[att] = [] + for obj in cur_val: + if issubclass(type(obj), JumpStartDataHolderType): + json_obj[att].append(obj.to_json()) + else: + json_obj[att].append(obj) + elif isinstance(cur_val, datetime.datetime): + json_obj[att] = str(cur_val) + else: + json_obj[att] = cur_val + return json_obj + + def __str__(self) -> str: + """Returns string representation of object. + + Example: "{'content_bucket': 'bucket', 'region_name': 'us-west-2'}" + """ + + att_dict = walk_and_apply_json(self.to_json(), snake_to_upper_camel) + return f"{json.dumps(att_dict, default=lambda o: o.to_json())}" + + +class CreateHubResponse(HubDataHolderType): + """Data class for the Hub from session.create_hub()""" + + __slots__ = [ + "hub_arn", + ] + + def __init__(self, json_obj: Dict[str, Any]) -> None: + """Instantiates CreateHubResponse object. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of session.create_hub() response. + """ + self.from_json(json_obj) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub description. + """ + self.hub_arn: str = json_obj["HubArn"] + + +class HubContentDependency(HubDataHolderType): + """Data class for any dependencies related to hub content. + + Content can be scripts, model artifacts, datasets, or notebooks. + """ + + __slots__ = ["dependency_copy_path", "dependency_origin_path", "dependency_type"] + + def __init__(self, json_obj: Dict[str, Any]) -> None: + """Instantiates HubContentDependency object + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content description. + """ + self.from_json(json_obj) + + def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content description. + """ + + self.dependency_copy_path: Optional[str] = json_obj.get("DependencyCopyPath", "") + self.dependency_origin_path: Optional[str] = json_obj.get("DependencyOriginPath", "") + self.dependency_type: Optional[str] = json_obj.get("DependencyType", "") + + +class DescribeHubContentResponse(HubDataHolderType): + """Data class for the Hub Content from session.describe_hub_contents()""" + + __slots__ = [ + "creation_time", + "document_schema_version", + "failure_reason", + "hub_arn", + "hub_content_arn", + "hub_content_dependencies", + "hub_content_description", + "hub_content_display_name", + "hub_content_document", + "hub_content_markdown", + "hub_content_name", + "hub_content_search_keywords", + "hub_content_status", + "hub_content_type", + "hub_content_version", + "reference_min_version", + "hub_name", + "_region", + ] + + _non_serializable_slots = ["_region"] + + def __init__(self, json_obj: Dict[str, Any]) -> None: + """Instantiates DescribeHubContentResponse object. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content description. + """ + self.from_json(json_obj) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content description. + """ + self.creation_time: datetime.datetime = json_obj["CreationTime"] + self.document_schema_version: str = json_obj["DocumentSchemaVersion"] + self.failure_reason: Optional[str] = json_obj.get("FailureReason") + self.hub_arn: str = json_obj["HubArn"] + self.hub_content_arn: str = json_obj["HubContentArn"] + self.hub_content_dependencies = [] + if "Dependencies" in json_obj: + self.hub_content_dependencies: Optional[List[HubContentDependency]] = [ + HubContentDependency(dep) for dep in json_obj.get(["Dependencies"]) + ] + self.hub_content_description: str = json_obj.get("HubContentDescription") + self.hub_content_display_name: str = json_obj.get("HubContentDisplayName") + hub_region: Optional[str] = HubArnExtractedInfo.extract_region_from_arn(self.hub_arn) + self._region = hub_region + self.hub_content_type: str = json_obj.get("HubContentType") + hub_content_document = json.loads(json_obj["HubContentDocument"]) + if self.hub_content_type == HubContentType.MODEL: + self.hub_content_document: HubContentDocument = HubModelDocument( + json_obj=hub_content_document, + region=self._region, + dependencies=self.hub_content_dependencies, + ) + elif self.hub_content_type == HubContentType.MODEL_REFERENCE: + self.hub_content_document: HubContentDocument = HubModelDocument( + json_obj=hub_content_document, + region=self._region, + dependencies=self.hub_content_dependencies, + ) + elif self.hub_content_type == HubContentType.NOTEBOOK: + self.hub_content_document: HubContentDocument = HubNotebookDocument( + json_obj=hub_content_document, region=self._region + ) + else: + raise ValueError( + f"[{self.hub_content_type}] is not a valid HubContentType." + f"Should be one of: {[item.name for item in HubContentType]}." + ) + + self.hub_content_markdown: str = json_obj.get("HubContentMarkdown") + self.hub_content_name: str = json_obj["HubContentName"] + self.hub_content_search_keywords: List[str] = json_obj.get("HubContentSearchKeywords") + self.hub_content_status: str = json_obj["HubContentStatus"] + self.hub_content_version: str = json_obj["HubContentVersion"] + self.hub_name: str = json_obj["HubName"] + + def get_hub_region(self) -> Optional[str]: + """Returns the region hub is in.""" + return self._region + + +class HubS3StorageConfig(HubDataHolderType): + """Data class for any dependencies related to hub content. + + Includes scripts, model artifacts, datasets, or notebooks. + """ + + __slots__ = ["s3_output_path"] + + def __init__(self, json_obj: Dict[str, Any]) -> None: + """Instantiates HubS3StorageConfig object + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content description. + """ + self.from_json(json_obj) + + def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content description. + """ + + self.s3_output_path: Optional[str] = json_obj.get("S3OutputPath", "") + + +class DescribeHubResponse(HubDataHolderType): + """Data class for the Hub from session.describe_hub()""" + + __slots__ = [ + "creation_time", + "failure_reason", + "hub_arn", + "hub_description", + "hub_display_name", + "hub_name", + "hub_search_keywords", + "hub_status", + "last_modified_time", + "s3_storage_config", + "_region", + ] + + _non_serializable_slots = ["_region"] + + def __init__(self, json_obj: Dict[str, Any]) -> None: + """Instantiates DescribeHubResponse object. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub description. + """ + self.from_json(json_obj) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub description. + """ + + self.creation_time: datetime.datetime = datetime.datetime(json_obj["CreationTime"]) + self.failure_reason: str = json_obj["FailureReason"] + self.hub_arn: str = json_obj["HubArn"] + hub_region: Optional[str] = HubArnExtractedInfo.extract_region_from_arn(self.hub_arn) + self._region = hub_region + self.hub_description: str = json_obj["HubDescription"] + self.hub_display_name: str = json_obj["HubDisplayName"] + self.hub_name: str = json_obj["HubName"] + self.hub_search_keywords: List[str] = json_obj["HubSearchKeywords"] + self.hub_status: str = json_obj["HubStatus"] + self.last_modified_time: datetime.datetime = datetime.datetime(json_obj["LastModifiedTime"]) + self.s3_storage_config: HubS3StorageConfig = HubS3StorageConfig(json_obj["S3StorageConfig"]) + + def get_hub_region(self) -> Optional[str]: + """Returns the region hub is in.""" + return self._region + + +class ImportHubResponse(HubDataHolderType): + """Data class for the Hub from session.import_hub()""" + + __slots__ = [ + "hub_arn", + "hub_content_arn", + ] + + def __init__(self, json_obj: Dict[str, Any]) -> None: + """Instantiates ImportHubResponse object. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub description. + """ + self.from_json(json_obj) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub description. + """ + self.hub_arn: str = json_obj["HubArn"] + self.hub_content_arn: str = json_obj["HubContentArn"] + + +class HubSummary(HubDataHolderType): + """Data class for the HubSummary from session.list_hubs()""" + + __slots__ = [ + "creation_time", + "hub_arn", + "hub_description", + "hub_display_name", + "hub_name", + "hub_search_keywords", + "hub_status", + "last_modified_time", + ] + + def __init__(self, json_obj: Dict[str, Any]) -> None: + """Instantiates HubSummary object. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub description. + """ + self.from_json(json_obj) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub description. + """ + self.creation_time: datetime.datetime = datetime.datetime(json_obj["CreationTime"]) + self.hub_arn: str = json_obj["HubArn"] + self.hub_description: str = json_obj["HubDescription"] + self.hub_display_name: str = json_obj["HubDisplayName"] + self.hub_name: str = json_obj["HubName"] + self.hub_search_keywords: List[str] = json_obj["HubSearchKeywords"] + self.hub_status: str = json_obj["HubStatus"] + self.last_modified_time: datetime.datetime = datetime.datetime(json_obj["LastModifiedTime"]) + + +class ListHubsResponse(HubDataHolderType): + """Data class for the Hub from session.list_hubs()""" + + __slots__ = [ + "hub_summaries", + "next_token", + ] + + def __init__(self, json_obj: Dict[str, Any]) -> None: + """Instantiates ListHubsResponse object. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of session.list_hubs() response. + """ + self.from_json(json_obj) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of session.list_hubs() response. + """ + self.hub_summaries: List[HubSummary] = [ + HubSummary(item) for item in json_obj["HubSummaries"] + ] + self.next_token: str = json_obj["NextToken"] + + +class EcrUri(HubDataHolderType): + """Data class for ECR image uri.""" + + __slots__ = ["account", "region_name", "repository", "tag"] + + def __init__(self, uri: str): + """Instantiates EcrUri object.""" + self.from_ecr_uri(uri) + + def from_ecr_uri(self, uri: str) -> None: + """Parse a given aws ecr image uri into its various components.""" + uri_regex = ( + r"^(?:(?P[a-zA-Z0-9][\w-]*)\.dkr\.ecr\.(?P[a-zA-Z0-9][\w-]*)" + r"\.(?P[a-zA-Z0-9\.-]+))\/(?P([a-z0-9]+" + r"(?:[._-][a-z0-9]+)*\/)*[a-z0-9]+(?:[._-][a-z0-9]+)*)(:*)(?P.*)?" + ) + + parsed_image_uri = re.compile(uri_regex).match(uri) + + account = parsed_image_uri.group("account_id") + region = parsed_image_uri.group("region") + repository = parsed_image_uri.group("repository_name") + tag = parsed_image_uri.group("image_tag") + + self.account = account + self.region_name = region + self.repository = repository + self.tag = tag + + +class NotebookLocationUris(HubDataHolderType): + """Data class for Notebook Location uri.""" + + __slots__ = ["demo_notebook", "model_fit", "model_deploy"] + + def __init__(self, json_obj: Dict[str, Any]): + """Instantiates EcrUri object.""" + self.from_json(json_obj) + + def from_json(self, json_obj: str) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub description. + """ + self.demo_notebook = json_obj.get("demo_notebook") + self.model_fit = json_obj.get("model_fit") + self.model_deploy = json_obj.get("model_deploy") + + +class HubModelDocument(HubDataHolderType): + """Data class for model type HubContentDocument from session.describe_hub_content().""" + + SCHEMA_VERSION = "2.2.0" + + __slots__ = [ + "url", + "min_sdk_version", + "training_supported", + "incremental_training_supported", + "dynamic_container_deployment_supported", + "hosting_ecr_uri", + "hosting_artifact_s3_data_type", + "hosting_artifact_compression_type", + "hosting_artifact_uri", + "hosting_prepacked_artifact_uri", + "hosting_prepacked_artifact_version", + "hosting_script_uri", + "hosting_use_script_uri", + "hosting_eula_uri", + "hosting_model_package_arn", + "training_artifact_s3_data_type", + "training_artifact_compression_type", + "training_model_package_artifact_uri", + "hyperparameters", + "inference_environment_variables", + "training_script_uri", + "training_prepacked_script_uri", + "training_prepacked_script_version", + "training_ecr_uri", + "training_metrics", + "training_artifact_uri", + "inference_dependencies", + "training_dependencies", + "default_inference_instance_type", + "supported_inference_instance_types", + "default_training_instance_type", + "supported_training_instance_types", + "sage_maker_sdk_predictor_specifications", + "inference_volume_size", + "training_volume_size", + "inference_enable_network_isolation", + "training_enable_network_isolation", + "fine_tuning_supported", + "validation_supported", + "default_training_dataset_uri", + "resource_name_base", + "gated_bucket", + "default_payloads", + "hosting_resource_requirements", + "hosting_instance_type_variants", + "training_instance_type_variants", + "notebook_location_uris", + "model_provider_icon_uri", + "task", + "framework", + "datatype", + "license", + "contextual_help", + "model_data_download_timeout", + "container_startup_health_check_timeout", + "encrypt_inter_container_traffic", + "max_runtime_in_seconds", + "disable_output_compression", + "model_dir", + "dependencies", + "_region", + ] + + _non_serializable_slots = ["_region"] + + def __init__( + self, + json_obj: Dict[str, Any], + region: str, + dependencies: List[HubContentDependency] = None, + ) -> None: + """Instantiates HubModelDocument object. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content document. + + Raises: + ValueError: When one of (json_obj) or (model_specs and studio_specs) is not provided. + """ + self._region = region + self.dependencies = dependencies or [] + self.from_json(json_obj) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub model document. + """ + self.url: str = json_obj["Url"] + self.min_sdk_version: str = json_obj["MinSdkVersion"] + self.hosting_ecr_uri: Optional[str] = json_obj["HostingEcrUri"] + self.hosting_artifact_uri = json_obj["HostingArtifactUri"] + self.hosting_script_uri = json_obj["HostingScriptUri"] + self.inference_dependencies: List[str] = json_obj["InferenceDependencies"] + self.inference_environment_variables: List[JumpStartEnvironmentVariable] = [ + JumpStartEnvironmentVariable(env_variable, is_hub_content=True) + for env_variable in json_obj["InferenceEnvironmentVariables"] + ] + self.training_supported: bool = bool(json_obj["TrainingSupported"]) + self.incremental_training_supported: bool = bool(json_obj["IncrementalTrainingSupported"]) + self.dynamic_container_deployment_supported: Optional[bool] = ( + bool(json_obj.get("DynamicContainerDeploymentSupported")) + if json_obj.get("DynamicContainerDeploymentSupported") + else None + ) + self.hosting_artifact_s3_data_type: Optional[str] = json_obj.get( + "HostingArtifactS3DataType" + ) + self.hosting_artifact_compression_type: Optional[str] = json_obj.get( + "HostingArtifactCompressionType" + ) + self.hosting_prepacked_artifact_uri: Optional[str] = json_obj.get( + "HostingPrepackedArtifactUri" + ) + self.hosting_prepacked_artifact_version: Optional[str] = json_obj.get( + "HostingPrepackedArtifactVersion" + ) + self.hosting_use_script_uri: Optional[bool] = ( + bool(json_obj.get("HostingUseScriptUri")) + if json_obj.get("HostingUseScriptUri") is not None + else None + ) + self.hosting_eula_uri: Optional[str] = json_obj.get("HostingEulaUri") + self.hosting_model_package_arn: Optional[str] = json_obj.get("HostingModelPackageArn") + self.default_inference_instance_type: Optional[str] = json_obj.get( + "DefaultInferenceInstanceType" + ) + self.supported_inference_instance_types: Optional[str] = json_obj.get( + "SupportedInferenceInstanceTypes" + ) + self.sage_maker_sdk_predictor_specifications: Optional[JumpStartPredictorSpecs] = ( + JumpStartPredictorSpecs( + json_obj.get("SageMakerSdkPredictorSpecifications"), + is_hub_content=True, + ) + if json_obj.get("SageMakerSdkPredictorSpecifications") + else None + ) + self.inference_volume_size: Optional[int] = json_obj.get("InferenceVolumeSize") + self.inference_enable_network_isolation: Optional[str] = json_obj.get( + "InferenceEnableNetworkIsolation", False + ) + self.fine_tuning_supported: Optional[bool] = ( + bool(json_obj.get("FineTuningSupported")) + if json_obj.get("FineTuningSupported") + else None + ) + self.validation_supported: Optional[bool] = ( + bool(json_obj.get("ValidationSupported")) + if json_obj.get("ValidationSupported") + else None + ) + self.default_training_dataset_uri: Optional[str] = json_obj.get("DefaultTrainingDatasetUri") + self.resource_name_base: Optional[str] = json_obj.get("ResourceNameBase") + self.gated_bucket: bool = bool(json_obj.get("GatedBucket", False)) + self.default_payloads: Optional[Dict[str, JumpStartSerializablePayload]] = ( + { + alias: JumpStartSerializablePayload(payload, is_hub_content=True) + for alias, payload in json_obj.get("DefaultPayloads").items() + } + if json_obj.get("DefaultPayloads") + else None + ) + self.hosting_resource_requirements: Optional[Dict[str, int]] = json_obj.get( + "HostingResourceRequirements", None + ) + self.hosting_instance_type_variants: Optional[JumpStartInstanceTypeVariants] = ( + JumpStartInstanceTypeVariants( + json_obj.get("HostingInstanceTypeVariants"), + is_hub_content=True, + ) + if json_obj.get("HostingInstanceTypeVariants") + else None + ) + self.notebook_location_uris: Optional[NotebookLocationUris] = ( + NotebookLocationUris(json_obj.get("NotebookLocationUris")) + if json_obj.get("NotebookLocationUris") + else None + ) + self.model_provider_icon_uri: Optional[str] = None # Not needed for private beta + self.task: Optional[str] = json_obj.get("Task") + self.framework: Optional[str] = json_obj.get("Framework") + self.datatype: Optional[str] = json_obj.get("Datatype") + self.license: Optional[str] = json_obj.get("License") + self.contextual_help: Optional[str] = json_obj.get("ContextualHelp") + self.model_dir: Optional[str] = json_obj.get("ModelDir") + # Deploy kwargs + self.model_data_download_timeout: Optional[str] = json_obj.get("ModelDataDownloadTimeout") + self.container_startup_health_check_timeout: Optional[str] = json_obj.get( + "ContainerStartupHealthCheckTimeout" + ) + + if self.training_supported: + self.training_model_package_artifact_uri: Optional[str] = json_obj.get( + "TrainingModelPackageArtifactUri" + ) + self.training_artifact_compression_type: Optional[str] = json_obj.get( + "TrainingArtifactCompressionType" + ) + self.training_artifact_s3_data_type: Optional[str] = json_obj.get( + "TrainingArtifactS3DataType" + ) + self.hyperparameters: List[JumpStartHyperparameter] = [] + hyperparameters: Any = json_obj.get("Hyperparameters") + if hyperparameters is not None: + self.hyperparameters.extend( + [ + JumpStartHyperparameter(hyperparameter, is_hub_content=True) + for hyperparameter in hyperparameters + ] + ) + + self.training_script_uri: Optional[str] = json_obj.get("TrainingScriptUri") + self.training_prepacked_script_uri: Optional[str] = json_obj.get( + "TrainingPrepackedScriptUri" + ) + self.training_prepacked_script_version: Optional[str] = json_obj.get( + "TrainingPrepackedScriptVersion" + ) + self.training_ecr_uri: Optional[str] = json_obj.get("TrainingEcrUri") + self._non_serializable_slots.append("training_ecr_specs") + self.training_metrics: Optional[List[Dict[str, str]]] = json_obj.get( + "TrainingMetrics", None + ) + self.training_artifact_uri: Optional[str] = json_obj.get("TrainingArtifactUri") + self.training_dependencies: Optional[str] = json_obj.get("TrainingDependencies") + self.default_training_instance_type: Optional[str] = json_obj.get( + "DefaultTrainingInstanceType" + ) + self.supported_training_instance_types: Optional[str] = json_obj.get( + "SupportedTrainingInstanceTypes" + ) + self.training_volume_size: Optional[int] = json_obj.get("TrainingVolumeSize") + self.training_enable_network_isolation: Optional[str] = json_obj.get( + "TrainingEnableNetworkIsolation", False + ) + self.training_instance_type_variants: Optional[JumpStartInstanceTypeVariants] = ( + JumpStartInstanceTypeVariants( + json_obj.get("TrainingInstanceTypeVariants"), + is_hub_content=True, + ) + if json_obj.get("TrainingInstanceTypeVariants") + else None + ) + # Estimator kwargs + self.encrypt_inter_container_traffic: Optional[bool] = ( + bool(json_obj.get("EncryptInterContainerTraffic")) + if json_obj.get("EncryptInterContainerTraffic") + else None + ) + self.max_runtime_in_seconds: Optional[str] = json_obj.get("MaxRuntimeInSeconds") + self.disable_output_compression: Optional[bool] = ( + bool(json_obj.get("DisableOutputCompression")) + if json_obj.get("DisableOutputCompression") + else None + ) + + def get_schema_version(self) -> str: + """Returns schema version.""" + return self.SCHEMA_VERSION + + def get_region(self) -> str: + """Returns hub region.""" + return self._region + + +class HubNotebookDocument(HubDataHolderType): + """Data class for notebook type HubContentDocument from session.describe_hub_content().""" + + SCHEMA_VERSION = "1.0.0" + + __slots__ = ["notebook_location", "dependencies", "_region"] + + _non_serializable_slots = ["_region"] + + def __init__(self, json_obj: Dict[str, Any], region: str) -> None: + """Instantiates HubNotebookDocument object. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content document. + """ + self._region = region + self.from_json(json_obj) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content description. + """ + self.notebook_location = json_obj["NotebookLocation"] + self.dependencies: List[HubContentDependency] = [ + HubContentDependency(dep) for dep in json_obj["Dependencies"] + ] + + def get_schema_version(self) -> str: + """Returns schema version.""" + return self.SCHEMA_VERSION + + def get_region(self) -> str: + """Returns hub region.""" + return self._region + + +HubContentDocument = Union[HubModelDocument, HubNotebookDocument] + + +class HubContentInfo(HubDataHolderType): + """Data class for the HubContentInfo from session.list_hub_contents().""" + + __slots__ = [ + "creation_time", + "document_schema_version", + "hub_content_arn", + "hub_content_name", + "hub_content_status", + "hub_content_type", + "hub_content_version", + "hub_content_description", + "hub_content_display_name", + "hub_content_search_keywords", + "_region", + ] + + _non_serializable_slots = ["_region"] + + def __init__(self, json_obj: Dict[str, Any]) -> None: + """Instantiates HubContentInfo object. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content description. + """ + self.from_json(json_obj) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub content description. + """ + self.creation_time: str = json_obj["CreationTime"] + self.document_schema_version: str = json_obj["DocumentSchemaVersion"] + self.hub_content_arn: str = json_obj["HubContentArn"] + self.hub_content_name: str = json_obj["HubContentName"] + self.hub_content_status: str = json_obj["HubContentStatus"] + self.hub_content_type: HubContentType = HubContentType(json_obj["HubContentType"]) + self.hub_content_version: str = json_obj["HubContentVersion"] + self.hub_content_description: Optional[str] = json_obj.get("HubContentDescription") + self.hub_content_display_name: Optional[str] = json_obj.get("HubContentDisplayName") + self._region: Optional[str] = HubArnExtractedInfo.extract_region_from_arn( + self.hub_content_arn + ) + self.hub_content_search_keywords: Optional[List[str]] = json_obj.get( + "HubContentSearchKeywords" + ) + + def get_hub_region(self) -> Optional[str]: + """Returns the region hub is in.""" + return self._region + + +class ListHubContentsResponse(HubDataHolderType): + """Data class for the Hub from session.list_hub_contents()""" + + __slots__ = [ + "hub_content_summaries", + "next_token", + ] + + def __init__(self, json_obj: Dict[str, Any]) -> None: + """Instantiates ImportHubResponse object. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub description. + """ + self.from_json(json_obj) + + def from_json(self, json_obj: Dict[str, Any]) -> None: + """Sets fields in object based on json. + + Args: + json_obj (Dict[str, Any]): Dictionary representation of hub description. + """ + self.hub_content_summaries: List[HubContentInfo] = [ + HubContentInfo(item) for item in json_obj["HubContentSummaries"] + ] + self.next_token: str = json_obj["NextToken"] diff --git a/src/sagemaker/jumpstart/hub/parser_utils.py b/src/sagemaker/jumpstart/hub/parser_utils.py new file mode 100644 index 0000000000..140c089b11 --- /dev/null +++ b/src/sagemaker/jumpstart/hub/parser_utils.py @@ -0,0 +1,56 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module contains utilities related to SageMaker JumpStart Hub.""" +from __future__ import absolute_import + +import re +from typing import Any, Dict + + +def camel_to_snake(camel_case_string: str) -> str: + """Converts camelCaseString or UpperCamelCaseString to snake_case_string.""" + snake_case_string = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", camel_case_string) + return re.sub("([a-z0-9])([A-Z])", r"\1_\2", snake_case_string).lower() + + +def snake_to_upper_camel(snake_case_string: str) -> str: + """Converts snake_case_string to UpperCamelCaseString.""" + upper_camel_case_string = "".join(word.title() for word in snake_case_string.split("_")) + return upper_camel_case_string + + +def walk_and_apply_json(json_obj: Dict[Any, Any], apply) -> Dict[Any, Any]: + """Recursively walks a json object and applies a given function to the keys.""" + + def _walk_and_apply_json(json_obj, new): + if isinstance(json_obj, dict) and isinstance(new, dict): + for key, value in json_obj.items(): + new_key = apply(key) + if isinstance(value, dict): + new[new_key] = {} + _walk_and_apply_json(value, new=new[new_key]) + elif isinstance(value, list): + new[new_key] = [] + for item in value: + _walk_and_apply_json(item, new=new[new_key]) + else: + new[new_key] = value + elif isinstance(json_obj, dict) and isinstance(new, list): + new.append(_walk_and_apply_json(json_obj, new={})) + elif isinstance(json_obj, list) and isinstance(new, dict): + new.update(json_obj) + elif isinstance(json_obj, list) and isinstance(new, list): + new.append(json_obj) + return new + + return _walk_and_apply_json(json_obj, new={}) diff --git a/src/sagemaker/jumpstart/hub/parsers.py b/src/sagemaker/jumpstart/hub/parsers.py new file mode 100644 index 0000000000..8226a380fd --- /dev/null +++ b/src/sagemaker/jumpstart/hub/parsers.py @@ -0,0 +1,262 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# pylint: skip-file +"""This module stores Hub converter utilities for JumpStart.""" +from __future__ import absolute_import + +from typing import Any, Dict, List +from sagemaker.jumpstart.enums import ModelSpecKwargType, NamingConventionType +from sagemaker.s3 import parse_s3_url +from sagemaker.jumpstart.types import ( + JumpStartModelSpecs, + HubContentType, + JumpStartDataHolderType, +) +from sagemaker.jumpstart.hub.interfaces import ( + DescribeHubContentResponse, + HubModelDocument, +) +from sagemaker.jumpstart.hub.parser_utils import ( + camel_to_snake, + snake_to_upper_camel, + walk_and_apply_json, +) + + +def _to_json(dictionary: Dict[Any, Any]) -> Dict[Any, Any]: + """Convert a nested dictionary of JumpStartDataHolderType into json with UpperCamelCase keys""" + for key, value in dictionary.items(): + if issubclass(type(value), JumpStartDataHolderType): + dictionary[key] = walk_and_apply_json(value.to_json(), snake_to_upper_camel) + elif isinstance(value, list): + new_value = [] + for value_in_list in value: + new_value_in_list = value_in_list + if issubclass(type(value_in_list), JumpStartDataHolderType): + new_value_in_list = walk_and_apply_json( + value_in_list.to_json(), snake_to_upper_camel + ) + new_value.append(new_value_in_list) + dictionary[key] = new_value + elif isinstance(value, dict): + for key_in_dict, value_in_dict in value.items(): + if issubclass(type(value_in_dict), JumpStartDataHolderType): + value[key_in_dict] = walk_and_apply_json( + value_in_dict.to_json(), snake_to_upper_camel + ) + return dictionary + + +def get_model_spec_arg_keys( + arg_type: ModelSpecKwargType, + naming_convention: NamingConventionType = NamingConventionType.DEFAULT, +) -> List[str]: + """Returns a list of arg keys for a specific model spec arg type. + + Args: + arg_type (ModelSpecKwargType): Type of the model spec's kwarg. + naming_convention (NamingConventionType): Type of naming convention to return. + + Raises: + ValueError: If the naming convention is not valid. + """ + arg_keys: List[str] = [] + if arg_type == ModelSpecKwargType.DEPLOY: + arg_keys = ["ModelDataDownloadTimeout", "ContainerStartupHealthCheckTimeout"] + elif arg_type == ModelSpecKwargType.ESTIMATOR: + arg_keys = [ + "EncryptInterContainerTraffic", + "MaxRuntimeInSeconds", + "DisableOutputCompression", + "ModelDir", + ] + elif arg_type == ModelSpecKwargType.MODEL: + arg_keys = [] + elif arg_type == ModelSpecKwargType.FIT: + arg_keys = [] + + if naming_convention == NamingConventionType.SNAKE_CASE: + arg_keys = [camel_to_snake(key) for key in arg_keys] + elif naming_convention == NamingConventionType.UPPER_CAMEL_CASE: + return arg_keys + else: + raise ValueError("Please provide a valid naming convention.") + return arg_keys + + +def get_model_spec_kwargs_from_hub_model_document( + arg_type: ModelSpecKwargType, + hub_content_document: Dict[str, Any], + naming_convention: NamingConventionType = NamingConventionType.UPPER_CAMEL_CASE, +) -> Dict[str, Any]: + """Returns a map of arg type to arg keys for a given hub content document. + + Args: + arg_type (ModelSpecKwargType): Type of the model spec's kwarg. + hub_content_document: A dictionary representation of hub content document. + naming_convention (NamingConventionType): Type of naming convention to return. + + """ + kwargs = dict() + keys = get_model_spec_arg_keys(arg_type, naming_convention=naming_convention) + for k in keys: + kwarg_value = hub_content_document.get(k) + if kwarg_value is not None: + kwargs[k] = kwarg_value + return kwargs + + +def make_model_specs_from_describe_hub_content_response( + response: DescribeHubContentResponse, +) -> JumpStartModelSpecs: + """Sets fields in JumpStartModelSpecs based on values in DescribeHubContentResponse + + Args: + response (Dict[str, any]): parsed DescribeHubContentResponse returned + from SageMaker:DescribeHubContent + """ + if response.hub_content_type not in {HubContentType.MODEL, HubContentType.MODEL_REFERENCE}: + raise AttributeError( + "Invalid content type, use either HubContentType.MODEL or HubContentType.MODEL_REFERENCE." + ) + region = response.get_hub_region() + specs = {} + model_id = response.hub_content_name + specs["model_id"] = model_id + specs["version"] = response.hub_content_version + hub_model_document: HubModelDocument = response.hub_content_document + specs["url"] = hub_model_document.url + specs["min_sdk_version"] = hub_model_document.min_sdk_version + specs["training_supported"] = bool(hub_model_document.training_supported) + specs["incremental_training_supported"] = bool( + hub_model_document.incremental_training_supported + ) + specs["hosting_ecr_uri"] = hub_model_document.hosting_ecr_uri + + hosting_artifact_bucket, hosting_artifact_key = parse_s3_url( # pylint: disable=unused-variable + hub_model_document.hosting_artifact_uri + ) + specs["hosting_artifact_key"] = hosting_artifact_key + specs["hosting_artifact_uri"] = hub_model_document.hosting_artifact_uri + hosting_script_bucket, hosting_script_key = parse_s3_url( # pylint: disable=unused-variable + hub_model_document.hosting_script_uri + ) + specs["hosting_script_key"] = hosting_script_key + specs["inference_environment_variables"] = hub_model_document.inference_environment_variables + specs["inference_vulnerable"] = False + specs["inference_dependencies"] = hub_model_document.inference_dependencies + specs["inference_vulnerabilities"] = [] + specs["training_vulnerable"] = False + specs["training_vulnerabilities"] = [] + specs["deprecated"] = False + specs["deprecated_message"] = None + specs["deprecate_warn_message"] = None + specs["usage_info_message"] = None + specs["default_inference_instance_type"] = hub_model_document.default_inference_instance_type + specs["supported_inference_instance_types"] = ( + hub_model_document.supported_inference_instance_types + ) + specs["dynamic_container_deployment_supported"] = ( + hub_model_document.dynamic_container_deployment_supported + ) + specs["hosting_resource_requirements"] = hub_model_document.hosting_resource_requirements + + specs["hosting_prepacked_artifact_key"] = None + if hub_model_document.hosting_prepacked_artifact_uri is not None: + ( + hosting_prepacked_artifact_bucket, # pylint: disable=unused-variable + hosting_prepacked_artifact_key, + ) = parse_s3_url(hub_model_document.hosting_prepacked_artifact_uri) + specs["hosting_prepacked_artifact_key"] = hosting_prepacked_artifact_key + + hub_content_document_dict: Dict[str, Any] = hub_model_document.to_json() + + specs["fit_kwargs"] = get_model_spec_kwargs_from_hub_model_document( + ModelSpecKwargType.FIT, hub_content_document_dict + ) + specs["model_kwargs"] = get_model_spec_kwargs_from_hub_model_document( + ModelSpecKwargType.MODEL, hub_content_document_dict + ) + specs["deploy_kwargs"] = get_model_spec_kwargs_from_hub_model_document( + ModelSpecKwargType.DEPLOY, hub_content_document_dict + ) + specs["estimator_kwargs"] = get_model_spec_kwargs_from_hub_model_document( + ModelSpecKwargType.ESTIMATOR, hub_content_document_dict + ) + + specs["predictor_specs"] = hub_model_document.sage_maker_sdk_predictor_specifications + default_payloads: Dict[str, Any] = {} + if hub_model_document.default_payloads is not None: + for alias, payload in hub_model_document.default_payloads.items(): + default_payloads[alias] = walk_and_apply_json(payload.to_json(), camel_to_snake) + specs["default_payloads"] = default_payloads + specs["gated_bucket"] = hub_model_document.gated_bucket + specs["inference_volume_size"] = hub_model_document.inference_volume_size + specs["inference_enable_network_isolation"] = ( + hub_model_document.inference_enable_network_isolation + ) + specs["resource_name_base"] = hub_model_document.resource_name_base + + specs["hosting_eula_key"] = None + if hub_model_document.hosting_eula_uri is not None: + hosting_eula_bucket, hosting_eula_key = parse_s3_url( # pylint: disable=unused-variable + hub_model_document.hosting_eula_uri + ) + specs["hosting_eula_key"] = hosting_eula_key + + if hub_model_document.hosting_model_package_arn: + specs["hosting_model_package_arns"] = {region: hub_model_document.hosting_model_package_arn} + + specs["hosting_use_script_uri"] = hub_model_document.hosting_use_script_uri + + specs["hosting_instance_type_variants"] = hub_model_document.hosting_instance_type_variants + + if specs["training_supported"]: + specs["training_ecr_uri"] = hub_model_document.training_ecr_uri + ( + training_artifact_bucket, # pylint: disable=unused-variable + training_artifact_key, + ) = parse_s3_url(hub_model_document.training_artifact_uri) + specs["training_artifact_key"] = training_artifact_key + ( + training_script_bucket, # pylint: disable=unused-variable + training_script_key, + ) = parse_s3_url(hub_model_document.training_script_uri) + specs["training_script_key"] = training_script_key + specs["training_dependencies"] = hub_model_document.training_dependencies + specs["default_training_instance_type"] = hub_model_document.default_training_instance_type + specs["supported_training_instance_types"] = ( + hub_model_document.supported_training_instance_types + ) + specs["metrics"] = hub_model_document.training_metrics + specs["training_prepacked_script_key"] = None + if hub_model_document.training_prepacked_script_uri is not None: + ( + training_prepacked_script_bucket, # pylint: disable=unused-variable + training_prepacked_script_key, + ) = parse_s3_url(hub_model_document.training_prepacked_script_uri) + specs["training_prepacked_script_key"] = training_prepacked_script_key + + specs["hyperparameters"] = hub_model_document.hyperparameters + specs["training_volume_size"] = hub_model_document.training_volume_size + specs["training_enable_network_isolation"] = ( + hub_model_document.training_enable_network_isolation + ) + if hub_model_document.training_model_package_artifact_uri: + specs["training_model_package_artifact_uris"] = { + region: hub_model_document.training_model_package_artifact_uri + } + specs["training_instance_type_variants"] = ( + hub_model_document.training_instance_type_variants + ) + return JumpStartModelSpecs(_to_json(specs), is_hub_content=True) diff --git a/src/sagemaker/jumpstart/hub/types.py b/src/sagemaker/jumpstart/hub/types.py new file mode 100644 index 0000000000..1a68f84bbc --- /dev/null +++ b/src/sagemaker/jumpstart/hub/types.py @@ -0,0 +1,35 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module stores types related to SageMaker JumpStart Hub.""" +from __future__ import absolute_import +from typing import Dict +from dataclasses import dataclass + + +@dataclass +class S3ObjectLocation: + """Helper class for S3 object references.""" + + bucket: str + key: str + + def format_for_s3_copy(self) -> Dict[str, str]: + """Returns a dict formatted for S3 copy calls""" + return { + "Bucket": self.bucket, + "Key": self.key, + } + + def get_uri(self) -> str: + """Returns the s3 URI""" + return f"s3://{self.bucket}/{self.key}" diff --git a/src/sagemaker/jumpstart/hub/utils.py b/src/sagemaker/jumpstart/hub/utils.py new file mode 100644 index 0000000000..3dfe99a8c4 --- /dev/null +++ b/src/sagemaker/jumpstart/hub/utils.py @@ -0,0 +1,219 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# pylint: skip-file +"""This module contains utilities related to SageMaker JumpStart Hub.""" +from __future__ import absolute_import +import re +from typing import Optional +from sagemaker.jumpstart.hub.types import S3ObjectLocation +from sagemaker.s3_utils import parse_s3_url +from sagemaker.session import Session +from sagemaker.utils import aws_partition +from sagemaker.jumpstart.types import HubContentType, HubArnExtractedInfo +from sagemaker.jumpstart import constants +from packaging.specifiers import SpecifierSet, InvalidSpecifier + + +def get_info_from_hub_resource_arn( + arn: str, +) -> HubArnExtractedInfo: + """Extracts descriptive information from a Hub or HubContent Arn.""" + + match = re.match(constants.HUB_CONTENT_ARN_REGEX, arn) + if match: + partition = match.group(1) + hub_region = match.group(2) + account_id = match.group(3) + hub_name = match.group(4) + hub_content_type = match.group(5) + hub_content_name = match.group(6) + hub_content_version = match.group(7) + + return HubArnExtractedInfo( + partition=partition, + region=hub_region, + account_id=account_id, + hub_name=hub_name, + hub_content_type=hub_content_type, + hub_content_name=hub_content_name, + hub_content_version=hub_content_version, + ) + + match = re.match(constants.HUB_ARN_REGEX, arn) + if match: + partition = match.group(1) + hub_region = match.group(2) + account_id = match.group(3) + hub_name = match.group(4) + return HubArnExtractedInfo( + partition=partition, + region=hub_region, + account_id=account_id, + hub_name=hub_name, + ) + + +def construct_hub_arn_from_name( + hub_name: str, + region: Optional[str] = None, + session: Optional[Session] = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, +) -> str: + """Constructs a Hub arn from the Hub name using default Session values.""" + + account_id = session.account_id() + region = region or session.boto_region_name + partition = aws_partition(region) + + return f"arn:{partition}:sagemaker:{region}:{account_id}:hub/{hub_name}" + + +def construct_hub_model_arn_from_inputs(hub_arn: str, model_name: str, version: str) -> str: + """Constructs a HubContent model arn from the Hub name, model name, and model version.""" + + info = get_info_from_hub_resource_arn(hub_arn) + arn = ( + f"arn:{info.partition}:sagemaker:{info.region}:{info.account_id}:hub-content/" + f"{info.hub_name}/{HubContentType.MODEL.value}/{model_name}/{version}" + ) + + return arn + + +def construct_hub_model_reference_arn_from_inputs( + hub_arn: str, model_name: str, version: str +) -> str: + """Constructs a HubContent model arn from the Hub name, model name, and model version.""" + + info = get_info_from_hub_resource_arn(hub_arn) + arn = ( + f"arn:{info.partition}:sagemaker:{info.region}:{info.account_id}:hub-content/" + f"{info.hub_name}/{HubContentType.MODEL_REFERENCE}/{model_name}/{version}" + ) + + return arn + + +def generate_hub_arn_for_init_kwargs( + hub_name: str, region: Optional[str] = None, session: Optional[Session] = None +): + """Generates the Hub Arn for JumpStart class args from a HubName or Arn. + + Args: + hub_name (str): HubName or HubArn from JumpStart class args + region (str): Region from JumpStart class args + session (Session): Custom SageMaker Session from JumpStart class args + """ + + hub_arn = None + if hub_name: + if hub_name == constants.JUMPSTART_MODEL_HUB_NAME: + return None + match = re.match(constants.HUB_ARN_REGEX, hub_name) + if match: + hub_arn = hub_name + else: + hub_arn = construct_hub_arn_from_name(hub_name=hub_name, region=region, session=session) + return hub_arn + + +def generate_default_hub_bucket_name( + sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, +) -> str: + """Return the name of the default bucket to use in relevant Amazon SageMaker Hub interactions. + + Returns: + str: The name of the default bucket. If the name was not explicitly specified through + the Session or sagemaker_config, the bucket will take the form: + ``sagemaker-hubs-{region}-{AWS account ID}``. + """ + + region: str = sagemaker_session.boto_region_name + account_id: str = sagemaker_session.account_id() + + # TODO: Validate and fast fail + + return f"sagemaker-hubs-{region}-{account_id}" + + +def create_s3_object_reference_from_uri(s3_uri: Optional[str]) -> Optional[S3ObjectLocation]: + """Utiity to help generate an S3 object reference""" + if not s3_uri: + return None + + bucket, key = parse_s3_url(s3_uri) + + return S3ObjectLocation( + bucket=bucket, + key=key, + ) + + +def create_hub_bucket_if_it_does_not_exist( + bucket_name: Optional[str] = None, + sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, +) -> str: + """Creates the default SageMaker Hub bucket if it does not exist. + + Returns: + str: The name of the default bucket. Takes the form: + ``sagemaker-hubs-{region}-{AWS account ID}``. + """ + + region: str = sagemaker_session.boto_region_name + if bucket_name is None: + bucket_name: str = generate_default_hub_bucket_name(sagemaker_session) + + sagemaker_session._create_s3_bucket_if_it_does_not_exist( + bucket_name=bucket_name, + region=region, + ) + + return bucket_name + + +def is_gated_bucket(bucket_name: str) -> bool: + """Returns true if the bucket name is the JumpStart gated bucket.""" + return bucket_name in constants.JUMPSTART_GATED_BUCKET_NAME_SET + + +def get_hub_model_version( + hub_name: str, + hub_model_name: str, + hub_model_type: str, + hub_model_version: Optional[str] = None, + sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, +) -> str: + """Returns available Jumpstart hub model version""" + + try: + hub_content_summaries = sagemaker_session.list_hub_content_versions( + hub_name=hub_name, hub_content_name=hub_model_name, hub_content_type=hub_model_type + ).get("HubContentSummaries") + except Exception as ex: + raise Exception(f"Failed calling list_hub_content_versions: {str(ex)}") + + available_model_versions = [model.get("HubContentVersion") for model in hub_content_summaries] + + if hub_model_version == "*" or hub_model_version is None: + return str(max(available_model_versions)) + + try: + spec = SpecifierSet(f"=={hub_model_version}") + except InvalidSpecifier: + raise KeyError(f"Bad semantic version: {hub_model_version}") + available_versions_filtered = list(spec.filter(available_model_versions)) + if not available_versions_filtered: + raise KeyError("Model version not available in the Hub") + hub_model_version = str(max(available_versions_filtered)) + + return hub_model_version diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index d2a09345a1..e99cbcc57a 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -24,6 +24,7 @@ from sagemaker.enums import EndpointType from sagemaker.explainer.explainer_config import ExplainerConfig from sagemaker.jumpstart.accessors import JumpStartModelsAccessor +from sagemaker.jumpstart.hub.utils import generate_hub_arn_for_init_kwargs from sagemaker.jumpstart.enums import JumpStartScriptScope from sagemaker.jumpstart.exceptions import ( INVALID_MODEL_ID_ERROR_MSG, @@ -74,6 +75,7 @@ def __init__( self, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_name: Optional[str] = None, tolerate_vulnerable_model: Optional[bool] = None, tolerate_deprecated_model: Optional[bool] = None, region: Optional[str] = None, @@ -111,6 +113,7 @@ def __init__( https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html for list of model IDs. model_version (Optional[str]): Version for JumpStart model to use (Default: None). + hub_name (Optional[str]): Hub name or arn where the model is stored (Default: None). tolerate_vulnerable_model (Optional[bool]): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with @@ -286,6 +289,12 @@ def __init__( ValueError: If the model ID is not recognized by JumpStart. """ + hub_arn = None + if hub_name: + hub_arn = generate_hub_arn_for_init_kwargs( + hub_name=hub_name, region=region, session=sagemaker_session + ) + def _validate_model_id_and_type(): return validate_model_id_and_get_type( model_id=model_id, @@ -293,13 +302,14 @@ def _validate_model_id_and_type(): region=region or getattr(sagemaker_session, "boto_region_name", None), script=JumpStartScriptScope.INFERENCE, sagemaker_session=sagemaker_session, + hub_arn=hub_arn, ) self.model_type = _validate_model_id_and_type() if not self.model_type: JumpStartModelsAccessor.reset_cache() self.model_type = _validate_model_id_and_type() - if not self.model_type: + if not self.model_type and not hub_arn: raise ValueError(INVALID_MODEL_ID_ERROR_MSG.format(model_id=model_id)) self._model_data_is_set = model_data is not None @@ -308,6 +318,7 @@ def _validate_model_id_and_type(): model_from_estimator=False, model_type=self.model_type, model_version=model_version, + hub_arn=hub_arn, instance_type=instance_type, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, @@ -337,17 +348,21 @@ def _validate_model_id_and_type(): self.model_id = model_init_kwargs.model_id self.model_version = model_init_kwargs.model_version + self.hub_arn = model_init_kwargs.hub_arn self.instance_type = model_init_kwargs.instance_type self.resources = model_init_kwargs.resources self.tolerate_vulnerable_model = model_init_kwargs.tolerate_vulnerable_model self.tolerate_deprecated_model = model_init_kwargs.tolerate_deprecated_model self.region = model_init_kwargs.region self.sagemaker_session = model_init_kwargs.sagemaker_session + self.model_reference_arn = model_init_kwargs.model_reference_arn if self.model_type == JumpStartModelType.PROPRIETARY: self.log_subscription_warning() - super(JumpStartModel, self).__init__(**model_init_kwargs.to_kwargs_dict()) + model_init_kwargs_dict = model_init_kwargs.to_kwargs_dict() + + super(JumpStartModel, self).__init__(**model_init_kwargs_dict) self.model_package_arn = model_init_kwargs.model_package_arn @@ -357,6 +372,7 @@ def log_subscription_warning(self) -> None: region=self.region, model_id=self.model_id, version=self.model_version, + hub_arn=self.hub_arn, model_type=self.model_type, scope=JumpStartScriptScope.INFERENCE, sagemaker_session=self.sagemaker_session, @@ -378,6 +394,7 @@ def retrieve_all_examples(self) -> Optional[List[JumpStartSerializablePayload]]: return payloads.retrieve_all_examples( model_id=self.model_id, model_version=self.model_version, + hub_arn=self.hub_arn, region=self.region, tolerate_deprecated_model=self.tolerate_deprecated_model, tolerate_vulnerable_model=self.tolerate_vulnerable_model, @@ -647,6 +664,7 @@ def deploy( model_id=self.model_id, model_version=self.model_version, region=self.region, + hub_arn=self.hub_arn, tolerate_deprecated_model=self.tolerate_deprecated_model, tolerate_vulnerable_model=self.tolerate_vulnerable_model, initial_instance_count=initial_instance_count, @@ -669,6 +687,7 @@ def deploy( explainer_config=explainer_config, sagemaker_session=self.sagemaker_session, accept_eula=accept_eula, + model_reference_arn=self.model_reference_arn, endpoint_logging=endpoint_logging, resources=resources, managed_instance_scaling=managed_instance_scaling, @@ -694,6 +713,7 @@ def deploy( model_type=self.model_type, scope=JumpStartScriptScope.INFERENCE, sagemaker_session=self.sagemaker_session, + hub_arn=self.hub_arn, ).model_subscription_link get_proprietary_model_subscription_error(e, subscription_link) raise @@ -704,6 +724,7 @@ def deploy( predictor=predictor, model_id=self.model_id, model_version=self.model_version, + hub_arn=self.hub_arn, region=self.region, tolerate_deprecated_model=self.tolerate_deprecated_model, tolerate_vulnerable_model=self.tolerate_vulnerable_model, @@ -796,6 +817,7 @@ def register( register_kwargs = get_register_kwargs( model_id=self.model_id, model_version=self.model_version, + hub_arn=self.hub_arn, region=self.region, tolerate_deprecated_model=self.tolerate_deprecated_model, tolerate_vulnerable_model=self.tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 88e25f8a94..420505c508 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -10,8 +10,10 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. +# pylint: skip-file """This module stores types related to SageMaker JumpStart.""" from __future__ import absolute_import +import re from copy import deepcopy from enum import Enum from typing import Any, Dict, List, Optional, Set, Union @@ -30,6 +32,10 @@ from sagemaker.workflow.entities import PipelineVariable from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements from sagemaker.enums import EndpointType +from sagemaker.jumpstart.hub.parser_utils import ( + camel_to_snake, + walk_and_apply_json, +) class JumpStartDataHolderType: @@ -114,6 +120,23 @@ class JumpStartS3FileType(str, Enum): PROPRIETARY_SPECS = "proprietary_specs" +class HubType(str, Enum): + """Enum for Hub objects.""" + + HUB = "Hub" + + +class HubContentType(str, Enum): + """Enum for Hub content objects.""" + + MODEL = "Model" + NOTEBOOK = "Notebook" + MODEL_REFERENCE = "ModelReference" + + +JumpStartContentDataType = Union[JumpStartS3FileType, HubType, HubContentType] + + class JumpStartLaunchedRegionInfo(JumpStartDataHolderType): """Data class for launched region info.""" @@ -178,14 +201,18 @@ class JumpStartECRSpecs(JumpStartDataHolderType): "framework_version", "py_version", "huggingface_transformers_version", + "_is_hub_content", ] - def __init__(self, spec: Dict[str, Any]): + _non_serializable_slots = ["_is_hub_content"] + + def __init__(self, spec: Dict[str, Any], is_hub_content: Optional[bool] = False): """Initializes a JumpStartECRSpecs object from its json representation. Args: spec (Dict[str, Any]): Dictionary representation of spec. """ + self._is_hub_content = is_hub_content self.from_json(spec) def from_json(self, json_obj: Dict[str, Any]) -> None: @@ -198,6 +225,9 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: if not json_obj: return + if self._is_hub_content: + json_obj = walk_and_apply_json(json_obj, camel_to_snake) + self.framework = json_obj.get("framework") self.framework_version = json_obj.get("framework_version") self.py_version = json_obj.get("py_version") @@ -207,7 +237,11 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: def to_json(self) -> Dict[str, Any]: """Returns json representation of JumpStartECRSpecs object.""" - json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)} + json_obj = { + att: getattr(self, att) + for att in self.__slots__ + if hasattr(self, att) and att not in getattr(self, "_non_serializable_slots", []) + } return json_obj @@ -224,14 +258,18 @@ class JumpStartHyperparameter(JumpStartDataHolderType): "max", "exclusive_min", "exclusive_max", + "_is_hub_content", ] - def __init__(self, spec: Dict[str, Any]): + _non_serializable_slots = ["_is_hub_content"] + + def __init__(self, spec: Dict[str, Any], is_hub_content: Optional[bool] = False): """Initializes a JumpStartHyperparameter object from its json representation. Args: spec (Dict[str, Any]): Dictionary representation of hyperparameter. """ + self._is_hub_content = is_hub_content self.from_json(spec) def from_json(self, json_obj: Dict[str, Any]) -> None: @@ -241,6 +279,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: json_obj (Dict[str, Any]): Dictionary representation of hyperparameter. """ + if self._is_hub_content: + json_obj = walk_and_apply_json(json_obj, camel_to_snake) self.name = json_obj["name"] self.type = json_obj["type"] self.default = json_obj["default"] @@ -258,17 +298,24 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: if max_val is not None: self.max = max_val + # HubContentDocument model schema does not allow exclusive min/max. + if self._is_hub_content: + return + exclusive_min_val = json_obj.get("exclusive_min") + exclusive_max_val = json_obj.get("exclusive_max") if exclusive_min_val is not None: self.exclusive_min = exclusive_min_val - - exclusive_max_val = json_obj.get("exclusive_max") if exclusive_max_val is not None: self.exclusive_max = exclusive_max_val def to_json(self) -> Dict[str, Any]: """Returns json representation of JumpStartHyperparameter object.""" - json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)} + json_obj = { + att: getattr(self, att) + for att in self.__slots__ + if hasattr(self, att) and att not in getattr(self, "_non_serializable_slots", []) + } return json_obj @@ -281,14 +328,18 @@ class JumpStartEnvironmentVariable(JumpStartDataHolderType): "default", "scope", "required_for_model_class", + "_is_hub_content", ] - def __init__(self, spec: Dict[str, Any]): + _non_serializable_slots = ["_is_hub_content"] + + def __init__(self, spec: Dict[str, Any], is_hub_content: Optional[bool] = False): """Initializes a JumpStartEnvironmentVariable object from its json representation. Args: spec (Dict[str, Any]): Dictionary representation of environment variable. """ + self._is_hub_content = is_hub_content self.from_json(spec) def from_json(self, json_obj: Dict[str, Any]) -> None: @@ -297,7 +348,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: Args: json_obj (Dict[str, Any]): Dictionary representation of environment variable. """ - + json_obj = walk_and_apply_json(json_obj, camel_to_snake) self.name = json_obj["name"] self.type = json_obj["type"] self.default = json_obj["default"] @@ -306,7 +357,11 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: def to_json(self) -> Dict[str, Any]: """Returns json representation of JumpStartEnvironmentVariable object.""" - json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)} + json_obj = { + att: getattr(self, att) + for att in self.__slots__ + if hasattr(self, att) and att not in getattr(self, "_non_serializable_slots", []) + } return json_obj @@ -318,14 +373,18 @@ class JumpStartPredictorSpecs(JumpStartDataHolderType): "supported_content_types", "default_accept_type", "supported_accept_types", + "_is_hub_content", ] - def __init__(self, spec: Optional[Dict[str, Any]]): + _non_serializable_slots = ["_is_hub_content"] + + def __init__(self, spec: Optional[Dict[str, Any]], is_hub_content: Optional[bool] = False): """Initializes a JumpStartPredictorSpecs object from its json representation. Args: spec (Dict[str, Any]): Dictionary representation of predictor specs. """ + self._is_hub_content = is_hub_content self.from_json(spec) def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: @@ -338,6 +397,8 @@ def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: if json_obj is None: return + if self._is_hub_content: + json_obj = walk_and_apply_json(json_obj, camel_to_snake) self.default_content_type = json_obj["default_content_type"] self.supported_content_types = json_obj["supported_content_types"] self.default_accept_type = json_obj["default_accept_type"] @@ -345,7 +406,11 @@ def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: def to_json(self) -> Dict[str, Any]: """Returns json representation of JumpStartPredictorSpecs object.""" - json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)} + json_obj = { + att: getattr(self, att) + for att in self.__slots__ + if hasattr(self, att) and att not in getattr(self, "_non_serializable_slots", []) + } return json_obj @@ -358,16 +423,18 @@ class JumpStartSerializablePayload(JumpStartDataHolderType): "accept", "body", "prompt_key", + "_is_hub_content", ] - _non_serializable_slots = ["raw_payload", "prompt_key"] + _non_serializable_slots = ["raw_payload", "prompt_key", "_is_hub_content"] - def __init__(self, spec: Optional[Dict[str, Any]]): + def __init__(self, spec: Optional[Dict[str, Any]], is_hub_content: Optional[bool] = False): """Initializes a JumpStartSerializablePayload object from its json representation. Args: spec (Dict[str, Any]): Dictionary representation of payload specs. """ + self._is_hub_content = is_hub_content self.from_json(spec) def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: @@ -384,9 +451,11 @@ def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: if json_obj is None: return + if self._is_hub_content: + json_obj = walk_and_apply_json(json_obj, camel_to_snake) self.raw_payload = json_obj self.content_type = json_obj["content_type"] - self.body = json_obj["body"] + self.body = json_obj.get("body") accept = json_obj.get("accept") self.prompt_key = json_obj.get("prompt_key") if accept: @@ -402,16 +471,26 @@ class JumpStartInstanceTypeVariants(JumpStartDataHolderType): __slots__ = [ "regional_aliases", + "aliases", "variants", + "_is_hub_content", ] - def __init__(self, spec: Optional[Dict[str, Any]]): + _non_serializable_slots = ["_is_hub_content"] + + def __init__(self, spec: Optional[Dict[str, Any]], is_hub_content: Optional[bool] = False): """Initializes a JumpStartInstanceTypeVariants object from its json representation. Args: spec (Dict[str, Any]): Dictionary representation of instance type variants. """ - self.from_json(spec) + + self._is_hub_content = is_hub_content + + if self._is_hub_content: + self.from_describe_hub_content_response(spec) + else: + self.from_json(spec) def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: """Sets fields in object based on json. @@ -423,14 +502,50 @@ def from_json(self, json_obj: Optional[Dict[str, Any]]) -> None: if json_obj is None: return + self.aliases = None self.regional_aliases: Optional[dict] = json_obj.get("regional_aliases") self.variants: Optional[dict] = json_obj.get("variants") def to_json(self) -> Dict[str, Any]: - """Returns json representation of JumpStartInstanceTypeVariants object.""" - json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)} + """Returns json representation of JumpStartInstance object.""" + json_obj = { + att: getattr(self, att) + for att in self.__slots__ + if hasattr(self, att) and att not in getattr(self, "_non_serializable_slots", []) + } return json_obj + def from_describe_hub_content_response(self, response: Optional[Dict[str, Any]]) -> None: + """Sets fields in object based on DescribeHubContent response. + + Args: + response (Dict[str, Any]): Dictionary representation of instance type variants. + """ + + if response is None: + return + + response = walk_and_apply_json(response, camel_to_snake) + self.aliases: Optional[dict] = response.get("aliases") + self.regional_aliases = None + self.variants: Optional[dict] = response.get("variants") + + def regionalize( # pylint: disable=inconsistent-return-statements + self, region: str + ) -> Optional[Dict[str, Any]]: + """Returns regionalized instance type variants.""" + + if self.regional_aliases is None or self.aliases is not None: + return + aliases = self.regional_aliases.get(region, {}) + variants = {} + for instance_name, properties in self.variants.items(): + if properties.get("regional_properties") is not None: + variants.update({instance_name: properties.get("regional_properties")}) + if properties.get("properties") is not None: + variants.update({instance_name: properties.get("properties")}) + return {"Aliases": aliases, "Variants": variants} + def get_instance_specific_metric_definitions( self, instance_type: str ) -> List[JumpStartHyperparameter]: @@ -628,7 +743,12 @@ def get_instance_specific_gated_model_key_env_var_value( Returns None if a model, instance type tuple does not have instance specific property. """ - return self._get_instance_specific_property(instance_type, "gated_model_key_env_var_value") + + gated_model_key_env_var_value = ( + "gated_model_env_var_uri" if self._is_hub_content else "gated_model_key_env_var_value" + ) + + return self._get_instance_specific_property(instance_type, gated_model_key_env_var_value) def get_instance_specific_default_inference_instance_type( self, instance_type: str @@ -680,7 +800,7 @@ def get_instance_specific_supported_inference_instance_types( ) ) - def get_image_uri(self, instance_type: str, region: str) -> Optional[str]: + def get_image_uri(self, instance_type: str, region: Optional[str] = None) -> Optional[str]: """Returns image uri from instance type and region. Returns None if no instance type is available or found. @@ -701,36 +821,63 @@ def get_model_package_arn(self, instance_type: str, region: str) -> Optional[str ) def _get_regional_property( - self, instance_type: str, region: str, property_name: str + self, instance_type: str, region: Optional[str], property_name: str ) -> Optional[str]: """Returns regional property from instance type and region. Returns None if no instance type is available or found. None is also returned if the metadata is improperly formatted. """ + # pylint: disable=too-many-return-statements + # if self.variants is None or (self.aliases is None and self.regional_aliases is None): + # return None - if None in [self.regional_aliases, self.variants]: + if self.variants is None: return None - regional_property_alias: Optional[str] = ( - self.variants.get(instance_type, {}).get("regional_properties", {}).get(property_name) - ) - if regional_property_alias is None: - instance_type_family = get_instance_type_family(instance_type) + if region is None and self.regional_aliases is not None: + return None - if instance_type_family in {"", None}: - return None + regional_property_alias: Optional[str] = None + regional_property_value: Optional[str] = None + if self.regional_aliases: regional_property_alias = ( - self.variants.get(instance_type_family, {}) + self.variants.get(instance_type, {}) .get("regional_properties", {}) .get(property_name) ) + else: + regional_property_value = ( + self.variants.get(instance_type, {}).get("properties", {}).get(property_name) + ) + + if regional_property_alias is None and regional_property_value is None: + instance_type_family = get_instance_type_family(instance_type) + + if instance_type_family in {"", None}: + return None - if regional_property_alias is None or len(regional_property_alias) == 0: + if self.regional_aliases: + regional_property_alias = ( + self.variants.get(instance_type_family, {}) + .get("regional_properties", {}) + .get(property_name) + ) + else: + # if reading from HubContent, aliases are already regionalized + regional_property_value = ( + self.variants.get(instance_type_family, {}) + .get("properties", {}) + .get(property_name) + ) + + if (regional_property_alias is None or len(regional_property_alias) == 0) and ( + regional_property_value is None or len(regional_property_value) == 0 + ): return None - if not regional_property_alias.startswith("$"): + if regional_property_alias and not regional_property_alias.startswith("$"): # No leading '$' indicates bad metadata. # There are tests to ensure this never happens. # However, to allow for fallback options in the unlikely event @@ -738,10 +885,13 @@ def _get_regional_property( # We return None, indicating the field does not exist. return None - if region not in self.regional_aliases: + if self.regional_aliases and region not in self.regional_aliases: return None - alias_value = self.regional_aliases[region].get(regional_property_alias[1:], None) - return alias_value + + if self.regional_aliases: + alias_value = self.regional_aliases[region].get(regional_property_alias[1:], None) + return alias_value + return regional_property_value class JumpStartBenchmarkStat(JumpStartDataHolderType): @@ -811,10 +961,13 @@ class JumpStartMetadataBaseFields(JumpStartDataHolderType): "min_sdk_version", "incremental_training_supported", "hosting_ecr_specs", + "hosting_ecr_uri", + "hosting_artifact_uri", "hosting_artifact_key", "hosting_script_key", "training_supported", "training_ecr_specs", + "training_ecr_uri", "training_artifact_key", "training_script_key", "hyperparameters", @@ -837,7 +990,9 @@ class JumpStartMetadataBaseFields(JumpStartDataHolderType): "supported_training_instance_types", "metrics", "training_prepacked_script_key", + "training_prepacked_script_version", "hosting_prepacked_artifact_key", + "hosting_prepacked_artifact_version", "model_kwargs", "deploy_kwargs", "estimator_kwargs", @@ -857,14 +1012,19 @@ class JumpStartMetadataBaseFields(JumpStartDataHolderType): "default_payloads", "gated_bucket", "model_subscription_link", + "hub_content_type", + "_is_hub_content", ] - def __init__(self, fields: Dict[str, Any]): + _non_serializable_slots = ["_is_hub_content"] + + def __init__(self, fields: Dict[str, Any], is_hub_content: Optional[bool] = False): """Initializes a JumpStartMetadataFields object. Args: fields (Dict[str, Any]): Dictionary representation of metadata fields. """ + self._is_hub_content = is_hub_content self.from_json(fields) def from_json(self, json_obj: Dict[str, Any]) -> None: @@ -880,16 +1040,24 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: self.incremental_training_supported: bool = bool( json_obj.get("incremental_training_supported", False) ) - self.hosting_ecr_specs: Optional[JumpStartECRSpecs] = ( - JumpStartECRSpecs(json_obj["hosting_ecr_specs"]) - if "hosting_ecr_specs" in json_obj - else None - ) + if self._is_hub_content: + self.hosting_ecr_uri: Optional[str] = json_obj["hosting_ecr_uri"] + self._non_serializable_slots.append("hosting_ecr_specs") + else: + self.hosting_ecr_specs: Optional[JumpStartECRSpecs] = ( + JumpStartECRSpecs( + json_obj["hosting_ecr_specs"], is_hub_content=self._is_hub_content + ) + if "hosting_ecr_specs" in json_obj + else None + ) + self._non_serializable_slots.append("hosting_ecr_uri") self.hosting_artifact_key: Optional[str] = json_obj.get("hosting_artifact_key") + self.hosting_artifact_uri: Optional[str] = json_obj.get("hosting_artifact_uri") self.hosting_script_key: Optional[str] = json_obj.get("hosting_script_key") self.training_supported: Optional[bool] = bool(json_obj.get("training_supported", False)) self.inference_environment_variables = [ - JumpStartEnvironmentVariable(env_variable) + JumpStartEnvironmentVariable(env_variable, is_hub_content=self._is_hub_content) for env_variable in json_obj.get("inference_environment_variables", []) ] self.inference_vulnerable: bool = bool(json_obj.get("inference_vulnerable", False)) @@ -927,16 +1095,26 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: self.hosting_prepacked_artifact_key: Optional[str] = json_obj.get( "hosting_prepacked_artifact_key", None ) + # New fields required for Hub model. + if self._is_hub_content: + self.training_prepacked_script_version: Optional[str] = json_obj.get( + "training_prepacked_script_version" + ) + self.hosting_prepacked_artifact_version: Optional[str] = json_obj.get( + "hosting_prepacked_artifact_version" + ) self.model_kwargs = deepcopy(json_obj.get("model_kwargs", {})) self.deploy_kwargs = deepcopy(json_obj.get("deploy_kwargs", {})) self.predictor_specs: Optional[JumpStartPredictorSpecs] = ( - JumpStartPredictorSpecs(json_obj["predictor_specs"]) + JumpStartPredictorSpecs( + json_obj["predictor_specs"], is_hub_content=self._is_hub_content + ) if "predictor_specs" in json_obj else None ) self.default_payloads: Optional[Dict[str, JumpStartSerializablePayload]] = ( { - alias: JumpStartSerializablePayload(payload) + alias: JumpStartSerializablePayload(payload, is_hub_content=self._is_hub_content) for alias, payload in json_obj["default_payloads"].items() } if json_obj.get("default_payloads") @@ -955,24 +1133,34 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: self.hosting_use_script_uri: bool = json_obj.get("hosting_use_script_uri", True) self.hosting_instance_type_variants: Optional[JumpStartInstanceTypeVariants] = ( - JumpStartInstanceTypeVariants(json_obj["hosting_instance_type_variants"]) + JumpStartInstanceTypeVariants( + json_obj["hosting_instance_type_variants"], self._is_hub_content + ) if json_obj.get("hosting_instance_type_variants") else None ) if self.training_supported: - self.training_ecr_specs: Optional[JumpStartECRSpecs] = ( - JumpStartECRSpecs(json_obj["training_ecr_specs"]) - if "training_ecr_specs" in json_obj - else None - ) + if self._is_hub_content: + self.training_ecr_uri: Optional[str] = json_obj["training_ecr_uri"] + self._non_serializable_slots.append("training_ecr_specs") + else: + self.training_ecr_specs: Optional[JumpStartECRSpecs] = ( + JumpStartECRSpecs(json_obj["training_ecr_specs"]) + if "training_ecr_specs" in json_obj + else None + ) + self._non_serializable_slots.append("training_ecr_uri") self.training_artifact_key: str = json_obj["training_artifact_key"] self.training_script_key: str = json_obj["training_script_key"] hyperparameters: Any = json_obj.get("hyperparameters") self.hyperparameters: List[JumpStartHyperparameter] = [] if hyperparameters is not None: self.hyperparameters.extend( - [JumpStartHyperparameter(hyperparameter) for hyperparameter in hyperparameters] + [ + JumpStartHyperparameter(hyperparameter, is_hub_content=self._is_hub_content) + for hyperparameter in hyperparameters + ] ) self.estimator_kwargs = deepcopy(json_obj.get("estimator_kwargs", {})) self.fit_kwargs = deepcopy(json_obj.get("fit_kwargs", {})) @@ -984,7 +1172,9 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: "training_model_package_artifact_uris" ) self.training_instance_type_variants: Optional[JumpStartInstanceTypeVariants] = ( - JumpStartInstanceTypeVariants(json_obj["training_instance_type_variants"]) + JumpStartInstanceTypeVariants( + json_obj["training_instance_type_variants"], is_hub_content=self._is_hub_content + ) if json_obj.get("training_instance_type_variants") else None ) @@ -994,7 +1184,7 @@ def to_json(self) -> Dict[str, Any]: """Returns json representation of JumpStartMetadataBaseFields object.""" json_obj = {} for att in self.__slots__: - if hasattr(self, att): + if hasattr(self, att) and att not in getattr(self, "_non_serializable_slots", []): cur_val = getattr(self, att) if issubclass(type(cur_val), JumpStartDataHolderType): json_obj[att] = cur_val.to_json() @@ -1016,6 +1206,11 @@ def to_json(self) -> Dict[str, Any]: json_obj[att] = cur_val return json_obj + def set_hub_content_type(self, hub_content_type: HubContentType) -> None: + """Sets the hub content type.""" + if self._is_hub_content: + self.hub_content_type = hub_content_type + class JumpStartConfigComponent(JumpStartMetadataBaseFields): """Data class of JumpStart config component.""" @@ -1212,13 +1407,13 @@ class JumpStartModelSpecs(JumpStartMetadataBaseFields): __slots__ = JumpStartMetadataBaseFields.__slots__ + slots - def __init__(self, spec: Dict[str, Any]): + def __init__(self, spec: Dict[str, Any], is_hub_content: Optional[bool] = False): """Initializes a JumpStartModelSpecs object from its json representation. Args: spec (Dict[str, Any]): Dictionary representation of spec. """ - super().__init__(spec) + super().__init__(spec, is_hub_content) self.from_json(spec) if self.inference_configs and self.inference_configs.get_top_config_from_ranking(): super().from_json(self.inference_configs.get_top_config_from_ranking().resolved_config) @@ -1414,27 +1609,84 @@ def __init__( self.version = version -class JumpStartCachedS3ContentKey(JumpStartDataHolderType): - """Data class for the s3 cached content keys.""" +class JumpStartCachedContentKey(JumpStartDataHolderType): + """Data class for the cached content keys.""" - __slots__ = ["file_type", "s3_key"] + __slots__ = ["data_type", "id_info"] def __init__( self, - file_type: JumpStartS3FileType, - s3_key: str, + data_type: JumpStartContentDataType, + id_info: str, ) -> None: - """Instantiates JumpStartCachedS3ContentKey object. + """Instantiates JumpStartCachedContentKey object. Args: - file_type (JumpStartS3FileType): JumpStart file type. - s3_key (str): object key in s3. + data_type (JumpStartContentDataType): JumpStart content data type. + id_info (str): if S3Content, object key in s3. if HubContent, hub content arn. """ - self.file_type = file_type - self.s3_key = s3_key + self.data_type = data_type + self.id_info = id_info + + +class HubArnExtractedInfo(JumpStartDataHolderType): + """Data class for info extracted from Hub arn.""" + + __slots__ = [ + "partition", + "region", + "account_id", + "hub_name", + "hub_content_type", + "hub_content_name", + "hub_content_version", + ] + + def __init__( + self, + partition: str, + region: str, + account_id: str, + hub_name: str, + hub_content_type: Optional[str] = None, + hub_content_name: Optional[str] = None, + hub_content_version: Optional[str] = None, + ) -> None: + """Instantiates HubArnExtractedInfo object.""" + + self.partition = partition + self.region = region + self.account_id = account_id + self.hub_name = hub_name + self.hub_content_name = hub_content_name + self.hub_content_type = hub_content_type + self.hub_content_version = hub_content_version + + @staticmethod + def extract_region_from_arn(arn: str) -> Optional[str]: + """Extracts hub_name, content_name, and content_version from a HubContentArn""" + + HUB_CONTENT_ARN_REGEX = ( + r"arn:(.*?):sagemaker:(.*?):(.*?):hub-content/(.*?)/(.*?)/(.*?)/(.*?)$" + ) + HUB_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub/(.*?)$" + + match = re.match(HUB_CONTENT_ARN_REGEX, arn) + hub_region = None + if match: + hub_region = match.group(2) + + return hub_region + + match = re.match(HUB_ARN_REGEX, arn) + if match: + hub_region = match.group(2) + return hub_region + + return hub_region -class JumpStartCachedS3ContentValue(JumpStartDataHolderType): +class JumpStartCachedContentValue(JumpStartDataHolderType): """Data class for the s3 cached content values.""" __slots__ = ["formatted_content", "md5_hash"] @@ -1447,7 +1699,7 @@ def __init__( ], md5_hash: Optional[str] = None, ) -> None: - """Instantiates JumpStartCachedS3ContentValue object. + """Instantiates JumpStartCachedContentValue object. Args: formatted_content (Union[Dict[JumpStartVersionedModelId, JumpStartModelHeader], @@ -1482,6 +1734,7 @@ class JumpStartModelInitKwargs(JumpStartKwargs): __slots__ = [ "model_id", "model_version", + "hub_arn", "model_type", "instance_type", "tolerate_vulnerable_model", @@ -1507,24 +1760,29 @@ class JumpStartModelInitKwargs(JumpStartKwargs): "model_package_arn", "training_instance_type", "resources", + "hub_content_type", + "model_reference_arn", ] SERIALIZATION_EXCLUSION_SET = { "instance_type", "model_id", "model_version", + "hub_arn", "model_type", "tolerate_vulnerable_model", "tolerate_deprecated_model", "region", "model_package_arn", "training_instance_type", + "hub_content_type", } def __init__( self, model_id: str, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, region: Optional[str] = None, instance_type: Optional[str] = None, @@ -1555,6 +1813,7 @@ def __init__( self.model_id = model_id self.model_version = model_version + self.hub_arn = hub_arn self.model_type = model_type self.instance_type = instance_type self.region = region @@ -1588,6 +1847,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs): __slots__ = [ "model_id", "model_version", + "hub_arn", "model_type", "initial_instance_count", "instance_type", @@ -1613,6 +1873,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs): "sagemaker_session", "training_instance_type", "accept_eula", + "model_reference_arn", "endpoint_logging", "resources", "endpoint_type", @@ -1623,6 +1884,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs): "model_id", "model_version", "model_type", + "hub_arn", "region", "tolerate_deprecated_model", "tolerate_vulnerable_model", @@ -1634,6 +1896,7 @@ def __init__( self, model_id: str, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, region: Optional[str] = None, initial_instance_count: Optional[int] = None, @@ -1659,6 +1922,7 @@ def __init__( sagemaker_session: Optional[Session] = None, training_instance_type: Optional[str] = None, accept_eula: Optional[bool] = None, + model_reference_arn: Optional[str] = None, endpoint_logging: Optional[bool] = None, resources: Optional[ResourceRequirements] = None, endpoint_type: Optional[EndpointType] = None, @@ -1668,6 +1932,7 @@ def __init__( self.model_id = model_id self.model_version = model_version + self.hub_arn = hub_arn self.model_type = model_type self.initial_instance_count = initial_instance_count self.instance_type = instance_type @@ -1693,6 +1958,7 @@ def __init__( self.sagemaker_session = sagemaker_session self.training_instance_type = training_instance_type self.accept_eula = accept_eula + self.model_reference_arn = model_reference_arn self.endpoint_logging = endpoint_logging self.resources = resources self.endpoint_type = endpoint_type @@ -1705,6 +1971,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs): __slots__ = [ "model_id", "model_version", + "hub_arn", "model_type", "instance_type", "instance_count", @@ -1766,6 +2033,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs): "tolerate_vulnerable_model", "model_id", "model_version", + "hub_arn", "model_type", } @@ -1773,6 +2041,7 @@ def __init__( self, model_id: str, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, region: Optional[str] = None, image_uri: Optional[Union[str, Any]] = None, @@ -1831,6 +2100,7 @@ def __init__( self.model_id = model_id self.model_version = model_version + self.hub_arn = hub_arn self.model_type = (model_type,) self.instance_type = instance_type self.instance_count = instance_count @@ -1894,6 +2164,7 @@ class JumpStartEstimatorFitKwargs(JumpStartKwargs): __slots__ = [ "model_id", "model_version", + "hub_arn", "model_type", "region", "inputs", @@ -1909,6 +2180,7 @@ class JumpStartEstimatorFitKwargs(JumpStartKwargs): SERIALIZATION_EXCLUSION_SET = { "model_id", "model_version", + "hub_arn", "model_type", "region", "tolerate_deprecated_model", @@ -1920,6 +2192,7 @@ def __init__( self, model_id: str, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, region: Optional[str] = None, inputs: Optional[Union[str, Dict, Any, Any]] = None, @@ -1935,6 +2208,7 @@ def __init__( self.model_id = model_id self.model_version = model_version + self.hub_arn = hub_arn self.model_type = model_type self.region = region self.inputs = inputs @@ -1953,6 +2227,7 @@ class JumpStartEstimatorDeployKwargs(JumpStartKwargs): __slots__ = [ "model_id", "model_version", + "hub_arn", "instance_type", "initial_instance_count", "region", @@ -1998,6 +2273,7 @@ class JumpStartEstimatorDeployKwargs(JumpStartKwargs): "region", "model_id", "model_version", + "hub_arn", "sagemaker_session", } @@ -2005,6 +2281,7 @@ def __init__( self, model_id: str, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, region: Optional[str] = None, initial_instance_count: Optional[int] = None, instance_type: Optional[str] = None, @@ -2047,6 +2324,7 @@ def __init__( self.model_id = model_id self.model_version = model_version + self.hub_arn = hub_arn self.instance_type = instance_type self.initial_instance_count = initial_instance_count self.region = region @@ -2095,6 +2373,7 @@ class JumpStartModelRegisterKwargs(JumpStartKwargs): "region", "model_id", "model_version", + "hub_arn", "sagemaker_session", "content_types", "response_types", @@ -2127,6 +2406,7 @@ class JumpStartModelRegisterKwargs(JumpStartKwargs): "region", "model_id", "model_version", + "hub_arn", "sagemaker_session", } @@ -2134,6 +2414,7 @@ def __init__( self, model_id: str, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, region: Optional[str] = None, tolerate_deprecated_model: Optional[bool] = None, tolerate_vulnerable_model: Optional[bool] = None, @@ -2166,6 +2447,7 @@ def __init__( self.model_id = model_id self.model_version = model_version + self.hub_arn = hub_arn self.region = region self.image_uri = image_uri self.sagemaker_session = sagemaker_session diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 657ab11535..989ca426b5 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -382,6 +382,21 @@ def add_jumpstart_model_id_version_tags( return tags +def add_hub_content_arn_tags( + tags: Optional[List[TagsDict]], + hub_arn: str, +) -> Optional[List[TagsDict]]: + """Adds custom Hub arn tag to JumpStart related resources.""" + + tags = add_single_jumpstart_tag( + hub_arn, + enums.JumpStartTag.HUB_CONTENT_ARN, + tags, + is_uri=False, + ) + return tags + + def add_jumpstart_uri_tags( tags: Optional[List[TagsDict]] = None, inference_model_uri: Optional[Union[str, dict]] = None, @@ -546,6 +561,7 @@ def verify_model_region_and_return_specs( version: Optional[str], scope: Optional[str], region: Optional[str] = None, + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -561,6 +577,8 @@ def verify_model_region_and_return_specs( scope (Optional[str]): scope of the JumpStart model to verify. region (Optional[str]): region of the JumpStart model to verify and obtains specs. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -600,9 +618,11 @@ def verify_model_region_and_return_specs( model_specs = accessors.JumpStartModelsAccessor.get_model_specs( # type: ignore region=region, model_id=model_id, + hub_arn=hub_arn, version=version, s3_client=sagemaker_session.s3_client, model_type=model_type, + sagemaker_session=sagemaker_session, ) if ( @@ -763,6 +783,7 @@ def validate_model_id_and_get_type( model_version: Optional[str] = None, script: enums.JumpStartScriptScope = enums.JumpStartScriptScope.INFERENCE, sagemaker_session: Optional[Session] = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + hub_arn: Optional[str] = None, ) -> Optional[enums.JumpStartModelType]: """Returns model type if the model ID is supported for the given script. @@ -774,6 +795,8 @@ def validate_model_id_and_get_type( return None if not isinstance(model_id, str): return None + if hub_arn: + return None s3_client = sagemaker_session.s3_client if sagemaker_session else None region = region or constants.JUMPSTART_DEFAULT_REGION_NAME @@ -918,6 +941,7 @@ def get_benchmark_stats( model_id: str, model_version: str, config_names: Optional[List[str]] = None, + hub_arn: Optional[str] = None, sagemaker_session: Optional[Session] = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, scope: enums.JumpStartScriptScope = enums.JumpStartScriptScope.INFERENCE, model_type: enums.JumpStartModelType = enums.JumpStartModelType.OPEN_WEIGHTS, @@ -927,6 +951,7 @@ def get_benchmark_stats( region=region, model_id=model_id, version=model_version, + hub_arn=hub_arn, sagemaker_session=sagemaker_session, scope=scope, model_type=model_type, diff --git a/src/sagemaker/jumpstart/validators.py b/src/sagemaker/jumpstart/validators.py index c7098a1185..1a4849522c 100644 --- a/src/sagemaker/jumpstart/validators.py +++ b/src/sagemaker/jumpstart/validators.py @@ -167,6 +167,7 @@ def validate_hyperparameters( model_version: str, hyperparameters: Dict[str, Any], validation_mode: HyperparameterValidationMode = HyperparameterValidationMode.VALIDATE_PROVIDED, + hub_arn: Optional[str] = None, region: Optional[str] = None, sagemaker_session: Optional[session.Session] = None, tolerate_vulnerable_model: bool = False, @@ -213,6 +214,7 @@ def validate_hyperparameters( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, region=region, scope=JumpStartScriptScope.TRAINING, sagemaker_session=sagemaker_session, diff --git a/src/sagemaker/metric_definitions.py b/src/sagemaker/metric_definitions.py index 71dd26db45..a31d5d930d 100644 --- a/src/sagemaker/metric_definitions.py +++ b/src/sagemaker/metric_definitions.py @@ -29,6 +29,7 @@ def retrieve_default( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, instance_type: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -43,6 +44,8 @@ def retrieve_default( retrieve the default training metric definitions. (Default: None). model_version (str): The version of the model for which to retrieve the default training metric definitions. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (default: None). instance_type (str): An instance type to optionally supply in order to get metric definitions specific for the instance type. tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -71,6 +74,7 @@ def retrieve_default( return artifacts._retrieve_default_training_metric_definitions( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, instance_type=instance_type, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index b6848800dd..09e04ad840 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -164,6 +164,7 @@ def __init__( dependencies: Optional[List[str]] = None, git_config: Optional[Dict[str, str]] = None, resources: Optional[ResourceRequirements] = None, + model_reference_arn: Optional[str] = None, ): """Initialize an SageMaker ``Model``. @@ -327,6 +328,8 @@ def __init__( for a model to be deployed to an endpoint. Only EndpointType.INFERENCE_COMPONENT_BASED supports this feature. (Default: None). + model_reference_arn (Optional [str]): Hub Content Arn of a Model Reference type + content (default: None). """ self.model_data = model_data @@ -405,6 +408,7 @@ def __init__( self.content_types = None self.response_types = None self.accept_eula = None + self.model_reference_arn = model_reference_arn @classmethod def attach( @@ -586,6 +590,7 @@ def create( serverless_inference_config: Optional[ServerlessInferenceConfig] = None, tags: Optional[Tags] = None, accept_eula: Optional[bool] = None, + model_reference_arn: Optional[str] = None, ): """Create a SageMaker Model Entity @@ -627,6 +632,7 @@ def create( tags=format_tags(tags), serverless_inference_config=serverless_inference_config, accept_eula=accept_eula, + model_reference_arn=model_reference_arn, ) def _init_sagemaker_session_if_does_not_exist(self, instance_type=None): @@ -648,6 +654,7 @@ def prepare_container_def( accelerator_type=None, serverless_inference_config=None, accept_eula=None, + model_reference_arn=None, ): # pylint: disable=unused-argument """Return a dict created by ``sagemaker.container_def()``. @@ -690,6 +697,11 @@ def prepare_container_def( accept_eula=( accept_eula if accept_eula is not None else getattr(self, "accept_eula", None) ), + model_reference_arn=( + model_reference_arn + if model_reference_arn is not None + else getattr(self, "model_reference_arn", None) + ), ) def is_repack(self) -> bool: @@ -832,6 +844,7 @@ def _create_sagemaker_model( tags: Optional[Tags] = None, serverless_inference_config=None, accept_eula=None, + model_reference_arn: Optional[str] = None, ): """Create a SageMaker Model Entity @@ -856,6 +869,8 @@ def _create_sagemaker_model( The `accept_eula` value must be explicitly defined as `True` in order to accept the end-user license agreement (EULA) that some models require. (Default: None). + model_reference_arn (Optional [str]): Hub Content Arn of a Model Reference type + content (default: None). """ if self.model_package_arn is not None or self.algorithm_arn is not None: model_package = ModelPackage( @@ -887,6 +902,7 @@ def _create_sagemaker_model( accelerator_type=accelerator_type, serverless_inference_config=serverless_inference_config, accept_eula=accept_eula, + model_reference_arn=model_reference_arn, ) if not isinstance(self.sagemaker_session, PipelineSession): @@ -1650,6 +1666,7 @@ def deploy( accelerator_type=accelerator_type, tags=tags, serverless_inference_config=serverless_inference_config, + **kwargs, ) serverless_inference_config_dict = ( serverless_inference_config._to_request_dict() if is_serverless else None diff --git a/src/sagemaker/model_uris.py b/src/sagemaker/model_uris.py index 937180bd44..a2177c0ec5 100644 --- a/src/sagemaker/model_uris.py +++ b/src/sagemaker/model_uris.py @@ -29,6 +29,7 @@ def retrieve( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, model_scope: Optional[str] = None, instance_type: Optional[str] = None, tolerate_vulnerable_model: bool = False, @@ -43,6 +44,8 @@ def retrieve( the model artifact S3 URI. model_version (str): The version of the JumpStart model for which to retrieve the model artifact S3 URI. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (default: None). model_scope (str): The model type. Valid values: "training" and "inference". instance_type (str): The ML compute instance type for the specified scope. (Default: None). @@ -75,6 +78,7 @@ def retrieve( return artifacts._retrieve_model_uri( model_id=model_id, model_version=model_version, # type: ignore + hub_arn=hub_arn, model_scope=model_scope, instance_type=instance_type, region=region, diff --git a/src/sagemaker/multidatamodel.py b/src/sagemaker/multidatamodel.py index 9c1e6ac4f4..9ed348c927 100644 --- a/src/sagemaker/multidatamodel.py +++ b/src/sagemaker/multidatamodel.py @@ -126,6 +126,7 @@ def prepare_container_def( accelerator_type=None, serverless_inference_config=None, accept_eula=None, + model_reference_arn=None, ): """Return a container definition set. @@ -154,6 +155,7 @@ def prepare_container_def( model_data_url=self.model_data_prefix, container_mode=self.container_mode, accept_eula=accept_eula, + model_reference_arn=model_reference_arn, ) def deploy( diff --git a/src/sagemaker/mxnet/model.py b/src/sagemaker/mxnet/model.py index 8d389e9f59..487d336497 100644 --- a/src/sagemaker/mxnet/model.py +++ b/src/sagemaker/mxnet/model.py @@ -284,6 +284,7 @@ def prepare_container_def( accelerator_type=None, serverless_inference_config=None, accept_eula=None, + model_reference_arn=None, ): """Return a container definition with framework configuration. @@ -337,6 +338,7 @@ def prepare_container_def( self.repacked_model_data or self.model_data, deploy_env, accept_eula=accept_eula, + model_reference_arn=model_reference_arn, ) def serving_image_uri( diff --git a/src/sagemaker/payloads.py b/src/sagemaker/payloads.py index 06d2ecfcde..403445525b 100644 --- a/src/sagemaker/payloads.py +++ b/src/sagemaker/payloads.py @@ -32,6 +32,7 @@ def retrieve_all_examples( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, serialize: bool = False, tolerate_vulnerable_model: bool = False, @@ -78,11 +79,12 @@ def retrieve_all_examples( unserialized_payload_dict: Optional[Dict[str, JumpStartSerializablePayload]] = ( artifacts._retrieve_example_payloads( - model_id, - model_version, - region, - tolerate_vulnerable_model, - tolerate_deprecated_model, + model_id=model_id, + model_version=model_version, + region=region, + hub_arn=hub_arn, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, ) @@ -123,6 +125,7 @@ def retrieve_example( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, serialize: bool = False, tolerate_vulnerable_model: bool = False, @@ -168,6 +171,7 @@ def retrieve_example( region=region, model_id=model_id, model_version=model_version, + hub_arn=hub_arn, model_type=model_type, serialize=serialize, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/predictor.py b/src/sagemaker/predictor.py index 6f846bba65..d3f41bd9a6 100644 --- a/src/sagemaker/predictor.py +++ b/src/sagemaker/predictor.py @@ -40,6 +40,7 @@ def retrieve_default( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, @@ -58,6 +59,8 @@ def retrieve_default( retrieve the default predictor. (Default: None). model_version (str): The version of the model for which to retrieve the default predictor. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -105,6 +108,7 @@ def retrieve_default( predictor=predictor, model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, tolerate_deprecated_model=tolerate_deprecated_model, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/pytorch/model.py b/src/sagemaker/pytorch/model.py index 6d915772cd..92b96bd8c8 100644 --- a/src/sagemaker/pytorch/model.py +++ b/src/sagemaker/pytorch/model.py @@ -286,6 +286,7 @@ def prepare_container_def( accelerator_type=None, serverless_inference_config=None, accept_eula=None, + model_reference_arn=None, ): """A container definition with framework configuration set in model environment variables. @@ -337,6 +338,7 @@ def prepare_container_def( self.repacked_model_data or self.model_data, deploy_env, accept_eula=accept_eula, + model_reference_arn=model_reference_arn, ) def serving_image_uri( diff --git a/src/sagemaker/resource_requirements.py b/src/sagemaker/resource_requirements.py index df14ac558f..396a158939 100644 --- a/src/sagemaker/resource_requirements.py +++ b/src/sagemaker/resource_requirements.py @@ -31,6 +31,7 @@ def retrieve_default( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, scope: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -47,6 +48,8 @@ def retrieve_default( retrieve the default resource requirements. (Default: None). model_version (str): The version of the model for which to retrieve the default resource requirements. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). scope (str): The model type, i.e. what it is used for. Valid values: "training" and "inference". tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -78,12 +81,13 @@ def retrieve_default( raise ValueError("Must specify scope for resource requirements.") return artifacts._retrieve_default_resources( - model_id, - model_version, - scope, - region, - tolerate_vulnerable_model, - tolerate_deprecated_model, + model_id=model_id, + model_version=model_version, + hub_arn=hub_arn, + scope=scope, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, model_type=model_type, sagemaker_session=sagemaker_session, instance_type=instance_type, diff --git a/src/sagemaker/script_uris.py b/src/sagemaker/script_uris.py index 9a1c4933d2..91a5a97b1f 100644 --- a/src/sagemaker/script_uris.py +++ b/src/sagemaker/script_uris.py @@ -29,6 +29,7 @@ def retrieve( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, script_scope: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -42,6 +43,8 @@ def retrieve( retrieve the script S3 URI. model_version (str): The version of the JumpStart model for which to retrieve the model script S3 URI. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). script_scope (str): The script type. Valid values: "training" and "inference". tolerate_vulnerable_model (bool): ``True`` if vulnerable versions of model @@ -71,11 +74,12 @@ def retrieve( ) return artifacts._retrieve_script_uri( - model_id, - model_version, - script_scope, - region, - tolerate_vulnerable_model, - tolerate_deprecated_model, + model_id=model_id, + model_version=model_version, + hub_arn=hub_arn, + script_scope=script_scope, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, ) diff --git a/src/sagemaker/serializers.py b/src/sagemaker/serializers.py index aefb52bd97..4ffd121ad8 100644 --- a/src/sagemaker/serializers.py +++ b/src/sagemaker/serializers.py @@ -42,6 +42,7 @@ def retrieve_options( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -55,6 +56,8 @@ def retrieve_options( retrieve the supported serializers. (Default: None). model_version (str): The version of the model for which to retrieve the supported serializers. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -79,11 +82,12 @@ def retrieve_options( ) return artifacts._retrieve_serializer_options( - model_id, - model_version, - region, - tolerate_vulnerable_model, - tolerate_deprecated_model, + model_id=model_id, + model_version=model_version, + hub_arn=hub_arn, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, ) @@ -92,6 +96,7 @@ def retrieve_default( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, @@ -106,6 +111,8 @@ def retrieve_default( retrieve the default serializer. (Default: None). model_version (str): The version of the model for which to retrieve the default serializer. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -130,11 +137,12 @@ def retrieve_default( ) return artifacts._retrieve_default_serializer( - model_id, - model_version, - region, - tolerate_vulnerable_model, - tolerate_deprecated_model, + model_id=model_id, + model_version=model_version, + hub_arn=hub_arn, + region=region, + tolerate_vulnerable_model=tolerate_vulnerable_model, + tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, model_type=model_type, ) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 205b14a3e6..cd36cc739c 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -6774,6 +6774,293 @@ def create_presigned_mlflow_tracking_server_url( **create_presigned_url_args ) + def create_hub( + self, + hub_name: str, + hub_description: str, + hub_display_name: str = None, + hub_search_keywords: List[str] = None, + s3_storage_config: Dict[str, Any] = None, + tags: List[Dict[str, Any]] = None, + ) -> Dict[str, str]: + """Creates a SageMaker Hub + + Args: + hub_name (str): The name of the Hub to create. + hub_description (str): A description of the Hub. + hub_display_name (str): The display name of the Hub. + hub_search_keywords (list): The searchable keywords for the Hub. + s3_storage_config (S3StorageConfig): The Amazon S3 storage configuration for the Hub. + tags (list): Any tags to associate with the Hub. + + Returns: + (dict): Return value from the ``CreateHub`` API. + """ + request = {"HubName": hub_name, "HubDescription": hub_description} + + if hub_display_name: + request["HubDisplayName"] = hub_display_name + else: + request["HubDisplayName"] = hub_name + + if hub_search_keywords: + request["HubSearchKeywords"] = hub_search_keywords + if s3_storage_config: + request["S3StorageConfig"] = s3_storage_config + if tags: + request["Tags"] = tags + + return self.sagemaker_client.create_hub(**request) + + def describe_hub(self, hub_name: str) -> Dict[str, Any]: + """Describes a SageMaker Hub + + Args: + hub_name (str): The name of the hub to describe. + + Returns: + (dict): Return value for ``DescribeHub`` API + """ + request = {"HubName": hub_name} + + return self.sagemaker_client.describe_hub(**request) + + def list_hubs( + self, + creation_time_after: str = None, + creation_time_before: str = None, + max_results: int = None, + max_schema_version: str = None, + name_contains: str = None, + next_token: str = None, + sort_by: str = None, + sort_order: str = None, + ) -> Dict[str, Any]: + """Lists all existing SageMaker Hubs + + Args: + creation_time_after (str): Only list HubContent that was created after + the time specified. + creation_time_before (str): Only list HubContent that was created + before the time specified. + max_results (int): The maximum amount of HubContent to list. + max_schema_version (str): The upper bound of the HubContentSchemaVersion. + name_contains (str): Only list HubContent if the name contains the specified string. + next_token (str): If the response to a previous ``ListHubContents`` request was + truncated, the response includes a ``NextToken``. To retrieve the next set of + hub content, use the token in the next request. + sort_by (str): Sort HubContent versions by either name or creation time. + sort_order (str): Sort Hubs by ascending or descending order. + Returns: + (dict): Return value for ``ListHubs`` API + """ + request = {} + if creation_time_after: + request["CreationTimeAfter"] = creation_time_after + if creation_time_before: + request["CreationTimeBefore"] = creation_time_before + if max_results: + request["MaxResults"] = max_results + if max_schema_version: + request["MaxSchemaVersion"] = max_schema_version + if name_contains: + request["NameContains"] = name_contains + if next_token: + request["NextToken"] = next_token + if sort_by: + request["SortBy"] = sort_by + if sort_order: + request["SortOrder"] = sort_order + + return self.sagemaker_client.list_hubs(**request) + + def list_hub_contents( + self, + hub_name: str, + hub_content_type: str, + creation_time_after: str = None, + creation_time_before: str = None, + max_results: int = None, + max_schema_version: str = None, + name_contains: str = None, + next_token: str = None, + sort_by: str = None, + sort_order: str = None, + ) -> Dict[str, Any]: + """Lists the HubContents in a SageMaker Hub + + Args: + hub_name (str): The name of the Hub to list the contents of. + hub_content_type (str): The type of the HubContent to list. + creation_time_after (str): Only list HubContent that was created after the + time specified. + creation_time_before (str): Only list HubContent that was created before the + time specified. + max_results (int): The maximum amount of HubContent to list. + max_schema_version (str): The upper bound of the HubContentSchemaVersion. + name_contains (str): Only list HubContent if the name contains the specified string. + next_token (str): If the response to a previous ``ListHubContents`` request was + truncated, the response includes a ``NextToken``. To retrieve the next set of + hub content, use the token in the next request. + sort_by (str): Sort HubContent versions by either name or creation time. + sort_order (str): Sort Hubs by ascending or descending order. + Returns: + (dict): Return value for ``ListHubContents`` API + """ + request = {"HubName": hub_name, "HubContentType": hub_content_type} + if creation_time_after: + request["CreationTimeAfter"] = creation_time_after + if creation_time_before: + request["CreationTimeBefore"] = creation_time_before + if max_results: + request["MaxResults"] = max_results + if max_schema_version: + request["MaxSchemaVersion"] = max_schema_version + if name_contains: + request["NameContains"] = name_contains + if next_token: + request["NextToken"] = next_token + if sort_by: + request["SortBy"] = sort_by + if sort_order: + request["SortOrder"] = sort_order + + return self.sagemaker_client.list_hub_contents(**request) + + def delete_hub(self, hub_name: str) -> None: + """Deletes a SageMaker Hub + + Args: + hub_name (str): The name of the hub to delete. + """ + request = {"HubName": hub_name} + + return self.sagemaker_client.delete_hub(**request) + + def create_hub_content_reference( + self, + hub_name: str, + source_hub_content_arn: str, + hub_content_name: str = None, + min_version: str = None, + ) -> Dict[str, str]: + """Creates a given HubContent reference in a SageMaker Hub + + Args: + hub_name (str): The name of the Hub that you want to delete content in. + source_hub_content_arn (str): Hub content arn in the public/source Hub. + hub_content_name (str): The name of the reference that you want to add to the Hub. + min_version (str): A minimum version of the hub content to add to the Hub. + + Returns: + (dict): Return value for ``CreateHubContentReference`` API + """ + + request = {"HubName": hub_name, "SageMakerPublicHubContentArn": source_hub_content_arn} + + if hub_content_name: + request["HubContentName"] = hub_content_name + if min_version: + request["MinVersion"] = min_version + + return self.sagemaker_client.create_hub_content_reference(**request) + + def delete_hub_content_reference( + self, hub_name: str, hub_content_type: str, hub_content_name: str + ) -> None: + """Deletes a given HubContent reference in a SageMaker Hub + + Args: + hub_name (str): The name of the Hub that you want to delete content in. + hub_content_type (str): The type of the content that you want to delete from a Hub. + hub_content_name (str): The name of the content that you want to delete from a Hub. + """ + request = { + "HubName": hub_name, + "HubContentType": hub_content_type, + "HubContentName": hub_content_name, + } + + return self.sagemaker_client.delete_hub_content_reference(**request) + + def describe_hub_content( + self, + hub_content_name: str, + hub_content_type: str, + hub_name: str, + hub_content_version: str = None, + ) -> Dict[str, Any]: + """Describes a HubContent in a SageMaker Hub + + Args: + hub_content_name (str): The name of the HubContent to describe. + hub_content_type (str): The type of HubContent in the Hub. + hub_name (str): The name of the Hub that contains the HubContent to describe. + hub_content_version (str): The version of the HubContent to describe + + Returns: + (dict): Return value for ``DescribeHubContent`` API + """ + request = { + "HubContentName": hub_content_name, + "HubContentType": hub_content_type, + "HubName": hub_name, + } + if hub_content_version: + request["HubContentVersion"] = hub_content_version + + return self.sagemaker_client.describe_hub_content(**request) + + def list_hub_content_versions( + self, + hub_name, + hub_content_type: str, + hub_content_name: str, + min_version: str = None, + max_schema_version: str = None, + creation_time_after: str = None, + creation_time_before: str = None, + max_results: int = None, + next_token: str = None, + sort_by: str = None, + sort_order: str = None, + ) -> Dict[str, Any]: + """List all versions of a HubContent in a SageMaker Hub + + Args: + hub_content_name (str): The name of the HubContent to describe. + hub_content_type (str): The type of HubContent in the Hub. + hub_name (str): The name of the Hub that contains the HubContent to describe. + + Returns: + (dict): Return value for ``DescribeHubContent`` API + """ + + request = { + "HubName": hub_name, + "HubContentName": hub_content_name, + "HubContentType": hub_content_type, + } + + if min_version: + request["MinVersion"] = min_version + if creation_time_after: + request["CreationTimeAfter"] = creation_time_after + if creation_time_before: + request["CreationTimeBefore"] = creation_time_before + if max_results: + request["MaxResults"] = max_results + if max_schema_version: + request["MaxSchemaVersion"] = max_schema_version + if next_token: + request["NextToken"] = next_token + if sort_by: + request["SortBy"] = sort_by + if sort_order: + request["SortOrder"] = sort_order + + return self.sagemaker_client.list_hub_content_versions(**request) + def get_model_package_args( content_types=None, @@ -7223,6 +7510,7 @@ def container_def( container_mode=None, image_config=None, accept_eula=None, + model_reference_arn=None, ): """Create a definition for executing a container as part of a SageMaker model. @@ -7275,6 +7563,11 @@ def container_def( c_def["ModelDataSource"]["S3DataSource"]["ModelAccessConfig"] = { "AcceptEula": accept_eula } + if model_reference_arn: + c_def["ModelDataSource"]["S3DataSource"]["HubAccessConfig"] = { + "HubContentArn": model_reference_arn + } + elif model_data_url is not None: c_def["ModelDataUrl"] = model_data_url diff --git a/src/sagemaker/sklearn/model.py b/src/sagemaker/sklearn/model.py index 82d9510e53..1ab28eac37 100644 --- a/src/sagemaker/sklearn/model.py +++ b/src/sagemaker/sklearn/model.py @@ -279,6 +279,7 @@ def prepare_container_def( accelerator_type=None, serverless_inference_config=None, accept_eula=None, + model_reference_arn=None, ): """Container definition with framework configuration set in model environment variables. @@ -328,6 +329,7 @@ def prepare_container_def( model_data_uri, deploy_env, accept_eula=accept_eula, + model_reference_arn=model_reference_arn, ) def serving_image_uri(self, region_name, instance_type, serverless_inference_config=None): diff --git a/src/sagemaker/tensorflow/model.py b/src/sagemaker/tensorflow/model.py index 4a22f1abcb..c06fe74887 100644 --- a/src/sagemaker/tensorflow/model.py +++ b/src/sagemaker/tensorflow/model.py @@ -397,6 +397,7 @@ def prepare_container_def( accelerator_type=None, serverless_inference_config=None, accept_eula=None, + model_reference_arn=None, ): """Prepare the container definition. @@ -473,6 +474,7 @@ def prepare_container_def( model_data, env, accept_eula=accept_eula, + model_reference_arn=model_reference_arn, ) def _get_container_env(self): diff --git a/src/sagemaker/xgboost/model.py b/src/sagemaker/xgboost/model.py index 157f3cb8fd..6d69801847 100644 --- a/src/sagemaker/xgboost/model.py +++ b/src/sagemaker/xgboost/model.py @@ -267,6 +267,7 @@ def prepare_container_def( accelerator_type=None, serverless_inference_config=None, accept_eula=None, + model_reference_arn=None, ): """Return a container definition with framework configuration. @@ -314,6 +315,7 @@ def prepare_container_def( model_data, deploy_env, accept_eula=accept_eula, + model_reference_arn=model_reference_arn, ) def serving_image_uri(self, region_name, instance_type, serverless_inference_config=None): diff --git a/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py b/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py index 11165a0625..91c132f053 100644 --- a/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py +++ b/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py @@ -13,7 +13,7 @@ from __future__ import absolute_import import boto3 -from mock.mock import patch, Mock +from mock.mock import patch, Mock, ANY from sagemaker import accept_types from sagemaker.jumpstart.utils import verify_model_region_and_return_specs @@ -54,9 +54,11 @@ def test_jumpstart_default_accept_types( patched_get_model_specs.assert_called_once_with( region=region, model_id=model_id, + hub_arn=None, version=model_version, - s3_client=mock_client, + s3_client=ANY, model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session=mock_session, ) @@ -91,6 +93,8 @@ def test_jumpstart_supported_accept_types( region=region, model_id=model_id, version=model_version, - s3_client=mock_client, + s3_client=ANY, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) diff --git a/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py b/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py index d116c8121b..dda1e30db2 100644 --- a/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py +++ b/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py @@ -56,6 +56,8 @@ def test_jumpstart_default_content_types( version=model_version, s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) @@ -75,15 +77,12 @@ def test_jumpstart_supported_content_types( model_id, model_version = "predictor-specs-model", "*" region = "us-west-2" - supported_content_types = content_types.retrieve_options( + content_types.retrieve_options( region=region, model_id=model_id, model_version=model_version, sagemaker_session=mock_session, ) - assert supported_content_types == [ - "application/x-text", - ] patched_get_model_specs.assert_called_once_with( region=region, @@ -91,4 +90,6 @@ def test_jumpstart_supported_content_types( version=model_version, s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) diff --git a/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py b/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py index f0102068e7..9bbca51654 100644 --- a/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py +++ b/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py @@ -58,6 +58,8 @@ def test_jumpstart_default_deserializers( version=model_version, s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) @@ -98,4 +100,6 @@ def test_jumpstart_deserializer_options( version=model_version, s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) diff --git a/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py b/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py index 5f00f93abf..13f720870c 100644 --- a/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py +++ b/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py @@ -61,6 +61,8 @@ def test_jumpstart_default_environment_variables( version="*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) patched_get_model_specs.reset_mock() @@ -85,6 +87,8 @@ def test_jumpstart_default_environment_variables( version="1.*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) patched_get_model_specs.reset_mock() @@ -147,6 +151,8 @@ def test_jumpstart_sdk_environment_variables( version="*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) patched_get_model_specs.reset_mock() @@ -172,6 +178,8 @@ def test_jumpstart_sdk_environment_variables( version="1.*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py b/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py index 40ee4978cf..565ebbce87 100644 --- a/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py +++ b/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py @@ -54,6 +54,8 @@ def test_jumpstart_default_hyperparameters( version="*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) patched_get_model_specs.reset_mock() @@ -72,6 +74,8 @@ def test_jumpstart_default_hyperparameters( version="1.*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) patched_get_model_specs.reset_mock() @@ -98,6 +102,8 @@ def test_jumpstart_default_hyperparameters( version="1.*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py b/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py index fdc29b4d90..edf2cfca59 100644 --- a/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py +++ b/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py @@ -146,6 +146,8 @@ def add_options_to_hyperparameter(*largs, **kwargs): version=model_version, s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) patched_get_model_specs.reset_mock() @@ -452,6 +454,8 @@ def add_options_to_hyperparameter(*largs, **kwargs): version=model_version, s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) patched_get_model_specs.reset_mock() @@ -514,6 +518,8 @@ def test_jumpstart_validate_all_hyperparameters( version=model_version, s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/image_uris/jumpstart/test_common.py b/tests/unit/sagemaker/image_uris/jumpstart/test_common.py index 88b95b9403..cc3723c3c5 100644 --- a/tests/unit/sagemaker/image_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/image_uris/jumpstart/test_common.py @@ -53,9 +53,11 @@ def test_jumpstart_common_image_uri( patched_get_model_specs.assert_called_once_with( region="us-west-2", model_id="pytorch-ic-mobilenet-v2", + hub_arn=None, version="*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session=mock_session, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -74,9 +76,11 @@ def test_jumpstart_common_image_uri( patched_get_model_specs.assert_called_once_with( region="us-west-2", model_id="pytorch-ic-mobilenet-v2", + hub_arn=None, version="1.*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session=mock_session, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -95,9 +99,11 @@ def test_jumpstart_common_image_uri( patched_get_model_specs.assert_called_once_with( region=sagemaker_constants.JUMPSTART_DEFAULT_REGION_NAME, model_id="pytorch-ic-mobilenet-v2", + hub_arn=None, version="*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session=mock_session, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -116,9 +122,11 @@ def test_jumpstart_common_image_uri( patched_get_model_specs.assert_called_once_with( region=sagemaker_constants.JUMPSTART_DEFAULT_REGION_NAME, model_id="pytorch-ic-mobilenet-v2", + hub_arn=None, version="1.*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session=mock_session, ) patched_verify_model_region_and_return_specs.assert_called_once() diff --git a/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py b/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py index 2e51afd3f7..5db149c4c3 100644 --- a/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py +++ b/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py @@ -51,6 +51,8 @@ def test_jumpstart_instance_types(patched_get_model_specs, patched_validate_mode version=model_version, s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session=mock_session, + hub_arn=None, ) patched_get_model_specs.reset_mock() @@ -70,6 +72,8 @@ def test_jumpstart_instance_types(patched_get_model_specs, patched_validate_mode version=model_version, s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) patched_get_model_specs.reset_mock() @@ -95,6 +99,8 @@ def test_jumpstart_instance_types(patched_get_model_specs, patched_validate_mode version=model_version, s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) patched_get_model_specs.reset_mock() @@ -122,6 +128,8 @@ def test_jumpstart_instance_types(patched_get_model_specs, patched_validate_mode version=model_version, s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/jumpstart/constants.py b/tests/unit/sagemaker/jumpstart/constants.py index f165a513a9..6298a06db2 100644 --- a/tests/unit/sagemaker/jumpstart/constants.py +++ b/tests/unit/sagemaker/jumpstart/constants.py @@ -1250,288 +1250,557 @@ "dynamic_container_deployment_supported": True, }, }, - "env-var-variant-model": { - "model_id": "huggingface-llm-falcon-180b-bf16", - "url": "https://huggingface.co/tiiuae/falcon-180B", - "version": "1.0.0", - "min_sdk_version": "2.175.0", - "training_supported": False, + # noqa: E501 + "gemma-model-2b-v1_1_0": { + "model_id": "huggingface-llm-gemma-2b-instruct", + "url": "https://huggingface.co/google/gemma-2b-it", + "version": "1.1.0", + "min_sdk_version": "2.189.0", + "training_supported": True, "incremental_training_supported": False, "hosting_ecr_specs": { "framework": "huggingface-llm", - "framework_version": "0.9.3", - "py_version": "py39", - "huggingface_transformers_version": "4.29.2", + "framework_version": "1.4.2", + "py_version": "py310", + "huggingface_transformers_version": "4.33.2", }, - "hosting_artifact_key": "huggingface-infer/infer-huggingface-llm-falcon-180b-bf16.tar.gz", + "hosting_artifact_key": "huggingface-llm/huggingface-llm-gemma-2b-instruct/artifacts/inference/v1.0.0/", "hosting_script_key": "source-directory-tarballs/huggingface/inference/llm/v1.0.1/sourcedir.tar.gz", - "hosting_prepacked_artifact_key": "huggingface-infer/prepack/v1.0.1/infer-prepack" - "-huggingface-llm-falcon-180b-bf16.tar.gz", - "hosting_prepacked_artifact_version": "1.0.1", + "hosting_prepacked_artifact_key": ( + "huggingface-llm/huggingface-llm-gemma-2b-instruct/artifacts/inference-prepack/v1.0.0/" + ), + "hosting_prepacked_artifact_version": "1.0.0", "hosting_use_script_uri": False, + "hosting_eula_key": "fmhMetadata/terms/gemmaTerms.txt", "inference_vulnerable": False, "inference_dependencies": [], "inference_vulnerabilities": [], "training_vulnerable": False, - "training_dependencies": [], + "training_dependencies": [ + "accelerate==0.26.1", + "bitsandbytes==0.42.0", + "deepspeed==0.10.3", + "docstring-parser==0.15", + "flash_attn==2.5.5", + "ninja==1.11.1", + "packaging==23.2", + "peft==0.8.2", + "py_cpuinfo==9.0.0", + "rich==13.7.0", + "safetensors==0.4.2", + "sagemaker_jumpstart_huggingface_script_utilities==1.2.1", + "sagemaker_jumpstart_script_utilities==1.1.9", + "sagemaker_jumpstart_tabular_script_utilities==1.0.0", + "shtab==1.6.5", + "tokenizers==0.15.1", + "transformers==4.38.1", + "trl==0.7.10", + "tyro==0.7.2", + ], "training_vulnerabilities": [], "deprecated": False, - "inference_environment_variables": [ - { - "name": "SAGEMAKER_PROGRAM", - "type": "text", - "default": "inference.py", - "scope": "container", - "required_for_model_class": True, - }, + "hyperparameters": [ { - "name": "SAGEMAKER_SUBMIT_DIRECTORY", + "name": "peft_type", "type": "text", - "default": "/opt/ml/model/code", - "scope": "container", - "required_for_model_class": False, + "default": "lora", + "options": ["lora", "None"], + "scope": "algorithm", }, { - "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", + "name": "instruction_tuned", "type": "text", - "default": "20", - "scope": "container", - "required_for_model_class": False, + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", }, { - "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "name": "chat_dataset", "type": "text", - "default": "3600", - "scope": "container", - "required_for_model_class": False, + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", }, { - "name": "ENDPOINT_SERVER_TIMEOUT", + "name": "epoch", "type": "int", - "default": 3600, - "scope": "container", - "required_for_model_class": True, - }, - { - "name": "MODEL_CACHE_ROOT", - "type": "text", - "default": "/opt/ml/model", - "scope": "container", - "required_for_model_class": True, + "default": 1, + "min": 1, + "max": 1000, + "scope": "algorithm", }, { - "name": "SAGEMAKER_ENV", - "type": "text", - "default": "1", - "scope": "container", - "required_for_model_class": True, + "name": "learning_rate", + "type": "float", + "default": 0.0001, + "min": 1e-08, + "max": 1, + "scope": "algorithm", }, { - "name": "HF_MODEL_ID", - "type": "text", - "default": "/opt/ml/model", - "scope": "container", - "required_for_model_class": True, + "name": "lora_r", + "type": "int", + "default": 64, + "min": 1, + "max": 1000, + "scope": "algorithm", }, + {"name": "lora_alpha", "type": "int", "default": 16, "min": 0, "scope": "algorithm"}, { - "name": "SM_NUM_GPUS", - "type": "text", - "default": "8", - "scope": "container", - "required_for_model_class": True, + "name": "lora_dropout", + "type": "float", + "default": 0, + "min": 0, + "max": 1, + "scope": "algorithm", }, + {"name": "bits", "type": "int", "default": 4, "scope": "algorithm"}, { - "name": "MAX_INPUT_LENGTH", + "name": "double_quant", "type": "text", - "default": "1024", - "scope": "container", - "required_for_model_class": True, + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", }, { - "name": "MAX_TOTAL_TOKENS", + "name": "quant_type", "type": "text", - "default": "2048", - "scope": "container", - "required_for_model_class": True, + "default": "nf4", + "options": ["fp4", "nf4"], + "scope": "algorithm", }, { - "name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "name": "per_device_train_batch_size", "type": "int", "default": 1, - "scope": "container", - "required_for_model_class": True, + "min": 1, + "max": 1000, + "scope": "algorithm", }, - ], - "metrics": [], - "default_inference_instance_type": "ml.p4de.24xlarge", - "supported_inference_instance_types": ["ml.p4de.24xlarge"], - "model_kwargs": {}, - "deploy_kwargs": { - "model_data_download_timeout": 3600, - "container_startup_health_check_timeout": 3600, - }, - "predictor_specs": { - "supported_content_types": ["application/json"], - "supported_accept_types": ["application/json"], - "default_content_type": "application/json", - "default_accept_type": "application/json", - }, - "inference_volume_size": 512, - "inference_enable_network_isolation": True, - "validation_supported": False, - "fine_tuning_supported": False, - "resource_name_base": "hf-llm-falcon-180b-bf16", - "hosting_instance_type_variants": { - "regional_aliases": { - "us-west-2": { - "gpu_image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" - "huggingface-pytorch-inference:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04", - "cpu_image_uri": "867930986793.dkr.us-west-2.amazonaws.com/cpu-blah", - } + { + "name": "per_device_eval_batch_size", + "type": "int", + "default": 2, + "min": 1, + "max": 1000, + "scope": "algorithm", }, - "variants": { - "g4dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, - "g5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, - "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, - "p2": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, - "p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, - "p3dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, - "p4d": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, - "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, - "p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, - "ml.g5.48xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "80"}}}, - "ml.p4d.24xlarge": { - "properties": { - "environment_variables": { - "YODEL": "NACEREMA", - } - } - }, + { + "name": "warmup_ratio", + "type": "float", + "default": 0.1, + "min": 0, + "max": 1, + "scope": "algorithm", }, - }, - }, - "inference-instance-types-variant-model": { - "model_id": "huggingface-llm-falcon-180b-bf16", - "url": "https://huggingface.co/tiiuae/falcon-180B", - "version": "1.0.0", - "min_sdk_version": "2.175.0", - "training_supported": True, - "incremental_training_supported": False, - "hosting_ecr_specs": { - "framework": "huggingface-llm", - "framework_version": "0.9.3", - "py_version": "py39", - "huggingface_transformers_version": "4.29.2", - }, - "hosting_artifact_key": "huggingface-infer/infer-huggingface-llm-falcon-180b-bf16.tar.gz", - "hosting_script_key": "source-directory-tarballs/huggingface/inference/llm/v1.0.1/sourcedir.tar.gz", - "hosting_prepacked_artifact_key": "huggingface-infer/prepack/v1.0.1/infer-prepack" - "-huggingface-llm-falcon-180b-bf16.tar.gz", - "hosting_prepacked_artifact_version": "1.0.1", - "hosting_use_script_uri": False, - "inference_vulnerable": False, - "inference_dependencies": [], - "inference_vulnerabilities": [], - "training_vulnerable": False, - "training_dependencies": [], - "training_vulnerabilities": [], - "deprecated": False, - "inference_environment_variables": [ { - "name": "SAGEMAKER_PROGRAM", + "name": "train_from_scratch", "type": "text", - "default": "inference.py", - "scope": "container", - "required_for_model_class": True, + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", }, { - "name": "SAGEMAKER_SUBMIT_DIRECTORY", + "name": "fp16", "type": "text", - "default": "/opt/ml/model/code", - "scope": "container", - "required_for_model_class": False, + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", }, { - "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", + "name": "bf16", "type": "text", - "default": "20", - "scope": "container", - "required_for_model_class": False, + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", }, { - "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "name": "evaluation_strategy", "type": "text", - "default": "3600", - "scope": "container", - "required_for_model_class": False, + "default": "steps", + "options": ["steps", "epoch", "no"], + "scope": "algorithm", }, { - "name": "ENDPOINT_SERVER_TIMEOUT", + "name": "eval_steps", "type": "int", - "default": 3600, - "scope": "container", - "required_for_model_class": True, + "default": 20, + "min": 1, + "max": 1000, + "scope": "algorithm", }, { - "name": "MODEL_CACHE_ROOT", + "name": "gradient_accumulation_steps", + "type": "int", + "default": 4, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "logging_steps", + "type": "int", + "default": 8, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "weight_decay", + "type": "float", + "default": 0.2, + "min": 1e-08, + "max": 1, + "scope": "algorithm", + }, + { + "name": "load_best_model_at_end", "type": "text", - "default": "/opt/ml/model", - "scope": "container", - "required_for_model_class": True, + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", }, { - "name": "SAGEMAKER_ENV", + "name": "max_train_samples", + "type": "int", + "default": -1, + "min": -1, + "scope": "algorithm", + }, + { + "name": "max_val_samples", + "type": "int", + "default": -1, + "min": -1, + "scope": "algorithm", + }, + { + "name": "seed", + "type": "int", + "default": 10, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "max_input_length", + "type": "int", + "default": 1024, + "min": -1, + "scope": "algorithm", + }, + { + "name": "validation_split_ratio", + "type": "float", + "default": 0.2, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "train_data_split_seed", + "type": "int", + "default": 0, + "min": 0, + "scope": "algorithm", + }, + { + "name": "preprocessing_num_workers", "type": "text", - "default": "1", - "scope": "container", - "required_for_model_class": True, + "default": "None", + "scope": "algorithm", }, + {"name": "max_steps", "type": "int", "default": -1, "scope": "algorithm"}, { - "name": "HF_MODEL_ID", + "name": "gradient_checkpointing", "type": "text", - "default": "/opt/ml/model", - "scope": "container", - "required_for_model_class": True, + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", }, { - "name": "SM_NUM_GPUS", + "name": "early_stopping_patience", + "type": "int", + "default": 3, + "min": 1, + "scope": "algorithm", + }, + { + "name": "early_stopping_threshold", + "type": "float", + "default": 0.0, + "min": 0, + "scope": "algorithm", + }, + { + "name": "adam_beta1", + "type": "float", + "default": 0.9, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "adam_beta2", + "type": "float", + "default": 0.999, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "adam_epsilon", + "type": "float", + "default": 1e-08, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "max_grad_norm", + "type": "float", + "default": 1.0, + "min": 0, + "scope": "algorithm", + }, + { + "name": "label_smoothing_factor", + "type": "float", + "default": 0, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "logging_first_step", "type": "text", - "default": "8", - "scope": "container", - "required_for_model_class": True, + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", }, { - "name": "MAX_INPUT_LENGTH", + "name": "logging_nan_inf_filter", "type": "text", - "default": "1024", + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "save_strategy", + "type": "text", + "default": "steps", + "options": ["no", "epoch", "steps"], + "scope": "algorithm", + }, + {"name": "save_steps", "type": "int", "default": 500, "min": 1, "scope": "algorithm"}, + {"name": "save_total_limit", "type": "int", "default": 1, "scope": "algorithm"}, + { + "name": "dataloader_drop_last", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "dataloader_num_workers", + "type": "int", + "default": 0, + "min": 0, + "scope": "algorithm", + }, + { + "name": "eval_accumulation_steps", + "type": "text", + "default": "None", + "scope": "algorithm", + }, + { + "name": "auto_find_batch_size", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "lr_scheduler_type", + "type": "text", + "default": "constant_with_warmup", + "options": ["constant_with_warmup", "linear"], + "scope": "algorithm", + }, + {"name": "warmup_steps", "type": "int", "default": 0, "min": 0, "scope": "algorithm"}, + { + "name": "deepspeed", + "type": "text", + "default": "False", + "options": ["False"], + "scope": "algorithm", + }, + { + "name": "sagemaker_submit_directory", + "type": "text", + "default": "/opt/ml/input/data/code/sourcedir.tar.gz", "scope": "container", - "required_for_model_class": True, }, { - "name": "MAX_TOTAL_TOKENS", + "name": "sagemaker_program", "type": "text", - "default": "2048", + "default": "transfer_learning.py", "scope": "container", - "required_for_model_class": True, }, { - "name": "SAGEMAKER_MODEL_SERVER_WORKERS", - "type": "int", - "default": 1, + "name": "sagemaker_container_log_level", + "type": "text", + "default": "20", "scope": "container", - "required_for_model_class": True, }, ], - "metrics": [], - "default_inference_instance_type": "ml.p4de.24xlarge", - "supported_inference_instance_types": ["ml.p4de.24xlarge"], - "default_training_instance_type": "ml.p4de.24xlarge", - "supported_training_instance_types": ["ml.p4de.24xlarge"], - "model_kwargs": {}, + "training_script_key": "source-directory-tarballs/huggingface/transfer_learning/llm/v1.1.1/sourcedir.tar.gz", + "training_prepacked_script_key": ( + "source-directory-tarballs/huggingface/transfer_learning/llm/prepack/v1.1.1/sourcedir.tar.gz" + ), + "training_prepacked_script_version": "1.1.1", + "training_ecr_specs": { + "framework": "huggingface", + "framework_version": "2.0.0", + "py_version": "py310", + "huggingface_transformers_version": "4.28.1", + }, + "training_artifact_key": "huggingface-training/train-huggingface-llm-gemma-2b-instruct.tar.gz", + "inference_environment_variables": [ + { + "name": "SAGEMAKER_PROGRAM", + "type": "text", + "default": "inference.py", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_SUBMIT_DIRECTORY", + "type": "text", + "default": "/opt/ml/model/code", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", + "type": "text", + "default": "20", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "type": "text", + "default": "3600", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "ENDPOINT_SERVER_TIMEOUT", + "type": "int", + "default": 3600, + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MODEL_CACHE_ROOT", + "type": "text", + "default": "/opt/ml/model", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_ENV", + "type": "text", + "default": "1", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "HF_MODEL_ID", + "type": "text", + "default": "/opt/ml/model", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MAX_INPUT_LENGTH", + "type": "text", + "default": "8191", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MAX_TOTAL_TOKENS", + "type": "text", + "default": "8192", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MAX_BATCH_PREFILL_TOKENS", + "type": "text", + "default": "8191", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SM_NUM_GPUS", + "type": "text", + "default": "1", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "type": "int", + "default": 1, + "scope": "container", + "required_for_model_class": True, + }, + ], + "metrics": [ + { + "Name": "huggingface-textgeneration:eval-loss", + "Regex": "'eval_loss': ([0-9]+\\.[0-9]+)", + }, + {"Name": "huggingface-textgeneration:train-loss", "Regex": "'loss': ([0-9]+\\.[0-9]+)"}, + ], + "default_inference_instance_type": "ml.g5.xlarge", + "supported_inference_instance_types": [ + "ml.g5.xlarge", + "ml.g5.2xlarge", + "ml.g5.4xlarge", + "ml.g5.8xlarge", + "ml.g5.16xlarge", + "ml.g5.12xlarge", + "ml.g5.24xlarge", + "ml.g5.48xlarge", + "ml.p4d.24xlarge", + ], + "default_training_instance_type": "ml.g5.2xlarge", + "supported_training_instance_types": [ + "ml.g5.2xlarge", + "ml.g5.4xlarge", + "ml.g5.8xlarge", + "ml.g5.16xlarge", + "ml.g5.12xlarge", + "ml.g5.24xlarge", + "ml.g5.48xlarge", + "ml.p4d.24xlarge", + ], + "model_kwargs": {}, "deploy_kwargs": { - "model_data_download_timeout": 3600, - "container_startup_health_check_timeout": 3600, + "model_data_download_timeout": 1200, + "container_startup_health_check_timeout": 1200, + }, + "estimator_kwargs": { + "encrypt_inter_container_traffic": True, + "disable_output_compression": True, + "max_run": 360000, }, + "fit_kwargs": {}, "predictor_specs": { "supported_content_types": ["application/json"], "supported_accept_types": ["application/json"], @@ -1539,103 +1808,769 @@ "default_accept_type": "application/json", }, "inference_volume_size": 512, + "training_volume_size": 512, "inference_enable_network_isolation": True, - "validation_supported": False, - "fine_tuning_supported": False, - "resource_name_base": "hf-llm-falcon-180b-bf16", - "training_instance_type_variants": { - "regional_aliases": { - "us-west-2": { - "gpu_image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" - "huggingface-pytorch-inference:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04", - "gpu_image_uri_2": "763104351884.dkr.ecr.us-west-2.amazonaws.com/stud-gpu", - "cpu_image_uri": "867930986793.dkr.us-west-2.amazonaws.com/cpu-blah", - } - }, - "variants": { - "ml.p2.12xlarge": { - "properties": { - "environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"}, - "supported_inference_instance_types": ["ml.p5.xlarge"], - "default_inference_instance_type": "ml.p5.xlarge", - "metrics": [ - { - "Name": "huggingface-textgeneration:eval-loss", - "Regex": "'eval_loss': ([0-9]+\\.[0-9]+)", - }, - { - "Name": "huggingface-textgeneration:instance-typemetric-loss", - "Regex": "'eval_loss': ([0-9]+\\.[0-9]+)", - }, - { - "Name": "huggingface-textgeneration:train-loss", - "Regex": "'instance type specific': ([0-9]+\\.[0-9]+)", - }, - { - "Name": "huggingface-textgeneration:noneyourbusiness-loss", - "Regex": "'loss-noyb instance specific': ([0-9]+\\.[0-9]+)", - }, - ], - } + "training_enable_network_isolation": True, + "default_training_dataset_key": "training-datasets/oasst_top/train/", + "validation_supported": True, + "fine_tuning_supported": True, + "resource_name_base": "hf-llm-gemma-2b-instruct", + "default_payloads": { + "HelloWorld": { + "content_type": "application/json", + "prompt_key": "inputs", + "output_keys": { + "generated_text": "[0].generated_text", + "input_logprobs": "[0].details.prefill[*].logprob", }, - "p2": { - "regional_properties": {"image_uri": "$gpu_image_uri"}, - "properties": { - "supported_inference_instance_types": ["ml.p2.xlarge", "ml.p3.xlarge"], - "default_inference_instance_type": "ml.p2.xlarge", - "metrics": [ - { - "Name": "huggingface-textgeneration:wtafigo", - "Regex": "'evasadfasdl_loss': ([0-9]+\\.[0-9]+)", - }, - { - "Name": "huggingface-textgeneration:eval-loss", - "Regex": "'eval_loss': ([0-9]+\\.[0-9]+)", - }, - { - "Name": "huggingface-textgeneration:train-loss", - "Regex": "'instance family specific': ([0-9]+\\.[0-9]+)", - }, - { - "Name": "huggingface-textgeneration:noneyourbusiness-loss", - "Regex": "'loss-noyb': ([0-9]+\\.[0-9]+)", - }, - ], + "body": { + "inputs": ( + "user\nWrite a hello world program\nmodel" + ), + "parameters": { + "max_new_tokens": 256, + "decoder_input_details": True, + "details": True, }, }, - "p3": {"regional_properties": {"image_uri": "$gpu_image_uri"}}, - "ml.p3.200xlarge": {"regional_properties": {"image_uri": "$gpu_image_uri_2"}}, - "p4": { - "regional_properties": {"image_uri": "$gpu_image_uri"}, - "properties": { - "prepacked_artifact_key": "path/to/prepacked/inference/artifact/prefix/number2/" - }, + }, + "MachineLearningPoem": { + "content_type": "application/json", + "prompt_key": "inputs", + "output_keys": { + "generated_text": "[0].generated_text", + "input_logprobs": "[0].details.prefill[*].logprob", }, - "g4": { - "regional_properties": {"image_uri": "$gpu_image_uri"}, - "properties": { - "artifact_key": "path/to/prepacked/training/artifact/prefix/number2/" + "body": { + "inputs": "Write me a poem about Machine Learning.", + "parameters": { + "max_new_tokens": 256, + "decoder_input_details": True, + "details": True, }, }, - "g4dn": {"regional_properties": {"image_uri": "$gpu_image_uri"}}, - "g9": { - "regional_properties": {"image_uri": "$gpu_image_uri"}, - "properties": { - "prepacked_artifact_key": "asfs/adsf/sda/f", - "hyperparameters": [ - { - "name": "num_bag_sets", - "type": "int", - "default": 5, - "min": 5, - "scope": "algorithm", - }, - { - "name": "num_stack_levels", - "type": "int", - "default": 6, - "min": 7, - "max": 3, + }, + }, + "gated_bucket": True, + "hosting_instance_type_variants": { + "regional_aliases": { + "af-south-1": { + "gpu_ecr_uri_1": ( + "626614931356.dkr.ecr.af-south-1.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) + }, + "ap-east-1": { + "gpu_ecr_uri_1": ( + "871362719292.dkr.ecr.ap-east-1.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) + }, + "ap-northeast-1": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) + }, + "ap-northeast-2": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) + }, + "ap-south-1": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.ap-south-1.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) + }, + "ap-southeast-1": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) + }, + "ap-southeast-2": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) + }, + "ca-central-1": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.ca-central-1.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) + }, + "cn-north-1": { + "gpu_ecr_uri_1": ( + "727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) + }, + "eu-central-1": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.eu-central-1.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) + }, + "eu-north-1": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.eu-north-1.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) + }, + "eu-south-1": { + "gpu_ecr_uri_1": ( + "692866216735.dkr.ecr.eu-south-1.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) + }, + "eu-west-1": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.eu-west-1.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) + }, + "eu-west-2": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.eu-west-2.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) + }, + "eu-west-3": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.eu-west-3.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) + }, + "il-central-1": { + "gpu_ecr_uri_1": ( + "780543022126.dkr.ecr.il-central-1.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) + }, + "me-south-1": { + "gpu_ecr_uri_1": ( + "217643126080.dkr.ecr.me-south-1.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) + }, + "sa-east-1": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.sa-east-1.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) + }, + "us-east-1": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.us-east-1.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) + }, + "us-east-2": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.us-east-2.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) + }, + "us-west-1": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.us-west-1.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) + }, + "us-west-2": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) + }, + }, + "variants": { + "g4dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "g5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p2": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4d": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "ml.g4dn.12xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "4"}}}, + "ml.g5.12xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "4"}}}, + "ml.g5.24xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "4"}}}, + "ml.g5.48xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "8"}}}, + "ml.p4d.24xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "8"}}}, + }, + }, + "training_instance_type_variants": { + "regional_aliases": { + "af-south-1": { + "gpu_ecr_uri_1": ( + "626614931356.dkr.ecr.af-south-1.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) + }, + "ap-east-1": { + "gpu_ecr_uri_1": ( + "871362719292.dkr.ecr.ap-east-1.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) + }, + "ap-northeast-1": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.ap-northeast-1.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) + }, + "ap-northeast-2": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.ap-northeast-2.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) + }, + "ap-south-1": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.ap-south-1.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) + }, + "ap-southeast-1": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.ap-southeast-1.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) + }, + "ap-southeast-2": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.ap-southeast-2.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) + }, + "ca-central-1": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.ca-central-1.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) + }, + "cn-north-1": { + "gpu_ecr_uri_1": ( + "727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) + }, + "eu-central-1": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.eu-central-1.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) + }, + "eu-north-1": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.eu-north-1.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) + }, + "eu-south-1": { + "gpu_ecr_uri_1": ( + "692866216735.dkr.ecr.eu-south-1.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) + }, + "eu-west-1": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.eu-west-1.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) + }, + "eu-west-2": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.eu-west-2.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) + }, + "eu-west-3": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.eu-west-3.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) + }, + "il-central-1": { + "gpu_ecr_uri_1": ( + "780543022126.dkr.ecr.il-central-1.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) + }, + "me-south-1": { + "gpu_ecr_uri_1": ( + "217643126080.dkr.ecr.me-south-1.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) + }, + "sa-east-1": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.sa-east-1.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) + }, + "us-east-1": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.us-east-1.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) + }, + "us-east-2": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.us-east-2.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) + }, + "us-west-1": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.us-west-1.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) + }, + "us-west-2": { + "gpu_ecr_uri_1": ( + "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + ) + }, + }, + "variants": { + "g4dn": { + "regional_properties": {"image_uri": "$gpu_ecr_uri_1"}, + "properties": { + "gated_model_key_env_var_value": ( + "huggingface-training/g4dn/v1.0.0/train-huggingface-llm-gemma-2b-instruct.tar.gz" + ) + }, + }, + "g5": { + "regional_properties": {"image_uri": "$gpu_ecr_uri_1"}, + "properties": { + "gated_model_key_env_var_value": ( + "huggingface-training/g5/v1.0.0/train-huggingface-llm-gemma-2b-instruct.tar.gz" + ) + }, + }, + "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p2": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3dn": { + "regional_properties": {"image_uri": "$gpu_ecr_uri_1"}, + "properties": { + "gated_model_key_env_var_value": ( + "huggingface-training/p3dn/v1.0.0/train-huggingface-llm-gemma-2b-instruct.tar.gz" + ) + }, + }, + "p4d": { + "regional_properties": {"image_uri": "$gpu_ecr_uri_1"}, + "properties": { + "gated_model_key_env_var_value": ( + "huggingface-training/p4d/v1.0.0/train-huggingface-llm-gemma-2b-instruct.tar.gz" + ) + }, + }, + "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + }, + }, + "hosting_artifact_s3_data_type": "S3Prefix", + "hosting_artifact_compression_type": "None", + "hosting_resource_requirements": {"min_memory_mb": 8192, "num_accelerators": 1}, + "dynamic_container_deployment_supported": True, + }, + # noqa: E501 + "env-var-variant-model": { + "model_id": "huggingface-llm-falcon-180b-bf16", + "url": "https://huggingface.co/tiiuae/falcon-180B", + "version": "1.0.0", + "min_sdk_version": "2.175.0", + "training_supported": False, + "incremental_training_supported": False, + "hosting_ecr_specs": { + "framework": "huggingface-llm", + "framework_version": "0.9.3", + "py_version": "py39", + "huggingface_transformers_version": "4.29.2", + }, + "hosting_artifact_key": "huggingface-infer/infer-huggingface-llm-falcon-180b-bf16.tar.gz", + "hosting_script_key": "source-directory-tarballs/huggingface/inference/llm/v1.0.1/sourcedir.tar.gz", + "hosting_prepacked_artifact_key": "huggingface-infer/prepack/v1.0.1/infer-prepack" + "-huggingface-llm-falcon-180b-bf16.tar.gz", + "hosting_prepacked_artifact_version": "1.0.1", + "hosting_use_script_uri": False, + "inference_vulnerable": False, + "inference_dependencies": [], + "inference_vulnerabilities": [], + "training_vulnerable": False, + "training_dependencies": [], + "training_vulnerabilities": [], + "deprecated": False, + "inference_environment_variables": [ + { + "name": "SAGEMAKER_PROGRAM", + "type": "text", + "default": "inference.py", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_SUBMIT_DIRECTORY", + "type": "text", + "default": "/opt/ml/model/code", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", + "type": "text", + "default": "20", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "type": "text", + "default": "3600", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "ENDPOINT_SERVER_TIMEOUT", + "type": "int", + "default": 3600, + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MODEL_CACHE_ROOT", + "type": "text", + "default": "/opt/ml/model", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_ENV", + "type": "text", + "default": "1", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "HF_MODEL_ID", + "type": "text", + "default": "/opt/ml/model", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SM_NUM_GPUS", + "type": "text", + "default": "8", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MAX_INPUT_LENGTH", + "type": "text", + "default": "1024", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MAX_TOTAL_TOKENS", + "type": "text", + "default": "2048", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "type": "int", + "default": 1, + "scope": "container", + "required_for_model_class": True, + }, + ], + "metrics": [], + "default_inference_instance_type": "ml.p4de.24xlarge", + "supported_inference_instance_types": ["ml.p4de.24xlarge"], + "model_kwargs": {}, + "deploy_kwargs": { + "model_data_download_timeout": 3600, + "container_startup_health_check_timeout": 3600, + }, + "predictor_specs": { + "supported_content_types": ["application/json"], + "supported_accept_types": ["application/json"], + "default_content_type": "application/json", + "default_accept_type": "application/json", + }, + "inference_volume_size": 512, + "inference_enable_network_isolation": True, + "validation_supported": False, + "fine_tuning_supported": False, + "resource_name_base": "hf-llm-falcon-180b-bf16", + "hosting_instance_type_variants": { + "regional_aliases": { + "us-west-2": { + "gpu_image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "huggingface-pytorch-inference:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04", + "cpu_image_uri": "867930986793.dkr.us-west-2.amazonaws.com/cpu-blah", + } + }, + "variants": { + "g4dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "g5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p2": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4d": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "ml.g5.48xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "80"}}}, + "ml.p4d.24xlarge": { + "properties": { + "environment_variables": { + "YODEL": "NACEREMA", + } + } + }, + }, + }, + }, + "inference-instance-types-variant-model": { + "model_id": "huggingface-llm-falcon-180b-bf16", + "url": "https://huggingface.co/tiiuae/falcon-180B", + "version": "1.0.0", + "min_sdk_version": "2.175.0", + "training_supported": True, + "incremental_training_supported": False, + "hosting_ecr_specs": { + "framework": "huggingface-llm", + "framework_version": "0.9.3", + "py_version": "py39", + "huggingface_transformers_version": "4.29.2", + }, + "hosting_artifact_key": "huggingface-infer/infer-huggingface-llm-falcon-180b-bf16.tar.gz", + "hosting_script_key": "source-directory-tarballs/huggingface/inference/llm/v1.0.1/sourcedir.tar.gz", + "hosting_prepacked_artifact_key": "huggingface-infer/prepack/v1.0.1/infer-prepack" + "-huggingface-llm-falcon-180b-bf16.tar.gz", + "hosting_prepacked_artifact_version": "1.0.1", + "hosting_use_script_uri": False, + "inference_vulnerable": False, + "inference_dependencies": [], + "inference_vulnerabilities": [], + "training_vulnerable": False, + "training_dependencies": [], + "training_vulnerabilities": [], + "deprecated": False, + "inference_environment_variables": [ + { + "name": "SAGEMAKER_PROGRAM", + "type": "text", + "default": "inference.py", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_SUBMIT_DIRECTORY", + "type": "text", + "default": "/opt/ml/model/code", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", + "type": "text", + "default": "20", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "type": "text", + "default": "3600", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "ENDPOINT_SERVER_TIMEOUT", + "type": "int", + "default": 3600, + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MODEL_CACHE_ROOT", + "type": "text", + "default": "/opt/ml/model", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_ENV", + "type": "text", + "default": "1", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "HF_MODEL_ID", + "type": "text", + "default": "/opt/ml/model", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SM_NUM_GPUS", + "type": "text", + "default": "8", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MAX_INPUT_LENGTH", + "type": "text", + "default": "1024", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MAX_TOTAL_TOKENS", + "type": "text", + "default": "2048", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "type": "int", + "default": 1, + "scope": "container", + "required_for_model_class": True, + }, + ], + "metrics": [], + "default_inference_instance_type": "ml.p4de.24xlarge", + "supported_inference_instance_types": ["ml.p4de.24xlarge"], + "default_training_instance_type": "ml.p4de.24xlarge", + "supported_training_instance_types": ["ml.p4de.24xlarge"], + "model_kwargs": {}, + "deploy_kwargs": { + "model_data_download_timeout": 3600, + "container_startup_health_check_timeout": 3600, + }, + "predictor_specs": { + "supported_content_types": ["application/json"], + "supported_accept_types": ["application/json"], + "default_content_type": "application/json", + "default_accept_type": "application/json", + }, + "inference_volume_size": 512, + "inference_enable_network_isolation": True, + "validation_supported": False, + "fine_tuning_supported": False, + "resource_name_base": "hf-llm-falcon-180b-bf16", + "training_instance_type_variants": { + "regional_aliases": { + "us-west-2": { + "gpu_image_uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "huggingface-pytorch-inference:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04", + "gpu_image_uri_2": "763104351884.dkr.ecr.us-west-2.amazonaws.com/stud-gpu", + "cpu_image_uri": "867930986793.dkr.us-west-2.amazonaws.com/cpu-blah", + } + }, + "variants": { + "ml.p2.12xlarge": { + "properties": { + "environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"}, + "supported_inference_instance_types": ["ml.p5.xlarge"], + "default_inference_instance_type": "ml.p5.xlarge", + "metrics": [ + { + "Name": "huggingface-textgeneration:eval-loss", + "Regex": "'eval_loss': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:instance-typemetric-loss", + "Regex": "'eval_loss': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:train-loss", + "Regex": "'instance type specific': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:noneyourbusiness-loss", + "Regex": "'loss-noyb instance specific': ([0-9]+\\.[0-9]+)", + }, + ], + } + }, + "p2": { + "regional_properties": {"image_uri": "$gpu_image_uri"}, + "properties": { + "supported_inference_instance_types": ["ml.p2.xlarge", "ml.p3.xlarge"], + "default_inference_instance_type": "ml.p2.xlarge", + "metrics": [ + { + "Name": "huggingface-textgeneration:wtafigo", + "Regex": "'evasadfasdl_loss': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:eval-loss", + "Regex": "'eval_loss': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:train-loss", + "Regex": "'instance family specific': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:noneyourbusiness-loss", + "Regex": "'loss-noyb': ([0-9]+\\.[0-9]+)", + }, + ], + }, + }, + "p3": {"regional_properties": {"image_uri": "$gpu_image_uri"}}, + "ml.p3.200xlarge": {"regional_properties": {"image_uri": "$gpu_image_uri_2"}}, + "p4": { + "regional_properties": {"image_uri": "$gpu_image_uri"}, + "properties": { + "prepacked_artifact_key": "path/to/prepacked/inference/artifact/prefix/number2/" + }, + }, + "g4": { + "regional_properties": {"image_uri": "$gpu_image_uri"}, + "properties": { + "artifact_key": "path/to/prepacked/training/artifact/prefix/number2/" + }, + }, + "g4dn": {"regional_properties": {"image_uri": "$gpu_image_uri"}}, + "g9": { + "regional_properties": {"image_uri": "$gpu_image_uri"}, + "properties": { + "prepacked_artifact_key": "asfs/adsf/sda/f", + "hyperparameters": [ + { + "name": "num_bag_sets", + "type": "int", + "default": 5, + "min": 5, + "scope": "algorithm", + }, + { + "name": "num_stack_levels", + "type": "int", + "default": 6, + "min": 7, + "max": 3, "scope": "algorithm", }, { @@ -1852,6 +2787,7 @@ "inference_enable_network_isolation": True, "training_enable_network_isolation": False, }, + # noqa: E501 "variant-model": { "model_id": "pytorch-ic-mobilenet-v2", "url": "https://pytorch.org/hub/pytorch_vision_mobilenet_v2/", @@ -7349,6 +8285,7 @@ }, "training_instance_type_variants": None, "hosting_artifact_key": "pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz", + "hosting_artifact_uri": None, "training_artifact_key": "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz", "hosting_script_key": "source-directory-tarballs/pytorch/inference/ic/v1.0.0/sourcedir.tar.gz", "training_script_key": "source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz", @@ -7657,253 +8594,1248 @@ }, } - -INFERENCE_CONFIGS = { - "inference_configs": { - "neuron-inference": { - "benchmark_metrics": { - "ml.inf2.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}] + +INFERENCE_CONFIGS = { + "inference_configs": { + "neuron-inference": { + "benchmark_metrics": { + "ml.inf2.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}] + }, + "component_names": ["neuron-inference"], + }, + "neuron-inference-budget": { + "benchmark_metrics": { + "ml.inf2.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}] + }, + "component_names": ["neuron-base"], + }, + "gpu-inference-budget": { + "benchmark_metrics": { + "ml.p3.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}] + }, + "component_names": ["gpu-inference-budget"], + }, + "gpu-inference": { + "benchmark_metrics": { + "ml.p3.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}] + }, + "component_names": ["gpu-inference"], + }, + }, + "inference_config_components": { + "neuron-base": { + "supported_inference_instance_types": ["ml.inf2.xlarge", "ml.inf2.2xlarge"] + }, + "neuron-inference": { + "default_inference_instance_type": "ml.inf2.xlarge", + "supported_inference_instance_types": ["ml.inf2.xlarge", "ml.inf2.2xlarge"], + "hosting_ecr_specs": { + "framework": "huggingface-llm-neuronx", + "framework_version": "0.0.17", + "py_version": "py310", + }, + "hosting_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/neuron-inference/model/", + "hosting_instance_type_variants": { + "regional_aliases": { + "us-west-2": { + "neuron-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "huggingface-pytorch-hosting:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + } + }, + "variants": {"inf2": {"regional_properties": {"image_uri": "$neuron-ecr-uri"}}}, + }, + }, + "neuron-budget": {"inference_environment_variables": {"BUDGET": "1234"}}, + "gpu-inference": { + "supported_inference_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"], + "hosting_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/gpu-inference/model/", + "hosting_instance_type_variants": { + "regional_aliases": { + "us-west-2": { + "gpu-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "pytorch-hosting-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04" + } + }, + "variants": { + "p3": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, + "p2": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, + }, + }, + }, + "gpu-inference-budget": { + "supported_inference_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"], + "hosting_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/gpu-inference-budget/model/", + "hosting_instance_type_variants": { + "regional_aliases": { + "us-west-2": { + "gpu-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "pytorch-hosting-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04" + } + }, + "variants": { + "p2": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, + "p3": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, + }, + }, + }, + }, +} + +TRAINING_CONFIGS = { + "training_configs": { + "neuron-training": { + "benchmark_metrics": { + "ml.tr1n1.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}], + "ml.tr1n1.4xlarge": [{"name": "Latency", "value": "50", "unit": "Tokens/S"}], + }, + "component_names": ["neuron-training"], + }, + "neuron-training-budget": { + "benchmark_metrics": { + "ml.tr1n1.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}], + "ml.tr1n1.4xlarge": [{"name": "Latency", "value": "50", "unit": "Tokens/S"}], + }, + "component_names": ["neuron-training-budget"], + }, + "gpu-training": { + "benchmark_metrics": { + "ml.p3.2xlarge": [{"name": "Latency", "value": "200", "unit": "Tokens/S"}], + }, + "component_names": ["gpu-training"], + }, + "gpu-training-budget": { + "benchmark_metrics": { + "ml.p3.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}] + }, + "component_names": ["gpu-training-budget"], + }, + }, + "training_config_components": { + "neuron-training": { + "supported_training_instance_types": ["ml.trn1.xlarge", "ml.trn1.2xlarge"], + "training_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/neuron-training/model/", + "training_instance_type_variants": { + "regional_aliases": { + "us-west-2": { + "neuron-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "pytorch-training-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04" + } + }, + "variants": {"trn1": {"regional_properties": {"image_uri": "$neuron-ecr-uri"}}}, + }, + }, + "gpu-training": { + "supported_training_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"], + "training_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/gpu-training/model/", + "training_instance_type_variants": { + "regional_aliases": { + "us-west-2": { + "gpu-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "pytorch-hosting-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04" + } + }, + "variants": { + "p2": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, + "p3": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, + }, + }, + }, + "neuron-training-budget": { + "supported_training_instance_types": ["ml.trn1.xlarge", "ml.trn1.2xlarge"], + "training_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/neuron-training-budget/model/", + "training_instance_type_variants": { + "regional_aliases": { + "us-west-2": { + "neuron-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "pytorch-training-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04" + } + }, + "variants": {"trn1": {"regional_properties": {"image_uri": "$neuron-ecr-uri"}}}, + }, + }, + "gpu-training-budget": { + "supported_training_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"], + "training_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/gpu-training-budget/model/", + "training_instance_type_variants": { + "regional_aliases": { + "us-west-2": { + "gpu-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" + "pytorch-hosting-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04" + } + }, + "variants": { + "p2": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, + "p3": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, + }, + }, + }, + }, +} + + +INFERENCE_CONFIG_RANKINGS = { + "inference_config_rankings": { + "overall": { + "description": "Overall rankings of configs", + "rankings": [ + "neuron-inference", + "neuron-inference-budget", + "gpu-inference", + "gpu-inference-budget", + ], + }, + "performance": { + "description": "Configs ranked based on performance", + "rankings": [ + "neuron-inference", + "gpu-inference", + "neuron-inference-budget", + "gpu-inference-budget", + ], + }, + "cost": { + "description": "Configs ranked based on cost", + "rankings": [ + "neuron-inference-budget", + "gpu-inference-budget", + "neuron-inference", + "gpu-inference", + ], + }, + } +} + +TRAINING_CONFIG_RANKINGS = { + "training_config_rankings": { + "overall": { + "description": "Overall rankings of configs", + "rankings": [ + "neuron-training", + "neuron-training-budget", + "gpu-training", + "gpu-training-budget", + ], + }, + "performance_training": { + "description": "Configs ranked based on performance", + "rankings": [ + "neuron-training", + "gpu-training", + "neuron-training-budget", + "gpu-training-budget", + ], + "instance_type_overrides": { + "ml.p2.xlarge": [ + "neuron-training", + "neuron-training-budget", + "gpu-training", + "gpu-training-budget", + ] + }, + }, + "cost_training": { + "description": "Configs ranked based on cost", + "rankings": [ + "neuron-training-budget", + "gpu-training-budget", + "neuron-training", + "gpu-training", + ], + }, + } +} + +HUB_MODEL_DOCUMENT_DICTS = { + "huggingface-llm-gemma-2b-instruct": { + "Url": "https://huggingface.co/google/gemma-2b-it", + "MinSdkVersion": "2.189.0", + "TrainingSupported": True, + "IncrementalTrainingSupported": False, + "HostingEcrUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04", # noqa: E501 + "HostingArtifactS3DataType": "S3Prefix", + "HostingArtifactCompressionType": "None", + "HostingArtifactUri": "s3://jumpstart-cache-prod-us-west-2/huggingface-llm/huggingface-llm-gemma-2b-instruct/artifacts/inference/v1.0.0/", # noqa: E501 + "HostingScriptUri": "s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/huggingface/inference/llm/v1.0.1/sourcedir.tar.gz", # noqa: E501 + "HostingPrepackedArtifactUri": "s3://jumpstart-cache-prod-us-west-2/huggingface-llm/huggingface-llm-gemma-2b-instruct/artifacts/inference-prepack/v1.0.0/", # noqa: E501 + "HostingPrepackedArtifactVersion": "1.0.0", + "HostingUseScriptUri": False, + "HostingEulaUri": "s3://jumpstart-cache-prod-us-west-2/huggingface-llm/fmhMetadata/terms/gemmaTerms.txt", + "TrainingScriptUri": "s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/huggingface/transfer_learning/llm/v1.1.1/sourcedir.tar.gz", # noqa: E501 + "TrainingPrepackedScriptUri": "s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/huggingface/transfer_learning/llm/prepack/v1.1.1/sourcedir.tar.gz", # noqa: E501 + "TrainingPrepackedScriptVersion": "1.1.1", + "TrainingEcrUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04", # noqa: E501 + "TrainingArtifactS3DataType": "S3Prefix", + "TrainingArtifactCompressionType": "None", + "TrainingArtifactUri": "s3://jumpstart-cache-prod-us-west-2/huggingface-training/train-huggingface-llm-gemma-2b-instruct.tar.gz", # noqa: E501 + "Hyperparameters": [ + { + "Name": "peft_type", + "Type": "text", + "Default": "lora", + "Options": ["lora", "None"], + "Scope": "algorithm", + }, + { + "Name": "instruction_tuned", + "Type": "text", + "Default": "False", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + { + "Name": "chat_dataset", + "Type": "text", + "Default": "True", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + { + "Name": "epoch", + "Type": "int", + "Default": 1, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", + }, + { + "Name": "learning_rate", + "Type": "float", + "Default": 0.0001, + "Min": 1e-08, + "Max": 1, + "Scope": "algorithm", + }, + { + "Name": "lora_r", + "Type": "int", + "Default": 64, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", + }, + {"Name": "lora_alpha", "Type": "int", "Default": 16, "Min": 0, "Scope": "algorithm"}, + { + "Name": "lora_dropout", + "Type": "float", + "Default": 0, + "Min": 0, + "Max": 1, + "Scope": "algorithm", + }, + {"Name": "bits", "Type": "int", "Default": 4, "Scope": "algorithm"}, + { + "Name": "double_quant", + "Type": "text", + "Default": "True", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + { + "Name": "quant_Type", + "Type": "text", + "Default": "nf4", + "Options": ["fp4", "nf4"], + "Scope": "algorithm", + }, + { + "Name": "per_device_train_batch_size", + "Type": "int", + "Default": 1, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", + }, + { + "Name": "per_device_eval_batch_size", + "Type": "int", + "Default": 2, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", + }, + { + "Name": "warmup_ratio", + "Type": "float", + "Default": 0.1, + "Min": 0, + "Max": 1, + "Scope": "algorithm", }, - "component_names": ["neuron-inference"], - }, - "neuron-inference-budget": { - "benchmark_metrics": { - "ml.inf2.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}] + { + "Name": "train_from_scratch", + "Type": "text", + "Default": "False", + "Options": ["True", "False"], + "Scope": "algorithm", }, - "component_names": ["neuron-base"], - }, - "gpu-inference-budget": { - "benchmark_metrics": { - "ml.p3.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}] + { + "Name": "fp16", + "Type": "text", + "Default": "True", + "Options": ["True", "False"], + "Scope": "algorithm", }, - "component_names": ["gpu-inference-budget"], - }, - "gpu-inference": { - "benchmark_metrics": { - "ml.p3.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}] + { + "Name": "bf16", + "Type": "text", + "Default": "False", + "Options": ["True", "False"], + "Scope": "algorithm", }, - "component_names": ["gpu-inference"], - }, - }, - "inference_config_components": { - "neuron-base": { - "supported_inference_instance_types": ["ml.inf2.xlarge", "ml.inf2.2xlarge"] - }, - "neuron-inference": { - "default_inference_instance_type": "ml.inf2.xlarge", - "supported_inference_instance_types": ["ml.inf2.xlarge", "ml.inf2.2xlarge"], - "hosting_ecr_specs": { - "framework": "huggingface-llm-neuronx", - "framework_version": "0.0.17", - "py_version": "py310", + { + "Name": "evaluation_strategy", + "Type": "text", + "Default": "steps", + "Options": ["steps", "epoch", "no"], + "Scope": "algorithm", }, - "hosting_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/neuron-inference/model/", - "hosting_instance_type_variants": { - "regional_aliases": { - "us-west-2": { - "neuron-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" - "huggingface-pytorch-hosting:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" - } + { + "Name": "eval_steps", + "Type": "int", + "Default": 20, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", + }, + { + "Name": "gradient_accumulation_steps", + "Type": "int", + "Default": 4, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", + }, + { + "Name": "logging_steps", + "Type": "int", + "Default": 8, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", + }, + { + "Name": "weight_decay", + "Type": "float", + "Default": 0.2, + "Min": 1e-08, + "Max": 1, + "Scope": "algorithm", + }, + { + "Name": "load_best_model_at_end", + "Type": "text", + "Default": "True", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + { + "Name": "max_train_samples", + "Type": "int", + "Default": -1, + "Min": -1, + "Scope": "algorithm", + }, + { + "Name": "max_val_samples", + "Type": "int", + "Default": -1, + "Min": -1, + "Scope": "algorithm", + }, + { + "Name": "seed", + "Type": "int", + "Default": 10, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", + }, + { + "Name": "max_input_length", + "Type": "int", + "Default": 1024, + "Min": -1, + "Scope": "algorithm", + }, + { + "Name": "validation_split_ratio", + "Type": "float", + "Default": 0.2, + "Min": 0, + "Max": 1, + "Scope": "algorithm", + }, + { + "Name": "train_data_split_seed", + "Type": "int", + "Default": 0, + "Min": 0, + "Scope": "algorithm", + }, + { + "Name": "preprocessing_num_workers", + "Type": "text", + "Default": "None", + "Scope": "algorithm", + }, + {"Name": "max_steps", "Type": "int", "Default": -1, "Scope": "algorithm"}, + { + "Name": "gradient_checkpointing", + "Type": "text", + "Default": "False", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + { + "Name": "early_stopping_patience", + "Type": "int", + "Default": 3, + "Min": 1, + "Scope": "algorithm", + }, + { + "Name": "early_stopping_threshold", + "Type": "float", + "Default": 0.0, + "Min": 0, + "Scope": "algorithm", + }, + { + "Name": "adam_beta1", + "Type": "float", + "Default": 0.9, + "Min": 0, + "Max": 1, + "Scope": "algorithm", + }, + { + "Name": "adam_beta2", + "Type": "float", + "Default": 0.999, + "Min": 0, + "Max": 1, + "Scope": "algorithm", + }, + { + "Name": "adam_epsilon", + "Type": "float", + "Default": 1e-08, + "Min": 0, + "Max": 1, + "Scope": "algorithm", + }, + { + "Name": "max_grad_norm", + "Type": "float", + "Default": 1.0, + "Min": 0, + "Scope": "algorithm", + }, + { + "Name": "label_smoothing_factor", + "Type": "float", + "Default": 0, + "Min": 0, + "Max": 1, + "Scope": "algorithm", + }, + { + "Name": "logging_first_step", + "Type": "text", + "Default": "False", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + { + "Name": "logging_nan_inf_filter", + "Type": "text", + "Default": "True", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + { + "Name": "save_strategy", + "Type": "text", + "Default": "steps", + "Options": ["no", "epoch", "steps"], + "Scope": "algorithm", + }, + { + "Name": "save_steps", + "Type": "int", + "Default": 500, + "Min": 1, + "Scope": "algorithm", + }, # noqa: E501 + { + "Name": "save_total_limit", + "Type": "int", + "Default": 1, + "Scope": "algorithm", + }, # noqa: E501 + { + "Name": "dataloader_drop_last", + "Type": "text", + "Default": "False", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + { + "Name": "dataloader_num_workers", + "Type": "int", + "Default": 0, + "Min": 0, + "Scope": "algorithm", + }, + { + "Name": "eval_accumulation_steps", + "Type": "text", + "Default": "None", + "Scope": "algorithm", + }, + { + "Name": "auto_find_batch_size", + "Type": "text", + "Default": "False", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + { + "Name": "lr_scheduler_type", + "Type": "text", + "Default": "constant_with_warmup", + "Options": ["constant_with_warmup", "linear"], + "Scope": "algorithm", + }, + { + "Name": "warmup_steps", + "Type": "int", + "Default": 0, + "Min": 0, + "Scope": "algorithm", + }, # noqa: E501 + { + "Name": "deepspeed", + "Type": "text", + "Default": "False", + "Options": ["False"], + "Scope": "algorithm", + }, + { + "Name": "sagemaker_submit_directory", + "Type": "text", + "Default": "/opt/ml/input/data/code/sourcedir.tar.gz", + "Scope": "container", + }, + { + "Name": "sagemaker_program", + "Type": "text", + "Default": "transfer_learning.py", + "Scope": "container", + }, + { + "Name": "sagemaker_container_log_level", + "Type": "text", + "Default": "20", + "Scope": "container", + }, + ], + "InferenceEnvironmentVariables": [ + { + "Name": "SAGEMAKER_PROGRAM", + "Type": "text", + "Default": "inference.py", + "Scope": "container", + "RequiredForModelClass": True, + }, + { + "Name": "SAGEMAKER_SUBMIT_DIRECTORY", + "Type": "text", + "Default": "/opt/ml/model/code", + "Scope": "container", + "RequiredForModelClass": False, + }, + { + "Name": "SAGEMAKER_CONTAINER_LOG_LEVEL", + "Type": "text", + "Default": "20", + "Scope": "container", + "RequiredForModelClass": False, + }, + { + "Name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "Type": "text", + "Default": "3600", + "Scope": "container", + "RequiredForModelClass": False, + }, + { + "Name": "ENDPOINT_SERVER_TIMEOUT", + "Type": "int", + "Default": 3600, + "Scope": "container", + "RequiredForModelClass": True, + }, + { + "Name": "MODEL_CACHE_ROOT", + "Type": "text", + "Default": "/opt/ml/model", + "Scope": "container", + "RequiredForModelClass": True, + }, + { + "Name": "SAGEMAKER_ENV", + "Type": "text", + "Default": "1", + "Scope": "container", + "RequiredForModelClass": True, + }, + { + "Name": "HF_MODEL_ID", + "Type": "text", + "Default": "/opt/ml/model", + "Scope": "container", + "RequiredForModelClass": True, + }, + { + "Name": "MAX_INPUT_LENGTH", + "Type": "text", + "Default": "8191", + "Scope": "container", + "RequiredForModelClass": True, + }, + { + "Name": "MAX_TOTAL_TOKENS", + "Type": "text", + "Default": "8192", + "Scope": "container", + "RequiredForModelClass": True, + }, + { + "Name": "MAX_BATCH_PREFILL_TOKENS", + "Type": "text", + "Default": "8191", + "Scope": "container", + "RequiredForModelClass": True, + }, + { + "Name": "SM_NUM_GPUS", + "Type": "text", + "Default": "1", + "Scope": "container", + "RequiredForModelClass": True, + }, + { + "Name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "Type": "int", + "Default": 1, + "Scope": "container", + "RequiredForModelClass": True, + }, + ], + "TrainingMetrics": [ + { + "Name": "huggingface-textgeneration:eval-loss", + "Regex": "'eval_loss': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:train-loss", + "Regex": "'loss': ([0-9]+\\.[0-9]+)", + }, # noqa: E501 + ], + "InferenceDependencies": [], + "TrainingDependencies": [ + "accelerate==0.26.1", + "bitsandbytes==0.42.0", + "deepspeed==0.10.3", + "docstring-parser==0.15", + "flash_attn==2.5.5", + "ninja==1.11.1", + "packaging==23.2", + "peft==0.8.2", + "py_cpuinfo==9.0.0", + "rich==13.7.0", + "safetensors==0.4.2", + "sagemaker_jumpstart_huggingface_script_utilities==1.2.1", + "sagemaker_jumpstart_script_utilities==1.1.9", + "sagemaker_jumpstart_tabular_script_utilities==1.0.0", + "shtab==1.6.5", + "tokenizers==0.15.1", + "transformers==4.38.1", + "trl==0.7.10", + "tyro==0.7.2", + ], + "DefaultInferenceInstanceType": "ml.g5.xlarge", + "SupportedInferenceInstanceTypes": [ + "ml.g5.xlarge", + "ml.g5.2xlarge", + "ml.g5.4xlarge", + "ml.g5.8xlarge", + "ml.g5.16xlarge", + "ml.g5.12xlarge", + "ml.g5.24xlarge", + "ml.g5.48xlarge", + "ml.p4d.24xlarge", + ], + "DefaultTrainingInstanceType": "ml.g5.2xlarge", + "SupportedTrainingInstanceTypes": [ + "ml.g5.2xlarge", + "ml.g5.4xlarge", + "ml.g5.8xlarge", + "ml.g5.16xlarge", + "ml.g5.12xlarge", + "ml.g5.24xlarge", + "ml.g5.48xlarge", + "ml.p4d.24xlarge", + ], + "SageMakerSdkPredictorSpecifications": { + "SupportedContentTypes": ["application/json"], + "SupportedAcceptTypes": ["application/json"], + "DefaultContentType": "application/json", + "DefaultAcceptType": "application/json", + }, + "InferenceVolumeSize": 512, + "TrainingVolumeSize": 512, + "InferenceEnableNetworkIsolation": True, + "TrainingEnableNetworkIsolation": True, + "FineTuningSupported": True, + "ValidationSupported": True, + "DefaultTrainingDatasetUri": "s3://jumpstart-cache-prod-us-west-2/training-datasets/oasst_top/train/", # noqa: E501 + "ResourceNameBase": "hf-llm-gemma-2b-instruct", + "DefaultPayloads": { + "HelloWorld": { + "ContentType": "application/json", + "PromptKey": "inputs", + "OutputKeys": { + "GeneratedText": "[0].generated_text", + "InputLogprobs": "[0].details.prefill[*].logprob", + }, + "Body": { + "Inputs": "user\nWrite a hello world program\nmodel", # noqa: E501 + "Parameters": { + "MaxNewTokens": 256, + "DecoderInputDetails": True, + "Details": True, + }, }, - "variants": {"inf2": {"regional_properties": {"image_uri": "$neuron-ecr-uri"}}}, }, - }, - "neuron-budget": {"inference_environment_variables": {"BUDGET": "1234"}}, - "gpu-inference": { - "supported_inference_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"], - "hosting_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/gpu-inference/model/", - "hosting_instance_type_variants": { - "regional_aliases": { - "us-west-2": { - "gpu-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" - "pytorch-hosting-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04" - } + "MachineLearningPoem": { + "ContentType": "application/json", + "PromptKey": "inputs", + "OutputKeys": { + "GeneratedText": "[0].generated_text", + "InputLogprobs": "[0].details.prefill[*].logprob", }, - "variants": { - "p3": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, - "p2": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, + "Body": { + "Inputs": "Write me a poem about Machine Learning.", + "Parameters": { + "MaxNewTokens": 256, + "DecoderInputDetails": True, + "Details": True, + }, }, }, }, - "gpu-inference-budget": { - "supported_inference_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"], - "hosting_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/gpu-inference-budget/model/", - "hosting_instance_type_variants": { - "regional_aliases": { - "us-west-2": { - "gpu-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" - "pytorch-hosting-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04" - } + "GatedBucket": True, + "HostingResourceRequirements": {"MinMemoryMb": 8192, "NumAccelerators": 1}, + "HostingInstanceTypeVariants": { + "Aliases": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" # noqa: E501 + }, + "Variants": { + "g4dn": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "g5": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "local_gpu": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p2": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3dn": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4d": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4de": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p5": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "ml.g5.12xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "4"}}}, + "ml.g5.24xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "4"}}}, + "ml.g5.48xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "8"}}}, + "ml.p4d.24xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "8"}}}, + }, + }, + "TrainingInstanceTypeVariants": { + "Aliases": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" # noqa: E501 + }, + "Variants": { + "g4dn": { + "properties": { + "image_uri": "$gpu_ecr_uri_1", + "gated_model_key_env_var_value": "huggingface-training/g4dn/v1.0.0/train-huggingface-llm-gemma-2b-instruct.tar.gz", # noqa: E501 + }, }, - "variants": { - "p2": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, - "p3": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, + "g5": { + "properties": { + "image_uri": "$gpu_ecr_uri_1", + "gated_model_key_env_var_value": "huggingface-training/g5/v1.0.0/train-huggingface-llm-gemma-2b-instruct.tar.gz", # noqa: E501 + }, + }, + "local_gpu": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p2": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3dn": { + "properties": { + "image_uri": "$gpu_ecr_uri_1", + "gated_model_key_env_var_value": "huggingface-training/p3dn/v1.0.0/train-huggingface-llm-gemma-2b-instruct.tar.gz", # noqa: E501 + }, + }, + "p4d": { + "properties": { + "image_uri": "$gpu_ecr_uri_1", + "gated_model_key_env_var_value": "huggingface-training/p4d/v1.0.0/train-huggingface-llm-gemma-2b-instruct.tar.gz", # noqa: E501 + }, }, + "p4de": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p5": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, }, }, + "ContextualHelp": { + "HubFormatTrainData": [ + "A train and an optional validation directories. Each directory contains a CSV/JSON/TXT. ", + "- For CSV/JSON files, the text data is used from the column called 'text' or the first column if no column called 'text' is found", # noqa: E501 + "- The number of files under train and validation (if provided) should equal to one, respectively.", + " [Learn how to setup an AWS S3 bucket.](https://docs.aws.amazon.com/AmazonS3/latest/dev/UsingBucket.html)", # noqa: E501 + ], + "HubDefaultTrainData": [ + "Dataset: [SEC](https://www.sec.gov/edgar/searchedgar/companysearch)", + "SEC filing contains regulatory documents that companies and issuers of securities must submit to the Securities and Exchange Commission (SEC) on a regular basis.", # noqa: E501 + "License: [CC BY-SA 4.0](https://creativecommons.org/licenses/by-sa/4.0/legalcode)", + ], + }, + "ModelDataDownloadTimeout": 1200, + "ContainerStartupHealthCheckTimeout": 1200, + "EncryptInterContainerTraffic": True, + "DisableOutputCompression": True, + "MaxRuntimeInSeconds": 360000, + "DynamicContainerDeploymentSupported": True, + "TrainingModelPackageArtifactUri": None, + "Dependencies": [], }, -} - -TRAINING_CONFIGS = { - "training_configs": { - "neuron-training": { - "benchmark_metrics": { - "ml.tr1n1.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}], - "ml.tr1n1.4xlarge": [{"name": "Latency", "value": "50", "unit": "Tokens/S"}], + "meta-textgeneration-llama-2-70b": { + "Url": "https://ai.meta.com/resources/models-and-libraries/llama-downloads/", + "MinSdkVersion": "2.198.0", + "TrainingSupported": True, + "IncrementalTrainingSupported": False, + "HostingEcrUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04", # noqa: E501 + "HostingArtifactUri": "s3://jumpstart-cache-prod-us-west-2/meta-textgeneration/meta-textgeneration-llama-2-70b/artifacts/inference/v1.0.0/", # noqa: E501 + "HostingScriptUri": "s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/meta/inference/textgeneration/v1.2.3/sourcedir.tar.gz", # noqa: E501 + "HostingPrepackedArtifactUri": "s3://jumpstart-cache-prod-us-west-2/meta-textgeneration/meta-textgeneration-llama-2-70b/artifacts/inference-prepack/v1.0.0/", # noqa: E501 + "HostingPrepackedArtifactVersion": "1.0.0", + "HostingUseScriptUri": False, + "HostingEulaUri": "s3://jumpstart-cache-prod-us-west-2/fmhMetadata/eula/llamaEula.txt", + "InferenceDependencies": [], + "TrainingDependencies": [ + "accelerate==0.21.0", + "bitsandbytes==0.39.1", + "black==23.7.0", + "brotli==1.0.9", + "datasets==2.14.1", + "fire==0.5.0", + "huggingface-hub==0.20.3", + "inflate64==0.3.1", + "loralib==0.1.1", + "multivolumefile==0.2.3", + "mypy-extensions==1.0.0", + "nvidia-cublas-cu12==12.1.3.1", + "nvidia-cuda-cupti-cu12==12.1.105", + "nvidia-cuda-nvrtc-cu12==12.1.105", + "nvidia-cuda-runtime-cu12==12.1.105", + "nvidia-cudnn-cu12==8.9.2.26", + "nvidia-cufft-cu12==11.0.2.54", + "nvidia-curand-cu12==10.3.2.106", + "nvidia-cusolver-cu12==11.4.5.107", + "nvidia-cusolver-cu12==11.4.5.107", + "nvidia-cusparse-cu12==12.1.0.106", + "nvidia-nccl-cu12==2.19.3", + "nvidia-nvjitlink-cu12==12.3.101", + "nvidia-nvtx-cu12==12.1.105", + "pathspec==0.11.1", + "peft==0.4.0", + "py7zr==0.20.5", + "pybcj==1.0.1", + "pycryptodomex==3.18.0", + "pyppmd==1.0.0", + "pyzstd==0.15.9", + "safetensors==0.3.1", + "sagemaker_jumpstart_huggingface_script_utilities==1.1.4", + "sagemaker_jumpstart_script_utilities==1.1.9", + "scipy==1.11.1", + "termcolor==2.3.0", + "texttable==1.6.7", + "tokenize-rt==5.1.0", + "tokenizers==0.13.3", + "torch==2.2.0", + "transformers==4.33.3", + "triton==2.2.0", + "typing-extensions==4.8.0", + ], + "Hyperparameters": [ + { + "Name": "epoch", + "Type": "int", + "Default": 5, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", }, - "component_names": ["neuron-training"], - }, - "neuron-training-budget": { - "benchmark_metrics": { - "ml.tr1n1.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}], - "ml.tr1n1.4xlarge": [{"name": "Latency", "value": "50", "unit": "Tokens/S"}], + { + "Name": "learning_rate", + "Type": "float", + "Default": 0.0001, + "Min": 1e-08, + "Max": 1, + "Scope": "algorithm", }, - "component_names": ["neuron-training-budget"], - }, - "gpu-training": { - "benchmark_metrics": { - "ml.p3.2xlarge": [{"name": "Latency", "value": "200", "unit": "Tokens/S"}], + { + "Name": "instruction_tuned", + "Type": "text", + "Default": "False", + "Options": ["True", "False"], + "Scope": "algorithm", }, - "component_names": ["gpu-training"], - }, - "gpu-training-budget": { - "benchmark_metrics": { - "ml.p3.2xlarge": [{"name": "Latency", "value": "100", "unit": "Tokens/S"}] + ], + "TrainingScriptUri": "s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/meta/transfer_learning/textgeneration/v1.0.11/sourcedir.tar.gz", # noqa: E501 + "TrainingPrepackedScriptUri": "s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/meta/transfer_learning/textgeneration/prepack/v1.0.5/sourcedir.tar.gz", # noqa: E501 + "TrainingPrepackedScriptVersion": "1.0.5", + "TrainingEcrUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04", # TODO: not a training image # noqa: E501 + "TrainingArtifactUri": "s3://jumpstart-cache-prod-us-west-2/meta-training/train-meta-textgeneration-llama-2-70b.tar.gz", # noqa: E501 + "InferenceEnvironmentVariables": [ + { + "name": "SAGEMAKER_PROGRAM", + "type": "text", + "default": "inference.py", + "scope": "container", + "required_for_model_class": True, }, - "component_names": ["gpu-training-budget"], - }, - }, - "training_config_components": { - "neuron-training": { - "supported_training_instance_types": ["ml.trn1.xlarge", "ml.trn1.2xlarge"], - "training_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/neuron-training/model/", - "training_instance_type_variants": { - "regional_aliases": { - "us-west-2": { - "neuron-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" - "pytorch-training-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04" - } - }, - "variants": {"trn1": {"regional_properties": {"image_uri": "$neuron-ecr-uri"}}}, + { + "name": "ENDPOINT_SERVER_TIMEOUT", + "type": "int", + "default": 3600, + "scope": "container", + "required_for_model_class": True, }, - }, - "gpu-training": { - "supported_training_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"], - "training_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/gpu-training/model/", - "training_instance_type_variants": { - "regional_aliases": { - "us-west-2": { - "gpu-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" - "pytorch-hosting-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04" - } + ], + "TrainingMetrics": [ + { + "Name": "huggingface-textgeneration:eval-loss", + "Regex": "eval_epoch_loss=tensor\\(([0-9\\.]+)", + }, + { + "Name": "huggingface-textgeneration:eval-ppl", + "Regex": "eval_ppl=tensor\\(([0-9\\.]+)", + }, + { + "Name": "huggingface-textgeneration:train-loss", + "Regex": "train_epoch_loss=([0-9\\.]+)", + }, + ], + "DefaultInferenceInstanceType": "ml.g5.48xlarge", + "supported_inference_instance_types": ["ml.g5.48xlarge", "ml.p4d.24xlarge"], + "default_training_instance_type": "ml.g5.48xlarge", + "SupportedInferenceInstanceTypes": ["ml.g5.48xlarge", "ml.p4d.24xlarge"], + "ModelDataDownloadTimeout": 1200, + "ContainerStartupHealthCheckTimeout": 1200, + "EncryptInterContainerTraffic": True, + "DisableOutputCompression": True, + "MaxRuntimeInSeconds": 360000, + "SageMakerSdkPredictorSpecifications": { + "SupportedContentTypes": ["application/json"], + "SupportedAcceptTypes": ["application/json"], + "DefaultContentType": "application/json", + "DefaultAcceptType": "application/json", + }, + "InferenceVolumeSize": 256, + "TrainingVolumeSize": 256, + "InferenceEnableNetworkIsolation": True, + "TrainingEnableNetworkIsolation": True, + "DefaultTrainingDatasetUri": "s3://jumpstart-cache-prod-us-west-2/training-datasets/sec_amazon/", # noqa: E501 + "ValidationSupported": True, + "FineTuningSupported": True, + "ResourceNameBase": "meta-textgeneration-llama-2-70b", + "DefaultPayloads": { + "meaningOfLife": { + "ContentType": "application/json", + "PromptKey": "inputs", + "OutputKeys": { + "generated_text": "[0].generated_text", + "input_logprobs": "[0].details.prefill[*].logprob", }, - "variants": { - "p2": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, - "p3": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, + "Body": { + "inputs": "I believe the meaning of life is", + "parameters": { + "max_new_tokens": 64, + "top_p": 0.9, + "temperature": 0.6, + "decoder_input_details": True, + "details": True, + }, }, }, - }, - "neuron-training-budget": { - "supported_training_instance_types": ["ml.trn1.xlarge", "ml.trn1.2xlarge"], - "training_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/neuron-training-budget/model/", - "training_instance_type_variants": { - "regional_aliases": { - "us-west-2": { - "neuron-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" - "pytorch-training-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04" - } + "theoryOfRelativity": { + "ContentType": "application/json", + "PromptKey": "inputs", + "OutputKeys": {"generated_text": "[0].generated_text"}, + "Body": { + "inputs": "Simply put, the theory of relativity states that ", + "parameters": {"max_new_tokens": 64, "top_p": 0.9, "temperature": 0.6}, }, - "variants": {"trn1": {"regional_properties": {"image_uri": "$neuron-ecr-uri"}}}, }, }, - "gpu-training-budget": { - "supported_training_instance_types": ["ml.p2.xlarge", "ml.p3.2xlarge"], - "training_artifact_key": "artifacts/meta-textgeneration-llama-2-7b/gpu-training-budget/model/", - "training_instance_type_variants": { - "regional_aliases": { - "us-west-2": { - "gpu-ecr-uri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/" - "pytorch-hosting-neuronx:1.13.1-neuronx-py310-sdk2.14.1-ubuntu20.04" - } + "GatedBucket": True, + "HostingInstanceTypeVariants": { + "Aliases": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" # noqa: E501 + }, + "Variants": { + "g4dn": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "g5": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "local_gpu": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p2": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3dn": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4d": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4de": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p5": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "ml.g5.12xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "4"}}}, + "ml.g5.24xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "4"}}}, + "ml.g5.48xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "8"}}}, + "ml.p4d.24xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "8"}}}, + }, + }, + "TrainingInstanceTypeVariants": { + "Aliases": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" # noqa: E501 + }, + "Variants": { + "g4dn": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "g5": { + "properties": { + "image_uri": "$gpu_ecr_uri_1", + "gated_model_key_env_var_value": "meta-training/g5/v1.0.0/train-meta-textgeneration-llama-2-70b.tar.gz", # noqa: E501 + }, }, - "variants": { - "p2": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, - "p3": {"regional_properties": {"image_uri": "$gpu-ecr-uri"}}, + "local_gpu": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p2": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3dn": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4d": { + "properties": { + "image_uri": "$gpu_ecr_uri_1", + "gated_model_key_env_var_value": "meta-training/p4d/v1.0.0/train-meta-textgeneration-llama-2-70b.tar.gz", # noqa: E501 + }, }, + "p4de": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p5": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, }, }, + "HostingArtifactS3DataType": "S3Prefix", + "HostingArtifactCompressionType": "None", + "HostingResourceRequirements": {"MinMemoryMb": 393216, "NumAccelerators": 8}, + "DynamicContainerDeploymentSupported": True, + "TrainingModelPackageArtifactUri": None, + "Task": "text generation", + "DataType": "text", + "Framework": "meta", + "Dependencies": [], + }, + "huggingface-textembedding-bloom-7b1": { + "Url": "https://huggingface.co/bigscience/bloom-7b1", + "MinSdkVersion": "2.144.0", + "TrainingSupported": False, + "IncrementalTrainingSupported": False, + "HostingEcrUri": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04", # noqa: E501 + "HostingArtifactUri": "s3://jumpstart-cache-prod-us-west-2/huggingface-infer/infer-huggingface-textembedding-bloom-7b1.tar.gz", # noqa: E501 + "HostingScriptUri": "s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/huggingface/inference/textembedding/v1.0.1/sourcedir.tar.gz", # noqa: E501 + "HostingPrepackedArtifactUri": "s3://jumpstart-cache-prod-us-west-2/huggingface-infer/prepack/v1.0.1/infer-prepack-huggingface-textembedding-bloom-7b1.tar.gz", # noqa: E501 + "HostingPrepackedArtifactVersion": "1.0.1", + "InferenceDependencies": [ + "accelerate==0.16.0", + "bitsandbytes==0.37.0", + "filelock==3.9.0", + "huggingface_hub==0.12.0", + "regex==2022.7.9", + "tokenizers==0.13.2", + "transformers==4.26.0", + ], + "TrainingDependencies": [], + "InferenceEnvironmentVariables": [ + { + "Name": "SAGEMAKER_PROGRAM", + "Type": "text", + "Default": "inference.py", + "Scope": "container", + "RequiredForModelClass": True, + } + ], + "TrainingMetrics": [], + "DefaultInferenceInstanceType": "ml.g5.12xlarge", + "SupportedInferenceInstanceTypes": [ + "ml.g5.12xlarge", + "ml.g5.24xlarge", + "ml.p3.8xlarge", + "ml.p3.16xlarge", + "ml.g4dn.12xlarge", + ], + "deploy_kwargs": { + "ModelDataDownloadTimeout": 3600, + "ContainerStartupHealthCheckTimeout": 3600, + }, + "SageMakerSdkPredictorSpecifications": { + "SupportedContentTypes": ["application/json", "application/x-text"], + "SupportedAcceptTypes": ["application/json;verbose", "application/json"], + "DefaultContentType": "application/json", + "DefaultAcceptType": "application/json", + }, + "InferenceVolumeSize": 256, + "InferenceEnableNetworkIsolation": True, + "ValidationSupported": False, + "FineTuningSupported": False, + "ResourceNameBase": "hf-textembedding-bloom-7b1", + "HostingInstanceTypeVariants": { + "Aliases": { + "alias_ecr_uri_3": "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training-neuron:1.11.0-neuron-py38-sdk2.4.0-ubuntu20.04", # noqa: E501 + "cpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.12.0-cpu-py38", + "gpu_ecr_uri_2": "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.12.0-gpu-py38", + }, + "Variants": { + "c4": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5d": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c5n": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c6i": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "c7i": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "g4dn": {"properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "g5": {"properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "local": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "local_gpu": {"properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "m4": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m5": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m5d": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m6i": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "m7i": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "p2": {"properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p3": {"properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p3dn": {"properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p4d": {"properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p4de": {"properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "p5": {"properties": {"image_uri": "$gpu_ecr_uri_2"}}, + "r5": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r5d": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "r7i": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "t2": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "t3": {"properties": {"image_uri": "$cpu_ecr_uri_1"}}, + "trn1": {"properties": {"image_uri": "$alias_ecr_uri_3"}}, + "trn1n": {"properties": {"image_uri": "$alias_ecr_uri_3"}}, + }, + }, + "TrainingModelPackageArtifactUri": None, + "DynamicContainerDeploymentSupported": False, + "License": "BigScience RAIL", + "Dependencies": [], }, -} - - -INFERENCE_CONFIG_RANKINGS = { - "inference_config_rankings": { - "overall": { - "description": "Overall rankings of configs", - "rankings": [ - "neuron-inference", - "neuron-inference-budget", - "gpu-inference", - "gpu-inference-budget", - ], - }, - "performance": { - "description": "Configs ranked based on performance", - "rankings": [ - "neuron-inference", - "gpu-inference", - "neuron-inference-budget", - "gpu-inference-budget", - ], - }, - "cost": { - "description": "Configs ranked based on cost", - "rankings": [ - "neuron-inference-budget", - "gpu-inference-budget", - "neuron-inference", - "gpu-inference", - ], - }, - } -} - -TRAINING_CONFIG_RANKINGS = { - "training_config_rankings": { - "overall": { - "description": "Overall rankings of configs", - "rankings": [ - "neuron-training", - "neuron-training-budget", - "gpu-training", - "gpu-training-budget", - ], - }, - "performance_training": { - "description": "Configs ranked based on performance", - "rankings": [ - "neuron-training", - "gpu-training", - "neuron-training-budget", - "gpu-training-budget", - ], - "instance_type_overrides": { - "ml.p2.xlarge": [ - "neuron-training", - "neuron-training-budget", - "gpu-training", - "gpu-training-budget", - ] - }, - }, - "cost_training": { - "description": "Configs ranked based on cost", - "rankings": [ - "neuron-training-budget", - "gpu-training-budget", - "neuron-training", - "gpu-training", - ], - }, - } } diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index 96b00793b8..062209e3a0 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -1137,6 +1137,7 @@ def test_jumpstart_estimator_kwargs_match_parent_class(self): "region", "tolerate_vulnerable_model", "tolerate_deprecated_model", + "hub_name", } assert parent_class_init_args - js_class_init_args == init_args_to_skip @@ -1163,6 +1164,7 @@ def test_jumpstart_estimator_kwargs_match_parent_class(self): "self", "name", "resources", + "model_reference_arn", } assert parent_class_deploy_args - js_class_deploy_args == deploy_args_to_skip @@ -1241,6 +1243,7 @@ def test_no_predictor_returns_default_predictor( tolerate_deprecated_model=False, tolerate_vulnerable_model=False, sagemaker_session=estimator.sagemaker_session, + hub_arn=None, ) self.assertEqual(type(predictor), Predictor) self.assertEqual(predictor, default_predictor_with_presets) @@ -1410,6 +1413,7 @@ def test_incremental_training_with_unsupported_model_logs_warning( tolerate_deprecated_model=False, tolerate_vulnerable_model=False, sagemaker_session=sagemaker_session, + hub_arn=None, ) @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") @@ -1465,6 +1469,7 @@ def test_incremental_training_with_supported_model_doesnt_log_warning( tolerate_deprecated_model=False, tolerate_vulnerable_model=False, sagemaker_session=sagemaker_session, + hub_arn=None, ) @mock.patch("sagemaker.utils.sagemaker_timestamp") @@ -1743,6 +1748,7 @@ def test_model_id_not_found_refeshes_cache_training( region=None, script=JumpStartScriptScope.TRAINING, sagemaker_session=None, + hub_arn=None, ), mock.call( model_id="js-trainable-model", @@ -1750,6 +1756,7 @@ def test_model_id_not_found_refeshes_cache_training( region=None, script=JumpStartScriptScope.TRAINING, sagemaker_session=None, + hub_arn=None, ), ] ) @@ -1771,6 +1778,7 @@ def test_model_id_not_found_refeshes_cache_training( region=None, script=JumpStartScriptScope.TRAINING, sagemaker_session=None, + hub_arn=None, ), mock.call( model_id="js-trainable-model", @@ -1778,6 +1786,7 @@ def test_model_id_not_found_refeshes_cache_training( region=None, script=JumpStartScriptScope.TRAINING, sagemaker_session=None, + hub_arn=None, ), ] ) diff --git a/tests/unit/sagemaker/jumpstart/hub/__init__.py b/tests/unit/sagemaker/jumpstart/hub/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/sagemaker/jumpstart/hub/test_hub.py b/tests/unit/sagemaker/jumpstart/hub/test_hub.py new file mode 100644 index 0000000000..e2085e5ab9 --- /dev/null +++ b/tests/unit/sagemaker/jumpstart/hub/test_hub.py @@ -0,0 +1,235 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +from datetime import datetime +from unittest.mock import patch, MagicMock +import pytest +from mock import Mock +from sagemaker.jumpstart.hub.hub import Hub +from sagemaker.jumpstart.hub.types import S3ObjectLocation + + +REGION = "us-east-1" +ACCOUNT_ID = "123456789123" +HUB_NAME = "mock-hub-name" + +MODULE_PATH = "sagemaker.jumpstart.hub.hub.Hub" + +FAKE_TIME = datetime(1997, 8, 14, 00, 00, 00) + + +@pytest.fixture() +def sagemaker_session(): + boto_mock = Mock(name="boto_session") + sagemaker_session_mock = Mock( + name="sagemaker_session", boto_session=boto_mock, boto_region_name=REGION + ) + sagemaker_session_mock._client_config.user_agent = ( + "Boto3/1.9.69 Python/3.6.5 Linux/4.14.77-70.82.amzn1.x86_64 Botocore/1.12.69 Resource" + ) + sagemaker_session_mock.describe_hub.return_value = { + "S3StorageConfig": {"S3OutputPath": "s3://mock-bucket-123"} + } + sagemaker_session_mock.account_id.return_value = ACCOUNT_ID + return sagemaker_session_mock + + +@pytest.fixture +def mock_instance(sagemaker_session): + mock_instance = MagicMock() + mock_instance.hub_name = "test-hub" + mock_instance._sagemaker_session = sagemaker_session + return mock_instance + + +def test_instantiates(sagemaker_session): + hub = Hub(hub_name=HUB_NAME, sagemaker_session=sagemaker_session) + assert hub.hub_name == HUB_NAME + assert hub.region == "us-east-1" + assert hub._sagemaker_session == sagemaker_session + + +@pytest.mark.parametrize( + ("hub_name,hub_description,hub_bucket_name,hub_display_name,hub_search_keywords,tags"), + [ + pytest.param("MockHub1", "this is my sagemaker hub", None, None, None, None), + pytest.param( + "MockHub2", + "this is my sagemaker hub two", + None, + "DisplayMockHub2", + ["mock", "hub", "123"], + [{"Key": "tag-key-1", "Value": "tag-value-1"}], + ), + ], +) +@patch("sagemaker.jumpstart.hub.hub.Hub._generate_hub_storage_location") +def test_create_with_no_bucket_name( + mock_generate_hub_storage_location, + sagemaker_session, + hub_name, + hub_description, + hub_bucket_name, + hub_display_name, + hub_search_keywords, + tags, +): + storage_location = S3ObjectLocation( + "sagemaker-hubs-us-east-1-123456789123", f"{hub_name}-{FAKE_TIME.timestamp()}" + ) + mock_generate_hub_storage_location.return_value = storage_location + create_hub = {"HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{hub_name}"} + sagemaker_session.create_hub = Mock(return_value=create_hub) + sagemaker_session.describe_hub.return_value = { + "S3StorageConfig": {"S3OutputPath": f"s3://{hub_bucket_name}/{storage_location.key}"} + } + hub = Hub(hub_name=hub_name, sagemaker_session=sagemaker_session) + request = { + "hub_name": hub_name, + "hub_description": hub_description, + "hub_display_name": hub_display_name, + "hub_search_keywords": hub_search_keywords, + "s3_storage_config": { + "S3OutputPath": f"s3://sagemaker-hubs-us-east-1-123456789123/{storage_location.key}" + }, + "tags": tags, + } + response = hub.create( + description=hub_description, + display_name=hub_display_name, + search_keywords=hub_search_keywords, + tags=tags, + ) + sagemaker_session.create_hub.assert_called_with(**request) + assert response == {"HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{hub_name}"} + + +@pytest.mark.parametrize( + ("hub_name,hub_description,hub_bucket_name,hub_display_name,hub_search_keywords,tags"), + [ + pytest.param("MockHub1", "this is my sagemaker hub", "mock-bucket-123", None, None, None), + pytest.param( + "MockHub2", + "this is my sagemaker hub two", + "mock-bucket-123", + "DisplayMockHub2", + ["mock", "hub", "123"], + [{"Key": "tag-key-1", "Value": "tag-value-1"}], + ), + ], +) +@patch("sagemaker.jumpstart.hub.hub.Hub._generate_hub_storage_location") +def test_create_with_bucket_name( + mock_generate_hub_storage_location, + sagemaker_session, + hub_name, + hub_description, + hub_bucket_name, + hub_display_name, + hub_search_keywords, + tags, +): + storage_location = S3ObjectLocation(hub_bucket_name, f"{hub_name}-{FAKE_TIME.timestamp()}") + mock_generate_hub_storage_location.return_value = storage_location + create_hub = {"HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{hub_name}"} + sagemaker_session.create_hub = Mock(return_value=create_hub) + hub = Hub(hub_name=hub_name, sagemaker_session=sagemaker_session, bucket_name=hub_bucket_name) + request = { + "hub_name": hub_name, + "hub_description": hub_description, + "hub_display_name": hub_display_name, + "hub_search_keywords": hub_search_keywords, + "s3_storage_config": {"S3OutputPath": f"s3://mock-bucket-123/{storage_location.key}"}, + "tags": tags, + } + response = hub.create( + description=hub_description, + display_name=hub_display_name, + search_keywords=hub_search_keywords, + tags=tags, + ) + sagemaker_session.create_hub.assert_called_with(**request) + assert response == {"HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{hub_name}"} + + +@patch("sagemaker.jumpstart.hub.interfaces.DescribeHubContentResponse.from_json") +def test_describe_model_success(mock_describe_hub_content_response, sagemaker_session): + mock_describe_hub_content_response.return_value = Mock() + mock_list_hub_content_versions = sagemaker_session.list_hub_content_versions + mock_list_hub_content_versions.return_value = { + "HubContentSummaries": [ + {"HubContentVersion": "1.0"}, + {"HubContentVersion": "2.0"}, + {"HubContentVersion": "3.0"}, + ] + } + + hub = Hub(hub_name=HUB_NAME, sagemaker_session=sagemaker_session) + + with patch("sagemaker.jumpstart.hub.utils.get_hub_model_version") as mock_get_hub_model_version: + mock_get_hub_model_version.return_value = "3.0" + + hub.describe_model("test-model") + + mock_list_hub_content_versions.assert_called_with( + hub_name=HUB_NAME, hub_content_name="test-model", hub_content_type="Model" + ) + sagemaker_session.describe_hub_content.assert_called_with( + hub_name=HUB_NAME, + hub_content_name="test-model", + hub_content_version="3.0", + hub_content_type="Model", + ) + + +def test_create_hub_content_reference(sagemaker_session): + hub = Hub(hub_name=HUB_NAME, sagemaker_session=sagemaker_session) + model_name = "mock-model-one-huggingface" + min_version = "1.1.1" + public_model_arn = ( + f"arn:aws:sagemaker:us-east-1:123456789123:hub-content/JumpStartHub/model/{model_name}" + ) + create_hub_content_reference = { + "HubArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub/{HUB_NAME}", + "HubContentReferenceArn": f"arn:aws:sagemaker:us-east-1:123456789123:hub-content/{HUB_NAME}/ModelRef/{model_name}", # noqa: E501 + } + sagemaker_session.create_hub_content_reference = Mock(return_value=create_hub_content_reference) + + request = { + "hub_name": HUB_NAME, + "source_hub_content_arn": public_model_arn, + "hub_content_name": model_name, + "min_version": min_version, + } + + response = hub.create_model_reference( + model_arn=public_model_arn, model_name=model_name, min_version=min_version + ) + sagemaker_session.create_hub_content_reference.assert_called_with(**request) + + assert response == { + "HubArn": "arn:aws:sagemaker:us-east-1:123456789123:hub/mock-hub-name", + "HubContentReferenceArn": "arn:aws:sagemaker:us-east-1:123456789123:hub-content/mock-hub-name/ModelRef/mock-model-one-huggingface", # noqa: E501 + } + + +def test_delete_hub_content_reference(sagemaker_session): + hub = Hub(hub_name=HUB_NAME, sagemaker_session=sagemaker_session) + model_name = "mock-model-one-huggingface" + + hub.delete_model_reference(model_name) + sagemaker_session.delete_hub_content_reference.assert_called_with( + hub_name=HUB_NAME, + hub_content_type="ModelReference", + hub_content_name="mock-model-one-huggingface", + ) diff --git a/tests/unit/sagemaker/jumpstart/hub/test_interfaces.py b/tests/unit/sagemaker/jumpstart/hub/test_interfaces.py new file mode 100644 index 0000000000..c4b95443ec --- /dev/null +++ b/tests/unit/sagemaker/jumpstart/hub/test_interfaces.py @@ -0,0 +1,981 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest +import numpy as np +from sagemaker.jumpstart.types import ( + JumpStartHyperparameter, + JumpStartInstanceTypeVariants, + JumpStartEnvironmentVariable, + JumpStartPredictorSpecs, + JumpStartSerializablePayload, +) +from sagemaker.jumpstart.hub.interfaces import HubModelDocument +from tests.unit.sagemaker.jumpstart.constants import ( + SPECIAL_MODEL_SPECS_DICT, + HUB_MODEL_DOCUMENT_DICTS, +) + +gemma_model_spec = SPECIAL_MODEL_SPECS_DICT["gemma-model-2b-v1_1_0"] + + +def test_hub_content_document_from_json_obj(): + region = "us-west-2" + gemma_model_document = HubModelDocument( + json_obj=HUB_MODEL_DOCUMENT_DICTS["huggingface-llm-gemma-2b-instruct"], region=region + ) + assert gemma_model_document.url == "https://huggingface.co/google/gemma-2b-it" + assert gemma_model_document.min_sdk_version == "2.189.0" + assert gemma_model_document.training_supported is True + assert gemma_model_document.incremental_training_supported is False + assert ( + gemma_model_document.hosting_ecr_uri + == "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:" + "2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04" + ) + with pytest.raises(AttributeError) as excinfo: + gemma_model_document.hosting_ecr_specs + assert str(excinfo.value) == "'HubModelDocument' object has no attribute 'hosting_ecr_specs'" + assert gemma_model_document.hosting_artifact_s3_data_type == "S3Prefix" + assert gemma_model_document.hosting_artifact_compression_type == "None" + assert ( + gemma_model_document.hosting_artifact_uri + == "s3://jumpstart-cache-prod-us-west-2/huggingface-llm/huggingface-llm-gemma-2b-instruct" + "/artifacts/inference/v1.0.0/" + ) + assert ( + gemma_model_document.hosting_script_uri + == "s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/huggingface/inference/" + "llm/v1.0.1/sourcedir.tar.gz" + ) + assert gemma_model_document.inference_dependencies == [] + assert gemma_model_document.training_dependencies == [ + "accelerate==0.26.1", + "bitsandbytes==0.42.0", + "deepspeed==0.10.3", + "docstring-parser==0.15", + "flash_attn==2.5.5", + "ninja==1.11.1", + "packaging==23.2", + "peft==0.8.2", + "py_cpuinfo==9.0.0", + "rich==13.7.0", + "safetensors==0.4.2", + "sagemaker_jumpstart_huggingface_script_utilities==1.2.1", + "sagemaker_jumpstart_script_utilities==1.1.9", + "sagemaker_jumpstart_tabular_script_utilities==1.0.0", + "shtab==1.6.5", + "tokenizers==0.15.1", + "transformers==4.38.1", + "trl==0.7.10", + "tyro==0.7.2", + ] + assert ( + gemma_model_document.hosting_prepacked_artifact_uri + == "s3://jumpstart-cache-prod-us-west-2/huggingface-llm/huggingface-llm-gemma-2b-instruct/" + "artifacts/inference-prepack/v1.0.0/" + ) + assert gemma_model_document.hosting_prepacked_artifact_version == "1.0.0" + assert gemma_model_document.hosting_use_script_uri is False + assert ( + gemma_model_document.hosting_eula_uri + == "s3://jumpstart-cache-prod-us-west-2/huggingface-llm/fmhMetadata/terms/gemmaTerms.txt" + ) + assert ( + gemma_model_document.training_ecr_uri + == "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-training:2.0.0-transformers" + "4.28.1-gpu-py310-cu118-ubuntu20.04" + ) + with pytest.raises(AttributeError) as excinfo: + gemma_model_document.training_ecr_specs + assert str(excinfo.value) == "'HubModelDocument' object has no attribute 'training_ecr_specs'" + assert ( + gemma_model_document.training_prepacked_script_uri + == "s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/huggingface/transfer_learning/" + "llm/prepack/v1.1.1/sourcedir.tar.gz" + ) + assert gemma_model_document.training_prepacked_script_version == "1.1.1" + assert ( + gemma_model_document.training_script_uri + == "s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/huggingface/transfer_learning/" + "llm/v1.1.1/sourcedir.tar.gz" + ) + assert gemma_model_document.training_artifact_s3_data_type == "S3Prefix" + assert gemma_model_document.training_artifact_compression_type == "None" + assert ( + gemma_model_document.training_artifact_uri + == "s3://jumpstart-cache-prod-us-west-2/huggingface-training/train-huggingface-llm-gemma-2b-instruct" + ".tar.gz" + ) + assert gemma_model_document.hyperparameters == [ + JumpStartHyperparameter( + { + "Name": "peft_type", + "Type": "text", + "Default": "lora", + "Options": ["lora", "None"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "instruction_tuned", + "Type": "text", + "Default": "False", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "chat_dataset", + "Type": "text", + "Default": "True", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "epoch", + "Type": "int", + "Default": 1, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "learning_rate", + "Type": "float", + "Default": 0.0001, + "Min": 1e-08, + "Max": 1, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "lora_r", + "Type": "int", + "Default": 64, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + {"Name": "lora_alpha", "Type": "int", "Default": 16, "Min": 0, "Scope": "algorithm"}, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "lora_dropout", + "Type": "float", + "Default": 0, + "Min": 0, + "Max": 1, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + {"Name": "bits", "Type": "int", "Default": 4, "Scope": "algorithm"}, is_hub_content=True + ), + JumpStartHyperparameter( + { + "Name": "double_quant", + "Type": "text", + "Default": "True", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "quant_Type", + "Type": "text", + "Default": "nf4", + "Options": ["fp4", "nf4"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "per_device_train_batch_size", + "Type": "int", + "Default": 1, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "per_device_eval_batch_size", + "Type": "int", + "Default": 2, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "warmup_ratio", + "Type": "float", + "Default": 0.1, + "Min": 0, + "Max": 1, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "train_from_scratch", + "Type": "text", + "Default": "False", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "fp16", + "Type": "text", + "Default": "True", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "bf16", + "Type": "text", + "Default": "False", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "evaluation_strategy", + "Type": "text", + "Default": "steps", + "Options": ["steps", "epoch", "no"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "eval_steps", + "Type": "int", + "Default": 20, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "gradient_accumulation_steps", + "Type": "int", + "Default": 4, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "logging_steps", + "Type": "int", + "Default": 8, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "weight_decay", + "Type": "float", + "Default": 0.2, + "Min": 1e-08, + "Max": 1, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "load_best_model_at_end", + "Type": "text", + "Default": "True", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "max_train_samples", + "Type": "int", + "Default": -1, + "Min": -1, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "max_val_samples", + "Type": "int", + "Default": -1, + "Min": -1, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "seed", + "Type": "int", + "Default": 10, + "Min": 1, + "Max": 1000, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "max_input_length", + "Type": "int", + "Default": 1024, + "Min": -1, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "validation_split_ratio", + "Type": "float", + "Default": 0.2, + "Min": 0, + "Max": 1, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "train_data_split_seed", + "Type": "int", + "Default": 0, + "Min": 0, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "preprocessing_num_workers", + "Type": "text", + "Default": "None", + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + {"Name": "max_steps", "Type": "int", "Default": -1, "Scope": "algorithm"}, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "gradient_checkpointing", + "Type": "text", + "Default": "False", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "early_stopping_patience", + "Type": "int", + "Default": 3, + "Min": 1, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "early_stopping_threshold", + "Type": "float", + "Default": 0.0, + "Min": 0, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "adam_beta1", + "Type": "float", + "Default": 0.9, + "Min": 0, + "Max": 1, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "adam_beta2", + "Type": "float", + "Default": 0.999, + "Min": 0, + "Max": 1, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "adam_epsilon", + "Type": "float", + "Default": 1e-08, + "Min": 0, + "Max": 1, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "max_grad_norm", + "Type": "float", + "Default": 1.0, + "Min": 0, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "label_smoothing_factor", + "Type": "float", + "Default": 0, + "Min": 0, + "Max": 1, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "logging_first_step", + "Type": "text", + "Default": "False", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "logging_nan_inf_filter", + "Type": "text", + "Default": "True", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "save_strategy", + "Type": "text", + "Default": "steps", + "Options": ["no", "epoch", "steps"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "save_steps", + "Type": "int", + "Default": 500, + "Min": 1, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "save_total_limit", + "Type": "int", + "Default": 1, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "dataloader_drop_last", + "Type": "text", + "Default": "False", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "dataloader_num_workers", + "Type": "int", + "Default": 0, + "Min": 0, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "eval_accumulation_steps", + "Type": "text", + "Default": "None", + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "auto_find_batch_size", + "Type": "text", + "Default": "False", + "Options": ["True", "False"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "lr_scheduler_type", + "Type": "text", + "Default": "constant_with_warmup", + "Options": ["constant_with_warmup", "linear"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "warmup_steps", + "Type": "int", + "Default": 0, + "Min": 0, + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "deepspeed", + "Type": "text", + "Default": "False", + "Options": ["False"], + "Scope": "algorithm", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "sagemaker_submit_directory", + "Type": "text", + "Default": "/opt/ml/input/data/code/sourcedir.tar.gz", + "Scope": "container", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "sagemaker_program", + "Type": "text", + "Default": "transfer_learning.py", + "Scope": "container", + }, + is_hub_content=True, + ), + JumpStartHyperparameter( + { + "Name": "sagemaker_container_log_level", + "Type": "text", + "Default": "20", + "Scope": "container", + }, + is_hub_content=True, + ), + ] + assert gemma_model_document.inference_environment_variables == [ + JumpStartEnvironmentVariable( + { + "Name": "SAGEMAKER_PROGRAM", + "Type": "text", + "Default": "inference.py", + "Scope": "container", + "RequiredForModelClass": True, + }, + is_hub_content=True, + ), + JumpStartEnvironmentVariable( + { + "Name": "SAGEMAKER_SUBMIT_DIRECTORY", + "Type": "text", + "Default": "/opt/ml/model/code", + "Scope": "container", + "RequiredForModelClass": False, + }, + is_hub_content=True, + ), + JumpStartEnvironmentVariable( + { + "Name": "SAGEMAKER_CONTAINER_LOG_LEVEL", + "Type": "text", + "Default": "20", + "Scope": "container", + "RequiredForModelClass": False, + }, + is_hub_content=True, + ), + JumpStartEnvironmentVariable( + { + "Name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "Type": "text", + "Default": "3600", + "Scope": "container", + "RequiredForModelClass": False, + }, + is_hub_content=True, + ), + JumpStartEnvironmentVariable( + { + "Name": "ENDPOINT_SERVER_TIMEOUT", + "Type": "int", + "Default": 3600, + "Scope": "container", + "RequiredForModelClass": True, + }, + is_hub_content=True, + ), + JumpStartEnvironmentVariable( + { + "Name": "MODEL_CACHE_ROOT", + "Type": "text", + "Default": "/opt/ml/model", + "Scope": "container", + "RequiredForModelClass": True, + }, + is_hub_content=True, + ), + JumpStartEnvironmentVariable( + { + "Name": "SAGEMAKER_ENV", + "Type": "text", + "Default": "1", + "Scope": "container", + "RequiredForModelClass": True, + }, + is_hub_content=True, + ), + JumpStartEnvironmentVariable( + { + "Name": "HF_MODEL_ID", + "Type": "text", + "Default": "/opt/ml/model", + "Scope": "container", + "RequiredForModelClass": True, + }, + is_hub_content=True, + ), + JumpStartEnvironmentVariable( + { + "Name": "MAX_INPUT_LENGTH", + "Type": "text", + "Default": "8191", + "Scope": "container", + "RequiredForModelClass": True, + }, + is_hub_content=True, + ), + JumpStartEnvironmentVariable( + { + "Name": "MAX_TOTAL_TOKENS", + "Type": "text", + "Default": "8192", + "Scope": "container", + "RequiredForModelClass": True, + }, + is_hub_content=True, + ), + JumpStartEnvironmentVariable( + { + "Name": "MAX_BATCH_PREFILL_TOKENS", + "Type": "text", + "Default": "8191", + "Scope": "container", + "RequiredForModelClass": True, + }, + is_hub_content=True, + ), + JumpStartEnvironmentVariable( + { + "Name": "SM_NUM_GPUS", + "Type": "text", + "Default": "1", + "Scope": "container", + "RequiredForModelClass": True, + }, + is_hub_content=True, + ), + JumpStartEnvironmentVariable( + { + "Name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "Type": "int", + "Default": 1, + "Scope": "container", + "RequiredForModelClass": True, + }, + is_hub_content=True, + ), + ] + assert gemma_model_document.training_metrics == [ + { + "Name": "huggingface-textgeneration:eval-loss", + "Regex": "'eval_loss': ([0-9]+\\.[0-9]+)", + }, + { + "Name": "huggingface-textgeneration:train-loss", + "Regex": "'loss': ([0-9]+\\.[0-9]+)", + }, + ] + assert gemma_model_document.default_inference_instance_type == "ml.g5.xlarge" + assert gemma_model_document.supported_inference_instance_types == [ + "ml.g5.xlarge", + "ml.g5.2xlarge", + "ml.g5.4xlarge", + "ml.g5.8xlarge", + "ml.g5.16xlarge", + "ml.g5.12xlarge", + "ml.g5.24xlarge", + "ml.g5.48xlarge", + "ml.p4d.24xlarge", + ] + assert gemma_model_document.default_training_instance_type == "ml.g5.2xlarge" + assert np.array_equal( + gemma_model_document.supported_training_instance_types, + [ + "ml.g5.2xlarge", + "ml.g5.4xlarge", + "ml.g5.8xlarge", + "ml.g5.16xlarge", + "ml.g5.12xlarge", + "ml.g5.24xlarge", + "ml.g5.48xlarge", + "ml.p4d.24xlarge", + ], + ) + assert gemma_model_document.sage_maker_sdk_predictor_specifications == JumpStartPredictorSpecs( + { + "SupportedContentTypes": ["application/json"], + "SupportedAcceptTypes": ["application/json"], + "DefaultContentType": "application/json", + "DefaultAcceptType": "application/json", + }, + is_hub_content=True, + ) + assert gemma_model_document.inference_volume_size == 512 + assert gemma_model_document.training_volume_size == 512 + assert gemma_model_document.inference_enable_network_isolation is True + assert gemma_model_document.training_enable_network_isolation is True + assert gemma_model_document.fine_tuning_supported is True + assert gemma_model_document.validation_supported is True + assert ( + gemma_model_document.default_training_dataset_uri + == "s3://jumpstart-cache-prod-us-west-2/training-datasets/oasst_top/train/" + ) + assert gemma_model_document.resource_name_base == "hf-llm-gemma-2b-instruct" + assert gemma_model_document.default_payloads == { + "HelloWorld": JumpStartSerializablePayload( + { + "ContentType": "application/json", + "PromptKey": "inputs", + "OutputKeys": { + "GeneratedText": "[0].generated_text", + "InputLogprobs": "[0].details.prefill[*].logprob", + }, + "Body": { + "Inputs": "user\nWrite a hello world program" + "\nmodel", + "Parameters": { + "MaxNewTokens": 256, + "DecoderInputDetails": True, + "Details": True, + }, + }, + }, + is_hub_content=True, + ), + "MachineLearningPoem": JumpStartSerializablePayload( + { + "ContentType": "application/json", + "PromptKey": "inputs", + "OutputKeys": { + "GeneratedText": "[0].generated_text", + "InputLogprobs": "[0].details.prefill[*].logprob", + }, + "Body": { + "Inputs": "Write me a poem about Machine Learning.", + "Parameters": { + "MaxNewTokens": 256, + "DecoderInputDetails": True, + "Details": True, + }, + }, + }, + is_hub_content=True, + ), + } + assert gemma_model_document.gated_bucket is True + assert gemma_model_document.hosting_resource_requirements == { + "MinMemoryMb": 8192, + "NumAccelerators": 1, + } + assert gemma_model_document.hosting_instance_type_variants == JumpStartInstanceTypeVariants( + { + "Aliases": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch" + "-tgi-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" + }, + "Variants": { + "g4dn": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "g5": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "local_gpu": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p2": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3dn": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4d": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4de": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p5": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "ml.g5.12xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "4"}}}, + "ml.g5.24xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "4"}}}, + "ml.g5.48xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "8"}}}, + "ml.p4d.24xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "8"}}}, + }, + }, + is_hub_content=True, + ) + assert gemma_model_document.training_instance_type_variants == JumpStartInstanceTypeVariants( + { + "Aliases": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-" + "training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + "Variants": { + "g4dn": { + "properties": { + "image_uri": "$gpu_ecr_uri_1", + "gated_model_key_env_var_value": "huggingface-training/g4dn/v1.0.0/train-" + "huggingface-llm-gemma-2b-instruct.tar.gz", + }, + }, + "g5": { + "properties": { + "image_uri": "$gpu_ecr_uri_1", + "gated_model_key_env_var_value": "huggingface-training/g5/v1.0.0/train-" + "huggingface-llm-gemma-2b-instruct.tar.gz", + }, + }, + "local_gpu": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p2": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3dn": { + "properties": { + "image_uri": "$gpu_ecr_uri_1", + "gated_model_key_env_var_value": "huggingface-training/p3dn/v1.0.0/train-" + "huggingface-llm-gemma-2b-instruct.tar.gz", + }, + }, + "p4d": { + "properties": { + "image_uri": "$gpu_ecr_uri_1", + "gated_model_key_env_var_value": "huggingface-training/p4d/v1.0.0/train-" + "huggingface-llm-gemma-2b-instruct.tar.gz", + }, + }, + "p4de": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p5": {"properties": {"image_uri": "$gpu_ecr_uri_1"}}, + }, + }, + is_hub_content=True, + ) + assert gemma_model_document.contextual_help == { + "HubFormatTrainData": [ + "A train and an optional validation directories. Each directory contains a CSV/JSON/TXT. ", + "- For CSV/JSON files, the text data is used from the column called 'text' or the " + "first column if no column called 'text' is found", + "- The number of files under train and validation (if provided) should equal to one," + " respectively.", + " [Learn how to setup an AWS S3 bucket.]" + "(https://docs.aws.amazon.com/AmazonS3/latest/dev/UsingBucket.html)", + ], + "HubDefaultTrainData": [ + "Dataset: [SEC](https://www.sec.gov/edgar/searchedgar/companysearch)", + "SEC filing contains regulatory documents that companies and issuers of securities must " + "submit to the Securities and Exchange Commission (SEC) on a regular basis.", + "License: [CC BY-SA 4.0](https://creativecommons.org/licenses/by-sa/4.0/legalcode)", + ], + } + assert gemma_model_document.model_data_download_timeout == 1200 + assert gemma_model_document.container_startup_health_check_timeout == 1200 + assert gemma_model_document.encrypt_inter_container_traffic is True + assert gemma_model_document.disable_output_compression is True + assert gemma_model_document.max_runtime_in_seconds == 360000 + assert gemma_model_document.dynamic_container_deployment_supported is True + assert gemma_model_document.training_model_package_artifact_uri is None + assert gemma_model_document.dependencies == [] diff --git a/tests/unit/sagemaker/jumpstart/hub/test_utils.py b/tests/unit/sagemaker/jumpstart/hub/test_utils.py new file mode 100644 index 0000000000..ee50805792 --- /dev/null +++ b/tests/unit/sagemaker/jumpstart/hub/test_utils.py @@ -0,0 +1,256 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from unittest.mock import patch, Mock +from sagemaker.jumpstart.types import HubArnExtractedInfo +from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME +from sagemaker.jumpstart.hub import utils + + +def test_get_info_from_hub_resource_arn(): + model_arn = ( + "arn:aws:sagemaker:us-west-2:000000000000:hub-content/MockHub/Model/my-mock-model/1.0.2" + ) + assert utils.get_info_from_hub_resource_arn(model_arn) == HubArnExtractedInfo( + partition="aws", + region="us-west-2", + account_id="000000000000", + hub_name="MockHub", + hub_content_type="Model", + hub_content_name="my-mock-model", + hub_content_version="1.0.2", + ) + + notebook_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub-content/MockHub/Notebook/my-mock-notebook/1.0.2" + assert utils.get_info_from_hub_resource_arn(notebook_arn) == HubArnExtractedInfo( + partition="aws", + region="us-west-2", + account_id="000000000000", + hub_name="MockHub", + hub_content_type="Notebook", + hub_content_name="my-mock-notebook", + hub_content_version="1.0.2", + ) + + hub_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub/MockHub" + assert utils.get_info_from_hub_resource_arn(hub_arn) == HubArnExtractedInfo( + partition="aws", + region="us-west-2", + account_id="000000000000", + hub_name="MockHub", + ) + + invalid_arn = "arn:aws:sagemaker:us-west-2:000000000000:endpoint/my-endpoint-123" + assert None is utils.get_info_from_hub_resource_arn(invalid_arn) + + invalid_arn = "nonsense-string" + assert None is utils.get_info_from_hub_resource_arn(invalid_arn) + + invalid_arn = "" + assert None is utils.get_info_from_hub_resource_arn(invalid_arn) + + +def test_construct_hub_arn_from_name(): + mock_sagemaker_session = Mock() + mock_sagemaker_session.account_id.return_value = "123456789123" + mock_sagemaker_session.boto_region_name = "us-west-2" + hub_name = "my-cool-hub" + + assert ( + utils.construct_hub_arn_from_name(hub_name=hub_name, session=mock_sagemaker_session) + == "arn:aws:sagemaker:us-west-2:123456789123:hub/my-cool-hub" + ) + + assert ( + utils.construct_hub_arn_from_name( + hub_name=hub_name, region="us-east-1", session=mock_sagemaker_session + ) + == "arn:aws:sagemaker:us-east-1:123456789123:hub/my-cool-hub" + ) + + +def test_construct_hub_model_arn_from_inputs(): + model_name, version = "pytorch-ic-imagenet-v2", "1.0.2" + hub_arn = "arn:aws:sagemaker:us-west-2:123456789123:hub/my-mock-hub" + + assert ( + utils.construct_hub_model_arn_from_inputs(hub_arn, model_name, version) + == "arn:aws:sagemaker:us-west-2:123456789123:hub-content/my-mock-hub/Model/pytorch-ic-imagenet-v2/1.0.2" + ) + + version = "*" + assert ( + utils.construct_hub_model_arn_from_inputs(hub_arn, model_name, version) + == "arn:aws:sagemaker:us-west-2:123456789123:hub-content/my-mock-hub/Model/pytorch-ic-imagenet-v2/*" + ) + + +def test_generate_hub_arn_for_init_kwargs(): + hub_name = "my-hub-name" + hub_arn = "arn:aws:sagemaker:us-west-2:12346789123:hub/my-awesome-hub" + # Mock default session with default values + mock_default_session = Mock() + mock_default_session.account_id.return_value = "123456789123" + mock_default_session.boto_region_name = JUMPSTART_DEFAULT_REGION_NAME + # Mock custom session with custom values + mock_custom_session = Mock() + mock_custom_session.account_id.return_value = "000000000000" + mock_custom_session.boto_region_name = "us-east-2" + + assert ( + utils.generate_hub_arn_for_init_kwargs(hub_name, session=mock_default_session) + == "arn:aws:sagemaker:us-west-2:123456789123:hub/my-hub-name" + ) + + assert ( + utils.generate_hub_arn_for_init_kwargs(hub_name, "us-east-1", session=mock_default_session) + == "arn:aws:sagemaker:us-east-1:123456789123:hub/my-hub-name" + ) + + assert ( + utils.generate_hub_arn_for_init_kwargs(hub_name, "eu-west-1", mock_custom_session) + == "arn:aws:sagemaker:eu-west-1:000000000000:hub/my-hub-name" + ) + + assert ( + utils.generate_hub_arn_for_init_kwargs(hub_name, None, mock_custom_session) + == "arn:aws:sagemaker:us-east-2:000000000000:hub/my-hub-name" + ) + + assert utils.generate_hub_arn_for_init_kwargs(hub_arn, session=mock_default_session) == hub_arn + + assert ( + utils.generate_hub_arn_for_init_kwargs(hub_arn, "us-east-1", session=mock_default_session) + == hub_arn + ) + + assert ( + utils.generate_hub_arn_for_init_kwargs(hub_arn, "us-east-1", mock_custom_session) == hub_arn + ) + + assert utils.generate_hub_arn_for_init_kwargs(hub_arn, None, mock_custom_session) == hub_arn + + +def test_create_hub_bucket_if_it_does_not_exist_hub_arn(): + mock_sagemaker_session = Mock() + mock_sagemaker_session.account_id.return_value = "123456789123" + mock_sagemaker_session.client("sts").get_caller_identity.return_value = { + "Account": "123456789123" + } + hub_arn = "arn:aws:sagemaker:us-west-2:12346789123:hub/my-awesome-hub" + # Mock custom session with custom values + mock_custom_session = Mock() + mock_custom_session.account_id.return_value = "000000000000" + mock_custom_session.boto_region_name = "us-east-2" + mock_sagemaker_session.boto_session.resource("s3").Bucket().creation_date = None + mock_sagemaker_session.boto_region_name = "us-east-1" + + bucket_name = "sagemaker-hubs-us-east-1-123456789123" + created_hub_bucket_name = utils.create_hub_bucket_if_it_does_not_exist( + sagemaker_session=mock_sagemaker_session + ) + + mock_sagemaker_session.boto_session.resource("s3").create_bucketassert_called_once() + assert created_hub_bucket_name == bucket_name + assert utils.generate_hub_arn_for_init_kwargs(hub_arn, None, mock_custom_session) == hub_arn + + +def test_is_gated_bucket(): + assert utils.is_gated_bucket("jumpstart-private-cache-prod-us-west-2") is True + + assert utils.is_gated_bucket("jumpstart-private-cache-prod-us-east-1") is True + + assert utils.is_gated_bucket("jumpstart-cache-prod-us-west-2") is False + + assert utils.is_gated_bucket("") is False + + +def test_create_hub_bucket_if_it_does_not_exist(): + mock_sagemaker_session = Mock() + mock_sagemaker_session.account_id.return_value = "123456789123" + mock_sagemaker_session.client("sts").get_caller_identity.return_value = { + "Account": "123456789123" + } + mock_sagemaker_session.boto_session.resource("s3").Bucket().creation_date = None + mock_sagemaker_session.boto_region_name = "us-east-1" + bucket_name = "sagemaker-hubs-us-east-1-123456789123" + created_hub_bucket_name = utils.create_hub_bucket_if_it_does_not_exist( + sagemaker_session=mock_sagemaker_session + ) + + mock_sagemaker_session.boto_session.resource("s3").create_bucketassert_called_once() + assert created_hub_bucket_name == bucket_name + + +@patch("sagemaker.session.Session") +def test_get_hub_model_version_success(mock_session): + hub_name = "test_hub" + hub_model_name = "test_model" + hub_model_type = "test_type" + hub_model_version = "1.0.0" + mock_session.list_hub_content_versions.return_value = { + "HubContentSummaries": [ + {"HubContentVersion": "1.0.0"}, + {"HubContentVersion": "1.2.3"}, + {"HubContentVersion": "2.0.0"}, + ] + } + + result = utils.get_hub_model_version( + hub_name, hub_model_name, hub_model_type, hub_model_version, mock_session + ) + + assert result == "1.0.0" + + +@patch("sagemaker.session.Session") +def test_get_hub_model_version_None(mock_session): + hub_name = "test_hub" + hub_model_name = "test_model" + hub_model_type = "test_type" + hub_model_version = None + mock_session.list_hub_content_versions.return_value = { + "HubContentSummaries": [ + {"HubContentVersion": "1.0.0"}, + {"HubContentVersion": "1.2.3"}, + {"HubContentVersion": "2.0.0"}, + ] + } + + result = utils.get_hub_model_version( + hub_name, hub_model_name, hub_model_type, hub_model_version, mock_session + ) + + assert result == "2.0.0" + + +@patch("sagemaker.session.Session") +def test_get_hub_model_version_wildcard_char(mock_session): + hub_name = "test_hub" + hub_model_name = "test_model" + hub_model_type = "test_type" + hub_model_version = "*" + mock_session.list_hub_content_versions.return_value = { + "HubContentSummaries": [ + {"HubContentVersion": "1.0.0"}, + {"HubContentVersion": "1.2.3"}, + {"HubContentVersion": "2.0.0"}, + ] + } + + result = utils.get_hub_model_version( + hub_name, hub_model_name, hub_model_type, hub_model_version, mock_session + ) + + assert result == "2.0.0" diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index 58c08f5b3d..90a2c573d9 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -714,7 +714,7 @@ def test_jumpstart_model_kwargs_match_parent_class(self): Please add the new argument to the skip set below, and reach out to JumpStart team.""" - init_args_to_skip: Set[str] = set([]) + init_args_to_skip: Set[str] = set(["model_reference_arn"]) deploy_args_to_skip: Set[str] = set(["kwargs"]) parent_class_init = Model.__init__ @@ -731,6 +731,7 @@ def test_jumpstart_model_kwargs_match_parent_class(self): "tolerate_deprecated_model", "instance_type", "model_package_arn", + "hub_name", } assert parent_class_init_args - js_class_init_args == init_args_to_skip @@ -800,6 +801,7 @@ def test_no_predictor_returns_default_predictor( model_id=model_id, model_version="*", region=region, + hub_arn=None, tolerate_deprecated_model=False, tolerate_vulnerable_model=False, sagemaker_session=model.sagemaker_session, @@ -927,6 +929,7 @@ def test_model_id_not_found_refeshes_cache_inference( region=None, script=JumpStartScriptScope.INFERENCE, sagemaker_session=None, + hub_arn=None, ), mock.call( model_id="js-trainable-model", @@ -934,6 +937,7 @@ def test_model_id_not_found_refeshes_cache_inference( region=None, script=JumpStartScriptScope.INFERENCE, sagemaker_session=None, + hub_arn=None, ), ] ) @@ -958,6 +962,7 @@ def test_model_id_not_found_refeshes_cache_inference( region=None, script=JumpStartScriptScope.INFERENCE, sagemaker_session=None, + hub_arn=None, ), mock.call( model_id="js-trainable-model", @@ -965,6 +970,7 @@ def test_model_id_not_found_refeshes_cache_inference( region=None, script=JumpStartScriptScope.INFERENCE, sagemaker_session=None, + hub_arn=None, ), ] ) diff --git a/tests/unit/sagemaker/jumpstart/test_cache.py b/tests/unit/sagemaker/jumpstart/test_cache.py index 50fe6da0a6..c97e6ba895 100644 --- a/tests/unit/sagemaker/jumpstart/test_cache.py +++ b/tests/unit/sagemaker/jumpstart/test_cache.py @@ -495,8 +495,8 @@ def test_jumpstart_cache_accepts_input_parameters(): assert cache.get_manifest_file_s3_key() == manifest_file_key assert cache.get_region() == region assert cache.get_bucket() == bucket - assert cache._s3_cache._max_cache_items == max_s3_cache_items - assert cache._s3_cache._expiration_horizon == s3_cache_expiration_horizon + assert cache._content_cache._max_cache_items == max_s3_cache_items + assert cache._content_cache._expiration_horizon == s3_cache_expiration_horizon assert ( cache._open_weight_model_id_manifest_key_cache._max_cache_items == max_semantic_version_cache_items @@ -535,8 +535,8 @@ def test_jumpstart_proprietary_cache_accepts_input_parameters(): ) assert cache.get_region() == region assert cache.get_bucket() == bucket - assert cache._s3_cache._max_cache_items == max_s3_cache_items - assert cache._s3_cache._expiration_horizon == s3_cache_expiration_horizon + assert cache._content_cache._max_cache_items == max_s3_cache_items + assert cache._content_cache._expiration_horizon == s3_cache_expiration_horizon assert ( cache._proprietary_model_id_manifest_key_cache._max_cache_items == max_semantic_version_cache_items diff --git a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py index c0a37c5b38..9fd9cc8398 100644 --- a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py @@ -3,7 +3,9 @@ import datetime from unittest import TestCase -from unittest.mock import Mock, patch +from unittest.mock import Mock, patch, ANY + +import boto3 import pytest from sagemaker.jumpstart.constants import ( @@ -754,6 +756,9 @@ def test_get_model_url( patched_get_manifest.side_effect = ( lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type) ) + mock_client = boto3.client("s3") + region = "us-west-2" + mock_session = Mock(s3_client=mock_client, boto_region_name=region) model_id, version = "xgboost-classification-model", "1.0.0" assert "https://xgboost.readthedocs.io/en/latest/" == get_model_url(model_id, version) @@ -772,12 +777,14 @@ def test_get_model_url( **{key: value for key, value in kwargs.items() if key != "region"}, ) - get_model_url(model_id, version, region="us-west-2") + get_model_url(model_id, version, region="us-west-2", sagemaker_session=mock_session) patched_get_model_specs.assert_called_once_with( model_id=model_id, version=version, region="us-west-2", - s3_client=DEFAULT_JUMPSTART_SAGEMAKER_SESSION.s3_client, + s3_client=ANY, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) diff --git a/tests/unit/sagemaker/jumpstart/test_predictor.py b/tests/unit/sagemaker/jumpstart/test_predictor.py index 52f28f2da1..6f86f724a9 100644 --- a/tests/unit/sagemaker/jumpstart/test_predictor.py +++ b/tests/unit/sagemaker/jumpstart/test_predictor.py @@ -128,6 +128,7 @@ def test_jumpstart_predictor_support_no_model_id_supplied_happy_case( tolerate_vulnerable_model=False, sagemaker_session=mock_session, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) diff --git a/tests/unit/sagemaker/jumpstart/test_types.py b/tests/unit/sagemaker/jumpstart/test_types.py index b2758c73ef..987feef7da 100644 --- a/tests/unit/sagemaker/jumpstart/test_types.py +++ b/tests/unit/sagemaker/jumpstart/test_types.py @@ -369,6 +369,7 @@ def test_jumpstart_model_specs(): { "name": "epochs", "type": "int", + # "_is_hub_content": False, "default": 3, "min": 1, "max": 1000, @@ -379,6 +380,7 @@ def test_jumpstart_model_specs(): { "name": "adam-learning-rate", "type": "float", + # "_is_hub_content": False, "default": 0.05, "min": 1e-08, "max": 1, @@ -389,6 +391,7 @@ def test_jumpstart_model_specs(): { "name": "batch-size", "type": "int", + # "_is_hub_content": False, "default": 4, "min": 1, "max": 1024, @@ -399,6 +402,7 @@ def test_jumpstart_model_specs(): { "name": "sagemaker_submit_directory", "type": "text", + # "_is_hub_content": False, "default": "/opt/ml/input/data/code/sourcedir.tar.gz", "scope": "container", } @@ -407,6 +411,7 @@ def test_jumpstart_model_specs(): { "name": "sagemaker_program", "type": "text", + # "_is_hub_content": False, "default": "transfer_learning.py", "scope": "container", } @@ -415,6 +420,7 @@ def test_jumpstart_model_specs(): { "name": "sagemaker_container_log_level", "type": "text", + # "_is_hub_content": False, "default": "20", "scope": "container", } diff --git a/tests/unit/sagemaker/jumpstart/utils.py b/tests/unit/sagemaker/jumpstart/utils.py index e102251060..e599d4eee1 100644 --- a/tests/unit/sagemaker/jumpstart/utils.py +++ b/tests/unit/sagemaker/jumpstart/utils.py @@ -12,7 +12,7 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import import copy -from typing import List +from typing import List, Optional import boto3 from sagemaker.jumpstart.cache import JumpStartModelsCache @@ -22,8 +22,8 @@ JUMPSTART_REGION_NAME_SET, ) from sagemaker.jumpstart.types import ( - JumpStartCachedS3ContentKey, - JumpStartCachedS3ContentValue, + JumpStartCachedContentKey, + JumpStartCachedContentValue, JumpStartModelSpecs, JumpStartS3FileType, JumpStartModelHeader, @@ -108,7 +108,9 @@ def get_prototype_model_spec( model_id: str = None, version: str = None, s3_client: boto3.client = None, + hub_arn: Optional[str] = None, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session: Optional[str] = None, ) -> JumpStartModelSpecs: """This function mocks cache accessor functions. For this mock, we only retrieve model specs based on the model ID. @@ -124,7 +126,9 @@ def get_special_model_spec( model_id: str = None, version: str = None, s3_client: boto3.client = None, + hub_arn: Optional[str] = None, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session: Optional[str] = None, ) -> JumpStartModelSpecs: """This function mocks cache accessor functions. For this mock, we only retrieve model specs based on the model ID. This is reserved @@ -140,7 +144,9 @@ def get_special_model_spec_for_inference_component_based_endpoint( model_id: str = None, version: str = None, s3_client: boto3.client = None, + hub_arn: Optional[str] = None, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session: Optional[str] = None, ) -> JumpStartModelSpecs: """This function mocks cache accessor functions. For this mock, we only retrieve model specs based on the model ID and adding @@ -163,8 +169,10 @@ def get_spec_from_base_spec( model_id: str = None, version_str: str = None, version: str = None, + hub_arn: Optional[str] = None, s3_client: boto3.client = None, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session: Optional[str] = None, ) -> JumpStartModelSpecs: if version and version_str: @@ -187,6 +195,7 @@ def get_spec_from_base_spec( "catboost" not in model_id, "lightgbm" not in model_id, "sklearn" not in model_id, + "ai21" not in model_id, ] ): raise KeyError("Bad model ID") @@ -209,8 +218,10 @@ def get_base_spec_with_prototype_configs( region: str = None, model_id: str = None, version: str = None, + hub_arn: Optional[str] = None, s3_client: boto3.client = None, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session: Optional[str] = None, ) -> JumpStartModelSpecs: spec = copy.deepcopy(BASE_SPEC) inference_configs = {**INFERENCE_CONFIGS, **INFERENCE_CONFIG_RANKINGS} @@ -224,33 +235,31 @@ def get_base_spec_with_prototype_configs( def patched_retrieval_function( _modelCacheObj: JumpStartModelsCache, - key: JumpStartCachedS3ContentKey, - value: JumpStartCachedS3ContentValue, -) -> JumpStartCachedS3ContentValue: + key: JumpStartCachedContentKey, + value: JumpStartCachedContentValue, +) -> JumpStartCachedContentValue: - filetype, s3_key = key.file_type, key.s3_key - if filetype == JumpStartS3FileType.OPEN_WEIGHT_MANIFEST: + data_type, id_info = key.data_type, key.id_info + if data_type == JumpStartS3FileType.OPEN_WEIGHT_MANIFEST: - return JumpStartCachedS3ContentValue( - formatted_content=get_formatted_manifest(BASE_MANIFEST) - ) + return JumpStartCachedContentValue(formatted_content=get_formatted_manifest(BASE_MANIFEST)) - if filetype == JumpStartS3FileType.OPEN_WEIGHT_SPECS: - _, model_id, specs_version = s3_key.split("/") + if data_type == JumpStartS3FileType.OPEN_WEIGHT_SPECS: + _, model_id, specs_version = id_info.split("/") version = specs_version.replace("specs_v", "").replace(".json", "") - return JumpStartCachedS3ContentValue( + return JumpStartCachedContentValue( formatted_content=get_spec_from_base_spec(model_id=model_id, version=version) ) - if filetype == JumpStartS3FileType.PROPRIETARY_MANIFEST: - return JumpStartCachedS3ContentValue( + if data_type == JumpStartS3FileType.PROPRIETARY_MANIFEST: + return JumpStartCachedContentValue( formatted_content=get_formatted_manifest(BASE_PROPRIETARY_MANIFEST) ) - if filetype == JumpStartS3FileType.PROPRIETARY_SPECS: - _, model_id, specs_version = s3_key.split("/") + if data_type == JumpStartS3FileType.PROPRIETARY_SPECS: + _, model_id, specs_version = id_info.split("/") version = specs_version.replace("proprietary_specs_", "").replace(".json", "") - return JumpStartCachedS3ContentValue( + return JumpStartCachedContentValue( formatted_content=get_spec_from_base_spec( model_id=model_id, version=version, @@ -258,7 +267,7 @@ def patched_retrieval_function( ) ) - raise ValueError(f"Bad value for filetype: {filetype}") + raise ValueError(f"Bad value for filetype: {data_type}") def overwrite_dictionary( diff --git a/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py b/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py index 835a09a58c..12d3a2169d 100644 --- a/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py +++ b/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py @@ -59,6 +59,8 @@ def test_jumpstart_default_metric_definitions( version="*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session=mock_session, + hub_arn=None, ) patched_get_model_specs.reset_mock() @@ -79,6 +81,8 @@ def test_jumpstart_default_metric_definitions( version="1.*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session=mock_session, + hub_arn=None, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/model/test_deploy.py b/tests/unit/sagemaker/model/test_deploy.py index 69ea2c1f56..50f6c370d5 100644 --- a/tests/unit/sagemaker/model/test_deploy.py +++ b/tests/unit/sagemaker/model/test_deploy.py @@ -114,7 +114,11 @@ def test_deploy(name_from_base, prepare_container_def, production_variant, sagem assert 2 == name_from_base.call_count prepare_container_def.assert_called_with( - INSTANCE_TYPE, accelerator_type=None, serverless_inference_config=None, accept_eula=None + INSTANCE_TYPE, + accelerator_type=None, + serverless_inference_config=None, + accept_eula=None, + model_reference_arn=None, ) production_variant.assert_called_with( MODEL_NAME, @@ -930,7 +934,11 @@ def test_deploy_customized_volume_size_and_timeout( assert 2 == name_from_base.call_count prepare_container_def.assert_called_with( - INSTANCE_TYPE, accelerator_type=None, serverless_inference_config=None, accept_eula=None + INSTANCE_TYPE, + accelerator_type=None, + serverless_inference_config=None, + accept_eula=None, + model_reference_arn=None, ) production_variant.assert_called_with( MODEL_NAME, diff --git a/tests/unit/sagemaker/model/test_model.py b/tests/unit/sagemaker/model/test_model.py index c0b18a3eb3..e43ad0ed0a 100644 --- a/tests/unit/sagemaker/model/test_model.py +++ b/tests/unit/sagemaker/model/test_model.py @@ -287,7 +287,11 @@ def test_create_sagemaker_model(prepare_container_def, sagemaker_session): model._create_sagemaker_model() prepare_container_def.assert_called_with( - None, accelerator_type=None, serverless_inference_config=None, accept_eula=None + None, + accelerator_type=None, + serverless_inference_config=None, + accept_eula=None, + model_reference_arn=None, ) sagemaker_session.create_model.assert_called_with( name=MODEL_NAME, @@ -305,7 +309,11 @@ def test_create_sagemaker_model_instance_type(prepare_container_def, sagemaker_s model._create_sagemaker_model(INSTANCE_TYPE) prepare_container_def.assert_called_with( - INSTANCE_TYPE, accelerator_type=None, serverless_inference_config=None, accept_eula=None + INSTANCE_TYPE, + accelerator_type=None, + serverless_inference_config=None, + accept_eula=None, + model_reference_arn=None, ) @@ -321,6 +329,7 @@ def test_create_sagemaker_model_accelerator_type(prepare_container_def, sagemake accelerator_type=accelerator_type, serverless_inference_config=None, accept_eula=None, + model_reference_arn=None, ) @@ -336,6 +345,7 @@ def test_create_sagemaker_model_with_eula(prepare_container_def, sagemaker_sessi accelerator_type=accelerator_type, serverless_inference_config=None, accept_eula=True, + model_reference_arn=None, ) @@ -351,6 +361,7 @@ def test_create_sagemaker_model_with_eula_false(prepare_container_def, sagemaker accelerator_type=accelerator_type, serverless_inference_config=None, accept_eula=False, + model_reference_arn=None, ) diff --git a/tests/unit/sagemaker/model_uris/jumpstart/test_common.py b/tests/unit/sagemaker/model_uris/jumpstart/test_common.py index 8ec9478d8a..e71207d439 100644 --- a/tests/unit/sagemaker/model_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/model_uris/jumpstart/test_common.py @@ -54,6 +54,8 @@ def test_jumpstart_common_model_uri( version="*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -72,6 +74,8 @@ def test_jumpstart_common_model_uri( version="1.*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -91,6 +95,8 @@ def test_jumpstart_common_model_uri( version="*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -110,6 +116,8 @@ def test_jumpstart_common_model_uri( version="1.*", s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) patched_verify_model_region_and_return_specs.assert_called_once() diff --git a/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py b/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py index 1c0cfa35b3..d149e08cab 100644 --- a/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py +++ b/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py @@ -56,6 +56,8 @@ def test_jumpstart_resource_requirements( version=model_version, s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) patched_get_model_specs.reset_mock() @@ -76,6 +78,7 @@ def test_jumpstart_resource_requirements_instance_type_variants(patched_get_mode scope="inference", sagemaker_session=mock_session, instance_type="ml.g5.xlarge", + hub_arn=None, ) assert default_inference_resource_requirements.requests == { "memory": 81999, @@ -89,6 +92,7 @@ def test_jumpstart_resource_requirements_instance_type_variants(patched_get_mode scope="inference", sagemaker_session=mock_session, instance_type="ml.g5.555xlarge", + hub_arn=None, ) assert default_inference_resource_requirements.requests == { "memory": 81999, @@ -102,6 +106,7 @@ def test_jumpstart_resource_requirements_instance_type_variants(patched_get_mode scope="inference", sagemaker_session=mock_session, instance_type="ml.f9.555xlarge", + hub_arn=None, ) assert default_inference_resource_requirements.requests == { "memory": 81999, @@ -138,6 +143,8 @@ def test_jumpstart_no_supported_resource_requirements( version=model_version, s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, + sagemaker_session=mock_session, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py index 16b7256ed2..b67f238cac 100644 --- a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py @@ -53,7 +53,9 @@ def test_jumpstart_common_script_uri( model_id="pytorch-ic-mobilenet-v2", version="*", s3_client=mock_client, + hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session=mock_session, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -71,7 +73,9 @@ def test_jumpstart_common_script_uri( model_id="pytorch-ic-mobilenet-v2", version="1.*", s3_client=mock_client, + hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session=mock_session, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -90,7 +94,9 @@ def test_jumpstart_common_script_uri( model_id="pytorch-ic-mobilenet-v2", version="*", s3_client=mock_client, + hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session=mock_session, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -109,7 +115,9 @@ def test_jumpstart_common_script_uri( model_id="pytorch-ic-mobilenet-v2", version="1.*", s3_client=mock_client, + hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session=mock_session, ) patched_verify_model_region_and_return_specs.assert_called_once() diff --git a/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py b/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py index 90ec5df6b5..dde308dcfb 100644 --- a/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py +++ b/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py @@ -53,9 +53,11 @@ def test_jumpstart_default_serializers( patched_get_model_specs.assert_called_once_with( region=region, model_id=model_id, + hub_arn=None, version=model_version, s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session=mock_session, ) patched_get_model_specs.reset_mock() @@ -99,6 +101,8 @@ def test_jumpstart_serializer_options( region=region, model_id=model_id, version=model_version, + hub_arn=None, s3_client=mock_client, model_type=JumpStartModelType.OPEN_WEIGHTS, + sagemaker_session=mock_session, ) diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 295f1a8d24..b557a9c9f0 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -266,6 +266,7 @@ def prepare_container_def( accelerator_type=None, serverless_inference_config=None, accept_eula=None, + model_reference_arn=None, ): return MODEL_CONTAINER_DEF @@ -5256,6 +5257,7 @@ def test_all_framework_estimators_add_jumpstart_uri_tags( entry_point="inference.py", role=ROLE, tags=[{"Key": "blah", "Value": "yoyoma"}], + model_reference_arn=None, ) assert sagemaker_session.create_model.call_args_list[0][1]["tags"] == [ diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index e95359d52c..c776dfe479 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -7009,3 +7009,170 @@ def test_download_data_with_file_and_directory(makedirs, sagemaker_session): Filename="./foo/bar/mode.tar.gz", ExtraArgs=None, ) + + +def test_create_hub(sagemaker_session): + sagemaker_session.create_hub( + hub_name="mock-hub-name", + hub_description="this is my sagemaker hub", + hub_display_name="Mock Hub", + hub_search_keywords=["mock", "hub", "123"], + s3_storage_config={"S3OutputPath": "s3://my-hub-bucket/"}, + tags=[{"Key": "tag-key-1", "Value": "tag-value-1"}], + ) + + request = { + "HubName": "mock-hub-name", + "HubDescription": "this is my sagemaker hub", + "HubDisplayName": "Mock Hub", + "HubSearchKeywords": ["mock", "hub", "123"], + "S3StorageConfig": {"S3OutputPath": "s3://my-hub-bucket/"}, + "Tags": [{"Key": "tag-key-1", "Value": "tag-value-1"}], + } + + sagemaker_session.sagemaker_client.create_hub.assert_called_with(**request) + + +def test_describe_hub(sagemaker_session): + sagemaker_session.describe_hub( + hub_name="mock-hub-name", + ) + + request = { + "HubName": "mock-hub-name", + } + + sagemaker_session.sagemaker_client.describe_hub.assert_called_with(**request) + + +def test_list_hubs(sagemaker_session): + sagemaker_session.list_hubs( + creation_time_after="08-14-1997 12:00:00", + creation_time_before="01-08-2024 10:25:00", + max_results=25, + max_schema_version="1.0.5", + name_contains="mock-hub", + sort_by="HubName", + sort_order="Ascending", + ) + + request = { + "CreationTimeAfter": "08-14-1997 12:00:00", + "CreationTimeBefore": "01-08-2024 10:25:00", + "MaxResults": 25, + "MaxSchemaVersion": "1.0.5", + "NameContains": "mock-hub", + "SortBy": "HubName", + "SortOrder": "Ascending", + } + + sagemaker_session.sagemaker_client.list_hubs.assert_called_with(**request) + + +def test_list_hub_contents(sagemaker_session): + sagemaker_session.list_hub_contents( + hub_name="mock-hub-123", + hub_content_type="MODELREF", + creation_time_after="08-14-1997 12:00:00", + creation_time_before="01-08/2024 10:25:00", + max_results=25, + max_schema_version="1.0.5", + name_contains="mock-hub", + sort_by="HubName", + sort_order="Ascending", + ) + + request = { + "HubName": "mock-hub-123", + "HubContentType": "MODELREF", + "CreationTimeAfter": "08-14-1997 12:00:00", + "CreationTimeBefore": "01-08/2024 10:25:00", + "MaxResults": 25, + "MaxSchemaVersion": "1.0.5", + "NameContains": "mock-hub", + "SortBy": "HubName", + "SortOrder": "Ascending", + } + + sagemaker_session.sagemaker_client.list_hub_contents.assert_called_with(**request) + + +def test_list_hub_content_versions(sagemaker_session): + sagemaker_session.list_hub_content_versions( + hub_name="mock-hub-123", + hub_content_type="MODELREF", + hub_content_name="mock-hub-content-1", + min_version="1.0.0", + creation_time_after="08-14-1997 12:00:00", + creation_time_before="01-08/2024 10:25:00", + max_results=25, + max_schema_version="1.0.5", + sort_by="HubName", + sort_order="Ascending", + ) + + request = { + "HubName": "mock-hub-123", + "HubContentType": "MODELREF", + "HubContentName": "mock-hub-content-1", + "MinVersion": "1.0.0", + "CreationTimeAfter": "08-14-1997 12:00:00", + "CreationTimeBefore": "01-08/2024 10:25:00", + "MaxResults": 25, + "MaxSchemaVersion": "1.0.5", + "SortBy": "HubName", + "SortOrder": "Ascending", + } + + sagemaker_session.sagemaker_client.list_hub_content_versions.assert_called_with(**request) + + +def test_delete_hub(sagemaker_session): + sagemaker_session.delete_hub( + hub_name="mock-hub-123", + ) + + request = { + "HubName": "mock-hub-123", + } + + sagemaker_session.sagemaker_client.delete_hub.assert_called_with(**request) + + +def test_create_hub_content_reference(sagemaker_session): + sagemaker_session.create_hub_content_reference( + hub_name="mock-hub-name", + source_hub_content_arn=( + "arn:aws:sagemaker:us-east-1:" + "123456789123:" + "hub-content/JumpStartHub/" + "model/mock-hub-content-1" + ), + hub_content_name="mock-hub-content-1", + min_version="1.1.1", + ) + + request = { + "HubName": "mock-hub-name", + "SageMakerPublicHubContentArn": "arn:aws:sagemaker:us-east-1:123456789123:hub-content/JumpStartHub/model/mock-hub-content-1", # noqa: E501 + "HubContentName": "mock-hub-content-1", + "MinVersion": "1.1.1", + } + + sagemaker_session.sagemaker_client.create_hub_content_reference.assert_called_with(**request) + + +def test_delete_hub_content_reference(sagemaker_session): + sagemaker_session.delete_hub_content_reference( + hub_name="mock-hub-name", + hub_content_type="ModelReference", + hub_content_name="mock-hub-content-1", + ) + + request = { + "HubName": "mock-hub-name", + "HubContentType": "ModelReference", + "HubContentName": "mock-hub-content-1", + } + + sagemaker_session.sagemaker_client.delete_hub_content_reference.assert_called_with(**request)