Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/jsch jumpstart estimator support #4439

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
032cb80
add hub and hubcontent support in retrieval function for jumpstart mo…
bencrabtree Feb 19, 2024
ef042d9
update types and var names
bencrabtree Feb 19, 2024
49ae11b
update linter
bencrabtree Feb 19, 2024
6175087
linter
bencrabtree Feb 19, 2024
4c9b2d0
linter
bencrabtree Feb 19, 2024
63345ea
flake8 check
bencrabtree Feb 19, 2024
6efc206
add hub name support for jumpstart estimator
bencrabtree Feb 20, 2024
2e9f76f
linter
bencrabtree Feb 20, 2024
ac8dd60
linter2
bencrabtree Feb 20, 2024
d4f7a00
fix param
bencrabtree Feb 20, 2024
5492474
move to utils and test
bencrabtree Feb 21, 2024
1e26760
feat: add hub and hubcontent support in retrieval function for jumpst…
bencrabtree Feb 21, 2024
4ae201c
add hub and hubcontent support in retrieval function for jumpstart mo…
bencrabtree Feb 19, 2024
4a19a33
update types and var names
bencrabtree Feb 19, 2024
4d35379
update linter
bencrabtree Feb 19, 2024
174c4fd
linter
bencrabtree Feb 19, 2024
b7a8835
linter
bencrabtree Feb 19, 2024
bb7a9fb
flake8 check
bencrabtree Feb 19, 2024
8ba576a
pass hub_arn into all estimator utils/artifacts
bencrabtree Feb 21, 2024
ecd1f97
feat: add hub and hubcontent support in retrieval function for jumpst…
bencrabtree Feb 21, 2024
8df4478
add hub and hubcontent support in retrieval function for jumpstart mo…
bencrabtree Feb 19, 2024
4870c0b
update types and var names
bencrabtree Feb 19, 2024
4dc21f5
update linter
bencrabtree Feb 19, 2024
dd51314
remove duplicate
bencrabtree Feb 21, 2024
195b84b
Merge branch 'master-jumpstart-curated-hub' of https://github.com/aws…
bencrabtree Feb 21, 2024
a39ae5f
linter
bencrabtree Feb 21, 2024
354b33e
add important unit test
bencrabtree Feb 21, 2024
424254e
update tests
bencrabtree Feb 22, 2024
8a3160a
black styles
bencrabtree Feb 22, 2024
151350c
finish tests
bencrabtree Feb 22, 2024
dd087da
create curated hub utils and types
bencrabtree Feb 23, 2024
a61dfb4
fix linter
bencrabtree Feb 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/sagemaker/environment_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def retrieve_default(
region: Optional[str] = None,
model_id: Optional[str] = None,
model_version: Optional[str] = None,
hub_arn: Optional[str] = None,
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
include_aws_sdk_env_vars: bool = True,
Expand All @@ -46,6 +47,8 @@ def retrieve_default(
retrieve the default environment variables. (Default: None).
model_version (str): Optional. The version of the model for which to retrieve the
default environment variables. (Default: None).
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
bencrabtree marked this conversation as resolved.
Show resolved Hide resolved
model details from. (default: None).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: capitalize Default

tolerate_vulnerable_model (bool): True if vulnerable versions of model
specifications should be tolerated (exception not raised). If False, raises an
exception if the script used by this version of the model has dependencies with known
Expand Down Expand Up @@ -80,6 +83,7 @@ def retrieve_default(
return artifacts._retrieve_default_environment_variables(
model_id,
model_version,
hub_arn,
region,
tolerate_vulnerable_model,
tolerate_deprecated_model,
Expand Down
8 changes: 8 additions & 0 deletions src/sagemaker/hyperparameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def retrieve_default(
region: Optional[str] = None,
model_id: Optional[str] = None,
model_version: Optional[str] = None,
hub_arn: Optional[str] = None,
instance_type: Optional[str] = None,
include_container_hyperparameters: bool = False,
tolerate_vulnerable_model: bool = False,
Expand All @@ -46,6 +47,8 @@ def retrieve_default(
retrieve the default hyperparameters. (Default: None).
model_version (str): The version of the model for which to retrieve the
default hyperparameters. (Default: None).
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
model details from. (default: None).
instance_type (str): An instance type to optionally supply in order to get hyperparameters
specific for the instance type.
include_container_hyperparameters (bool): ``True`` if the container hyperparameters
Expand Down Expand Up @@ -80,6 +83,7 @@ def retrieve_default(
return artifacts._retrieve_default_hyperparameters(
model_id=model_id,
model_version=model_version,
hub_arn=hub_arn,
instance_type=instance_type,
region=region,
include_container_hyperparameters=include_container_hyperparameters,
Expand All @@ -93,6 +97,7 @@ def validate(
region: Optional[str] = None,
model_id: Optional[str] = None,
model_version: Optional[str] = None,
hub_arn: Optional[str] = None,
hyperparameters: Optional[dict] = None,
validation_mode: HyperparameterValidationMode = HyperparameterValidationMode.VALIDATE_PROVIDED,
tolerate_vulnerable_model: bool = False,
Expand All @@ -107,6 +112,8 @@ def validate(
(Default: None).
model_version (str): The version of the model for which to validate hyperparameters.
(Default: None).
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
model details from. (default: None).
hyperparameters (dict): Hyperparameters to validate.
(Default: None).
validation_mode (HyperparameterValidationMode): Method of validation to use with
Expand Down Expand Up @@ -148,6 +155,7 @@ def validate(
return validate_hyperparameters(
model_id=model_id,
model_version=model_version,
hub_arn=hub_arn,
hyperparameters=hyperparameters,
validation_mode=validation_mode,
region=region,
Expand Down
4 changes: 4 additions & 0 deletions src/sagemaker/image_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def retrieve(
training_compiler_config=None,
model_id=None,
model_version=None,
hub_arn=None,
tolerate_vulnerable_model=False,
tolerate_deprecated_model=False,
sdk_version=None,
Expand Down Expand Up @@ -101,6 +102,8 @@ def retrieve(
(default: None).
model_version (str): The version of the JumpStart model for which to retrieve the
image URI (default: None).
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
model details from. (default: None).
tolerate_vulnerable_model (bool): ``True`` if vulnerable versions of model specifications
should be tolerated without an exception raised. If ``False``, raises an exception if
the script used by this version of the model has dependencies with known security
Expand Down Expand Up @@ -146,6 +149,7 @@ def retrieve(
model_id,
model_version,
image_scope,
hub_arn,
framework,
region,
version,
Expand Down
8 changes: 8 additions & 0 deletions src/sagemaker/instance_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def retrieve_default(
region: Optional[str] = None,
model_id: Optional[str] = None,
model_version: Optional[str] = None,
hub_arn: Optional[str] = None,
scope: Optional[str] = None,
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
Expand All @@ -44,6 +45,8 @@ def retrieve_default(
retrieve the default instance type. (Default: None).
model_version (str): The version of the model for which to retrieve the
default instance type. (Default: None).
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
model details from. (default: None).
scope (str): The model type, i.e. what it is used for.
Valid values: "training" and "inference".
tolerate_vulnerable_model (bool): True if vulnerable versions of model
Expand Down Expand Up @@ -80,6 +83,7 @@ def retrieve_default(
model_id,
model_version,
scope,
hub_arn,
region,
tolerate_vulnerable_model,
tolerate_deprecated_model,
Expand All @@ -92,6 +96,7 @@ def retrieve(
region: Optional[str] = None,
model_id: Optional[str] = None,
model_version: Optional[str] = None,
hub_arn: Optional[str] = None,
scope: Optional[str] = None,
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
Expand All @@ -107,6 +112,8 @@ def retrieve(
retrieve the supported instance types. (Default: None).
model_version (str): The version of the model for which to retrieve the
supported instance types. (Default: None).
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
model details from. (default: None).
tolerate_vulnerable_model (bool): True if vulnerable versions of model
specifications should be tolerated (exception not raised). If False, raises an
exception if the script used by this version of the model has dependencies with known
Expand Down Expand Up @@ -142,6 +149,7 @@ def retrieve(
model_id,
model_version,
scope,
hub_arn,
region,
tolerate_vulnerable_model,
tolerate_deprecated_model,
Expand Down
14 changes: 13 additions & 1 deletion src/sagemaker/jumpstart/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from sagemaker.deprecations import deprecated
from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartModelSpecs
from sagemaker.jumpstart import cache
from sagemaker.jumpstart.curated_hub.utils import construct_hub_model_arn_from_inputs
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME


Expand Down Expand Up @@ -239,7 +240,11 @@ def get_model_header(region: str, model_id: str, version: str) -> JumpStartModel

@staticmethod
def get_model_specs(
region: str, model_id: str, version: str, s3_client: Optional[boto3.client] = None
region: str,
model_id: str,
version: str,
hub_arn: Optional[str] = None,
s3_client: Optional[boto3.client] = None,
) -> JumpStartModelSpecs:
"""Returns model specs from JumpStart models cache.

Expand All @@ -259,6 +264,13 @@ def get_model_specs(
{**JumpStartModelsAccessor._cache_kwargs, **additional_kwargs}
)
JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs)

if hub_arn:
hub_model_arn = construct_hub_model_arn_from_inputs(
hub_arn=hub_arn, model_name=model_id, version=version
)
return JumpStartModelsAccessor._cache.get_hub_model(hub_model_arn=hub_model_arn)

return JumpStartModelsAccessor._cache.get_specs( # type: ignore
model_id=model_id, semantic_version_str=version
)
Expand Down
9 changes: 9 additions & 0 deletions src/sagemaker/jumpstart/artifacts/environment_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
def _retrieve_default_environment_variables(
model_id: str,
model_version: str,
hub_arn: Optional[str] = None,
region: Optional[str] = None,
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
Expand All @@ -46,6 +47,8 @@ def _retrieve_default_environment_variables(
retrieve the default environment variables.
model_version (str): Version of the JumpStart model for which to retrieve the
default environment variables.
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
model details from. (default: None).
region (Optional[str]): Region for which to retrieve default environment variables.
(Default: None).
tolerate_vulnerable_model (bool): True if vulnerable versions of model
Expand Down Expand Up @@ -77,6 +80,7 @@ def _retrieve_default_environment_variables(
model_specs = verify_model_region_and_return_specs(
model_id=model_id,
version=model_version,
hub_arn=hub_arn,
scope=script,
region=region,
tolerate_vulnerable_model=tolerate_vulnerable_model,
Expand Down Expand Up @@ -113,6 +117,7 @@ def _retrieve_default_environment_variables(
gated_model_env_var: Optional[str] = _retrieve_gated_model_uri_env_var_value(
model_id=model_id,
model_version=model_version,
hub_arn=hub_arn,
region=region,
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
Expand All @@ -131,6 +136,7 @@ def _retrieve_default_environment_variables(
def _retrieve_gated_model_uri_env_var_value(
model_id: str,
model_version: str,
hub_arn: Optional[str] = None,
region: Optional[str] = None,
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
Expand All @@ -144,6 +150,8 @@ def _retrieve_gated_model_uri_env_var_value(
retrieve the gated model env var URI.
model_version (str): Version of the JumpStart model for which to retrieve the
gated model env var URI.
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
model details from. (default: None).
region (Optional[str]): Region for which to retrieve the gated model env var URI.
(Default: None).
tolerate_vulnerable_model (bool): True if vulnerable versions of model
Expand Down Expand Up @@ -174,6 +182,7 @@ def _retrieve_gated_model_uri_env_var_value(
model_specs = verify_model_region_and_return_specs(
model_id=model_id,
version=model_version,
hub_arn=hub_arn,
scope=JumpStartScriptScope.TRAINING,
region=region,
tolerate_vulnerable_model=tolerate_vulnerable_model,
Expand Down
4 changes: 4 additions & 0 deletions src/sagemaker/jumpstart/artifacts/hyperparameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
def _retrieve_default_hyperparameters(
model_id: str,
model_version: str,
hub_arn: Optional[str] = None,
region: Optional[str] = None,
include_container_hyperparameters: bool = False,
tolerate_vulnerable_model: bool = False,
Expand All @@ -44,6 +45,8 @@ def _retrieve_default_hyperparameters(
retrieve the default hyperparameters.
model_version (str): Version of the JumpStart model for which to retrieve the
default hyperparameters.
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
model details from. (default: None).
region (str): Region for which to retrieve default hyperparameters.
(Default: None).
include_container_hyperparameters (bool): True if container hyperparameters
Expand Down Expand Up @@ -76,6 +79,7 @@ def _retrieve_default_hyperparameters(
model_specs = verify_model_region_and_return_specs(
model_id=model_id,
version=model_version,
hub_arn=hub_arn,
scope=JumpStartScriptScope.TRAINING,
region=region,
tolerate_vulnerable_model=tolerate_vulnerable_model,
Expand Down
4 changes: 4 additions & 0 deletions src/sagemaker/jumpstart/artifacts/image_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def _retrieve_image_uri(
model_id: str,
model_version: str,
image_scope: str,
hub_arn: Optional[str] = None,
framework: Optional[str] = None,
region: Optional[str] = None,
version: Optional[str] = None,
Expand All @@ -57,6 +58,8 @@ def _retrieve_image_uri(
model_id (str): JumpStart model ID for which to retrieve image URI.
model_version (str): Version of the JumpStart model for which to retrieve
the image URI.
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
model details from. (default: None).
image_scope (str): The image type, i.e. what it is used for.
Valid values: "training", "inference", "eia". If ``accelerator_type`` is set,
``image_scope`` is ignored.
Expand Down Expand Up @@ -110,6 +113,7 @@ def _retrieve_image_uri(
model_specs = verify_model_region_and_return_specs(
model_id=model_id,
version=model_version,
hub_arn=hub_arn,
scope=image_scope,
region=region,
tolerate_vulnerable_model=tolerate_vulnerable_model,
Expand Down
4 changes: 4 additions & 0 deletions src/sagemaker/jumpstart/artifacts/incremental_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def _model_supports_incremental_training(
model_id: str,
model_version: str,
region: Optional[str],
hub_arn: Optional[str] = None,
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
Expand All @@ -43,6 +44,8 @@ def _model_supports_incremental_training(
support status for incremental training.
region (Optional[str]): Region for which to retrieve the
support status for incremental training.
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
model details from. (default: None).
tolerate_vulnerable_model (bool): True if vulnerable versions of model
specifications should be tolerated (exception not raised). If False, raises an
exception if the script used by this version of the model has dependencies with known
Expand All @@ -64,6 +67,7 @@ def _model_supports_incremental_training(
model_specs = verify_model_region_and_return_specs(
model_id=model_id,
version=model_version,
hub_arn=hub_arn,
scope=JumpStartScriptScope.TRAINING,
region=region,
tolerate_vulnerable_model=tolerate_vulnerable_model,
Expand Down
8 changes: 8 additions & 0 deletions src/sagemaker/jumpstart/artifacts/instance_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def _retrieve_default_instance_type(
model_id: str,
model_version: str,
scope: str,
hub_arn: Optional[str] = None,
region: Optional[str] = None,
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
Expand All @@ -48,6 +49,8 @@ def _retrieve_default_instance_type(
default instance type.
scope (str): The script type, i.e. what it is used for.
Valid values: "training" and "inference".
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
model details from. (default: None).
region (Optional[str]): Region for which to retrieve default instance type.
(Default: None).
tolerate_vulnerable_model (bool): True if vulnerable versions of model
Expand Down Expand Up @@ -80,6 +83,7 @@ def _retrieve_default_instance_type(
model_specs = verify_model_region_and_return_specs(
model_id=model_id,
version=model_version,
hub_arn=hub_arn,
scope=scope,
region=region,
tolerate_vulnerable_model=tolerate_vulnerable_model,
Expand Down Expand Up @@ -119,6 +123,7 @@ def _retrieve_instance_types(
model_id: str,
model_version: str,
scope: str,
hub_arn: Optional[str] = None,
region: Optional[str] = None,
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
Expand All @@ -134,6 +139,8 @@ def _retrieve_instance_types(
supported instance types.
scope (str): The script type, i.e. what it is used for.
Valid values: "training" and "inference".
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
model details from. (default: None).
region (Optional[str]): Region for which to retrieve supported instance types.
(Default: None).
tolerate_vulnerable_model (bool): True if vulnerable versions of model
Expand Down Expand Up @@ -166,6 +173,7 @@ def _retrieve_instance_types(
model_specs = verify_model_region_and_return_specs(
model_id=model_id,
version=model_version,
hub_arn=hub_arn,
scope=scope,
region=region,
tolerate_vulnerable_model=tolerate_vulnerable_model,
Expand Down