Skip to content

Commit

Permalink
Feature: register proprietary models from jumpstart
Browse files Browse the repository at this point in the history
  • Loading branch information
selvask-aws committed Jun 21, 2024
1 parent c9be3cd commit 58e904c
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 21 deletions.
2 changes: 2 additions & 0 deletions src/sagemaker/jumpstart/factory/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,7 @@ def get_register_kwargs(
skip_model_validation: Optional[str] = None,
source_uri: Optional[str] = None,
model_card: Optional[Dict[ModelCard, ModelPackageModelCard]] = None,
accept_eula: Optional[bool] = None,
) -> JumpStartModelRegisterKwargs:
"""Returns kwargs required to call `register` on `sagemaker.estimator.Model` object."""

Expand Down Expand Up @@ -756,6 +757,7 @@ def get_register_kwargs(
skip_model_validation=skip_model_validation,
source_uri=source_uri,
model_card=model_card,
accept_eula=accept_eula,
)

model_specs = verify_model_region_and_return_specs(
Expand Down
13 changes: 12 additions & 1 deletion src/sagemaker/jumpstart/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,6 +760,8 @@ def register(
skip_model_validation: Optional[Union[str, PipelineVariable]] = None,
source_uri: Optional[Union[str, PipelineVariable]] = None,
model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None,
accept_eula: Optional[bool] = None,

):
"""Creates a model package for creating SageMaker models or listing on Marketplace.
Expand Down Expand Up @@ -809,11 +811,19 @@ def register(
(default: None).
model_card (ModeCard or ModelPackageModelCard): document contains qualitative and
quantitative information about a model (default: None).
accept_eula (bool): For models that require a Model Access Config, specify True or
False to indicate whether model terms of use have been accepted.
The `accept_eula` value must be explicitly defined as `True` in order to
accept the end-user license agreement (EULA) that some
models require. (Default: None).
Returns:
A `sagemaker.model.ModelPackage` instance.
"""

if model_package_group_name is None and self.model_type is JumpStartModelType.PROPRIETARY:
model_package_group_name = self.model_id
source_uri = self.model_package_arn

register_kwargs = get_register_kwargs(
model_id=self.model_id,
model_version=self.model_version,
Expand Down Expand Up @@ -846,6 +856,7 @@ def register(
skip_model_validation=skip_model_validation,
source_uri=source_uri,
model_card=model_card,
accept_eula=accept_eula,
)

model_package = super(JumpStartModel, self).register(**register_kwargs.to_kwargs_dict())
Expand Down
3 changes: 3 additions & 0 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2399,6 +2399,7 @@ class JumpStartModelRegisterKwargs(JumpStartKwargs):
"skip_model_validation",
"source_uri",
"model_card",
"accept_eula",
]

SERIALIZATION_EXCLUSION_SET = {
Expand Down Expand Up @@ -2445,6 +2446,7 @@ def __init__(
skip_model_validation: Optional[str] = None,
source_uri: Optional[str] = None,
model_card: Optional[Dict[ModelCard, ModelPackageModelCard]] = None,
accept_eula: Optional[bool] = None,
) -> None:
"""Instantiates JumpStartModelRegisterKwargs object."""

Expand Down Expand Up @@ -2480,3 +2482,4 @@ def __init__(
self.skip_model_validation = skip_model_validation
self.source_uri = source_uri
self.model_card = model_card
self.accept_eula = accept_eula
34 changes: 17 additions & 17 deletions src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,7 @@ def register(
skip_model_validation: Optional[Union[str, PipelineVariable]] = None,
source_uri: Optional[Union[str, PipelineVariable]] = None,
model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None,
accept_eula: Optional[bool] = None,
):
"""Creates a model package for creating SageMaker models or listing on Marketplace.
Expand Down Expand Up @@ -516,23 +517,22 @@ def register(

if image_uri is not None:
self.image_uri = image_uri

if model_package_group_name is None and model_package_name is None and self.model_type is not JumpStartModelType.PROPRIETARY:
# If model package group and model package name is not set
# then register to auto-generated model package group
model_package_group_name = utils.base_name_from_image(
self.image_uri, default_base_name=ModelPackage.__name__
)

if model_package_group_name is not None:
container_def = self.prepare_container_def()
container_def = update_container_with_inference_params(
framework=framework,
framework_version=framework_version,
nearest_model_name=nearest_model_name,
data_input_configuration=data_input_configuration,
container_def=container_def,
)
if self.model_type is not JumpStartModelType.PROPRIETARY:
if model_package_group_name is None and model_package_name is None:
# If model package group and model package name is not set
# then register to auto-generated model package group
model_package_group_name = utils.base_name_from_image(
self.image_uri, default_base_name=ModelPackage.__name__
)
if model_package_group_name is not None:
container_def = self.prepare_container_def(accept_eula=accept_eula)
container_def = update_container_with_inference_params(
framework=framework,
framework_version=framework_version,
nearest_model_name=nearest_model_name,
data_input_configuration=data_input_configuration,
container_def=container_def,
)
else:
container_def = {
"Image": self.image_uri,
Expand Down
35 changes: 32 additions & 3 deletions tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,8 @@ def test_jumpstart_model_register(setup):

response = predictor.predict("hello world!")

predictor.delete_predictor()

assert response is not None


Expand Down Expand Up @@ -306,10 +308,10 @@ def test_register_proprietary_jumpstart_model(setup):
role=get_sm_session().get_caller_identity_arn(),
sagemaker_session=get_sm_session(),
)
model_package = model.register()


pp = model.register()

predictor = pp.deploy(
predictor = model_package.deploy(
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}]
)
payload = {"prompt": "To be, or", "maxTokens": 4, "temperature": 0, "numResults": 1}
Expand All @@ -320,3 +322,30 @@ def test_register_proprietary_jumpstart_model(setup):

assert response is not None


@pytest.mark.skipif(
True,
reason="Only enable if test account is subscribed to the proprietary model",
)
def test_register_gated_jumpstart_model(setup):

model_id="meta-textgenerationneuron-llama-2-7b"
model = JumpStartModel(
model_id=model_id,
model_version="1.1.0",
role=get_sm_session().get_caller_identity_arn(),
sagemaker_session=get_sm_session(),
)
model_package = model.register(accept_eula=True)

predictor = model_package.deploy(
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], accept_eula=True
)
payload = {"prompt": "To be, or", "maxTokens": 4, "temperature": 0, "numResults": 1}

response = predictor.predict(payload)

predictor.delete_predictor()

assert response is not None

0 comments on commit 58e904c

Please sign in to comment.