diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index 8159ab5cc8..0a7a29a8cd 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -327,8 +327,10 @@ def _retrieval_function( ) if file_type == JumpStartS3FileType.SPECS: formatted_body, _ = self._get_json_file(s3_key, file_type) + model_specs = JumpStartModelSpecs(formatted_body) + utils.emit_logs_based_on_model_specs(model_specs, self.get_region()) return JumpStartCachedS3ContentValue( - formatted_content=JumpStartModelSpecs(formatted_body) + formatted_content=model_specs ) raise ValueError( f"Bad value for key '{key}': must be in {[JumpStartS3FileType.MANIFEST, JumpStartS3FileType.SPECS]}" diff --git a/src/sagemaker/jumpstart/constants.py b/src/sagemaker/jumpstart/constants.py index c08ccac7f7..93f01416ee 100644 --- a/src/sagemaker/jumpstart/constants.py +++ b/src/sagemaker/jumpstart/constants.py @@ -12,6 +12,7 @@ # language governing permissions and limitations under the License. """This module stores constants related to SageMaker JumpStart.""" from __future__ import absolute_import +import logging from typing import Dict, Set, Type import boto3 from sagemaker.base_deserializers import BaseDeserializer, JSONDeserializer @@ -173,3 +174,5 @@ } MODEL_ID_LIST_WEB_URL = "https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html" + +JUMPSTART_LOGGER = logging.getLogger("sagemaker.jumpstart") diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index d4c0016b36..129692262b 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -12,7 +12,6 @@ # language governing permissions and limitations under the License. """This module stores JumpStart implementation of Estimator class.""" from __future__ import absolute_import -import logging from typing import Dict, List, Optional, Union @@ -45,8 +44,6 @@ from sagemaker.serverless.serverless_inference_config import ServerlessInferenceConfig from sagemaker.workflow.entities import PipelineVariable -logger = logging.getLogger(__name__) - class JumpStartEstimator(Estimator): """JumpStartEstimator class. diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index e09b6b86c6..cce77efd1a 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -12,7 +12,6 @@ # language governing permissions and limitations under the License. """This module stores JumpStart Estimator factory methods.""" from __future__ import absolute_import -import logging from typing import Dict, List, Optional, Union @@ -41,6 +40,7 @@ ) from sagemaker.jumpstart.constants import ( JUMPSTART_DEFAULT_REGION_NAME, + JUMPSTART_LOGGER, TRAINING_ENTRY_POINT_SCRIPT_NAME, ) from sagemaker.jumpstart.enums import JumpStartScriptScope @@ -64,8 +64,6 @@ from sagemaker.utils import name_from_base from sagemaker.workflow.entities import PipelineVariable -logger = logging.getLogger("sagemaker") - def get_init_kwargs( model_id: str, @@ -421,7 +419,7 @@ def _add_instance_type_and_count_to_kwargs( kwargs.instance_count = kwargs.instance_count or 1 if orig_instance_type is None: - logger.info( + JUMPSTART_LOGGER.info( "No instance type selected for training job. Defaulting to %s.", kwargs.instance_type ) @@ -467,7 +465,7 @@ def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, ) ): - logger.warning( + JUMPSTART_LOGGER.warning( "'%s' does not support incremental training but is being trained with" " non-default model artifact.", kwargs.model_id, diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 5d00c4032f..8cff97b77a 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -12,7 +12,6 @@ # language governing permissions and limitations under the License. """This module stores JumpStart Model factory methods.""" from __future__ import absolute_import -import logging from typing import Any, Dict, List, Optional, Union @@ -31,6 +30,7 @@ from sagemaker.jumpstart.constants import ( INFERENCE_ENTRY_POINT_SCRIPT_NAME, JUMPSTART_DEFAULT_REGION_NAME, + JUMPSTART_LOGGER, ) from sagemaker.jumpstart.enums import JumpStartScriptScope from sagemaker.jumpstart.types import ( @@ -51,8 +51,6 @@ from sagemaker.utils import name_from_base from sagemaker.workflow.entities import PipelineVariable -logger = logging.getLogger("sagemaker") - def get_default_predictor( predictor: Predictor, @@ -170,7 +168,7 @@ def _add_instance_type_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartM ) if orig_instance_type is None: - logger.info( + JUMPSTART_LOGGER.info( "No instance type selected for inference hosting endpoint. Defaulting to %s.", kwargs.instance_type, ) diff --git a/src/sagemaker/jumpstart/model.py b/src/sagemaker/jumpstart/model.py index ba8ce1249f..8a912181af 100644 --- a/src/sagemaker/jumpstart/model.py +++ b/src/sagemaker/jumpstart/model.py @@ -13,7 +13,6 @@ """This module stores JumpStart implementation of Model class.""" from __future__ import absolute_import -import logging import re from typing import Dict, List, Optional, Union @@ -38,8 +37,6 @@ from sagemaker.session import Session from sagemaker.workflow.entities import PipelineVariable -logger = logging.getLogger(__name__) - class JumpStartModel(Model): """JumpStartModel class. diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index aed81d85c2..caac3b9c2f 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -42,8 +42,6 @@ from sagemaker.utils import resolve_value_from_config from sagemaker.workflow import is_pipeline_variable -LOGGER = logging.getLogger(__name__) - def get_jumpstart_launched_regions_message() -> str: """Returns formatted string indicating where JumpStart is launched.""" @@ -79,7 +77,7 @@ def get_jumpstart_content_bucket( and len(os.environ[constants.ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE]) > 0 ): bucket_override = os.environ[constants.ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE] - LOGGER.info("Using JumpStart bucket override: '%s'", bucket_override) + constants.JUMPSTART_LOGGER.info("Using JumpStart bucket override: '%s'", bucket_override) return bucket_override try: return constants.JUMPSTART_REGION_NAME_TO_LAUNCHED_REGION_DICT[region].content_bucket @@ -343,6 +341,39 @@ def update_inference_tags_with_jumpstart_training_tags( return inference_tags +def emit_logs_based_on_model_specs(model_specs: JumpStartModelSpecs, region: str) -> None: + """Emits logs based on model specs and region.""" + + if model_specs.hosting_eula_key: + constants.JUMPSTART_LOGGER.info( + "Model '%s' requires accepting end-user license agreement (EULA). " + "See https://%s.s3.%s.amazonaws.com%s/%s for terms of use.", + model_specs.model_id, + get_jumpstart_content_bucket(region=region), + region, + ".cn" if region.startswith("cn-") else "", + model_specs.hosting_eula_key, + ) + + if model_specs.deprecated: + deprecated_message = model_specs.deprecated_message or ( + "Using deprecated JumpStart model " + f"'{model_specs.model_id}' and version '{model_specs.version}'." + ) + + constants.JUMPSTART_LOGGER.warning(deprecated_message) + + if model_specs.deprecate_warn_message: + constants.JUMPSTART_LOGGER.warning(model_specs.deprecate_warn_message) + + if model_specs.inference_vulnerable or model_specs.training_vulnerable: + constants.JUMPSTART_LOGGER.warning( + "Using vulnerable JumpStart model '%s' and version '%s'.", + model_specs.model_id, + model_specs.version, + ) + + def verify_model_region_and_return_specs( model_id: Optional[str], version: Optional[str], @@ -402,26 +433,11 @@ def verify_model_region_and_return_specs( f"JumpStart model ID '{model_id}' and version '{version}' " "does not support training." ) - if model_specs.hosting_eula_key and scope == constants.JumpStartScriptScope.INFERENCE.value: - LOGGER.info( - "Model '%s' requires accepting end-user license agreement (EULA). " - "See https://%s.s3.%s.amazonaws.com%s/%s for terms of use.", - model_id, - get_jumpstart_content_bucket(region=region), - region, - ".cn" if region.startswith("cn-") else "", - model_specs.hosting_eula_key, - ) - if model_specs.deprecated: if not tolerate_deprecated_model: raise DeprecatedJumpStartModelError( model_id=model_id, version=version, message=model_specs.deprecated_message ) - LOGGER.warning("Using deprecated JumpStart model '%s' and version '%s'.", model_id, version) - - if model_specs.deprecate_warn_message: - LOGGER.warning(model_specs.deprecate_warn_message) if scope == constants.JumpStartScriptScope.INFERENCE.value and model_specs.inference_vulnerable: if not tolerate_vulnerable_model: @@ -431,9 +447,6 @@ def verify_model_region_and_return_specs( vulnerabilities=model_specs.inference_vulnerabilities, scope=constants.JumpStartScriptScope.INFERENCE, ) - LOGGER.warning( - "Using vulnerable JumpStart model '%s' and version '%s' (inference).", model_id, version - ) if scope == constants.JumpStartScriptScope.TRAINING.value and model_specs.training_vulnerable: if not tolerate_vulnerable_model: @@ -443,9 +456,6 @@ def verify_model_region_and_return_specs( vulnerabilities=model_specs.training_vulnerabilities, scope=constants.JumpStartScriptScope.TRAINING, ) - LOGGER.warning( - "Using vulnerable JumpStart model '%s' and version '%s' (training).", model_id, version - ) return model_specs diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index 3520891cf7..0f561ad73d 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -742,7 +742,7 @@ def test_yes_predictor_returns_unmodified_predictor( @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") @mock.patch("sagemaker.jumpstart.factory.estimator._model_supports_incremental_training") - @mock.patch("sagemaker.jumpstart.factory.estimator.logger.warning") + @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_LOGGER.warning") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.factory.estimator.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @@ -795,7 +795,7 @@ def test_incremental_training_with_unsupported_model_logs_warning( @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") @mock.patch("sagemaker.jumpstart.factory.estimator._model_supports_incremental_training") - @mock.patch("sagemaker.jumpstart.factory.estimator.logger.warning") + @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_LOGGER.warning") @mock.patch("sagemaker.jumpstart.factory.model.Session") @mock.patch("sagemaker.jumpstart.factory.estimator.Session") @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index 09ca83ba38..c375563a94 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -768,8 +768,83 @@ def test_update_inference_tags_with_jumpstart_training_model_tags_inference(): ) +def test_jumpstart_accept_eula_logs(): + def make_accept_eula_inference_spec(*largs, **kwargs): + spec = get_spec_from_base_spec(model_id="pytorch-eqa-bert-base-cased", version="*") + spec.hosting_eula_key = "read/the/fine/print.txt" + return spec + + with patch("logging.Logger.info") as mocked_info_log: + utils.emit_logs_based_on_model_specs(make_accept_eula_inference_spec(), "us-east-1") + mocked_info_log.assert_called_once_with( + "Model '%s' requires accepting end-user license agreement (EULA). " + "See https://%s.s3.%s.amazonaws.com%s/%s for terms of use.", + "pytorch-eqa-bert-base-cased", + "jumpstart-cache-prod-us-east-1", + "us-east-1", + "", + "read/the/fine/print.txt", + ) + + +def test_jumpstart_vulnerable_model_warnings(): + def make_vulnerable_inference_spec(*largs, **kwargs): + spec = get_spec_from_base_spec(model_id="pytorch-eqa-bert-base-cased", version="*") + spec.inference_vulnerable = True + spec.inference_vulnerabilities = ["some", "vulnerability"] + return spec + + with patch("logging.Logger.warning") as mocked_warning_log: + utils.emit_logs_based_on_model_specs(make_vulnerable_inference_spec(), "some-region") + mocked_warning_log.assert_called_once_with( + "Using vulnerable JumpStart model '%s' and version '%s'.", + "pytorch-eqa-bert-base-cased", + "*", + ) + + +def test_jumpstart_deprecated_model_warnings(): + def make_deprecated_spec(*largs, **kwargs): + spec = get_spec_from_base_spec(model_id="pytorch-eqa-bert-base-cased", version="*") + spec.deprecated = True + return spec + + with patch("logging.Logger.warning") as mocked_warning_log: + utils.emit_logs_based_on_model_specs(make_deprecated_spec(), "some-region") + + mocked_warning_log.assert_called_once_with( + "Using deprecated JumpStart model 'pytorch-eqa-bert-base-cased' and version '*'." + ) + + deprecated_message = "this model is deprecated" + + def make_deprecated_message_spec(*largs, **kwargs): + spec = get_spec_from_base_spec(model_id="pytorch-eqa-bert-base-cased", version="*") + spec.deprecated_message = deprecated_message + spec.deprecated = True + return spec + + with patch("logging.Logger.warning") as mocked_warning_log: + utils.emit_logs_based_on_model_specs(make_deprecated_message_spec(), "some-region") + + mocked_warning_log.assert_called_once_with(deprecated_message) + + deprecate_warn_message = "warn-msg" + + def make_deprecated_warning_message_spec(*largs, **kwargs): + spec = get_spec_from_base_spec(model_id="pytorch-eqa-bert-base-cased", version="*") + spec.deprecate_warn_message = deprecate_warn_message + return spec + + with patch("logging.Logger.warning") as mocked_warning_log: + utils.emit_logs_based_on_model_specs(make_deprecated_warning_message_spec(), "some-region") + mocked_warning_log.assert_called_once_with( + deprecate_warn_message, + ) + + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") -def test_jumpstart_vulnerable_model(patched_get_model_specs): +def test_jumpstart_vulnerable_model_errors(patched_get_model_specs): def make_vulnerable_inference_spec(*largs, **kwargs): spec = get_spec_from_base_spec(*largs, **kwargs) spec.inference_vulnerable = True @@ -792,23 +867,6 @@ def make_vulnerable_inference_spec(*largs, **kwargs): "List of vulnerabilities: some, vulnerability" ) == str(e.value.message) - with patch("logging.Logger.warning") as mocked_warning_log: - assert ( - utils.verify_model_region_and_return_specs( - model_id="pytorch-eqa-bert-base-cased", - version="*", - scope=JumpStartScriptScope.INFERENCE.value, - region="us-west-2", - tolerate_vulnerable_model=True, - ) - is not None - ) - mocked_warning_log.assert_called_once_with( - "Using vulnerable JumpStart model '%s' and version '%s' (inference).", - "pytorch-eqa-bert-base-cased", - "*", - ) - def make_vulnerable_training_spec(*largs, **kwargs): spec = get_spec_from_base_spec(*largs, **kwargs) spec.training_vulnerable = True @@ -831,26 +889,9 @@ def make_vulnerable_training_spec(*largs, **kwargs): "List of vulnerabilities: some, vulnerability" ) == str(e.value.message) - with patch("logging.Logger.warning") as mocked_warning_log: - assert ( - utils.verify_model_region_and_return_specs( - model_id="pytorch-eqa-bert-base-cased", - version="*", - scope=JumpStartScriptScope.TRAINING.value, - region="us-west-2", - tolerate_vulnerable_model=True, - ) - is not None - ) - mocked_warning_log.assert_called_once_with( - "Using vulnerable JumpStart model '%s' and version '%s' (training).", - "pytorch-eqa-bert-base-cased", - "*", - ) - @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") -def test_jumpstart_deprecated_model(patched_get_model_specs): +def test_jumpstart_deprecated_model_errors(patched_get_model_specs): def make_deprecated_spec(*largs, **kwargs): spec = get_spec_from_base_spec(*largs, **kwargs) spec.deprecated = True @@ -870,23 +911,6 @@ def make_deprecated_spec(*largs, **kwargs): e.value.message ) - with patch("logging.Logger.warning") as mocked_warning_log: - assert ( - utils.verify_model_region_and_return_specs( - model_id="pytorch-eqa-bert-base-cased", - version="*", - scope=JumpStartScriptScope.INFERENCE.value, - region="us-west-2", - tolerate_deprecated_model=True, - ) - is not None - ) - mocked_warning_log.assert_called_once_with( - "Using deprecated JumpStart model '%s' and version '%s'.", - "pytorch-eqa-bert-base-cased", - "*", - ) - deprecated_message = "this model is deprecated" def make_deprecated_message_spec(*largs, **kwargs): @@ -906,30 +930,6 @@ def make_deprecated_message_spec(*largs, **kwargs): ) assert deprecated_message == str(e.value.message) - deprecate_warn_message = "warn-msg" - - def make_deprecated_warning_message_spec(*largs, **kwargs): - spec = get_spec_from_base_spec(*largs, **kwargs) - spec.deprecate_warn_message = deprecate_warn_message - return spec - - patched_get_model_specs.side_effect = make_deprecated_warning_message_spec - - with patch("logging.Logger.warning") as mocked_warning_log: - assert ( - utils.verify_model_region_and_return_specs( - model_id="pytorch-eqa-bert-base-cased", - version="*", - scope=JumpStartScriptScope.INFERENCE.value, - region="us-west-2", - tolerate_deprecated_model=True, - ) - is not None - ) - mocked_warning_log.assert_called_once_with( - deprecate_warn_message, - ) - def test_get_jumpstart_base_name_if_jumpstart_model(): uris = [random_jumpstart_s3_uri("random_key") for _ in range(random.randint(1, 10))]