diff --git a/src/sagemaker/serve/builder/djl_builder.py b/src/sagemaker/serve/builder/djl_builder.py index 75acd0d1fe..7ad326e1c8 100644 --- a/src/sagemaker/serve/builder/djl_builder.py +++ b/src/sagemaker/serve/builder/djl_builder.py @@ -13,7 +13,9 @@ """Holds mixin logic to support deployment of Model ID""" from __future__ import absolute_import import logging +import os from typing import Type +from pathlib import Path from abc import ABC, abstractmethod from datetime import datetime, timedelta @@ -46,7 +48,12 @@ ) from sagemaker.serve.model_server.djl_serving.prepare import ( _create_dir_structure, + prepare_for_djl, ) +from sagemaker.serve.detector.image_detector import ( + auto_detect_container, +) +from sagemaker.serve.detector.pickler import save_pkl from sagemaker.serve.utils.predictors import DjlLocalModePredictor from sagemaker.serve.utils.types import ModelServer from sagemaker.serve.mode.function_pointers import Mode @@ -92,6 +99,8 @@ def __init__(self): self.nb_instance_type = None self.ram_usage_model_load = None self.role_arn = None + self.inference_spec = None + self.shared_libs = None @abstractmethod def _prepare_for_mode(self): @@ -247,17 +256,22 @@ def _build_for_hf_djl(self): _create_dir_structure(self.model_path) if not hasattr(self, "pysdk_model"): - self.env_vars.update({"HF_MODEL_ID": self.model}) + if self.inference_spec is not None: + self.env_vars.update({"HF_MODEL_ID": self.inference_spec.get_model()}) + else: + self.env_vars.update({"HF_MODEL_ID": self.model}) + self.hf_model_config = _get_model_config_properties_from_hf( - self.model, self.env_vars.get("HF_TOKEN") + self.env_vars.get("HF_MODEL_ID"), self.env_vars.get("HF_TOKEN") ) default_djl_configurations, _default_max_new_tokens = _get_default_djl_configurations( - self.model, self.hf_model_config, self.schema_builder + self.env_vars.get("HF_MODEL_ID"), self.hf_model_config, self.schema_builder ) self.env_vars.update(default_djl_configurations) self.schema_builder.sample_input["parameters"][ "max_new_tokens" ] = _default_max_new_tokens + self.pysdk_model = self._create_djl_model() if self.mode == Mode.LOCAL_CONTAINER: @@ -445,10 +459,67 @@ def _tune_for_hf_djl(self, max_tuning_duration: int = 1800): return self.pysdk_model + def _auto_detect_container(self): + """Set image_uri by detecting container via model name or inference spec""" + # Auto detect the container image uri + if self.image_uri: + logger.info( + "Skipping auto detection as the image uri is provided %s", + self.image_uri, + ) + return + + if self.model: + logger.info( + "Auto detect container url for the provided model and on instance %s", + self.nb_instance_type, + ) + self.image_uri = auto_detect_container( + self.model, self.sagemaker_session.boto_region_name, self.nb_instance_type + ) + + elif self.inference_spec: + # TODO: this won't work for larger image. + # Fail and let the customer include the image uri + logger.warning( + "model_path provided with no image_uri. Attempting to autodetect the image\ + by loading the model using inference_spec.load()..." + ) + self.image_uri = auto_detect_container( + self.inference_spec.load(self.model_path), + self.sagemaker_session.boto_region_name, + self.nb_instance_type, + ) + else: + raise ValueError( + "Cannot detect and set image_uri. Please pass model or inference spec." + ) + def _build_for_djl(self): - """Placeholder docstring""" + """Checks if inference spec passed and builds DJL server accordingly""" self._validate_djl_serving_sample_data() self.secret_key = None + self.model_server = ModelServer.DJL_SERVING + + if self.inference_spec: + + os.makedirs(self.model_path, exist_ok=True) + + code_path = Path(self.model_path).joinpath("code") + + save_pkl(code_path, (self.inference_spec, self.schema_builder)) + logger.info("PKL file saved to file: %s", code_path) + + self._auto_detect_container() + + self.secret_key = prepare_for_djl( + model_path=self.model_path, + shared_libs=self.shared_libs, + dependencies=self.dependencies, + session=self.sagemaker_session, + image_uri=self.image_uri, + inference_spec=self.inference_spec, + ) self.pysdk_model = self._build_for_hf_djl() self.pysdk_model.tune = self._tune_for_hf_djl diff --git a/src/sagemaker/serve/mode/sagemaker_endpoint_mode.py b/src/sagemaker/serve/mode/sagemaker_endpoint_mode.py index 2f09d3d572..d055e92942 100644 --- a/src/sagemaker/serve/mode/sagemaker_endpoint_mode.py +++ b/src/sagemaker/serve/mode/sagemaker_endpoint_mode.py @@ -95,6 +95,7 @@ def prepare( upload_artifacts = self._upload_djl_artifacts( model_path=model_path, sagemaker_session=sagemaker_session, + secret_key=secret_key, s3_model_data_url=s3_model_data_url, image=image, should_upload_artifacts=True, diff --git a/src/sagemaker/serve/model_server/djl_serving/inference.py b/src/sagemaker/serve/model_server/djl_serving/inference.py new file mode 100644 index 0000000000..984f71a3db --- /dev/null +++ b/src/sagemaker/serve/model_server/djl_serving/inference.py @@ -0,0 +1,146 @@ +"""This module is for SageMaker inference.py.""" + +from __future__ import absolute_import +import io +import yaml +import logging + +from pathlib import Path +from djl_python import Input +from djl_python import Output + + +class DJLPythonInference(object): + """A class for DJL inference""" + + def __init__(self) -> None: + self.inference_spec = None + self.model_dir = None + self.model = None + self.schema_builder = None + self.inferenceSpec = None + self.metadata = None + self.default_serializer = None + self.default_deserializer = None + self.initialized = False + + def load_yaml(self, path: str): + """Placeholder docstring""" + with open(path, mode="r") as file: + return yaml.full_load(file) + + def load_metadata(self): + """Placeholder docstring""" + metadata_path = Path(self.model_dir).joinpath("metadata.yaml") + return self.load_yaml(metadata_path) + + def load_and_validate_pkl(self, path, hash_tag): + """Placeholder docstring""" + + import os + import hmac + import hashlib + import cloudpickle + + with open(path, mode="rb") as file: + buffer = file.read() + secret_key = os.getenv("SAGEMAKER_SERVE_SECRET_KEY") + stored_hash_tag = hmac.new( + secret_key.encode(), msg=buffer, digestmod=hashlib.sha256 + ).hexdigest() + if not hmac.compare_digest(stored_hash_tag, hash_tag): + raise Exception("Object is not valid: " + path) + + with open(path, mode="rb") as file: + return cloudpickle.load(file) + + def load(self): + """Detecting for inference spec and loading model""" + self.metadata = self.load_metadata() + if "InferenceSpec" in self.metadata: + inference_spec_path = ( + Path(self.model_dir).joinpath(self.metadata.get("InferenceSpec")).absolute() + ) + self.inference_spec = self.load_and_validate_pkl( + inference_spec_path, self.metadata.get("InferenceSpecHMAC") + ) + + # Load model + if self.inference_spec: + self.model = self.inference_spec.load(self.model_dir) + else: + raise Exception( + "SageMaker model format does not support model type: " + + self.metadata.get("ModelType") + ) + + def initialize(self, properties): + """Initialize SageMaker service, loading model and inferenceSpec""" + self.model_dir = properties.get("model_dir") + self.load() + self.initialized = True + logging.info("SageMaker saved format entry-point is applied, service is initilized") + + def preprocess_djl(self, inputs: Input): + """Placeholder docstring""" + content_type = inputs.get_property("content-type") + logging.info(f"Content-type is: {content_type}") + if self.schema_builder: + logging.info("Customized input deserializer is applied") + try: + if hasattr(self.schema_builder, "custom_input_translator"): + return self.schema_builder.custom_input_translator.deserialize( + io.BytesIO(inputs.get_as_bytes()), content_type + ) + else: + raise Exception("No custom input translator in cutomized schema builder.") + except Exception as e: + raise Exception("Encountered error in deserialize_request.") from e + elif self.default_deserializer: + return self.default_deserializer.deserialize( + io.BytesIO(inputs.get_as_bytes()), content_type + ) + + def postproces_djl(self, output): + """Placeholder docstring""" + if self.schema_builder: + logging.info("Customized output serializer is applied") + try: + if hasattr(self.schema_builder, "custom_output_translator"): + return self.schema_builder.custom_output_translator.serialize(output) + else: + raise Exception("No custom output translator in cutomized schema builder.") + except Exception as e: + raise Exception("Encountered error in serialize_response.") from e + elif self.default_serializer: + return self.default_serializer.serialize(output) + + def inference(self, inputs: Input): + """Detects if inference spec used, returns output accordingly""" + processed_input = self.preprocess_djl(inputs=inputs) + if self.inference_spec: + output = self.inference_spec.invoke(processed_input, self.model) + else: + raise Exception( + "SageMaker model format does not support model type: " + + self.metadata.get("ModelType") + ) + processed_output = self.postproces_djl(output=output) + output_data = Output() + return output_data.add(processed_output) + + +_service = DJLPythonInference() + + +def handle(inputs: Input) -> Output: + """Placeholder docstring""" + if not _service.initialized: + properties = inputs.get_properties() + _service.initialize(properties) + + if inputs.is_empty(): + # initialization request + return None + + return _service.inference(inputs) diff --git a/src/sagemaker/serve/model_server/djl_serving/prepare.py b/src/sagemaker/serve/model_server/djl_serving/prepare.py index 40cb04152c..f27043f612 100644 --- a/src/sagemaker/serve/model_server/djl_serving/prepare.py +++ b/src/sagemaker/serve/model_server/djl_serving/prepare.py @@ -16,12 +16,21 @@ import json import tarfile import logging +import shutil from typing import List from pathlib import Path from sagemaker.utils import _tmpdir, custom_extractall_tarfile from sagemaker.s3 import S3Downloader from sagemaker.serve.utils.local_hardware import _check_disk_space, _check_docker_disk_usage +from sagemaker.session import Session +from sagemaker.serve.spec.inference_spec import InferenceSpec +from sagemaker.serve.detector.dependency_manager import capture_dependencies +from sagemaker.serve.validations.check_integrity import ( + generate_secret_key, + compute_hash, +) +from sagemaker.remote_function.core.serialization import _MetaData _SETTING_PROPERTY_STMT = "Setting property: %s to %s" @@ -109,3 +118,56 @@ def prepare_djl_js_resources( model_path, code_dir = _create_dir_structure(model_path) return _copy_jumpstart_artifacts(model_data, js_id, code_dir) + + +def prepare_for_djl( + model_path: str, + shared_libs: List[str], + dependencies: dict, + session: Session, + image_uri: str, + inference_spec: InferenceSpec = None, +) -> str: + """Prepares for InferenceSpec using model_path, writes inference.py, and captures dependencies to generate secret_key. + + Args:to + model_path (str) : Argument + shared_libs (List[]) : Argument + dependencies (dict) : Argument + session (Session) : Argument + inference_spec (InferenceSpec, optional) : Argument + (default is None) + Returns: + ( str ) : secret_key + """ + model_path = Path(model_path) + if not model_path.exists(): + model_path.mkdir() + elif not model_path.is_dir(): + raise Exception("model_dir is not a valid directory") + + if inference_spec: + inference_spec.prepare(str(model_path)) + + code_dir = model_path.joinpath("code") + code_dir.mkdir(exist_ok=True) + + shutil.copy2(Path(__file__).parent.joinpath("inference.py"), code_dir) + + logger.info("Finished writing inference.py to code directory") + + shared_libs_dir = model_path.joinpath("shared_libs") + shared_libs_dir.mkdir(exist_ok=True) + for shared_lib in shared_libs: + shutil.copy2(Path(shared_lib), shared_libs_dir) + + capture_dependencies(dependencies=dependencies, work_dir=code_dir) + + secret_key = generate_secret_key() + with open(str(code_dir.joinpath("serve.pkl")), "rb") as f: + buffer = f.read() + hash_value = compute_hash(buffer=buffer, secret_key=secret_key) + with open(str(code_dir.joinpath("metadata.json")), "wb") as metadata: + metadata.write(_MetaData(hash_value).to_json()) + + return secret_key diff --git a/src/sagemaker/serve/model_server/djl_serving/server.py b/src/sagemaker/serve/model_server/djl_serving/server.py index 4ba7dd227d..b0502cc213 100644 --- a/src/sagemaker/serve/model_server/djl_serving/server.py +++ b/src/sagemaker/serve/model_server/djl_serving/server.py @@ -4,6 +4,7 @@ import requests import logging +import platform from pathlib import Path from docker.types import DeviceRequest from sagemaker import Session, fw_utils @@ -28,13 +29,27 @@ class LocalDJLServing: - """Placeholder docstring""" + """Local DJL server instance""" def _start_djl_serving( - self, client: object, image: str, model_path: str, secret_key: str, env_vars: dict + self, + client: object, + model_path: str, + secret_key: str, + env_vars: dict, + image: str, ): - """Placeholder docstring""" - updated_env_vars = _update_env_vars(env_vars) + """Initializes the start of the server""" + env = { + "SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code", + "SAGEMAKER_PROGRAM": "inference.py", + "SAGEMAKER_SERVE_SECRET_KEY": secret_key, + "LOCAL_PYTHON": platform.python_version(), + } + if env_vars: + env_vars.update(env) + else: + env_vars = env self.container = client.containers.run( image, @@ -50,11 +65,11 @@ def _start_djl_serving( "mode": "rw", }, }, - environment=updated_env_vars, + environment=env_vars, ) def _invoke_djl_serving(self, request: object, content_type: str, accept: str): - """Placeholder docstring""" + """Invokes DJL server by hitting the docker host""" try: response = requests.post( f"http://{get_docker_host()}:8080/predictions/model", @@ -68,7 +83,7 @@ def _invoke_djl_serving(self, request: object, content_type: str, accept: str): raise Exception("Unable to send request to the local container server %s", str(e)) def _djl_deep_ping(self, predictor: PredictorBase): - """Placeholder docstring""" + """Deep ping in order to ensure prediction""" response = None try: response = predictor.predict(self.schema_builder.sample_input) @@ -83,18 +98,19 @@ def _djl_deep_ping(self, predictor: PredictorBase): class SageMakerDjlServing: - """Placeholder docstring""" + """Sagemaker endpoint for DJL""" def _upload_djl_artifacts( self, model_path: str, + secret_key: str, sagemaker_session: Session, s3_model_data_url: str = None, image: str = None, env_vars: dict = None, should_upload_artifacts: bool = False, ): - """Placeholder docstring""" + """Uploads DJL server artifacts""" model_data_url = None if _is_s3_uri(model_path): model_data_url = model_path @@ -135,11 +151,21 @@ def _upload_djl_artifacts( else None ) + if secret_key: + env_vars = { + "SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code", + "SAGEMAKER_PROGRAM": "inference.py", + "SAGEMAKER_SERVE_SECRET_KEY": secret_key, + "SAGEMAKER_REGION": sagemaker_session.boto_region_name, + "SAGEMAKER_CONTAINER_LOG_LEVEL": "10", + "LOCAL_PYTHON": platform.python_version(), + } + return (model_data, _update_env_vars(env_vars)) def _update_env_vars(env_vars: dict) -> dict: - """Placeholder docstring""" + """Updating environment variables""" updated_env_vars = {} updated_env_vars.update(_DEFAULT_ENV_VARS) if env_vars: diff --git a/tests/unit/sagemaker/serve/model_server/djl_serving/test_djl_prepare.py b/tests/unit/sagemaker/serve/model_server/djl_serving/test_djl_prepare.py index 183d15d13e..1eaaac41c8 100644 --- a/tests/unit/sagemaker/serve/model_server/djl_serving/test_djl_prepare.py +++ b/tests/unit/sagemaker/serve/model_server/djl_serving/test_djl_prepare.py @@ -12,9 +12,16 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import +from pathlib import PosixPath +import platform +import numpy as np + from unittest import TestCase from unittest.mock import Mock, PropertyMock, patch, mock_open +from sagemaker.serve.model_server.djl_serving.server import ( + LocalDJLServing, +) from sagemaker.serve.model_server.djl_serving.prepare import ( _copy_jumpstart_artifacts, _create_dir_structure, @@ -32,8 +39,54 @@ MOCK_DJL_JUMPSTART_GLOBED_RESOURCES = ["./config.json"] +CPU_TF_IMAGE = ( + "763104351884.dkr.ecr.us-east-1.amazonaws.com/" + "huggingface-pytorch-inference:2.0.0-transformers4.28.1-cpu-py310-ubuntu20.04" +) +MODEL_PATH = "model_path" +MODEL_REPO = f"{MODEL_PATH}/1" +ENV_VAR = {"KEY": "VALUE"} +PAYLOAD = np.random.rand(3, 4).astype(dtype=np.float32) +DTYPE = "TYPE_FP32" +SECRET_KEY = "secret_key" +INFER_RESPONSE = {"outputs": [{"name": "output_name"}]} + class DjlPrepareTests(TestCase): + def test_start_invoke_destroy_local_djl_server(self): + mock_container = Mock() + mock_docker_client = Mock() + mock_docker_client.containers.run.return_value = mock_container + + local_djl_server = LocalDJLServing() + mock_schema_builder = Mock() + mock_schema_builder.input_serializer.serialize.return_value = PAYLOAD + local_djl_server.schema_builder = mock_schema_builder + + local_djl_server._start_serving( + client=mock_docker_client, + model_path=MODEL_PATH, + secret_key=SECRET_KEY, + env_vars=ENV_VAR, + image=CPU_TF_IMAGE, + ) + + mock_docker_client.containers.run.assert_called_once_with( + CPU_TF_IMAGE, + "serve", + network_mode="host", + detach=True, + auto_remove=True, + volumes={PosixPath("model_path/code"): {"bind": "/opt/ml/model/", "mode": "rw"}}, + environment={ + "KEY": "VALUE", + "SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code", + "SAGEMAKER_PROGRAM": "inference.py", + "SAGEMAKER_SERVE_SECRET_KEY": "secret_key", + "LOCAL_PYTHON": platform.python_version(), + }, + ) + @patch("sagemaker.serve.model_server.djl_serving.prepare._check_disk_space") @patch("sagemaker.serve.model_server.djl_serving.prepare._check_docker_disk_usage") @patch("sagemaker.serve.model_server.djl_serving.prepare.Path")