Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JumpStart CuratedHub Launch #4748

Merged
merged 19 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 18 additions & 11 deletions src/sagemaker/accept_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def retrieve_options(
region: Optional[str] = None,
model_id: Optional[str] = None,
model_version: Optional[str] = None,
hub_arn: Optional[str] = None,
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
Expand All @@ -37,6 +38,8 @@ def retrieve_options(
retrieve the supported accept types. (Default: None).
model_version (str): The version of the model for which to retrieve the
supported accept types. (Default: None).
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
model details from. (Default: None).
tolerate_vulnerable_model (bool): True if vulnerable versions of model
specifications should be tolerated (exception not raised). If False, raises an
exception if the script used by this version of the model has dependencies with known
Expand All @@ -60,11 +63,12 @@ def retrieve_options(
)

return artifacts._retrieve_supported_accept_types(
model_id,
model_version,
region,
tolerate_vulnerable_model,
tolerate_deprecated_model,
model_id=model_id,
model_version=model_version,
hub_arn=hub_arn,
region=region,
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
)

Expand All @@ -73,6 +77,7 @@ def retrieve_default(
region: Optional[str] = None,
model_id: Optional[str] = None,
model_version: Optional[str] = None,
hub_arn: Optional[str] = None,
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
Expand All @@ -87,6 +92,8 @@ def retrieve_default(
retrieve the default accept type. (Default: None).
model_version (str): The version of the model for which to retrieve the
default accept type. (Default: None).
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
model details from. (Default: None).
tolerate_vulnerable_model (bool): True if vulnerable versions of model
specifications should be tolerated (exception not raised). If False, raises an
exception if the script used by this version of the model has dependencies with known
Expand All @@ -110,11 +117,11 @@ def retrieve_default(
)

return artifacts._retrieve_default_accept_type(
model_id,
model_version,
region,
tolerate_vulnerable_model,
tolerate_deprecated_model,
model_id=model_id,
model_version=model_version,
hub_arn=hub_arn,
region=region,
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
model_type=model_type,
)
2 changes: 2 additions & 0 deletions src/sagemaker/chainer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ def prepare_container_def(
accelerator_type=None,
serverless_inference_config=None,
accept_eula=None,
model_reference_arn=None,
):
"""Return a container definition with framework configuration set in model environment.

Expand Down Expand Up @@ -333,6 +334,7 @@ def prepare_container_def(
self.model_data,
deploy_env,
accept_eula=accept_eula,
model_reference_arn=model_reference_arn,
)

def serving_image_uri(
Expand Down
28 changes: 18 additions & 10 deletions src/sagemaker/content_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def retrieve_options(
region: Optional[str] = None,
model_id: Optional[str] = None,
model_version: Optional[str] = None,
hub_arn: Optional[str] = None,
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
Expand All @@ -37,6 +38,8 @@ def retrieve_options(
retrieve the supported content types. (Default: None).
model_version (str): The version of the model for which to retrieve the
supported content types. (Default: None).
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
model details from. (Default: None).
tolerate_vulnerable_model (bool): True if vulnerable versions of model
specifications should be tolerated (exception not raised). If False, raises an
exception if the script used by this version of the model has dependencies with known
Expand All @@ -60,11 +63,12 @@ def retrieve_options(
)

return artifacts._retrieve_supported_content_types(
model_id,
model_version,
region,
tolerate_vulnerable_model,
tolerate_deprecated_model,
model_id=model_id,
model_version=model_version,
hub_arn=hub_arn,
region=region,
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
)

Expand All @@ -73,6 +77,7 @@ def retrieve_default(
region: Optional[str] = None,
model_id: Optional[str] = None,
model_version: Optional[str] = None,
hub_arn: Optional[str] = None,
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
Expand All @@ -87,6 +92,8 @@ def retrieve_default(
retrieve the default content type. (Default: None).
model_version (str): The version of the model for which to retrieve the
default content type. (Default: None).
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
model details from. (default: None).
tolerate_vulnerable_model (bool): True if vulnerable versions of model
specifications should be tolerated (exception not raised). If False, raises an
exception if the script used by this version of the model has dependencies with known
Expand All @@ -110,11 +117,12 @@ def retrieve_default(
)

return artifacts._retrieve_default_content_type(
model_id,
model_version,
region,
tolerate_vulnerable_model,
tolerate_deprecated_model,
model_id=model_id,
model_version=model_version,
hub_arn=hub_arn,
region=region,
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
model_type=model_type,
)
Expand Down
28 changes: 18 additions & 10 deletions src/sagemaker/deserializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def retrieve_options(
region: Optional[str] = None,
model_id: Optional[str] = None,
model_version: Optional[str] = None,
hub_arn: Optional[str] = None,
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
Expand All @@ -56,6 +57,8 @@ def retrieve_options(
retrieve the supported deserializers. (Default: None).
model_version (str): The version of the model for which to retrieve the
supported deserializers. (Default: None).
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
model details from. (Default: None).
tolerate_vulnerable_model (bool): True if vulnerable versions of model
specifications should be tolerated (exception not raised). If False, raises an
exception if the script used by this version of the model has dependencies with known
Expand All @@ -80,11 +83,12 @@ def retrieve_options(
)

return artifacts._retrieve_deserializer_options(
model_id,
model_version,
region,
tolerate_vulnerable_model,
tolerate_deprecated_model,
model_id=model_id,
model_version=model_version,
hub_arn=hub_arn,
region=region,
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
)

Expand All @@ -93,6 +97,7 @@ def retrieve_default(
region: Optional[str] = None,
model_id: Optional[str] = None,
model_version: Optional[str] = None,
hub_arn: Optional[str] = None,
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
Expand All @@ -107,6 +112,8 @@ def retrieve_default(
retrieve the default deserializer. (Default: None).
model_version (str): The version of the model for which to retrieve the
default deserializer. (Default: None).
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
model details from. (Default: None).
tolerate_vulnerable_model (bool): True if vulnerable versions of model
specifications should be tolerated (exception not raised). If False, raises an
exception if the script used by this version of the model has dependencies with known
Expand All @@ -131,11 +138,12 @@ def retrieve_default(
)

return artifacts._retrieve_default_deserializer(
model_id,
model_version,
region,
tolerate_vulnerable_model,
tolerate_deprecated_model,
model_id=model_id,
model_version=model_version,
hub_arn=hub_arn,
region=region,
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
model_type=model_type,
)
1 change: 1 addition & 0 deletions src/sagemaker/djl_inference/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,6 +732,7 @@ def prepare_container_def(
accelerator_type=None,
serverless_inference_config=None,
accept_eula=None,
model_reference_arn=None,
): # pylint: disable=unused-argument
"""A container definition with framework configuration set in model environment variables.

Expand Down
15 changes: 9 additions & 6 deletions src/sagemaker/environment_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def retrieve_default(
region: Optional[str] = None,
model_id: Optional[str] = None,
model_version: Optional[str] = None,
hub_arn: Optional[str] = None,
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
include_aws_sdk_env_vars: bool = True,
Expand All @@ -46,6 +47,7 @@ def retrieve_default(
retrieve the default environment variables. (Default: None).
model_version (str): Optional. The version of the model for which to retrieve the
default environment variables. (Default: None).
hub_arn (str): The arn of the SageMaker Hub for which to retrieve model details from. (Default: None).
tolerate_vulnerable_model (bool): True if vulnerable versions of model
specifications should be tolerated (exception not raised). If False, raises an
exception if the script used by this version of the model has dependencies with known
Expand Down Expand Up @@ -78,12 +80,13 @@ def retrieve_default(
)

return artifacts._retrieve_default_environment_variables(
model_id,
model_version,
region,
tolerate_vulnerable_model,
tolerate_deprecated_model,
include_aws_sdk_env_vars,
model_id=model_id,
model_version=model_version,
hub_arn=hub_arn,
region=region,
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
include_aws_sdk_env_vars=include_aws_sdk_env_vars,
sagemaker_session=sagemaker_session,
instance_type=instance_type,
script=script,
Expand Down
2 changes: 2 additions & 0 deletions src/sagemaker/huggingface/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,7 @@ def prepare_container_def(
serverless_inference_config=None,
inference_tool=None,
accept_eula=None,
model_reference_arn=None,
):
"""A container definition with framework configuration set in model environment variables.

Expand Down Expand Up @@ -533,6 +534,7 @@ def prepare_container_def(
self.repacked_model_data or self.model_data,
deploy_env,
accept_eula=accept_eula,
model_reference_arn=model_reference_arn,
)

def serving_image_uri(
Expand Down
8 changes: 8 additions & 0 deletions src/sagemaker/hyperparameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def retrieve_default(
region: Optional[str] = None,
model_id: Optional[str] = None,
model_version: Optional[str] = None,
hub_arn: Optional[str] = None,
instance_type: Optional[str] = None,
include_container_hyperparameters: bool = False,
tolerate_vulnerable_model: bool = False,
Expand All @@ -46,6 +47,8 @@ def retrieve_default(
retrieve the default hyperparameters. (Default: None).
model_version (str): The version of the model for which to retrieve the
default hyperparameters. (Default: None).
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
model details from. (default: None).
instance_type (str): An instance type to optionally supply in order to get hyperparameters
specific for the instance type.
include_container_hyperparameters (bool): ``True`` if the container hyperparameters
Expand Down Expand Up @@ -80,6 +83,7 @@ def retrieve_default(
return artifacts._retrieve_default_hyperparameters(
model_id=model_id,
model_version=model_version,
hub_arn=hub_arn,
instance_type=instance_type,
region=region,
include_container_hyperparameters=include_container_hyperparameters,
Expand All @@ -92,6 +96,7 @@ def retrieve_default(
def validate(
region: Optional[str] = None,
model_id: Optional[str] = None,
hub_arn: Optional[str] = None,
model_version: Optional[str] = None,
hyperparameters: Optional[dict] = None,
validation_mode: HyperparameterValidationMode = HyperparameterValidationMode.VALIDATE_PROVIDED,
Expand All @@ -107,6 +112,8 @@ def validate(
(Default: None).
model_version (str): The version of the model for which to validate hyperparameters.
(Default: None).
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
model details from. (default: None).
hyperparameters (dict): Hyperparameters to validate.
(Default: None).
validation_mode (HyperparameterValidationMode): Method of validation to use with
Expand Down Expand Up @@ -148,6 +155,7 @@ def validate(
return validate_hyperparameters(
model_id=model_id,
model_version=model_version,
hub_arn=hub_arn,
hyperparameters=hyperparameters,
validation_mode=validation_mode,
region=region,
Expand Down
4 changes: 4 additions & 0 deletions src/sagemaker/image_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def retrieve(
training_compiler_config=None,
model_id=None,
model_version=None,
hub_arn=None,
tolerate_vulnerable_model=False,
tolerate_deprecated_model=False,
sdk_version=None,
Expand Down Expand Up @@ -104,6 +105,8 @@ def retrieve(
(default: None).
model_version (str): The version of the JumpStart model for which to retrieve the
image URI (default: None).
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
model details from. (Default: None).
tolerate_vulnerable_model (bool): ``True`` if vulnerable versions of model specifications
should be tolerated without an exception raised. If ``False``, raises an exception if
the script used by this version of the model has dependencies with known security
Expand Down Expand Up @@ -149,6 +152,7 @@ def retrieve(
model_id,
model_version,
image_scope,
hub_arn,
framework,
region,
version,
Expand Down
Loading
Loading