diff --git a/src/sagemaker/environment_variables.py b/src/sagemaker/environment_variables.py index b67066fcde..17f8ebdf2c 100644 --- a/src/sagemaker/environment_variables.py +++ b/src/sagemaker/environment_variables.py @@ -30,6 +30,7 @@ def retrieve_default( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, include_aws_sdk_env_vars: bool = True, @@ -46,6 +47,8 @@ def retrieve_default( retrieve the default environment variables. (Default: None). model_version (str): Optional. The version of the model for which to retrieve the default environment variables. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -80,6 +83,7 @@ def retrieve_default( return artifacts._retrieve_default_environment_variables( model_id, model_version, + hub_arn, region, tolerate_vulnerable_model, tolerate_deprecated_model, diff --git a/src/sagemaker/hyperparameters.py b/src/sagemaker/hyperparameters.py index 5873e37b9f..83554b302b 100644 --- a/src/sagemaker/hyperparameters.py +++ b/src/sagemaker/hyperparameters.py @@ -31,6 +31,7 @@ def retrieve_default( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, instance_type: Optional[str] = None, include_container_hyperparameters: bool = False, tolerate_vulnerable_model: bool = False, @@ -46,6 +47,8 @@ def retrieve_default( retrieve the default hyperparameters. (Default: None). model_version (str): The version of the model for which to retrieve the default hyperparameters. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (default: None). instance_type (str): An instance type to optionally supply in order to get hyperparameters specific for the instance type. include_container_hyperparameters (bool): ``True`` if the container hyperparameters @@ -80,6 +83,7 @@ def retrieve_default( return artifacts._retrieve_default_hyperparameters( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, instance_type=instance_type, region=region, include_container_hyperparameters=include_container_hyperparameters, @@ -93,6 +97,7 @@ def validate( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, hyperparameters: Optional[dict] = None, validation_mode: HyperparameterValidationMode = HyperparameterValidationMode.VALIDATE_PROVIDED, tolerate_vulnerable_model: bool = False, @@ -107,6 +112,8 @@ def validate( (Default: None). model_version (str): The version of the model for which to validate hyperparameters. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (default: None). hyperparameters (dict): Hyperparameters to validate. (Default: None). validation_mode (HyperparameterValidationMode): Method of validation to use with @@ -148,6 +155,7 @@ def validate( return validate_hyperparameters( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, hyperparameters=hyperparameters, validation_mode=validation_mode, region=region, diff --git a/src/sagemaker/image_uris.py b/src/sagemaker/image_uris.py index 252bf3c504..2b6870a11c 100644 --- a/src/sagemaker/image_uris.py +++ b/src/sagemaker/image_uris.py @@ -61,6 +61,7 @@ def retrieve( training_compiler_config=None, model_id=None, model_version=None, + hub_arn=None, tolerate_vulnerable_model=False, tolerate_deprecated_model=False, sdk_version=None, @@ -101,6 +102,8 @@ def retrieve( (default: None). model_version (str): The version of the JumpStart model for which to retrieve the image URI (default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (default: None). tolerate_vulnerable_model (bool): ``True`` if vulnerable versions of model specifications should be tolerated without an exception raised. If ``False``, raises an exception if the script used by this version of the model has dependencies with known security @@ -146,6 +149,7 @@ def retrieve( model_id, model_version, image_scope, + hub_arn, framework, region, version, diff --git a/src/sagemaker/instance_types.py b/src/sagemaker/instance_types.py index 0471f374ae..d277f8cf3b 100644 --- a/src/sagemaker/instance_types.py +++ b/src/sagemaker/instance_types.py @@ -29,6 +29,7 @@ def retrieve_default( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, scope: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -44,6 +45,8 @@ def retrieve_default( retrieve the default instance type. (Default: None). model_version (str): The version of the model for which to retrieve the default instance type. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (default: None). scope (str): The model type, i.e. what it is used for. Valid values: "training" and "inference". tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -80,6 +83,7 @@ def retrieve_default( model_id, model_version, scope, + hub_arn, region, tolerate_vulnerable_model, tolerate_deprecated_model, @@ -92,6 +96,7 @@ def retrieve( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, scope: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -107,6 +112,8 @@ def retrieve( retrieve the supported instance types. (Default: None). model_version (str): The version of the model for which to retrieve the supported instance types. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -142,6 +149,7 @@ def retrieve( model_id, model_version, scope, + hub_arn, region, tolerate_vulnerable_model, tolerate_deprecated_model, diff --git a/src/sagemaker/jumpstart/accessors.py b/src/sagemaker/jumpstart/accessors.py index e03a13a7a3..456ba6baf6 100644 --- a/src/sagemaker/jumpstart/accessors.py +++ b/src/sagemaker/jumpstart/accessors.py @@ -19,6 +19,7 @@ from sagemaker.deprecations import deprecated from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartModelSpecs from sagemaker.jumpstart import cache +from sagemaker.jumpstart.curated_hub.utils import construct_hub_model_arn_from_inputs from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME @@ -239,7 +240,11 @@ def get_model_header(region: str, model_id: str, version: str) -> JumpStartModel @staticmethod def get_model_specs( - region: str, model_id: str, version: str, s3_client: Optional[boto3.client] = None + region: str, + model_id: str, + version: str, + hub_arn: Optional[str] = None, + s3_client: Optional[boto3.client] = None, ) -> JumpStartModelSpecs: """Returns model specs from JumpStart models cache. @@ -259,6 +264,13 @@ def get_model_specs( {**JumpStartModelsAccessor._cache_kwargs, **additional_kwargs} ) JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs) + + if hub_arn: + hub_model_arn = construct_hub_model_arn_from_inputs( + hub_arn=hub_arn, model_name=model_id, version=version + ) + return JumpStartModelsAccessor._cache.get_hub_model(hub_model_arn=hub_model_arn) + return JumpStartModelsAccessor._cache.get_specs( # type: ignore model_id=model_id, semantic_version_str=version ) diff --git a/src/sagemaker/jumpstart/artifacts/environment_variables.py b/src/sagemaker/jumpstart/artifacts/environment_variables.py index 0e666e4c14..b85cfe4572 100644 --- a/src/sagemaker/jumpstart/artifacts/environment_variables.py +++ b/src/sagemaker/jumpstart/artifacts/environment_variables.py @@ -31,6 +31,7 @@ def _retrieve_default_environment_variables( model_id: str, model_version: str, + hub_arn: Optional[str] = None, region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -46,6 +47,8 @@ def _retrieve_default_environment_variables( retrieve the default environment variables. model_version (str): Version of the JumpStart model for which to retrieve the default environment variables. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (default: None). region (Optional[str]): Region for which to retrieve default environment variables. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -77,6 +80,7 @@ def _retrieve_default_environment_variables( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=script, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -113,6 +117,7 @@ def _retrieve_default_environment_variables( gated_model_env_var: Optional[str] = _retrieve_gated_model_uri_env_var_value( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, @@ -131,6 +136,7 @@ def _retrieve_default_environment_variables( def _retrieve_gated_model_uri_env_var_value( model_id: str, model_version: str, + hub_arn: Optional[str] = None, region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -144,6 +150,8 @@ def _retrieve_gated_model_uri_env_var_value( retrieve the gated model env var URI. model_version (str): Version of the JumpStart model for which to retrieve the gated model env var URI. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (default: None). region (Optional[str]): Region for which to retrieve the gated model env var URI. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -174,6 +182,7 @@ def _retrieve_gated_model_uri_env_var_value( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.TRAINING, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/artifacts/hyperparameters.py b/src/sagemaker/jumpstart/artifacts/hyperparameters.py index e9e6f613f8..6b9689485f 100644 --- a/src/sagemaker/jumpstart/artifacts/hyperparameters.py +++ b/src/sagemaker/jumpstart/artifacts/hyperparameters.py @@ -30,6 +30,7 @@ def _retrieve_default_hyperparameters( model_id: str, model_version: str, + hub_arn: Optional[str] = None, region: Optional[str] = None, include_container_hyperparameters: bool = False, tolerate_vulnerable_model: bool = False, @@ -44,6 +45,8 @@ def _retrieve_default_hyperparameters( retrieve the default hyperparameters. model_version (str): Version of the JumpStart model for which to retrieve the default hyperparameters. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (default: None). region (str): Region for which to retrieve default hyperparameters. (Default: None). include_container_hyperparameters (bool): True if container hyperparameters @@ -76,6 +79,7 @@ def _retrieve_default_hyperparameters( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.TRAINING, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/artifacts/image_uris.py b/src/sagemaker/jumpstart/artifacts/image_uris.py index 6ea1ca84a1..bb7b0c7528 100644 --- a/src/sagemaker/jumpstart/artifacts/image_uris.py +++ b/src/sagemaker/jumpstart/artifacts/image_uris.py @@ -33,6 +33,7 @@ def _retrieve_image_uri( model_id: str, model_version: str, image_scope: str, + hub_arn: Optional[str] = None, framework: Optional[str] = None, region: Optional[str] = None, version: Optional[str] = None, @@ -57,6 +58,8 @@ def _retrieve_image_uri( model_id (str): JumpStart model ID for which to retrieve image URI. model_version (str): Version of the JumpStart model for which to retrieve the image URI. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (default: None). image_scope (str): The image type, i.e. what it is used for. Valid values: "training", "inference", "eia". If ``accelerator_type`` is set, ``image_scope`` is ignored. @@ -110,6 +113,7 @@ def _retrieve_image_uri( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=image_scope, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/artifacts/incremental_training.py b/src/sagemaker/jumpstart/artifacts/incremental_training.py index 753a911422..f877068e77 100644 --- a/src/sagemaker/jumpstart/artifacts/incremental_training.py +++ b/src/sagemaker/jumpstart/artifacts/incremental_training.py @@ -30,6 +30,7 @@ def _model_supports_incremental_training( model_id: str, model_version: str, region: Optional[str], + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -43,6 +44,8 @@ def _model_supports_incremental_training( support status for incremental training. region (Optional[str]): Region for which to retrieve the support status for incremental training. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -64,6 +67,7 @@ def _model_supports_incremental_training( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.TRAINING, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/artifacts/instance_types.py b/src/sagemaker/jumpstart/artifacts/instance_types.py index 38e02e3ebd..5f252f00ad 100644 --- a/src/sagemaker/jumpstart/artifacts/instance_types.py +++ b/src/sagemaker/jumpstart/artifacts/instance_types.py @@ -33,6 +33,7 @@ def _retrieve_default_instance_type( model_id: str, model_version: str, scope: str, + hub_arn: Optional[str] = None, region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -48,6 +49,8 @@ def _retrieve_default_instance_type( default instance type. scope (str): The script type, i.e. what it is used for. Valid values: "training" and "inference". + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (default: None). region (Optional[str]): Region for which to retrieve default instance type. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -80,6 +83,7 @@ def _retrieve_default_instance_type( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=scope, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -119,6 +123,7 @@ def _retrieve_instance_types( model_id: str, model_version: str, scope: str, + hub_arn: Optional[str] = None, region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -134,6 +139,8 @@ def _retrieve_instance_types( supported instance types. scope (str): The script type, i.e. what it is used for. Valid values: "training" and "inference". + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (default: None). region (Optional[str]): Region for which to retrieve supported instance types. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -166,6 +173,7 @@ def _retrieve_instance_types( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=scope, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/artifacts/kwargs.py b/src/sagemaker/jumpstart/artifacts/kwargs.py index 7acad9b793..69b8bfb51a 100644 --- a/src/sagemaker/jumpstart/artifacts/kwargs.py +++ b/src/sagemaker/jumpstart/artifacts/kwargs.py @@ -140,6 +140,7 @@ def _retrieve_estimator_init_kwargs( model_id: str, model_version: str, instance_type: str, + hub_arn: Optional[str] = None, region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -154,6 +155,8 @@ def _retrieve_estimator_init_kwargs( kwargs. instance_type (str): Instance type of the training job, to determine if volume size is supported. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (default: None). region (Optional[str]): Region for which to retrieve kwargs. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -177,6 +180,7 @@ def _retrieve_estimator_init_kwargs( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.TRAINING, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -198,6 +202,7 @@ def _retrieve_estimator_init_kwargs( def _retrieve_estimator_fit_kwargs( model_id: str, model_version: str, + hub_arn: Optional[str] = None, region: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -234,6 +239,7 @@ def _retrieve_estimator_fit_kwargs( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.TRAINING, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/artifacts/metric_definitions.py b/src/sagemaker/jumpstart/artifacts/metric_definitions.py index b6f6019641..64e53dbbbe 100644 --- a/src/sagemaker/jumpstart/artifacts/metric_definitions.py +++ b/src/sagemaker/jumpstart/artifacts/metric_definitions.py @@ -31,6 +31,7 @@ def _retrieve_default_training_metric_definitions( model_id: str, model_version: str, region: Optional[str], + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -45,6 +46,8 @@ def _retrieve_default_training_metric_definitions( default training metric definitions. region (Optional[str]): Region for which to retrieve default training metric definitions. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -68,6 +71,7 @@ def _retrieve_default_training_metric_definitions( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.TRAINING, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/artifacts/model_packages.py b/src/sagemaker/jumpstart/artifacts/model_packages.py index bd0ae365d9..e49d14682d 100644 --- a/src/sagemaker/jumpstart/artifacts/model_packages.py +++ b/src/sagemaker/jumpstart/artifacts/model_packages.py @@ -110,6 +110,7 @@ def _retrieve_model_package_model_artifact_s3_uri( model_id: str, model_version: str, region: Optional[str], + hub_arn: Optional[str] = None, scope: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -124,6 +125,8 @@ def _retrieve_model_package_model_artifact_s3_uri( model package artifact. region (Optional[str]): Region for which to retrieve the model package artifact. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (default: None). scope (Optional[str]): Scope for which to retrieve the model package artifact. (Default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -152,6 +155,7 @@ def _retrieve_model_package_model_artifact_s3_uri( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=scope, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/artifacts/model_uris.py b/src/sagemaker/jumpstart/artifacts/model_uris.py index c41f0a75b7..df9e1fd507 100644 --- a/src/sagemaker/jumpstart/artifacts/model_uris.py +++ b/src/sagemaker/jumpstart/artifacts/model_uris.py @@ -89,6 +89,7 @@ def _retrieve_training_artifact_key(model_specs: JumpStartModelSpecs, instance_t def _retrieve_model_uri( model_id: str, model_version: str, + hub_arn: Optional[str] = None, model_scope: Optional[str] = None, instance_type: Optional[str] = None, region: Optional[str] = None, @@ -105,6 +106,8 @@ def _retrieve_model_uri( the model artifact S3 URI. model_version (str): Version of the JumpStart model for which to retrieve the model artifact S3 URI. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (default: None). model_scope (str): The model type, i.e. what it is used for. Valid values: "training" and "inference". instance_type (str): The ML compute instance type for the specified scope. (Default: None). @@ -135,6 +138,7 @@ def _retrieve_model_uri( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=model_scope, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, @@ -178,6 +182,7 @@ def _model_supports_training_model_uri( model_id: str, model_version: str, region: Optional[str], + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -191,6 +196,8 @@ def _model_supports_training_model_uri( support status for model uri with training. region (Optional[str]): Region for which to retrieve the support status for model uri with training. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -212,6 +219,7 @@ def _model_supports_training_model_uri( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.TRAINING, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/artifacts/payloads.py b/src/sagemaker/jumpstart/artifacts/payloads.py index 3ea2c16f80..21fca2abd2 100644 --- a/src/sagemaker/jumpstart/artifacts/payloads.py +++ b/src/sagemaker/jumpstart/artifacts/payloads.py @@ -32,6 +32,7 @@ def _retrieve_example_payloads( model_id: str, model_version: str, region: Optional[str], + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -45,6 +46,8 @@ def _retrieve_example_payloads( example payloads. region (Optional[str]): Region for which to retrieve the example payloads. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -67,6 +70,7 @@ def _retrieve_example_payloads( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.INFERENCE, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/artifacts/resource_names.py b/src/sagemaker/jumpstart/artifacts/resource_names.py index 6b05f07b15..ca3044068b 100644 --- a/src/sagemaker/jumpstart/artifacts/resource_names.py +++ b/src/sagemaker/jumpstart/artifacts/resource_names.py @@ -30,6 +30,7 @@ def _retrieve_resource_name_base( model_id: str, model_version: str, region: Optional[str], + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -43,6 +44,8 @@ def _retrieve_resource_name_base( default resource name. region (Optional[str]): Region for which to retrieve the default resource name. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -64,6 +67,7 @@ def _retrieve_resource_name_base( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=JumpStartScriptScope.INFERENCE, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/artifacts/script_uris.py b/src/sagemaker/jumpstart/artifacts/script_uris.py index c1b037ce61..c04ae88ca3 100644 --- a/src/sagemaker/jumpstart/artifacts/script_uris.py +++ b/src/sagemaker/jumpstart/artifacts/script_uris.py @@ -32,6 +32,7 @@ def _retrieve_script_uri( model_id: str, model_version: str, + hub_arn: Optional[str] = None, script_scope: Optional[str] = None, region: Optional[str] = None, tolerate_vulnerable_model: bool = False, @@ -47,6 +48,8 @@ def _retrieve_script_uri( retrieve the script S3 URI. model_version (str): Version of the JumpStart model for which to retrieve the model script S3 URI. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (default: None). script_scope (str): The script type, i.e. what it is used for. Valid values: "training" and "inference". region (str): Region for which to retrieve model script S3 URI. @@ -77,6 +80,7 @@ def _retrieve_script_uri( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, scope=script_scope, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index d733f39864..9dc505a2ff 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -30,6 +30,7 @@ MODEL_ID_LIST_WEB_URL, ) from sagemaker.jumpstart.curated_hub.curated_hub import CuratedHub +from sagemaker.jumpstart.curated_hub.utils import get_info_from_hub_resource_arn from sagemaker.jumpstart.exceptions import get_wildcard_model_version_msg from sagemaker.jumpstart.parameters import ( JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS, @@ -44,7 +45,7 @@ JumpStartModelSpecs, JumpStartS3FileType, JumpStartVersionedModelId, - HubDataType, + HubContentType, ) from sagemaker.jumpstart import utils from sagemaker.utilities.cache import LRUCache @@ -338,12 +339,14 @@ def _retrieval_function( return JumpStartCachedContentValue( formatted_content=model_specs ) - if data_type == HubDataType.MODEL: - hub_name, region, model_name, model_version = utils.extract_info_from_hub_content_arn( + if data_type == HubContentType.MODEL: + info = get_info_from_hub_resource_arn( id_info ) - hub = CuratedHub(hub_name=hub_name, region=region) - hub_content = hub.describe_model(model_name=model_name, model_version=model_version) + hub = CuratedHub(hub_name=info.hub_name, region=info.region) + hub_content = hub.describe_model( + model_name=info.hub_content_name, model_version=info.hub_content_version + ) utils.emit_logs_based_on_model_specs( hub_content.content_document, self.get_region(), @@ -353,14 +356,16 @@ def _retrieval_function( return JumpStartCachedContentValue( formatted_content=model_specs ) - if data_type == HubDataType.HUB: - hub_name, region, _, _ = utils.extract_info_from_hub_content_arn(id_info) - hub = CuratedHub(hub_name=hub_name, region=region) + if data_type == HubContentType.HUB: + info = get_info_from_hub_resource_arn( + id_info + ) + hub = CuratedHub(hub_name=info.hub_name, region=info.region) hub_info = hub.describe() return JumpStartCachedContentValue(formatted_content=hub_info) raise ValueError( f"Bad value for key '{key}': must be in", - f"{[JumpStartS3FileType.MANIFEST, JumpStartS3FileType.SPECS, HubDataType.HUB, HubDataType.MODEL]}" + f"{[JumpStartS3FileType.MANIFEST, JumpStartS3FileType.SPECS, HubContentType.HUB, HubContentType.MODEL]}" ) def get_manifest(self) -> List[JumpStartModelHeader]: @@ -474,7 +479,7 @@ def get_hub_model(self, hub_model_arn: str) -> JumpStartModelSpecs: """ details, _ = self._content_cache.get( - JumpStartCachedContentKey(HubDataType.MODEL, hub_model_arn) + JumpStartCachedContentKey(HubContentType.MODEL, hub_model_arn) ) return details.formatted_content @@ -485,7 +490,7 @@ def get_hub(self, hub_arn: str) -> Dict[str, Any]: hub_arn (str): Arn for the Hub to get info for """ - details, _ = self._content_cache.get(JumpStartCachedContentKey(HubDataType.HUB, hub_arn)) + details, _ = self._content_cache.get(JumpStartCachedContentKey(HubContentType.HUB, hub_arn)) return details.formatted_content def clear(self) -> None: diff --git a/src/sagemaker/jumpstart/constants.py b/src/sagemaker/jumpstart/constants.py index 0e552aaac6..1fec16b9b3 100644 --- a/src/sagemaker/jumpstart/constants.py +++ b/src/sagemaker/jumpstart/constants.py @@ -170,8 +170,7 @@ JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json" -# works cross-partition -HUB_MODEL_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub_content/(.*?)/Model/(.*?)/(.*?)$" +HUB_CONTENT_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub-content/(.*?)/(.*?)/(.*?)/(.*?)$" HUB_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub/(.*?)$" INFERENCE_ENTRY_POINT_SCRIPT_NAME = "inference.py" diff --git a/src/sagemaker/jumpstart/curated_hub/curated_hub.py b/src/sagemaker/jumpstart/curated_hub/curated_hub.py index 273deb097b..b8885ff250 100644 --- a/src/sagemaker/jumpstart/curated_hub/curated_hub.py +++ b/src/sagemaker/jumpstart/curated_hub/curated_hub.py @@ -14,6 +14,7 @@ from __future__ import absolute_import from typing import Optional, Dict, Any +from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION from sagemaker.session import Session @@ -21,11 +22,15 @@ class CuratedHub: """Class for creating and managing a curated JumpStart hub""" - def __init__(self, hub_name: str, region: str, session: Optional[Session] = None): + def __init__( + self, + hub_name: str, + region: str, + session: Optional[Session] = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + ): self.hub_name = hub_name self.region = region - self.session = session - self._sm_session = session or Session() + self._sm_session = session def describe_model(self, model_name: str, model_version: str = "*") -> Dict[str, Any]: """Returns descriptive information about the Hub Model""" diff --git a/src/sagemaker/jumpstart/curated_hub/types.py b/src/sagemaker/jumpstart/curated_hub/types.py new file mode 100644 index 0000000000..d400137905 --- /dev/null +++ b/src/sagemaker/jumpstart/curated_hub/types.py @@ -0,0 +1,51 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module stores types related to SageMaker JumpStart CuratedHub.""" +from __future__ import absolute_import +from typing import Optional + +from sagemaker.jumpstart.types import JumpStartDataHolderType + + +class HubArnExtractedInfo(JumpStartDataHolderType): + """Data class for info extracted from Hub arn.""" + + __slots__ = [ + "partition", + "region", + "account_id", + "hub_name", + "hub_content_type", + "hub_content_name", + "hub_content_version", + ] + + def __init__( + self, + partition: str, + region: str, + account_id: str, + hub_name: str, + hub_content_type: Optional[str] = None, + hub_content_name: Optional[str] = None, + hub_content_version: Optional[str] = None, + ) -> None: + """Instantiates HubArnExtractedInfo object.""" + + self.partition = partition + self.region = region + self.account_id = account_id + self.hub_name = hub_name + self.hub_content_type = hub_content_type + self.hub_content_name = hub_content_name + self.hub_content_version = hub_content_version diff --git a/src/sagemaker/jumpstart/curated_hub/utils.py b/src/sagemaker/jumpstart/curated_hub/utils.py new file mode 100644 index 0000000000..5c7a91382b --- /dev/null +++ b/src/sagemaker/jumpstart/curated_hub/utils.py @@ -0,0 +1,111 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module contains utilities related to SageMaker JumpStart CuratedHub.""" +from __future__ import absolute_import +import re +from typing import Optional +from sagemaker.jumpstart import constants + +from sagemaker.jumpstart.curated_hub.types import HubArnExtractedInfo +from sagemaker.jumpstart.types import HubContentType +from sagemaker.session import Session +from sagemaker.utils import aws_partition + + +def get_info_from_hub_resource_arn( + arn: str, +) -> HubArnExtractedInfo: + """Extracts descriptive information from a Hub or HubContent Arn.""" + + match = re.match(constants.HUB_CONTENT_ARN_REGEX, arn) + if match: + partition = match.group(1) + hub_region = match.group(2) + account_id = match.group(3) + hub_name = match.group(4) + hub_content_type = match.group(5) + hub_content_name = match.group(6) + hub_content_version = match.group(7) + + return HubArnExtractedInfo( + partition=partition, + region=hub_region, + account_id=account_id, + hub_name=hub_name, + hub_content_type=hub_content_type, + hub_content_name=hub_content_name, + hub_content_version=hub_content_version, + ) + + match = re.match(constants.HUB_ARN_REGEX, arn) + if match: + partition = match.group(1) + hub_region = match.group(2) + account_id = match.group(3) + hub_name = match.group(4) + return HubArnExtractedInfo( + partition=partition, + region=hub_region, + account_id=account_id, + hub_name=hub_name, + ) + + return None + + +def construct_hub_arn_from_name( + hub_name: str, + region: Optional[str] = None, + session: Optional[Session] = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, +) -> str: + """Constructs a Hub arn from the Hub name using default Session values.""" + + account_id = session.account_id() + region = region or session.boto_region_name + partition = aws_partition(region) + + return f"arn:{partition}:sagemaker:{region}:{account_id}:hub/{hub_name}" + + +def construct_hub_model_arn_from_inputs(hub_arn: str, model_name: str, version: str) -> str: + """Constructs a HubContent model arn from the Hub name, model name, and model version.""" + + info = get_info_from_hub_resource_arn(hub_arn) + arn = ( + f"arn:{info.partition}:sagemaker:{info.region}:{info.account_id}:hub-content/" + f"{info.hub_name}/{HubContentType.MODEL}/{model_name}/{version}" + ) + + return arn + + +# TODO: Update to recognize JumpStartHub hub_name +def generate_hub_arn_for_estimator_init_kwargs( + hub_name: str, region: Optional[str] = None, session: Optional[Session] = None +): + """Generates the Hub Arn for JumpStartEstimator from a HubName or Arn. + + Args: + hub_name (str): HubName or HubArn from JumpStartEstimator args + region (str): Region from JumpStartEstimator args + session (Session): Custom SageMaker Session from JumpStartEstimator args + """ + + hub_arn = None + if hub_name: + match = re.match(constants.HUB_ARN_REGEX, hub_name) + if match: + hub_arn = hub_name + else: + hub_arn = construct_hub_arn_from_name(hub_name=hub_name, region=region, session=session) + return hub_arn diff --git a/src/sagemaker/jumpstart/enums.py b/src/sagemaker/jumpstart/enums.py index e33daca046..f962fdca80 100644 --- a/src/sagemaker/jumpstart/enums.py +++ b/src/sagemaker/jumpstart/enums.py @@ -79,6 +79,8 @@ class JumpStartTag(str, Enum): MODEL_ID = "sagemaker-sdk:jumpstart-model-id" MODEL_VERSION = "sagemaker-sdk:jumpstart-model-version" + HUB_ARN = "sagemaker-hub:hub-arn" + class SerializerType(str, Enum): """Enum class for serializers associated with JumpStart models.""" diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index 24105c4369..53d12b46a6 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -28,6 +28,7 @@ from sagemaker.instance_group import InstanceGroup from sagemaker.jumpstart.accessors import JumpStartModelsAccessor from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION +from sagemaker.jumpstart.curated_hub.utils import generate_hub_arn_for_estimator_init_kwargs from sagemaker.jumpstart.enums import JumpStartScriptScope from sagemaker.jumpstart.exceptions import INVALID_MODEL_ID_ERROR_MSG @@ -57,6 +58,7 @@ def __init__( self, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_name: Optional[str] = None, tolerate_vulnerable_model: Optional[bool] = None, tolerate_deprecated_model: Optional[bool] = None, region: Optional[str] = None, @@ -122,6 +124,7 @@ def __init__( https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html for list of model IDs. model_version (Optional[str]): Version for JumpStart model to use (Default: None). + hub_name (Optional[str]): Hub name or arn where the model is stored (Default: None). tolerate_vulnerable_model (Optional[bool]): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies @@ -518,9 +521,16 @@ def _is_valid_model_id_hook(): if not _is_valid_model_id_hook(): raise ValueError(INVALID_MODEL_ID_ERROR_MSG.format(model_id=model_id)) + hub_arn = None + if hub_name: + hub_arn = generate_hub_arn_for_estimator_init_kwargs( + hub_name=hub_name, region=region, session=sagemaker_session + ) + estimator_init_kwargs = get_init_kwargs( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, role=role, @@ -576,6 +586,7 @@ def _is_valid_model_id_hook(): enable_remote_debug=enable_remote_debug, ) + self.hub_arn = estimator_init_kwargs.hub_arn self.model_id = estimator_init_kwargs.model_id self.model_version = estimator_init_kwargs.model_version self.instance_type = estimator_init_kwargs.instance_type @@ -652,6 +663,7 @@ def fit( estimator_fit_kwargs = get_fit_kwargs( model_id=self.model_id, model_version=self.model_version, + hub_arn=self.hub_arn, region=self.region, inputs=inputs, wait=wait, @@ -1018,6 +1030,7 @@ def deploy( estimator_deploy_kwargs = get_deploy_kwargs( model_id=self.model_id, model_version=self.model_version, + hub_arn=self.hub_arn, region=self.region, tolerate_vulnerable_model=self.tolerate_vulnerable_model, tolerate_deprecated_model=self.tolerate_deprecated_model, diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index 7ccf57983b..71a6419f82 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -61,6 +61,7 @@ JumpStartModelInitKwargs, ) from sagemaker.jumpstart.utils import ( + add_hub_arn_tags, add_jumpstart_model_id_version_tags, update_dict_if_key_not_present, resolve_estimator_sagemaker_config_field, @@ -77,6 +78,7 @@ def get_init_kwargs( model_id: str, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, tolerate_vulnerable_model: Optional[bool] = None, tolerate_deprecated_model: Optional[bool] = None, region: Optional[str] = None, @@ -134,6 +136,7 @@ def get_init_kwargs( estimator_init_kwargs: JumpStartEstimatorInitKwargs = JumpStartEstimatorInitKwargs( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, role=role, region=region, instance_count=instance_count, @@ -209,6 +212,7 @@ def get_init_kwargs( def get_fit_kwargs( model_id: str, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, region: Optional[str] = None, inputs: Optional[Union[str, Dict, TrainingInput, FileSystemInput]] = None, wait: Optional[bool] = None, @@ -224,6 +228,7 @@ def get_fit_kwargs( estimator_fit_kwargs: JumpStartEstimatorFitKwargs = JumpStartEstimatorFitKwargs( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, inputs=inputs, wait=wait, @@ -246,6 +251,7 @@ def get_fit_kwargs( def get_deploy_kwargs( model_id: str, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, region: Optional[str] = None, initial_instance_count: Optional[int] = None, instance_type: Optional[str] = None, @@ -290,6 +296,7 @@ def get_deploy_kwargs( model_deploy_kwargs: JumpStartModelDeployKwargs = model.get_deploy_kwargs( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, initial_instance_count=initial_instance_count, instance_type=instance_type, @@ -432,6 +439,7 @@ def _add_instance_type_and_count_to_kwargs( region=kwargs.region, model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, scope=JumpStartScriptScope.TRAINING, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -454,6 +462,7 @@ def _add_tags_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstima full_model_version = verify_model_region_and_return_specs( model_id=kwargs.model_id, version=kwargs.model_version, + hub_arn=kwargs.hub_arn, scope=JumpStartScriptScope.TRAINING, region=kwargs.region, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -465,6 +474,10 @@ def _add_tags_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstima kwargs.tags = add_jumpstart_model_id_version_tags( kwargs.tags, kwargs.model_id, full_model_version ) + + if kwargs.hub_arn: + kwargs.tags = add_hub_arn_tags(kwargs.tags, kwargs.hub_arn) + return kwargs @@ -477,6 +490,7 @@ def _add_image_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE image_scope=JumpStartScriptScope.TRAINING, model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, instance_type=kwargs.instance_type, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -492,6 +506,7 @@ def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE if _model_supports_training_model_uri( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -501,6 +516,7 @@ def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE model_scope=JumpStartScriptScope.TRAINING, model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, @@ -513,6 +529,7 @@ def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE and not _model_supports_incremental_training( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -548,6 +565,7 @@ def _add_source_dir_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStart script_scope=JumpStartScriptScope.TRAINING, model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, @@ -564,6 +582,7 @@ def _add_env_to_kwargs( extra_env_vars = environment_variables.retrieve_default( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, region=kwargs.region, include_aws_sdk_env_vars=False, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, @@ -576,6 +595,7 @@ def _add_env_to_kwargs( model_package_artifact_uri = _retrieve_model_package_model_artifact_s3_uri( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, region=kwargs.region, scope=JumpStartScriptScope.TRAINING, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, @@ -616,6 +636,7 @@ def _add_training_job_name_to_kwargs( default_training_job_name = _retrieve_resource_name_base( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, @@ -642,6 +663,7 @@ def _add_hyperparameters_to_kwargs( region=kwargs.region, model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, @@ -675,6 +697,7 @@ def _add_metric_definitions_to_kwargs( region=kwargs.region, model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, sagemaker_session=kwargs.sagemaker_session, @@ -703,6 +726,7 @@ def _add_estimator_extra_kwargs( estimator_kwargs_to_add = _retrieve_estimator_init_kwargs( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, instance_type=kwargs.instance_type, region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, @@ -728,6 +752,7 @@ def _add_fit_extra_kwargs(kwargs: JumpStartEstimatorFitKwargs) -> JumpStartEstim fit_kwargs_to_add = _retrieve_estimator_fit_kwargs( model_id=kwargs.model_id, model_version=kwargs.model_version, + hub_arn=kwargs.hub_arn, region=kwargs.region, tolerate_deprecated_model=kwargs.tolerate_deprecated_model, tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 64e4727116..1dfe9ef5e2 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -44,6 +44,7 @@ JumpStartModelRegisterKwargs, ) from sagemaker.jumpstart.utils import ( + add_hub_arn_tags, add_jumpstart_model_id_version_tags, update_dict_if_key_not_present, resolve_model_sagemaker_config_field, @@ -447,6 +448,9 @@ def _add_tags_to_kwargs(kwargs: JumpStartModelDeployKwargs) -> Dict[str, Any]: kwargs.tags, kwargs.model_id, full_model_version ) + if kwargs.hub_arn: + kwargs.tags = add_hub_arn_tags(kwargs.tags, kwargs.hub_arn) + return kwargs @@ -489,6 +493,7 @@ def _add_resources_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel def get_deploy_kwargs( model_id: str, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, region: Optional[str] = None, initial_instance_count: Optional[int] = None, instance_type: Optional[str] = None, @@ -521,6 +526,7 @@ def get_deploy_kwargs( deploy_kwargs: JumpStartModelDeployKwargs = JumpStartModelDeployKwargs( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, region=region, initial_instance_count=initial_instance_count, instance_type=instance_type, diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 5a4e91d092..22cc70fcab 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -106,15 +106,15 @@ class JumpStartS3FileType(str, Enum): SPECS = "specs" -class HubDataType(str, Enum): +class HubContentType(str, Enum): """Enum for Hub data storage objects.""" - HUB = "hub" - MODEL = "model" - NOTEBOOK = "notebook" + HUB = "Hub" + MODEL = "Model" + NOTEBOOK = "Notebook" -JumpStartContentDataType = Union[JumpStartS3FileType, HubDataType] +JumpStartContentDataType = Union[JumpStartS3FileType, HubContentType] class JumpStartLaunchedRegionInfo(JumpStartDataHolderType): @@ -867,7 +867,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None: self.inference_enable_network_isolation: bool = json_obj.get( "inference_enable_network_isolation", False ) - self.resource_name_base: bool = json_obj.get("resource_name_base") + self.resource_name_base: Optional[str] = json_obj.get("resource_name_base") self.hosting_eula_key: Optional[str] = json_obj.get("hosting_eula_key") @@ -1151,6 +1151,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs): __slots__ = [ "model_id", "model_version", + "hub_arn", "initial_instance_count", "instance_type", "region", @@ -1182,6 +1183,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs): SERIALIZATION_EXCLUSION_SET = { "model_id", "model_version", + "hub_arn", "region", "tolerate_deprecated_model", "tolerate_vulnerable_model", @@ -1193,6 +1195,7 @@ def __init__( self, model_id: str, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, region: Optional[str] = None, initial_instance_count: Optional[int] = None, instance_type: Optional[str] = None, @@ -1224,6 +1227,7 @@ def __init__( self.model_id = model_id self.model_version = model_version + self.hub_arn = hub_arn self.initial_instance_count = initial_instance_count self.instance_type = instance_type self.region = region @@ -1258,6 +1262,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs): __slots__ = [ "model_id", "model_version", + "hub_arn", "instance_type", "instance_count", "region", @@ -1317,12 +1322,14 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs): "tolerate_vulnerable_model", "model_id", "model_version", + "hub_arn", } def __init__( self, model_id: str, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, region: Optional[str] = None, image_uri: Optional[Union[str, Any]] = None, role: Optional[str] = None, @@ -1379,6 +1386,7 @@ def __init__( self.model_id = model_id self.model_version = model_version + self.hub_arn = hub_arn self.instance_type = instance_type self.instance_count = instance_count self.region = region @@ -1440,6 +1448,7 @@ class JumpStartEstimatorFitKwargs(JumpStartKwargs): __slots__ = [ "model_id", "model_version", + "hub_arn", "region", "inputs", "wait", @@ -1454,6 +1463,7 @@ class JumpStartEstimatorFitKwargs(JumpStartKwargs): SERIALIZATION_EXCLUSION_SET = { "model_id", "model_version", + "hub_arn", "region", "tolerate_deprecated_model", "tolerate_vulnerable_model", @@ -1464,6 +1474,7 @@ def __init__( self, model_id: str, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, region: Optional[str] = None, inputs: Optional[Union[str, Dict, Any, Any]] = None, wait: Optional[bool] = None, @@ -1478,6 +1489,7 @@ def __init__( self.model_id = model_id self.model_version = model_version + self.hub_arn = hub_arn self.region = region self.inputs = inputs self.wait = wait @@ -1495,6 +1507,7 @@ class JumpStartEstimatorDeployKwargs(JumpStartKwargs): __slots__ = [ "model_id", "model_version", + "hub_arn", "instance_type", "initial_instance_count", "region", @@ -1540,6 +1553,7 @@ class JumpStartEstimatorDeployKwargs(JumpStartKwargs): "region", "model_id", "model_version", + "hub_arn", "sagemaker_session", } @@ -1547,6 +1561,7 @@ def __init__( self, model_id: str, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, region: Optional[str] = None, initial_instance_count: Optional[int] = None, instance_type: Optional[str] = None, @@ -1589,6 +1604,7 @@ def __init__( self.model_id = model_id self.model_version = model_version + self.hub_arn = hub_arn self.instance_type = instance_type self.initial_instance_count = initial_instance_count self.region = region diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 6de2761d20..1e2bb11d45 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -14,7 +14,6 @@ from __future__ import absolute_import import logging import os -import re from typing import Any, Dict, List, Optional, Tuple, Union from urllib.parse import urlparse import boto3 @@ -368,6 +367,21 @@ def add_jumpstart_model_id_version_tags( return tags +def add_hub_arn_tags( + tags: Optional[List[TagsDict]], + hub_arn: str, +) -> Optional[List[TagsDict]]: + """Adds custom Hub arn tag to JumpStart related resources.""" + + tags = add_single_jumpstart_tag( + hub_arn, + enums.JumpStartTag.HUB_ARN, + tags, + is_uri=False, + ) + return tags + + def add_jumpstart_uri_tags( tags: Optional[List[TagsDict]] = None, inference_model_uri: Optional[Union[str, dict]] = None, @@ -528,6 +542,7 @@ def verify_model_region_and_return_specs( version: Optional[str], scope: Optional[str], region: str, + hub_arn: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION, @@ -542,6 +557,8 @@ def verify_model_region_and_return_specs( scope (Optional[str]): scope of the JumpStart model to verify. region (Optional[str]): region of the JumpStart model to verify and obtains specs. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (default: None). tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications should be tolerated (exception not raised). If False, raises an exception if the script used by this version of the model has dependencies with known @@ -577,6 +594,7 @@ def verify_model_region_and_return_specs( model_specs = accessors.JumpStartModelsAccessor.get_model_specs( # type: ignore region=region, model_id=model_id, + hub_arn=hub_arn, version=version, s3_client=sagemaker_session.s3_client, ) @@ -811,26 +829,3 @@ def get_jumpstart_model_id_version_from_resource_arn( model_version = model_version_from_tag return model_id, model_version - - -def extract_info_from_hub_content_arn( - arn: str, -) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]: - """Extracts hub_name, content_name, and content_version from a HubContentArn""" - - match = re.match(constants.HUB_MODEL_ARN_REGEX, arn) - if match: - hub_name = match.group(4) - hub_region = match.group(2) - content_name = match.group(5) - content_version = match.group(6) - - return hub_name, hub_region, content_name, content_version - - match = re.match(constants.HUB_ARN_REGEX, arn) - if match: - hub_name = match.group(4) - hub_region = match.group(2) - return hub_name, hub_region, None, None - - return None, None, None, None diff --git a/src/sagemaker/jumpstart/validators.py b/src/sagemaker/jumpstart/validators.py index 3199e5fc2e..55cbdd90eb 100644 --- a/src/sagemaker/jumpstart/validators.py +++ b/src/sagemaker/jumpstart/validators.py @@ -167,6 +167,7 @@ def validate_hyperparameters( model_id: str, model_version: str, hyperparameters: Dict[str, Any], + hub_arn: Optional[str] = None, validation_mode: HyperparameterValidationMode = HyperparameterValidationMode.VALIDATE_PROVIDED, region: Optional[str] = JUMPSTART_DEFAULT_REGION_NAME, sagemaker_session: Optional[session.Session] = None, @@ -179,6 +180,8 @@ def validate_hyperparameters( model_id (str): Model ID of the model for which to validate hyperparameters. model_version (str): Version of the model for which to validate hyperparameters. hyperparameters (dict): Hyperparameters to validate. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (default: None). validation_mode (HyperparameterValidationMode): Method of validation to use with hyperparameters. If set to ``VALIDATE_PROVIDED``, only hyperparameters provided to this function will be validated, the missing hyperparameters will be ignored. @@ -211,6 +214,7 @@ def validate_hyperparameters( model_specs = verify_model_region_and_return_specs( model_id=model_id, version=model_version, + hub_arn=hub_arn, region=region, scope=JumpStartScriptScope.TRAINING, sagemaker_session=sagemaker_session, diff --git a/src/sagemaker/metric_definitions.py b/src/sagemaker/metric_definitions.py index 71dd26db45..a31d5d930d 100644 --- a/src/sagemaker/metric_definitions.py +++ b/src/sagemaker/metric_definitions.py @@ -29,6 +29,7 @@ def retrieve_default( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, instance_type: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -43,6 +44,8 @@ def retrieve_default( retrieve the default training metric definitions. (Default: None). model_version (str): The version of the model for which to retrieve the default training metric definitions. (Default: None). + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (default: None). instance_type (str): An instance type to optionally supply in order to get metric definitions specific for the instance type. tolerate_vulnerable_model (bool): True if vulnerable versions of model @@ -71,6 +74,7 @@ def retrieve_default( return artifacts._retrieve_default_training_metric_definitions( model_id=model_id, model_version=model_version, + hub_arn=hub_arn, instance_type=instance_type, region=region, tolerate_vulnerable_model=tolerate_vulnerable_model, diff --git a/src/sagemaker/model_uris.py b/src/sagemaker/model_uris.py index 937180bd44..a2177c0ec5 100644 --- a/src/sagemaker/model_uris.py +++ b/src/sagemaker/model_uris.py @@ -29,6 +29,7 @@ def retrieve( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, model_scope: Optional[str] = None, instance_type: Optional[str] = None, tolerate_vulnerable_model: bool = False, @@ -43,6 +44,8 @@ def retrieve( the model artifact S3 URI. model_version (str): The version of the JumpStart model for which to retrieve the model artifact S3 URI. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (default: None). model_scope (str): The model type. Valid values: "training" and "inference". instance_type (str): The ML compute instance type for the specified scope. (Default: None). @@ -75,6 +78,7 @@ def retrieve( return artifacts._retrieve_model_uri( model_id=model_id, model_version=model_version, # type: ignore + hub_arn=hub_arn, model_scope=model_scope, instance_type=instance_type, region=region, diff --git a/src/sagemaker/script_uris.py b/src/sagemaker/script_uris.py index 9a1c4933d2..9341a4198f 100644 --- a/src/sagemaker/script_uris.py +++ b/src/sagemaker/script_uris.py @@ -29,6 +29,7 @@ def retrieve( region: Optional[str] = None, model_id: Optional[str] = None, model_version: Optional[str] = None, + hub_arn: Optional[str] = None, script_scope: Optional[str] = None, tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, @@ -42,6 +43,8 @@ def retrieve( retrieve the script S3 URI. model_version (str): The version of the JumpStart model for which to retrieve the model script S3 URI. + hub_arn (str): The arn of the SageMaker Hub for which to retrieve + model details from. (default: None). script_scope (str): The script type. Valid values: "training" and "inference". tolerate_vulnerable_model (bool): ``True`` if vulnerable versions of model @@ -73,6 +76,7 @@ def retrieve( return artifacts._retrieve_script_uri( model_id, model_version, + hub_arn, script_scope, region, tolerate_vulnerable_model, diff --git a/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py b/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py index 28211d06f1..4284ac4d84 100644 --- a/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py +++ b/tests/unit/sagemaker/accept_types/jumpstart/test_accept_types.py @@ -45,7 +45,7 @@ def test_jumpstart_default_accept_types( assert default_accept_type == "application/json" patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version, s3_client=mock_client + region=region, model_id=model_id, version=model_version, s3_client=mock_client, hub_arn=None ) @@ -73,5 +73,5 @@ def test_jumpstart_supported_accept_types( ] patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version, s3_client=mock_client + region=region, model_id=model_id, version=model_version, s3_client=mock_client, hub_arn=None ) diff --git a/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py b/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py index 4b2db7d7f4..c924417946 100644 --- a/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py +++ b/tests/unit/sagemaker/content_types/jumpstart/test_content_types.py @@ -45,7 +45,7 @@ def test_jumpstart_default_content_types( assert default_content_type == "application/x-text" patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version, s3_client=mock_client + region=region, model_id=model_id, version=model_version, s3_client=mock_client, hub_arn=None ) @@ -72,5 +72,5 @@ def test_jumpstart_supported_content_types( ] patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version, s3_client=mock_client + region=region, model_id=model_id, version=model_version, s3_client=mock_client, hub_arn=None ) diff --git a/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py b/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py index 9d6e2f21de..4807fc7933 100644 --- a/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py +++ b/tests/unit/sagemaker/deserializers/jumpstart/test_deserializers.py @@ -47,7 +47,7 @@ def test_jumpstart_default_deserializers( assert isinstance(default_deserializer, base_deserializers.JSONDeserializer) patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version, s3_client=mock_client + region=region, model_id=model_id, version=model_version, s3_client=mock_client, hub_arn=None ) @@ -79,5 +79,5 @@ def test_jumpstart_deserializer_options( ) patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version, s3_client=mock_client + region=region, model_id=model_id, version=model_version, s3_client=mock_client, hub_arn=None ) diff --git a/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py b/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py index acd8d19923..d89687b3f4 100644 --- a/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py +++ b/tests/unit/sagemaker/environment_variables/jumpstart/test_default.py @@ -48,7 +48,7 @@ def test_jumpstart_default_environment_variables(patched_get_model_specs): } patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version="*", s3_client=mock_client + region=region, model_id=model_id, version="*", s3_client=mock_client, hub_arn=None ) patched_get_model_specs.reset_mock() @@ -68,7 +68,7 @@ def test_jumpstart_default_environment_variables(patched_get_model_specs): } patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version="1.*", s3_client=mock_client + region=region, model_id=model_id, version="1.*", s3_client=mock_client, hub_arn=None ) patched_get_model_specs.reset_mock() @@ -122,7 +122,7 @@ def test_jumpstart_sdk_environment_variables(patched_get_model_specs): } patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version="*", s3_client=mock_client + region=region, model_id=model_id, version="*", s3_client=mock_client, hub_arn=None ) patched_get_model_specs.reset_mock() @@ -143,7 +143,7 @@ def test_jumpstart_sdk_environment_variables(patched_get_model_specs): } patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version="1.*", s3_client=mock_client + region=region, model_id=model_id, version="1.*", s3_client=mock_client, hub_arn=None ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py b/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py index eebc079164..ba08aa0825 100644 --- a/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py +++ b/tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py @@ -43,10 +43,7 @@ def test_jumpstart_default_hyperparameters(patched_get_model_specs): assert params == {"adam-learning-rate": "0.05", "batch-size": "4", "epochs": "3"} patched_get_model_specs.assert_called_once_with( - region=region, - model_id=model_id, - version="*", - s3_client=mock_client, + region=region, model_id=model_id, version="*", s3_client=mock_client, hub_arn=None ) patched_get_model_specs.reset_mock() @@ -60,10 +57,7 @@ def test_jumpstart_default_hyperparameters(patched_get_model_specs): assert params == {"adam-learning-rate": "0.05", "batch-size": "4", "epochs": "3"} patched_get_model_specs.assert_called_once_with( - region=region, - model_id=model_id, - version="1.*", - s3_client=mock_client, + region=region, model_id=model_id, version="1.*", s3_client=mock_client, hub_arn=None ) patched_get_model_specs.reset_mock() @@ -85,10 +79,7 @@ def test_jumpstart_default_hyperparameters(patched_get_model_specs): } patched_get_model_specs.assert_called_once_with( - region=region, - model_id=model_id, - version="1.*", - s3_client=mock_client, + region=region, model_id=model_id, version="1.*", s3_client=mock_client, hub_arn=None ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py b/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py index 0054ed9dbd..ae8138b7c5 100644 --- a/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py +++ b/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py @@ -136,10 +136,7 @@ def add_options_to_hyperparameter(*largs, **kwargs): ) patched_get_model_specs.assert_called_once_with( - region=region, - model_id=model_id, - version=model_version, - s3_client=mock_client, + region=region, model_id=model_id, version=model_version, s3_client=mock_client, hub_arn=None ) patched_get_model_specs.reset_mock() @@ -437,7 +434,7 @@ def add_options_to_hyperparameter(*largs, **kwargs): ) patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version, s3_client=mock_client + region=region, model_id=model_id, version=model_version, s3_client=mock_client, hub_arn=None ) patched_get_model_specs.reset_mock() @@ -491,7 +488,7 @@ def test_jumpstart_validate_all_hyperparameters(patched_get_model_specs): ) patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version, s3_client=mock_client + region=region, model_id=model_id, version=model_version, s3_client=mock_client, hub_arn=None ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/image_uris/jumpstart/test_common.py b/tests/unit/sagemaker/image_uris/jumpstart/test_common.py index 8a41891280..b37b770875 100644 --- a/tests/unit/sagemaker/image_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/image_uris/jumpstart/test_common.py @@ -49,6 +49,7 @@ def test_jumpstart_common_image_uri( model_id="pytorch-ic-mobilenet-v2", version="*", s3_client=mock_client, + hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -69,6 +70,7 @@ def test_jumpstart_common_image_uri( model_id="pytorch-ic-mobilenet-v2", version="1.*", s3_client=mock_client, + hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -89,6 +91,7 @@ def test_jumpstart_common_image_uri( model_id="pytorch-ic-mobilenet-v2", version="*", s3_client=mock_client, + hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -109,6 +112,7 @@ def test_jumpstart_common_image_uri( model_id="pytorch-ic-mobilenet-v2", version="1.*", s3_client=mock_client, + hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() diff --git a/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py b/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py index bed2e50674..50f35cb872 100644 --- a/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py +++ b/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py @@ -46,6 +46,7 @@ def test_jumpstart_instance_types(patched_get_model_specs): region=region, model_id=model_id, version=model_version, + hub_arn=None, s3_client=mock_client, ) @@ -64,6 +65,7 @@ def test_jumpstart_instance_types(patched_get_model_specs): region=region, model_id=model_id, version=model_version, + hub_arn=None, s3_client=mock_client, ) @@ -88,6 +90,7 @@ def test_jumpstart_instance_types(patched_get_model_specs): region=region, model_id=model_id, version=model_version, + hub_arn=None, s3_client=mock_client, ) @@ -111,7 +114,11 @@ def test_jumpstart_instance_types(patched_get_model_specs): ] patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version=model_version, s3_client=mock_client + region=region, + model_id=model_id, + version=model_version, + hub_arn=None, + s3_client=mock_client, ) patched_get_model_specs.reset_mock() @@ -163,6 +170,116 @@ def test_jumpstart_instance_types(patched_get_model_specs): instance_types.retrieve(model_id=model_id, scope="training") +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +def test_jumpstart_instance_types_from_hub(patched_get_model_specs): + + patched_get_model_specs.side_effect = get_spec_from_base_spec + + model_id, model_version = "huggingface-eqa-bert-base-cased", "*" + region = "us-west-2" + hub_arn = f"arn:aws:sagemaker:{region}:123456789123:hub/my-mock-hub" + + mock_client = boto3.client("s3") + mock_session = Mock(s3_client=mock_client) + + default_training_instance_types = instance_types.retrieve_default( + region=region, + model_id=model_id, + hub_arn=hub_arn, + model_version=model_version, + scope="training", + sagemaker_session=mock_session, + ) + + assert default_training_instance_types == "ml.p3.2xlarge" + + patched_get_model_specs.assert_called_once_with( + hub_arn=hub_arn, + region=region, + model_id=model_id, + version=model_version, + s3_client=mock_client, + ) + + patched_get_model_specs.reset_mock() + + default_inference_instance_types = instance_types.retrieve_default( + region=region, + model_id=model_id, + hub_arn=hub_arn, + model_version=model_version, + scope="inference", + sagemaker_session=mock_session, + ) + + assert default_inference_instance_types == "ml.p2.xlarge" + + patched_get_model_specs.assert_called_once_with( + hub_arn=hub_arn, + region=region, + model_id=model_id, + version=model_version, + s3_client=mock_client, + ) + + patched_get_model_specs.reset_mock() + + default_training_instance_types = instance_types.retrieve( + region=region, + model_id=model_id, + model_version=model_version, + hub_arn=hub_arn, + scope="training", + sagemaker_session=mock_session, + ) + assert default_training_instance_types == [ + "ml.p3.2xlarge", + "ml.p2.xlarge", + "ml.g4dn.2xlarge", + "ml.m5.xlarge", + "ml.c5.2xlarge", + ] + + patched_get_model_specs.assert_called_once_with( + hub_arn=hub_arn, + region=region, + model_id=model_id, + version=model_version, + s3_client=mock_client, + ) + + patched_get_model_specs.reset_mock() + + default_inference_instance_types = instance_types.retrieve( + region=region, + model_id=model_id, + model_version=model_version, + hub_arn=hub_arn, + scope="inference", + sagemaker_session=mock_session, + ) + + assert default_inference_instance_types == [ + "ml.p2.xlarge", + "ml.p3.2xlarge", + "ml.g4dn.xlarge", + "ml.m5.large", + "ml.m5.xlarge", + "ml.c5.xlarge", + "ml.c5.2xlarge", + ] + + patched_get_model_specs.assert_called_once_with( + hub_arn=hub_arn, + region=region, + model_id=model_id, + version=model_version, + s3_client=mock_client, + ) + + patched_get_model_specs.reset_mock() + + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_inference_instance_type_variants(patched_get_model_specs): patched_get_model_specs.side_effect = get_special_model_spec diff --git a/tests/unit/sagemaker/jumpstart/constants.py b/tests/unit/sagemaker/jumpstart/constants.py index a3c4c747f7..a60f8f9315 100644 --- a/tests/unit/sagemaker/jumpstart/constants.py +++ b/tests/unit/sagemaker/jumpstart/constants.py @@ -1831,7 +1831,7 @@ "estimator_kwargs": { "encrypt_inter_container_traffic": True, }, - "fit_kwargs": {"some-estimator-fit-key": "some-estimator-fit-value"}, + "fit_kwargs": {"job_name": "some-estimator-fit-value"}, "predictor_specs": { "supported_content_types": ["application/x-image"], "supported_accept_types": ["application/json;verbose", "application/json"], @@ -2058,7 +2058,7 @@ "estimator_kwargs": { "encrypt_inter_container_traffic": True, }, - "fit_kwargs": {"some-estimator-fit-key": "some-estimator-fit-value"}, + "fit_kwargs": {"job_name": "some-estimator-fit-value"}, "predictor_specs": { "supported_content_types": ["application/x-image"], "supported_accept_types": ["application/json;verbose", "application/json"], @@ -2284,7 +2284,7 @@ "estimator_kwargs": { "encrypt_inter_container_traffic": True, }, - "fit_kwargs": {"some-estimator-fit-key": "some-estimator-fit-value"}, + "fit_kwargs": {"job_name": "some-estimator-fit-value"}, "predictor_specs": { "supported_content_types": ["application/x-image"], "supported_accept_types": ["application/json;verbose", "application/json"], diff --git a/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py b/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py new file mode 100644 index 0000000000..2f0841b4ea --- /dev/null +++ b/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py @@ -0,0 +1,153 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +from unittest.mock import Mock +from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME + +from sagemaker.jumpstart.curated_hub import utils +from sagemaker.jumpstart.curated_hub.types import HubArnExtractedInfo + + +def test_get_info_from_hub_resource_arn(): + model_arn = ( + "arn:aws:sagemaker:us-west-2:000000000000:hub-content/MockHub/Model/my-mock-model/1.0.2" + ) + assert utils.get_info_from_hub_resource_arn(model_arn) == HubArnExtractedInfo( + partition="aws", + region="us-west-2", + account_id="000000000000", + hub_name="MockHub", + hub_content_type="Model", + hub_content_name="my-mock-model", + hub_content_version="1.0.2", + ) + + notebook_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub-content/MockHub/Notebook/my-mock-notebook/1.0.2" + assert utils.get_info_from_hub_resource_arn(notebook_arn) == HubArnExtractedInfo( + partition="aws", + region="us-west-2", + account_id="000000000000", + hub_name="MockHub", + hub_content_type="Notebook", + hub_content_name="my-mock-notebook", + hub_content_version="1.0.2", + ) + + hub_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub/MockHub" + assert utils.get_info_from_hub_resource_arn(hub_arn) == HubArnExtractedInfo( + partition="aws", + region="us-west-2", + account_id="000000000000", + hub_name="MockHub", + ) + + invalid_arn = "arn:aws:sagemaker:us-west-2:000000000000:endpoint/my-endpoint-123" + assert None is utils.get_info_from_hub_resource_arn(invalid_arn) + + invalid_arn = "nonsense-string" + assert None is utils.get_info_from_hub_resource_arn(invalid_arn) + + invalid_arn = "" + assert None is utils.get_info_from_hub_resource_arn(invalid_arn) + + +def test_construct_hub_arn_from_name(): + mock_sagemaker_session = Mock() + mock_sagemaker_session.account_id.return_value = "123456789123" + mock_sagemaker_session.boto_region_name = "us-west-2" + hub_name = "my-cool-hub" + + assert ( + utils.construct_hub_arn_from_name(hub_name=hub_name, session=mock_sagemaker_session) + == "arn:aws:sagemaker:us-west-2:123456789123:hub/my-cool-hub" + ) + + assert ( + utils.construct_hub_arn_from_name( + hub_name=hub_name, region="us-east-1", session=mock_sagemaker_session + ) + == "arn:aws:sagemaker:us-east-1:123456789123:hub/my-cool-hub" + ) + + +def test_construct_hub_model_arn_from_inputs(): + model_name, version = "pytorch-ic-imagenet-v2", "1.0.2" + hub_arn = "arn:aws:sagemaker:us-west-2:123456789123:hub/my-mock-hub" + + assert ( + utils.construct_hub_model_arn_from_inputs(hub_arn, model_name, version) + == "arn:aws:sagemaker:us-west-2:123456789123:hub-content/my-mock-hub/Model/pytorch-ic-imagenet-v2/1.0.2" + ) + + version = "*" + assert ( + utils.construct_hub_model_arn_from_inputs(hub_arn, model_name, version) + == "arn:aws:sagemaker:us-west-2:123456789123:hub-content/my-mock-hub/Model/pytorch-ic-imagenet-v2/*" + ) + + +def test_generate_hub_arn_for_estimator_init_kwargs(): + hub_name = "my-hub-name" + hub_arn = "arn:aws:sagemaker:us-west-2:12346789123:hub/my-awesome-hub" + # Mock default session with default values + mock_default_session = Mock() + mock_default_session.account_id.return_value = "123456789123" + mock_default_session.boto_region_name = JUMPSTART_DEFAULT_REGION_NAME + # Mock custom session with custom values + mock_custom_session = Mock() + mock_custom_session.account_id.return_value = "000000000000" + mock_custom_session.boto_region_name = "us-east-2" + + assert ( + utils.generate_hub_arn_for_estimator_init_kwargs(hub_name, session=mock_default_session) + == "arn:aws:sagemaker:us-west-2:123456789123:hub/my-hub-name" + ) + + assert ( + utils.generate_hub_arn_for_estimator_init_kwargs( + hub_name, "us-east-1", session=mock_default_session + ) + == "arn:aws:sagemaker:us-east-1:123456789123:hub/my-hub-name" + ) + + assert ( + utils.generate_hub_arn_for_estimator_init_kwargs(hub_name, "eu-west-1", mock_custom_session) + == "arn:aws:sagemaker:eu-west-1:000000000000:hub/my-hub-name" + ) + + assert ( + utils.generate_hub_arn_for_estimator_init_kwargs(hub_name, None, mock_custom_session) + == "arn:aws:sagemaker:us-east-2:000000000000:hub/my-hub-name" + ) + + assert ( + utils.generate_hub_arn_for_estimator_init_kwargs(hub_arn, session=mock_default_session) + == hub_arn + ) + + assert ( + utils.generate_hub_arn_for_estimator_init_kwargs( + hub_arn, "us-east-1", session=mock_default_session + ) + == hub_arn + ) + + assert ( + utils.generate_hub_arn_for_estimator_init_kwargs(hub_arn, "us-east-1", mock_custom_session) + == hub_arn + ) + + assert ( + utils.generate_hub_arn_for_estimator_init_kwargs(hub_arn, None, mock_custom_session) + == hub_arn + ) diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index 4dc35b65ca..c8fc541816 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -41,6 +41,7 @@ from sagemaker.model import Model from sagemaker.predictor import Predictor from tests.unit.sagemaker.jumpstart.utils import ( + get_spec_from_base_spec, get_special_model_spec, overwrite_dictionary, ) @@ -281,6 +282,146 @@ def test_prepacked( ], ) + @mock.patch("sagemaker.jumpstart.artifacts.resource_names._retrieve_resource_name_base") + @mock.patch("sagemaker.session.Session.account_id") + @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs") + @mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_deploy_kwargs") + @mock.patch("sagemaker.jumpstart.factory.estimator._retrieve_estimator_fit_kwargs") + @mock.patch("sagemaker.jumpstart.curated_hub.utils.construct_hub_arn_from_name") + @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") + @mock.patch("sagemaker.jumpstart.factory.model.Session") + @mock.patch("sagemaker.jumpstart.factory.estimator.Session") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy") + @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_hub_model( + self, + mock_estimator_deploy: mock.Mock, + mock_estimator_fit: mock.Mock, + mock_estimator_init: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session_estimator: mock.Mock, + mock_session_model: mock.Mock, + mock_is_valid_model_id: mock.Mock, + mock_construct_hub_arn_from_name: mock.Mock, + mock_retrieve_estimator_fit_kwargs: mock.Mock, + mock_retrieve_model_deploy_kwargs: mock.Mock, + mock_retrieve_model_init_kwargs: mock.Mock, + mock_get_caller_identity: mock.Mock, + mock_retrieve_resource_name_base: mock.Mock, + ): + mock_retrieve_resource_name_base.return_value = "go-blue" + mock_get_caller_identity.return_value = "123456789123" + mock_estimator_deploy.return_value = default_predictor + + mock_is_valid_model_id.return_value = True + + model_id, _ = "pytorch-hub-model-1", "*" + hub_arn = "arn:aws:sagemaker:us-west-2:123456789123:hub/my-mock-hub" + + mock_retrieve_estimator_fit_kwargs.return_value = {} + mock_retrieve_model_deploy_kwargs.return_value = {} + mock_retrieve_model_init_kwargs.return_value = {} + + mock_get_model_specs.side_effect = get_spec_from_base_spec + + mock_session_estimator.return_value = sagemaker_session + mock_session_model.return_value = sagemaker_session + + estimator = JumpStartEstimator( + model_id=model_id, + hub_name=hub_arn, + ) + + mock_construct_hub_arn_from_name.assert_not_called() + + mock_estimator_init.assert_called_once_with( + instance_type="ml.p3.2xlarge", + instance_count=1, + image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.5.0-gpu-py3", + model_uri="s3://jumpstart-cache-prod-us-west-2/pytorch-training/" + "train-pytorch-ic-mobilenet-v2.tar.gz", + source_dir="s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/pytorch/" + "transfer_learning/ic/v1.0.0/sourcedir.tar.gz", + entry_point="transfer_learning.py", + hyperparameters={ + "epochs": "3", + "adam-learning-rate": "0.05", + "batch-size": "4", + }, + metric_definitions=[ + {"Regex": "val_accuracy: ([0-9\\.]+)", "Name": "pytorch-ic:val-accuracy"} + ], + role=execution_role, + volume_size=456, + sagemaker_session=sagemaker_session, + tags=[ + {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-hub-model-1"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "*"}, + { + "Key": JumpStartTag.HUB_ARN, + "Value": "arn:aws:sagemaker:us-west-2:123456789123:hub/my-mock-hub", + }, + ], + encrypt_inter_container_traffic=True, + enable_network_isolation=False, + ) + + channels = { + "training": f"s3://{get_jumpstart_content_bucket(region)}/" + f"some-training-dataset-doesn't-matter", + } + + estimator.fit(channels, job_name="go-blue") + + mock_estimator_fit.assert_called_once_with(inputs=channels, wait=True, job_name="go-blue") + + estimator.deploy(endpoint_name="go-blue", model_name="go-blue") + + mock_estimator_deploy.assert_called_once_with( + instance_type="ml.p2.xlarge", + initial_instance_count=1, + image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference" + ":1.5.0-gpu-py3", + source_dir="s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/" + "pytorch/inference/ic/v1.0.0/sourcedir.tar.gz", + entry_point="inference.py", + endpoint_name="go-blue", + model_name="go-blue", + env={ + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, + predictor_cls=Predictor, + tags=[ + {"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-hub-model-1"}, + {"Key": JumpStartTag.MODEL_VERSION, "Value": "*"}, + { + "Key": JumpStartTag.HUB_ARN, + "Value": "arn:aws:sagemaker:us-west-2:123456789123:hub/my-mock-hub", + }, + ], + role=execution_role, + wait=True, + use_compiled_model=False, + ) + + # Test Hub arn util + estimator = JumpStartEstimator( + model_id=model_id, + hub_name="my-mock-hub", + ) + + mock_construct_hub_arn_from_name.assert_called_once_with( + hub_name="my-mock-hub", region=None, session=None + ) + @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") @mock.patch("sagemaker.jumpstart.factory.model.Session") @@ -1082,6 +1223,7 @@ def test_jumpstart_estimator_kwargs_match_parent_class(self): "model_id", "model_version", "region", + "hub_name", "tolerate_vulnerable_model", "tolerate_deprecated_model", } @@ -1338,6 +1480,7 @@ def test_incremental_training_with_unsupported_model_logs_warning( model_id=model_id, model_version="*", region=region, + hub_arn=None, tolerate_deprecated_model=False, tolerate_vulnerable_model=False, sagemaker_session=sagemaker_session, @@ -1389,6 +1532,7 @@ def test_incremental_training_with_supported_model_doesnt_log_warning( model_id=model_id, model_version="*", region=region, + hub_arn=None, tolerate_deprecated_model=False, tolerate_vulnerable_model=False, sagemaker_session=sagemaker_session, diff --git a/tests/unit/sagemaker/jumpstart/test_accessors.py b/tests/unit/sagemaker/jumpstart/test_accessors.py index 97427be1ae..460494e116 100644 --- a/tests/unit/sagemaker/jumpstart/test_accessors.py +++ b/tests/unit/sagemaker/jumpstart/test_accessors.py @@ -45,6 +45,7 @@ def test_jumpstart_models_cache_get_fxs(mock_cache): mock_cache.get_manifest = Mock(return_value=BASE_MANIFEST) mock_cache.get_header = Mock(side_effect=get_header_from_base_header) mock_cache.get_specs = Mock(side_effect=get_spec_from_base_spec) + mock_cache.get_hub_model = Mock(side_effect=get_spec_from_base_spec) assert get_header_from_base_header( region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="*" @@ -56,6 +57,14 @@ def test_jumpstart_models_cache_get_fxs(mock_cache): ) == accessors.JumpStartModelsAccessor.get_model_specs( region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="*" ) + assert get_spec_from_base_spec( + hub_arn="arn:aws:sagemaker:us-west-2:123456789123:hub/my-mock-hub", + ) == accessors.JumpStartModelsAccessor.get_model_specs( + region="us-west-2", + model_id="pytorch-ic-mobilenet-v2", + version="*", + hub_arn="arn:aws:sagemaker:us-west-2:123456789123:hub/my-mock-hub", + ) assert len(accessors.JumpStartModelsAccessor._get_manifest()) > 0 @@ -63,6 +72,35 @@ def test_jumpstart_models_cache_get_fxs(mock_cache): reload(accessors) +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache") +def test_jumpstart_models_cache_get_model_specs(mock_cache): + mock_cache.get_specs = Mock() + mock_cache.get_hub_model = Mock() + model_id, version = "pytorch-ic-mobilenet-v2", "*" + region = "us-west-2" + + accessors.JumpStartModelsAccessor.get_model_specs( + region=region, model_id=model_id, version=version + ) + mock_cache.get_specs.assert_called_once_with(model_id=model_id, semantic_version_str=version) + mock_cache.get_hub_model.assert_not_called() + + accessors.JumpStartModelsAccessor.get_model_specs( + region=region, + model_id=model_id, + version=version, + hub_arn=f"arn:aws:sagemaker:{region}:123456789123:hub/my-mock-hub", + ) + mock_cache.get_hub_model.assert_called_once_with( + hub_model_arn=( + f"arn:aws:sagemaker:{region}:123456789123:hub-content/my-mock-hub/Model/{model_id}/{version}" + ) + ) + + # necessary because accessors is a static module + reload(accessors) + + @patch("sagemaker.jumpstart.cache.JumpStartModelsCache") def test_jumpstart_models_cache_set_reset_fxs(mock_model_cache: Mock): @@ -85,9 +123,14 @@ def test_jumpstart_models_cache_set_reset_fxs(mock_model_cache: Mock): mock_model_cache.assert_called_once() mock_model_cache.reset_mock() + # shouldn't matter if hub_arn is passed through accessors.JumpStartModelsAccessor.get_model_specs( - region="us-west-1", model_id="pytorch-ic-mobilenet-v2", version="*" + region="us-west-1", + model_id="pytorch-ic-mobilenet-v2", + version="*", + hub_arn="arn:aws:sagemaker:us-west-2:123456789123:hub/my-mock-hub", ) + mock_model_cache.assert_called_once() mock_model_cache.reset_mock() diff --git a/tests/unit/sagemaker/jumpstart/test_artifacts.py b/tests/unit/sagemaker/jumpstart/test_artifacts.py index 1a770f785f..8acd04f1f6 100644 --- a/tests/unit/sagemaker/jumpstart/test_artifacts.py +++ b/tests/unit/sagemaker/jumpstart/test_artifacts.py @@ -20,11 +20,22 @@ import copy from sagemaker.jumpstart import artifacts +from sagemaker.jumpstart.artifacts.environment_variables import ( + _retrieve_default_environment_variables, +) +from sagemaker.jumpstart.artifacts.hyperparameters import _retrieve_default_hyperparameters +from sagemaker.jumpstart.artifacts.image_uris import _retrieve_image_uri +from sagemaker.jumpstart.artifacts.incremental_training import _model_supports_incremental_training +from sagemaker.jumpstart.artifacts.instance_types import _retrieve_default_instance_type +from sagemaker.jumpstart.artifacts.metric_definitions import ( + _retrieve_default_training_metric_definitions, +) from sagemaker.jumpstart.artifacts.model_uris import ( _retrieve_hosting_prepacked_artifact_key, _retrieve_hosting_artifact_key, _retrieve_training_artifact_key, ) +from sagemaker.jumpstart.artifacts.script_uris import _retrieve_script_uri from sagemaker.jumpstart.types import JumpStartModelSpecs from tests.unit.sagemaker.jumpstart.constants import ( BASE_SPEC, @@ -456,3 +467,212 @@ def test_retrieve_uri_from_gated_bucket(self, patched_get_model_specs): ), "s3://jumpstart-private-cache-prod-us-west-2/pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz", ) + + +class HubModelTest(unittest.TestCase): + @patch("sagemaker.jumpstart.cache.JumpStartModelsCache.get_hub_model") + def test_retrieve_default_environment_variables(self, mock_get_hub_model): + mock_get_hub_model.return_value = JumpStartModelSpecs(spec=copy.deepcopy(BASE_SPEC)) + + model_id, version = "pytorch-ic-mobilenet-v2", "1.0.2" + hub_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub/my-cool-hub" + + self.assertEqual( + _retrieve_default_environment_variables( + model_id=model_id, + model_version=version, + hub_arn=hub_arn, + script=JumpStartScriptScope.INFERENCE, + ), + { + "SAGEMAKER_PROGRAM": "inference.py", + "SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code", + "SAGEMAKER_CONTAINER_LOG_LEVEL": "20", + "SAGEMAKER_MODEL_SERVER_TIMEOUT": "3600", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, + ) + mock_get_hub_model.assert_called_once_with( + hub_model_arn=( + f"arn:aws:sagemaker:us-west-2:000000000000:hub-content/my-cool-hub/Model/{model_id}/{version}" + ) + ) + + @patch("sagemaker.jumpstart.cache.JumpStartModelsCache.get_hub_model") + def test_retrieve_image_uri(self, mock_get_hub_model): + mock_get_hub_model.return_value = JumpStartModelSpecs(spec=copy.deepcopy(BASE_SPEC)) + + model_id, version = "pytorch-ic-mobilenet-v2", "1.0.2" + hub_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub/my-cool-hub" + + self.assertEqual( + _retrieve_image_uri( + model_id=model_id, + model_version=version, + hub_arn=hub_arn, + instance_type="ml.p3.2xlarge", + image_scope=JumpStartScriptScope.TRAINING, + ), + "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.5.0-gpu-py3", + ) + mock_get_hub_model.assert_called_once_with( + hub_model_arn=( + f"arn:aws:sagemaker:us-west-2:000000000000:hub-content/my-cool-hub/Model/{model_id}/{version}" + ) + ) + + @patch("sagemaker.jumpstart.cache.JumpStartModelsCache.get_hub_model") + def test_retrieve_default_hyperparameters(self, mock_get_hub_model): + mock_get_hub_model.return_value = JumpStartModelSpecs(spec=copy.deepcopy(BASE_SPEC)) + + model_id, version = "pytorch-ic-mobilenet-v2", "1.0.2" + hub_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub/my-cool-hub" + + self.assertEqual( + _retrieve_default_hyperparameters( + model_id=model_id, model_version=version, hub_arn=hub_arn + ), + { + "epochs": "3", + "adam-learning-rate": "0.05", + "batch-size": "4", + }, + ) + mock_get_hub_model.assert_called_once_with( + hub_model_arn=( + f"arn:aws:sagemaker:us-west-2:000000000000:hub-content/my-cool-hub/Model/{model_id}/{version}" + ) + ) + + @patch("sagemaker.jumpstart.cache.JumpStartModelsCache.get_hub_model") + def test_model_supports_incremental_training(self, mock_get_hub_model): + mock_get_hub_model.return_value = JumpStartModelSpecs(spec=copy.deepcopy(BASE_SPEC)) + + model_id, version = "pytorch-ic-mobilenet-v2", "1.0.2" + hub_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub/my-cool-hub" + + self.assertEqual( + _model_supports_incremental_training( + model_id=model_id, model_version=version, hub_arn=hub_arn, region="us-west-2" + ), + True, + ) + mock_get_hub_model.assert_called_once_with( + hub_model_arn=( + f"arn:aws:sagemaker:us-west-2:000000000000:hub-content/my-cool-hub/Model/{model_id}/{version}" + ) + ) + + @patch("sagemaker.jumpstart.cache.JumpStartModelsCache.get_hub_model") + def test_retrieve_default_instance_type(self, mock_get_hub_model): + mock_get_hub_model.return_value = JumpStartModelSpecs(spec=copy.deepcopy(BASE_SPEC)) + + model_id, version = "pytorch-ic-mobilenet-v2", "1.0.2" + hub_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub/my-cool-hub" + + self.assertEqual( + _retrieve_default_instance_type( + model_id=model_id, + model_version=version, + hub_arn=hub_arn, + scope=JumpStartScriptScope.TRAINING, + ), + "ml.p3.2xlarge", + ) + mock_get_hub_model.assert_called_once_with( + hub_model_arn=( + f"arn:aws:sagemaker:us-west-2:000000000000:hub-content/my-cool-hub/Model/{model_id}/{version}" + ) + ) + + self.assertEqual( + _retrieve_default_instance_type( + model_id=model_id, + model_version=version, + hub_arn=hub_arn, + scope=JumpStartScriptScope.INFERENCE, + ), + "ml.p2.xlarge", + ) + + @patch("sagemaker.jumpstart.cache.JumpStartModelsCache.get_hub_model") + def test_retrieve_default_training_metric_definitions(self, mock_get_hub_model): + mock_get_hub_model.return_value = JumpStartModelSpecs(spec=copy.deepcopy(BASE_SPEC)) + + model_id, version = "pytorch-ic-mobilenet-v2", "1.0.2" + hub_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub/my-cool-hub" + + self.assertEqual( + _retrieve_default_training_metric_definitions( + model_id=model_id, model_version=version, hub_arn=hub_arn, region="us-west-2" + ), + [{"Regex": "val_accuracy: ([0-9\\.]+)", "Name": "pytorch-ic:val-accuracy"}], + ) + mock_get_hub_model.assert_called_once_with( + hub_model_arn=( + f"arn:aws:sagemaker:us-west-2:000000000000:hub-content/my-cool-hub/Model/{model_id}/{version}" + ) + ) + + @patch("sagemaker.jumpstart.cache.JumpStartModelsCache.get_hub_model") + def test_retrieve_model_uri(self, mock_get_hub_model): + mock_get_hub_model.return_value = JumpStartModelSpecs(spec=copy.deepcopy(BASE_SPEC)) + + model_id, version = "pytorch-ic-mobilenet-v2", "1.0.2" + hub_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub/my-cool-hub" + + self.assertEqual( + _retrieve_model_uri( + model_id=model_id, model_version=version, hub_arn=hub_arn, model_scope="training" + ), + "s3://jumpstart-cache-prod-us-west-2/pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz", + ) + mock_get_hub_model.assert_called_once_with( + hub_model_arn=( + f"arn:aws:sagemaker:us-west-2:000000000000:hub-content/my-cool-hub/Model/{model_id}/{version}" + ) + ) + + self.assertEqual( + _retrieve_model_uri( + model_id=model_id, model_version=version, hub_arn=hub_arn, model_scope="inference" + ), + "s3://jumpstart-cache-prod-us-west-2/pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz", + ) + + @patch("sagemaker.jumpstart.cache.JumpStartModelsCache.get_hub_model") + def test_retrieve_script_uri(self, mock_get_hub_model): + mock_get_hub_model.return_value = JumpStartModelSpecs(spec=copy.deepcopy(BASE_SPEC)) + + model_id, version = "pytorch-ic-mobilenet-v2", "1.0.2" + hub_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub/my-cool-hub" + + self.assertEqual( + _retrieve_script_uri( + model_id=model_id, + model_version=version, + hub_arn=hub_arn, + script_scope=JumpStartScriptScope.TRAINING, + ), + "s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/pytorch/" + "transfer_learning/ic/v1.0.0/sourcedir.tar.gz", + ) + mock_get_hub_model.assert_called_once_with( + hub_model_arn=( + f"arn:aws:sagemaker:us-west-2:000000000000:hub-content/my-cool-hub/Model/{model_id}/{version}" + ) + ) + + self.assertEqual( + _retrieve_script_uri( + model_id=model_id, + model_version=version, + hub_arn=hub_arn, + script_scope=JumpStartScriptScope.INFERENCE, + ), + "s3://jumpstart-cache-prod-us-west-2/source-directory-tarballs/pytorch/" + "inference/ic/v1.0.0/sourcedir.tar.gz", + ) diff --git a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py index 8aae4c36a8..2fac16cc72 100644 --- a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py @@ -697,4 +697,5 @@ def test_get_model_url( version=version, region=region, s3_client=DEFAULT_JUMPSTART_SAGEMAKER_SESSION.s3_client, + hub_arn=None, ) diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index f1d6ccba43..c42c15ecf5 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -37,7 +37,10 @@ DeprecatedJumpStartModelError, VulnerableJumpStartModelError, ) -from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartVersionedModelId +from sagemaker.jumpstart.types import ( + JumpStartModelHeader, + JumpStartVersionedModelId, +) from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec from mock import MagicMock @@ -259,6 +262,44 @@ def test_add_jumpstart_model_id_version_tags(): ) +def test_add_hub_arn_tags(): + tags = None + hub_arn = "arn:aws:sagemaker:us-west-2:123456789123:hub/my-mock-hub" + + assert [ + { + "Key": "sagemaker-hub:hub-arn", + "Value": "arn:aws:sagemaker:us-west-2:123456789123:hub/my-mock-hub", + } + ] == utils.add_hub_arn_tags(tags=tags, hub_arn=hub_arn) + + tags = [ + { + "Key": "sagemaker-hub:hub-arn", + "Value": "arn:aws:sagemaker:us-west-2:123456789123:hub/my-mock-hub", + } + ] + # If tags are already present, don't modify existing tags + assert [ + { + "Key": "sagemaker-hub:hub-arn", + "Value": "arn:aws:sagemaker:us-west-2:123456789123:hub/my-mock-hub", + } + ] == utils.add_hub_arn_tags(tags=tags, hub_arn=hub_arn) + + tags = [ + {"Key": "random key", "Value": "random_value"}, + ] + hub_arn = "arn:aws:sagemaker:us-west-2:123456789123:hub/my-mock-hub" + assert [ + {"Key": "random key", "Value": "random_value"}, + { + "Key": "sagemaker-hub:hub-arn", + "Value": "arn:aws:sagemaker:us-west-2:123456789123:hub/my-mock-hub", + }, + ] == utils.add_hub_arn_tags(tags=tags, hub_arn=hub_arn) + + def test_add_jumpstart_uri_tags_inference(): tags = None inference_model_uri = "dfsdfsd" @@ -1177,35 +1218,6 @@ def test_mime_type_enum_from_str(): assert MIMEType.from_suffixed_type(mime_type_with_suffix) == mime_type -def test_extract_info_from_hub_content_arn(): - model_arn = ( - "arn:aws:sagemaker:us-west-2:000000000000:hub_content/MockHub/Model/my-mock-model/1.0.2" - ) - assert utils.extract_info_from_hub_content_arn(model_arn) == ( - "MockHub", - "us-west-2", - "my-mock-model", - "1.0.2", - ) - - hub_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub/MockHub" - assert utils.extract_info_from_hub_content_arn(hub_arn) == ("MockHub", "us-west-2", None, None) - - invalid_arn = "arn:aws:sagemaker:us-west-2:000000000000:endpoint/my-endpoint-123" - assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None) - - invalid_arn = "nonsense-string" - assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None) - - invalid_arn = "" - assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None) - - invalid_arn = ( - "arn:aws:sagemaker:us-west-2:000000000000:hub-content/MyHub/Notebook/my-notebook/1.0.0" - ) - assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None) - - class TestIsValidModelId(TestCase): @patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor.get_model_specs") diff --git a/tests/unit/sagemaker/jumpstart/utils.py b/tests/unit/sagemaker/jumpstart/utils.py index 76682c0f9e..81526485f9 100644 --- a/tests/unit/sagemaker/jumpstart/utils.py +++ b/tests/unit/sagemaker/jumpstart/utils.py @@ -12,7 +12,7 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import import copy -from typing import List +from typing import List, Optional import boto3 from sagemaker.jumpstart.cache import JumpStartModelsCache @@ -22,7 +22,7 @@ JUMPSTART_REGION_NAME_SET, ) from sagemaker.jumpstart.types import ( - HubDataType, + HubContentType, JumpStartCachedContentKey, JumpStartCachedContentValue, JumpStartModelSpecs, @@ -92,6 +92,7 @@ def get_prototype_model_spec( region: str = None, model_id: str = None, version: str = None, + hub_arn: Optional[str] = None, s3_client: boto3.client = None, ) -> JumpStartModelSpecs: """This function mocks cache accessor functions. For this mock, @@ -107,6 +108,7 @@ def get_special_model_spec( region: str = None, model_id: str = None, version: str = None, + hub_arn: Optional[str] = None, s3_client: boto3.client = None, ) -> JumpStartModelSpecs: """This function mocks cache accessor functions. For this mock, @@ -122,6 +124,7 @@ def get_special_model_spec_for_inference_component_based_endpoint( region: str = None, model_id: str = None, version: str = None, + hub_arn: Optional[str] = None, s3_client: boto3.client = None, ) -> JumpStartModelSpecs: """This function mocks cache accessor functions. For this mock, @@ -145,25 +148,28 @@ def get_spec_from_base_spec( model_id: str = None, semantic_version_str: str = None, version: str = None, + hub_arn: Optional[str] = None, + hub_model_arn: Optional[str] = None, s3_client: boto3.client = None, ) -> JumpStartModelSpecs: if version and semantic_version_str: raise ValueError("Cannot specify both `version` and `semantic_version_str` fields.") - if all( - [ - "pytorch" not in model_id, - "tensorflow" not in model_id, - "huggingface" not in model_id, - "mxnet" not in model_id, - "xgboost" not in model_id, - "catboost" not in model_id, - "lightgbm" not in model_id, - "sklearn" not in model_id, - ] - ): - raise KeyError("Bad model ID") + if model_id is not None: + if all( + [ + "pytorch" not in model_id, + "tensorflow" not in model_id, + "huggingface" not in model_id, + "mxnet" not in model_id, + "xgboost" not in model_id, + "catboost" not in model_id, + "lightgbm" not in model_id, + "sklearn" not in model_id, + ] + ): + raise KeyError("Bad model ID") if region is not None and region not in JUMPSTART_REGION_NAME_SET: raise ValueError( @@ -197,14 +203,14 @@ def patched_retrieval_function( formatted_content=get_spec_from_base_spec(model_id=model_id, version=version) ) - if datatype == HubDataType.MODEL: + if datatype == HubContentType.MODEL: _, _, _, model_name, model_version = id_info.split("/") return JumpStartCachedContentValue( formatted_content=get_spec_from_base_spec(model_id=model_name, version=model_version) ) # TODO: Implement - if datatype == HubDataType.HUB: + if datatype == HubContentType.HUB: return None raise ValueError(f"Bad value for filetype: {datatype}") diff --git a/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py b/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py index ffc6000c91..afba23b5a4 100644 --- a/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py +++ b/tests/unit/sagemaker/metric_definitions/jumpstart/test_default.py @@ -47,7 +47,7 @@ def test_jumpstart_default_metric_definitions(patched_get_model_specs): ] patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version="*", s3_client=mock_client + region=region, model_id=model_id, version="*", s3_client=mock_client, hub_arn=None ) patched_get_model_specs.reset_mock() @@ -63,7 +63,7 @@ def test_jumpstart_default_metric_definitions(patched_get_model_specs): ] patched_get_model_specs.assert_called_once_with( - region=region, model_id=model_id, version="1.*", s3_client=mock_client + region=region, model_id=model_id, version="1.*", s3_client=mock_client, hub_arn=None ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/model_uris/jumpstart/test_common.py b/tests/unit/sagemaker/model_uris/jumpstart/test_common.py index 000540e12e..cde3258133 100644 --- a/tests/unit/sagemaker/model_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/model_uris/jumpstart/test_common.py @@ -47,6 +47,7 @@ def test_jumpstart_common_model_uri( model_id="pytorch-ic-mobilenet-v2", version="*", s3_client=mock_client, + hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -64,6 +65,7 @@ def test_jumpstart_common_model_uri( model_id="pytorch-ic-mobilenet-v2", version="1.*", s3_client=mock_client, + hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -82,6 +84,7 @@ def test_jumpstart_common_model_uri( model_id="pytorch-ic-mobilenet-v2", version="*", s3_client=mock_client, + hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -100,6 +103,7 @@ def test_jumpstart_common_model_uri( model_id="pytorch-ic-mobilenet-v2", version="1.*", s3_client=mock_client, + hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() diff --git a/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py b/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py index 28b53270f8..86031fbd57 100644 --- a/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py +++ b/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py @@ -42,10 +42,7 @@ def test_jumpstart_resource_requirements(patched_get_model_specs): assert default_inference_resource_requirements.requests["memory"] == 34360 patched_get_model_specs.assert_called_once_with( - region=region, - model_id=model_id, - version=model_version, - s3_client=mock_client, + region=region, model_id=model_id, version=model_version, s3_client=mock_client, hub_arn=None ) patched_get_model_specs.reset_mock() @@ -69,10 +66,7 @@ def test_jumpstart_no_supported_resource_requirements(patched_get_model_specs): assert default_inference_resource_requirements is None patched_get_model_specs.assert_called_once_with( - region=region, - model_id=model_id, - version=model_version, - s3_client=mock_client, + region=region, model_id=model_id, version=model_version, s3_client=mock_client, hub_arn=None ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py index 3f38326608..5811a9d822 100644 --- a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py @@ -47,6 +47,7 @@ def test_jumpstart_common_script_uri( model_id="pytorch-ic-mobilenet-v2", version="*", s3_client=mock_client, + hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -64,6 +65,7 @@ def test_jumpstart_common_script_uri( model_id="pytorch-ic-mobilenet-v2", version="1.*", s3_client=mock_client, + hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -78,7 +80,11 @@ def test_jumpstart_common_script_uri( sagemaker_session=mock_session, ) patched_get_model_specs.assert_called_once_with( - region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="*", s3_client=mock_client + region="us-west-2", + model_id="pytorch-ic-mobilenet-v2", + version="*", + s3_client=mock_client, + hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -97,6 +103,7 @@ def test_jumpstart_common_script_uri( model_id="pytorch-ic-mobilenet-v2", version="1.*", s3_client=mock_client, + hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() diff --git a/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py b/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py index b22b61dc40..94354e782e 100644 --- a/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py +++ b/tests/unit/sagemaker/serializers/jumpstart/test_serializers.py @@ -46,10 +46,7 @@ def test_jumpstart_default_serializers( assert isinstance(default_serializer, base_serializers.IdentitySerializer) patched_get_model_specs.assert_called_once_with( - region=region, - model_id=model_id, - version=model_version, - s3_client=mock_client, + region=region, model_id=model_id, version=model_version, s3_client=mock_client, hub_arn=None ) patched_get_model_specs.reset_mock() @@ -85,8 +82,5 @@ def test_jumpstart_serializer_options( ) patched_get_model_specs.assert_called_once_with( - region=region, - model_id=model_id, - version=model_version, - s3_client=mock_client, + region=region, model_id=model_id, version=model_version, s3_client=mock_client, hub_arn=None )