From fced4088b5ac4a683d4890925d829afa51b5c347 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Thu, 28 Mar 2024 20:07:14 +0000 Subject: [PATCH] fix: attach jumpstart estimator for gated model --- src/sagemaker/jumpstart/estimator.py | 19 ++++++++++- .../estimator/test_jumpstart_estimator.py | 12 +++++-- .../jumpstart/estimator/test_estimator.py | 34 +++++++++++++++++++ 3 files changed, 62 insertions(+), 3 deletions(-) diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index bac076ea4a..88927ae931 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -37,6 +37,7 @@ from sagemaker.jumpstart.utils import ( validate_model_id_and_get_type, resolve_model_sagemaker_config_field, + verify_model_region_and_return_specs, ) from sagemaker.utils import stringify_object, format_tags, Tags from sagemaker.model_monitor.data_capture_config import DataCaptureConfig @@ -729,11 +730,27 @@ def attach( model_version = model_version or "*" + additional_kwargs = {"model_id": model_id, "model_version": model_version} + + model_specs = verify_model_region_and_return_specs( + model_id=model_id, + version=model_version, + region=sagemaker_session.boto_region_name, + scope=JumpStartScriptScope.TRAINING, + tolerate_deprecated_model=True, # model is already trained, so tolerate if deprecated + tolerate_vulnerable_model=True, # model is already trained, so tolerate if vulnerable + sagemaker_session=sagemaker_session, + ) + + # eula was already accepted if the model was successfully trained + if model_specs.is_gated_model(): + additional_kwargs.update({"environment": {"accept_eula": "true"}}) + return cls._attach( training_job_name=training_job_name, sagemaker_session=sagemaker_session, model_channel_name=model_channel_name, - additional_kwargs={"model_id": model_id, "model_version": model_version}, + additional_kwargs=additional_kwargs, ) def deploy( diff --git a/tests/integ/sagemaker/jumpstart/estimator/test_jumpstart_estimator.py b/tests/integ/sagemaker/jumpstart/estimator/test_jumpstart_estimator.py index 5faa40ccda..a839a293c5 100644 --- a/tests/integ/sagemaker/jumpstart/estimator/test_jumpstart_estimator.py +++ b/tests/integ/sagemaker/jumpstart/estimator/test_jumpstart_estimator.py @@ -160,8 +160,16 @@ def test_gated_model_training_v2(setup): } ) + # test that we can create a JumpStartEstimator from existing job with `attach` + attached_estimator = JumpStartEstimator.attach( + training_job_name=estimator.latest_training_job.name, + model_id=model_id, + model_version=model_version, + sagemaker_session=get_sm_session(), + ) + # uses ml.g5.2xlarge instance - predictor = estimator.deploy( + predictor = attached_estimator.deploy( tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], role=get_sm_session().get_caller_identity_arn(), sagemaker_session=get_sm_session(), @@ -172,7 +180,7 @@ def test_gated_model_training_v2(setup): "parameters": {"max_new_tokens": 256, "top_p": 0.9, "temperature": 0.6}, } - response = predictor.predict(payload, custom_attributes="accept_eula=true") + response = predictor.predict(payload) assert response is not None diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index 1e048ef0dd..5a11a5e88d 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -979,6 +979,40 @@ def test_jumpstart_estimator_tags( [{"Key": "blah", "Value": "blahagain"}] + js_tags, ) + @mock.patch("sagemaker.jumpstart.estimator.JumpStartEstimator._attach") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_jumpstart_estimator_attach_eula_model( + self, + mock_get_model_specs: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, + mock_attach: mock.Mock, + ): + + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS + + mock_get_model_specs.side_effect = get_special_model_spec + + mock_session = mock.MagicMock(sagemaker_config={}, boto_region_name="us-west-2") + + JumpStartEstimator.attach( + training_job_name="some-training-job-name", + model_id="gemma-model", + sagemaker_session=mock_session, + ) + + mock_attach.assert_called_once_with( + training_job_name="some-training-job-name", + sagemaker_session=mock_session, + model_channel_name="model", + additional_kwargs={ + "model_id": "gemma-model", + "model_version": "*", + "environment": {"accept_eula": "true"}, + }, + ) + @mock.patch("sagemaker.jumpstart.estimator.JumpStartEstimator._attach") @mock.patch("sagemaker.jumpstart.estimator.get_model_id_version_from_training_job") @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type")