Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/sagemaker/jumpstart/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,8 +327,10 @@ def _retrieval_function(
)
if file_type == JumpStartS3FileType.SPECS:
formatted_body, _ = self._get_json_file(s3_key, file_type)
model_specs = JumpStartModelSpecs(formatted_body)
utils.emit_logs_based_on_model_specs(model_specs, self.get_region())
return JumpStartCachedS3ContentValue(
formatted_content=JumpStartModelSpecs(formatted_body)
formatted_content=model_specs
)
raise ValueError(
f"Bad value for key '{key}': must be in {[JumpStartS3FileType.MANIFEST, JumpStartS3FileType.SPECS]}"
Expand Down
3 changes: 3 additions & 0 deletions src/sagemaker/jumpstart/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# language governing permissions and limitations under the License.
"""This module stores constants related to SageMaker JumpStart."""
from __future__ import absolute_import
import logging
from typing import Dict, Set, Type
import boto3
from sagemaker.base_deserializers import BaseDeserializer, JSONDeserializer
Expand Down Expand Up @@ -173,3 +174,5 @@
}

MODEL_ID_LIST_WEB_URL = "https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html"

JUMPSTART_LOGGER = logging.getLogger("sagemaker.jumpstart")
3 changes: 0 additions & 3 deletions src/sagemaker/jumpstart/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# language governing permissions and limitations under the License.
"""This module stores JumpStart implementation of Estimator class."""
from __future__ import absolute_import
import logging


from typing import Dict, List, Optional, Union
Expand Down Expand Up @@ -45,8 +44,6 @@
from sagemaker.serverless.serverless_inference_config import ServerlessInferenceConfig
from sagemaker.workflow.entities import PipelineVariable

logger = logging.getLogger(__name__)


class JumpStartEstimator(Estimator):
"""JumpStartEstimator class.
Expand Down
8 changes: 3 additions & 5 deletions src/sagemaker/jumpstart/factory/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# language governing permissions and limitations under the License.
"""This module stores JumpStart Estimator factory methods."""
from __future__ import absolute_import
import logging


from typing import Dict, List, Optional, Union
Expand Down Expand Up @@ -41,6 +40,7 @@
)
from sagemaker.jumpstart.constants import (
JUMPSTART_DEFAULT_REGION_NAME,
JUMPSTART_LOGGER,
TRAINING_ENTRY_POINT_SCRIPT_NAME,
)
from sagemaker.jumpstart.enums import JumpStartScriptScope
Expand All @@ -64,8 +64,6 @@
from sagemaker.utils import name_from_base
from sagemaker.workflow.entities import PipelineVariable

logger = logging.getLogger("sagemaker")


def get_init_kwargs(
model_id: str,
Expand Down Expand Up @@ -421,7 +419,7 @@ def _add_instance_type_and_count_to_kwargs(
kwargs.instance_count = kwargs.instance_count or 1

if orig_instance_type is None:
logger.info(
JUMPSTART_LOGGER.info(
"No instance type selected for training job. Defaulting to %s.", kwargs.instance_type
)

Expand Down Expand Up @@ -467,7 +465,7 @@ def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
)
):
logger.warning(
JUMPSTART_LOGGER.warning(
"'%s' does not support incremental training but is being trained with"
" non-default model artifact.",
kwargs.model_id,
Expand Down
6 changes: 2 additions & 4 deletions src/sagemaker/jumpstart/factory/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# language governing permissions and limitations under the License.
"""This module stores JumpStart Model factory methods."""
from __future__ import absolute_import
import logging


from typing import Any, Dict, List, Optional, Union
Expand All @@ -31,6 +30,7 @@
from sagemaker.jumpstart.constants import (
INFERENCE_ENTRY_POINT_SCRIPT_NAME,
JUMPSTART_DEFAULT_REGION_NAME,
JUMPSTART_LOGGER,
)
from sagemaker.jumpstart.enums import JumpStartScriptScope
from sagemaker.jumpstart.types import (
Expand All @@ -51,8 +51,6 @@
from sagemaker.utils import name_from_base
from sagemaker.workflow.entities import PipelineVariable

logger = logging.getLogger("sagemaker")


def get_default_predictor(
predictor: Predictor,
Expand Down Expand Up @@ -170,7 +168,7 @@ def _add_instance_type_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartM
)

if orig_instance_type is None:
logger.info(
JUMPSTART_LOGGER.info(
"No instance type selected for inference hosting endpoint. Defaulting to %s.",
kwargs.instance_type,
)
Expand Down
3 changes: 0 additions & 3 deletions src/sagemaker/jumpstart/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
"""This module stores JumpStart implementation of Model class."""

from __future__ import absolute_import
import logging
import re

from typing import Dict, List, Optional, Union
Expand All @@ -38,8 +37,6 @@
from sagemaker.session import Session
from sagemaker.workflow.entities import PipelineVariable

logger = logging.getLogger(__name__)


class JumpStartModel(Model):
"""JumpStartModel class.
Expand Down
58 changes: 34 additions & 24 deletions src/sagemaker/jumpstart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@
from sagemaker.utils import resolve_value_from_config
from sagemaker.workflow import is_pipeline_variable

LOGGER = logging.getLogger(__name__)


def get_jumpstart_launched_regions_message() -> str:
"""Returns formatted string indicating where JumpStart is launched."""
Expand Down Expand Up @@ -79,7 +77,7 @@ def get_jumpstart_content_bucket(
and len(os.environ[constants.ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE]) > 0
):
bucket_override = os.environ[constants.ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE]
LOGGER.info("Using JumpStart bucket override: '%s'", bucket_override)
constants.JUMPSTART_LOGGER.info("Using JumpStart bucket override: '%s'", bucket_override)
return bucket_override
try:
return constants.JUMPSTART_REGION_NAME_TO_LAUNCHED_REGION_DICT[region].content_bucket
Expand Down Expand Up @@ -343,6 +341,39 @@ def update_inference_tags_with_jumpstart_training_tags(
return inference_tags


def emit_logs_based_on_model_specs(model_specs: JumpStartModelSpecs, region: str) -> None:
"""Emits logs based on model specs and region."""

if model_specs.hosting_eula_key:
constants.JUMPSTART_LOGGER.info(
"Model '%s' requires accepting end-user license agreement (EULA). "
"See https://%s.s3.%s.amazonaws.com%s/%s for terms of use.",
model_specs.model_id,
get_jumpstart_content_bucket(region=region),
region,
".cn" if region.startswith("cn-") else "",
model_specs.hosting_eula_key,
)

if model_specs.deprecated:
deprecated_message = model_specs.deprecated_message or (
"Using deprecated JumpStart model "
f"'{model_specs.model_id}' and version '{model_specs.version}'."
)

constants.JUMPSTART_LOGGER.warning(deprecated_message)

if model_specs.deprecate_warn_message:
constants.JUMPSTART_LOGGER.warning(model_specs.deprecate_warn_message)

if model_specs.inference_vulnerable or model_specs.training_vulnerable:
constants.JUMPSTART_LOGGER.warning(
"Using vulnerable JumpStart model '%s' and version '%s'.",
model_specs.model_id,
model_specs.version,
)


def verify_model_region_and_return_specs(
model_id: Optional[str],
version: Optional[str],
Expand Down Expand Up @@ -402,26 +433,11 @@ def verify_model_region_and_return_specs(
f"JumpStart model ID '{model_id}' and version '{version}' " "does not support training."
)

if model_specs.hosting_eula_key and scope == constants.JumpStartScriptScope.INFERENCE.value:
LOGGER.info(
"Model '%s' requires accepting end-user license agreement (EULA). "
"See https://%s.s3.%s.amazonaws.com%s/%s for terms of use.",
model_id,
get_jumpstart_content_bucket(region=region),
region,
".cn" if region.startswith("cn-") else "",
model_specs.hosting_eula_key,
)

if model_specs.deprecated:
if not tolerate_deprecated_model:
raise DeprecatedJumpStartModelError(
model_id=model_id, version=version, message=model_specs.deprecated_message
)
LOGGER.warning("Using deprecated JumpStart model '%s' and version '%s'.", model_id, version)

if model_specs.deprecate_warn_message:
LOGGER.warning(model_specs.deprecate_warn_message)

if scope == constants.JumpStartScriptScope.INFERENCE.value and model_specs.inference_vulnerable:
if not tolerate_vulnerable_model:
Expand All @@ -431,9 +447,6 @@ def verify_model_region_and_return_specs(
vulnerabilities=model_specs.inference_vulnerabilities,
scope=constants.JumpStartScriptScope.INFERENCE,
)
LOGGER.warning(
"Using vulnerable JumpStart model '%s' and version '%s' (inference).", model_id, version
)

if scope == constants.JumpStartScriptScope.TRAINING.value and model_specs.training_vulnerable:
if not tolerate_vulnerable_model:
Expand All @@ -443,9 +456,6 @@ def verify_model_region_and_return_specs(
vulnerabilities=model_specs.training_vulnerabilities,
scope=constants.JumpStartScriptScope.TRAINING,
)
LOGGER.warning(
"Using vulnerable JumpStart model '%s' and version '%s' (training).", model_id, version
)

return model_specs

Expand Down
4 changes: 2 additions & 2 deletions tests/unit/sagemaker/jumpstart/estimator/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,7 +742,7 @@ def test_yes_predictor_returns_unmodified_predictor(

@mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id")
@mock.patch("sagemaker.jumpstart.factory.estimator._model_supports_incremental_training")
@mock.patch("sagemaker.jumpstart.factory.estimator.logger.warning")
@mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_LOGGER.warning")
@mock.patch("sagemaker.jumpstart.factory.model.Session")
@mock.patch("sagemaker.jumpstart.factory.estimator.Session")
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
Expand Down Expand Up @@ -795,7 +795,7 @@ def test_incremental_training_with_unsupported_model_logs_warning(

@mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id")
@mock.patch("sagemaker.jumpstart.factory.estimator._model_supports_incremental_training")
@mock.patch("sagemaker.jumpstart.factory.estimator.logger.warning")
@mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_LOGGER.warning")
@mock.patch("sagemaker.jumpstart.factory.model.Session")
@mock.patch("sagemaker.jumpstart.factory.estimator.Session")
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
Expand Down
Loading