Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion src/sagemaker/jumpstart/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from sagemaker.jumpstart.utils import (
validate_model_id_and_get_type,
resolve_model_sagemaker_config_field,
verify_model_region_and_return_specs,
)
from sagemaker.utils import stringify_object, format_tags, Tags
from sagemaker.model_monitor.data_capture_config import DataCaptureConfig
Expand Down Expand Up @@ -729,11 +730,27 @@ def attach(

model_version = model_version or "*"

additional_kwargs = {"model_id": model_id, "model_version": model_version}

model_specs = verify_model_region_and_return_specs(
model_id=model_id,
version=model_version,
region=sagemaker_session.boto_region_name,
scope=JumpStartScriptScope.TRAINING,
tolerate_deprecated_model=True, # model is already trained, so tolerate if deprecated
tolerate_vulnerable_model=True, # model is already trained, so tolerate if vulnerable
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question: if a model is vulnerable, would this code raise an exception alerting the user:

my_estimator = estimator.attach()
my_model = my_estimator.deploy()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nevertheless, it seems that you are only retrieving the specs to check if the model is gated or not, so that is non-blocking. Consider updating the comments to provide more details however.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the model was actually vulnerable or deprecated, you would get an exception when doing my_estimator.deploy(). As you mention, these tolerate flags are just to get the status of whether the model is gated or not

sagemaker_session=sagemaker_session,
)

# eula was already accepted if the model was successfully trained
if model_specs.is_gated_model():
additional_kwargs.update({"environment": {"accept_eula": "true"}})

return cls._attach(
training_job_name=training_job_name,
sagemaker_session=sagemaker_session,
model_channel_name=model_channel_name,
additional_kwargs={"model_id": model_id, "model_version": model_version},
additional_kwargs=additional_kwargs,
)

def deploy(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,16 @@ def test_gated_model_training_v2(setup):
}
)

# test that we can create a JumpStartEstimator from existing job with `attach`
attached_estimator = JumpStartEstimator.attach(
training_job_name=estimator.latest_training_job.name,
model_id=model_id,
model_version=model_version,
sagemaker_session=get_sm_session(),
)

# uses ml.g5.2xlarge instance
predictor = estimator.deploy(
predictor = attached_estimator.deploy(
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
role=get_sm_session().get_caller_identity_arn(),
sagemaker_session=get_sm_session(),
Expand All @@ -172,7 +180,7 @@ def test_gated_model_training_v2(setup):
"parameters": {"max_new_tokens": 256, "top_p": 0.9, "temperature": 0.6},
}

response = predictor.predict(payload, custom_attributes="accept_eula=true")
response = predictor.predict(payload)

assert response is not None

Expand Down
34 changes: 34 additions & 0 deletions tests/unit/sagemaker/jumpstart/estimator/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -979,6 +979,40 @@ def test_jumpstart_estimator_tags(
[{"Key": "blah", "Value": "blahagain"}] + js_tags,
)

@mock.patch("sagemaker.jumpstart.estimator.JumpStartEstimator._attach")
@mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type")
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
@mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region)
def test_jumpstart_estimator_attach_eula_model(
self,
mock_get_model_specs: mock.Mock,
mock_validate_model_id_and_get_type: mock.Mock,
mock_attach: mock.Mock,
):

mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS

mock_get_model_specs.side_effect = get_special_model_spec

mock_session = mock.MagicMock(sagemaker_config={}, boto_region_name="us-west-2")

JumpStartEstimator.attach(
training_job_name="some-training-job-name",
model_id="gemma-model",
sagemaker_session=mock_session,
)

mock_attach.assert_called_once_with(
training_job_name="some-training-job-name",
sagemaker_session=mock_session,
model_channel_name="model",
additional_kwargs={
"model_id": "gemma-model",
"model_version": "*",
"environment": {"accept_eula": "true"},
},
)

@mock.patch("sagemaker.jumpstart.estimator.JumpStartEstimator._attach")
@mock.patch("sagemaker.jumpstart.estimator.get_model_id_version_from_training_job")
@mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type")
Expand Down