diff --git a/src/sagemaker/serve/builder/jumpstart_builder.py b/src/sagemaker/serve/builder/jumpstart_builder.py index 56c5ec1a2f..985be14fdc 100644 --- a/src/sagemaker/serve/builder/jumpstart_builder.py +++ b/src/sagemaker/serve/builder/jumpstart_builder.py @@ -13,21 +13,41 @@ """Placeholder docstring""" from __future__ import absolute_import +import copy from abc import ABC, abstractmethod +from datetime import datetime, timedelta from typing import Type import logging from sagemaker.model import Model from sagemaker import model_uris from sagemaker.serve.model_server.djl_serving.prepare import prepare_djl_js_resources +from sagemaker.serve.model_server.djl_serving.utils import _get_admissible_tensor_parallel_degrees from sagemaker.serve.model_server.tgi.prepare import prepare_tgi_js_resources, _create_dir_structure from sagemaker.serve.mode.function_pointers import Mode +from sagemaker.serve.utils.exceptions import ( + LocalDeepPingException, + LocalModelOutOfMemoryException, + LocalModelInvocationException, + LocalModelLoadException, + SkipTuningComboException, +) from sagemaker.serve.utils.predictors import ( DjlLocalModePredictor, TgiLocalModePredictor, ) -from sagemaker.serve.utils.local_hardware import _get_nb_instance, _get_ram_usage_mb +from sagemaker.serve.utils.local_hardware import ( + _get_nb_instance, + _get_ram_usage_mb, +) from sagemaker.serve.utils.telemetry_logger import _capture_telemetry +from sagemaker.serve.utils.tuning import ( + _pretty_print_results_jumpstart, + _serial_benchmark, + _concurrent_benchmark, + _more_performant, + _sharded_supported, +) from sagemaker.serve.utils.types import ModelServer from sagemaker.base_predictor import PredictorBase from sagemaker.jumpstart.model import JumpStartModel @@ -134,7 +154,7 @@ def _js_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]: model_data=self.pysdk_model.model_data, ) elif not hasattr(self, "prepared_for_tgi"): - self.prepared_for_tgi = prepare_tgi_js_resources( + self.js_model_config, self.prepared_for_tgi = prepare_tgi_js_resources( model_path=self.model_path, js_id=self.model, dependencies=self.dependencies, @@ -222,7 +242,7 @@ def _build_for_tgi_jumpstart(self): env = {} if self.mode == Mode.LOCAL_CONTAINER: if not hasattr(self, "prepared_for_tgi"): - self.prepared_for_tgi = prepare_tgi_js_resources( + self.js_model_config, self.prepared_for_tgi = prepare_tgi_js_resources( model_path=self.model_path, js_id=self.model, dependencies=self.dependencies, @@ -234,6 +254,183 @@ def _build_for_tgi_jumpstart(self): self.pysdk_model.env.update(env) + def _tune_for_js(self, sharded_supported: bool, max_tuning_duration: int = 1800): + """Tune for Jumpstart Models in Local Mode. + + Args: + sharded_supported (bool): Indicates whether sharding is supported by this ``Model`` + max_tuning_duration (int): The maximum timeout to deploy this ``Model`` locally. + Default: ``1800`` + returns: + Tuned Model. + """ + if self.mode != Mode.LOCAL_CONTAINER: + logger.warning( + "Tuning is only a %s capability. Returning original model.", Mode.LOCAL_CONTAINER + ) + return self.pysdk_model + + num_shard_env_var_name = "SM_NUM_GPUS" + if "OPTION_TENSOR_PARALLEL_DEGREE" in self.pysdk_model.env.keys(): + num_shard_env_var_name = "OPTION_TENSOR_PARALLEL_DEGREE" + + initial_env_vars = copy.deepcopy(self.pysdk_model.env) + admissible_tensor_parallel_degrees = _get_admissible_tensor_parallel_degrees( + self.js_model_config + ) + + if len(admissible_tensor_parallel_degrees) > 1 and not sharded_supported: + admissible_tensor_parallel_degrees = [1] + logger.warning( + "Sharding across multiple GPUs is not supported for this model. " + "Model can only be sharded across [1] GPU" + ) + + benchmark_results = {} + best_tuned_combination = None + timeout = datetime.now() + timedelta(seconds=max_tuning_duration) + for tensor_parallel_degree in admissible_tensor_parallel_degrees: + if datetime.now() > timeout: + logger.info("Max tuning duration reached. Tuning stopped.") + break + + self.pysdk_model.env.update({num_shard_env_var_name: str(tensor_parallel_degree)}) + try: + logger.info("Trying tensor parallel degree: %s", tensor_parallel_degree) + + predictor = self.pysdk_model.deploy(model_data_download_timeout=max_tuning_duration) + + avg_latency, p90, avg_tokens_per_second = _serial_benchmark( + predictor, self.schema_builder.sample_input + ) + throughput_per_second, standard_deviation = _concurrent_benchmark( + predictor, self.schema_builder.sample_input + ) + + tested_env = copy.deepcopy(self.pysdk_model.env) + logger.info( + "Average latency: %s, throughput/s: %s for configuration: %s", + avg_latency, + throughput_per_second, + tested_env, + ) + benchmark_results[avg_latency] = [ + tested_env, + p90, + avg_tokens_per_second, + throughput_per_second, + standard_deviation, + ] + + if not best_tuned_combination: + best_tuned_combination = [ + avg_latency, + tensor_parallel_degree, + None, + p90, + avg_tokens_per_second, + throughput_per_second, + standard_deviation, + ] + else: + tuned_configuration = [ + avg_latency, + tensor_parallel_degree, + None, + p90, + avg_tokens_per_second, + throughput_per_second, + standard_deviation, + ] + if _more_performant(best_tuned_combination, tuned_configuration): + best_tuned_combination = tuned_configuration + except LocalDeepPingException as e: + logger.warning( + "Deployment unsuccessful with %s: %s. " "Failed to invoke the model server: %s", + num_shard_env_var_name, + tensor_parallel_degree, + str(e), + ) + except LocalModelOutOfMemoryException as e: + logger.warning( + "Deployment unsuccessful with %s: %s. " + "Out of memory when loading the model: %s", + num_shard_env_var_name, + tensor_parallel_degree, + str(e), + ) + except LocalModelInvocationException as e: + logger.warning( + "Deployment unsuccessful with %s: %s. " + "Failed to invoke the model server: %s" + "Please check that model server configurations are as expected " + "(Ex. serialization, deserialization, content_type, accept).", + num_shard_env_var_name, + tensor_parallel_degree, + str(e), + ) + except LocalModelLoadException as e: + logger.warning( + "Deployment unsuccessful with %s: %s. " "Failed to load the model: %s.", + num_shard_env_var_name, + tensor_parallel_degree, + str(e), + ) + except SkipTuningComboException as e: + logger.warning( + "Deployment with %s: %s" + "was expected to be successful. However failed with: %s. " + "Trying next combination.", + num_shard_env_var_name, + tensor_parallel_degree, + str(e), + ) + except Exception: # pylint: disable=W0703 + logger.exception( + "Deployment unsuccessful with %s: %s. " "with uncovered exception", + num_shard_env_var_name, + tensor_parallel_degree, + ) + + if best_tuned_combination: + self.pysdk_model.env.update({num_shard_env_var_name: str(best_tuned_combination[1])}) + + _pretty_print_results_jumpstart(benchmark_results, [num_shard_env_var_name]) + logger.info( + "Model Configuration: %s was most performant with avg latency: %s, " + "p90 latency: %s, average tokens per second: %s, throughput/s: %s, " + "standard deviation of request %s", + self.pysdk_model.env, + best_tuned_combination[0], + best_tuned_combination[3], + best_tuned_combination[4], + best_tuned_combination[5], + best_tuned_combination[6], + ) + else: + self.pysdk_model.env.update(initial_env_vars) + logger.debug( + "Failed to gather any tuning results. " + "Please inspect the stack trace emitted from live logging for more details. " + "Falling back to default model configurations: %s", + self.pysdk_model.env, + ) + + return self.pysdk_model + + @_capture_telemetry("djl_jumpstart.tune") + def tune_for_djl_jumpstart(self, max_tuning_duration: int = 1800): + """Tune for Jumpstart Models with DJL DLC""" + return self._tune_for_js(sharded_supported=True, max_tuning_duration=max_tuning_duration) + + @_capture_telemetry("tgi_jumpstart.tune") + def tune_for_tgi_jumpstart(self, max_tuning_duration: int = 1800): + """Tune for Jumpstart Models with TGI DLC""" + sharded_supported = _sharded_supported(self.model, self.js_model_config) + return self._tune_for_js( + sharded_supported=sharded_supported, max_tuning_duration=max_tuning_duration + ) + def _build_for_jumpstart(self): """Placeholder docstring""" # we do not pickle for jumpstart. set to none @@ -254,6 +451,8 @@ def _build_for_jumpstart(self): self.image_uri = self.pysdk_model.image_uri self._build_for_djl_jumpstart() + + self.pysdk_model.tune = self.tune_for_djl_jumpstart elif "tgi-inference" in image_uri: logger.info("Building for TGI JumpStart Model ID...") self.model_server = ModelServer.TGI @@ -262,6 +461,8 @@ def _build_for_jumpstart(self): self.image_uri = self.pysdk_model.image_uri self._build_for_tgi_jumpstart() + + self.pysdk_model.tune = self.tune_for_tgi_jumpstart else: raise ValueError( "JumpStart Model ID was not packaged with djl-inference or tgi-inference container." diff --git a/src/sagemaker/serve/model_server/tgi/prepare.py b/src/sagemaker/serve/model_server/tgi/prepare.py index af09515da9..5dcd760844 100644 --- a/src/sagemaker/serve/model_server/tgi/prepare.py +++ b/src/sagemaker/serve/model_server/tgi/prepare.py @@ -13,6 +13,8 @@ """Prepare TgiModel for Deployment""" from __future__ import absolute_import + +import json import tarfile import logging from typing import List @@ -32,7 +34,7 @@ def _extract_js_resource(js_model_dir: str, code_dir: Path, js_id: str): custom_extractall_tarfile(resources, code_dir) -def _copy_jumpstart_artifacts(model_data: str, js_id: str, code_dir: Path) -> bool: +def _copy_jumpstart_artifacts(model_data: str, js_id: str, code_dir: Path) -> tuple: """Copy the associated JumpStart Resource into the code directory""" logger.info("Downloading JumpStart artifacts from S3...") @@ -56,7 +58,13 @@ def _copy_jumpstart_artifacts(model_data: str, js_id: str, code_dir: Path) -> bo else: raise ValueError("JumpStart model data compression format is unsupported: %s", model_data) - return True + config_json_file = code_dir.joinpath("config.json") + hf_model_config = None + if config_json_file.is_file(): + with open(str(config_json_file)) as config_json: + hf_model_config = json.load(config_json) + + return (hf_model_config, True) def _create_dir_structure(model_path: str) -> tuple: @@ -82,7 +90,7 @@ def prepare_tgi_js_resources( shared_libs: List[str] = None, dependencies: str = None, model_data: str = None, -) -> bool: +) -> tuple: """Prepare serving when a JumpStart model id is given Args: diff --git a/src/sagemaker/serve/utils/tuning.py b/src/sagemaker/serve/utils/tuning.py index c095791ef9..de02708278 100644 --- a/src/sagemaker/serve/utils/tuning.py +++ b/src/sagemaker/serve/utils/tuning.py @@ -1,5 +1,6 @@ """Holds mixin logic to support deployment of Model ID""" from __future__ import absolute_import + import logging from time import perf_counter import collections @@ -98,6 +99,52 @@ def _pretty_print_results_tgi(results: dict): ) +def _pretty_print_results_jumpstart(results: dict, model_env_vars=None): + """Pretty prints benchmark results""" + if model_env_vars is None: + model_env_vars = [] + + __env_var_data = {} + for model_env_var in model_env_vars: + __env_var_data[model_env_var] = [] + + avg_latencies = [] + p90s = [] + avg_tokens_per_seconds = [] + throughput_per_seconds = [] + standard_deviations = [] + ordered = collections.OrderedDict(sorted(results.items())) + + for key, value in ordered.items(): + avg_latencies.append(key) + p90s.append(value[1]) + avg_tokens_per_seconds.append(value[2]) + throughput_per_seconds.append(value[3]) + standard_deviations.append(value[4]) + + for model_env_var in __env_var_data: + __env_var_data[model_env_var].append(value[0][model_env_var]) + + df = pd.DataFrame( + { + "AverageLatency (Serial)": avg_latencies, + "P90_Latency (Serial)": p90s, + "AverageTokensPerSecond (Serial)": avg_tokens_per_seconds, + "ThroughputPerSecond (Concurrent)": throughput_per_seconds, + "StandardDeviationResponse (Concurrent)": standard_deviations, + **__env_var_data, + } + ) + + logger.info( + "\n================================================================== Benchmark " + "Results ==================================================================\n%s" + "\n============================================================================" + "===========================================================================\n", + df.to_string(), + ) + + def _tokens_per_second(generated_text: str, max_token_length: int, latency: float) -> int: """Placeholder docstring""" est_tokens = (_tokens_from_chars(generated_text) + _tokens_from_words(generated_text)) / 2 @@ -216,3 +263,24 @@ def _more_performant(best_tuned_configuration: list, tuned_configuration: list) return True return False return tuned_avg_latency <= best_avg_latency + + +def _sharded_supported(model_id: str, config_dict: dict) -> bool: + """Check if sharded is supported for this ``Model``""" + model_type = config_dict.get("model_type", None) + + if model_type is None: + return False + + if model_id.startswith("facebook/galactica"): + return True + + if model_type in ["bloom", "mpt", "ssm", "gpt_neox", "phi", "phi-msft", "opt", "t5"]: + return True + + if model_type in ["RefinedWeb", "RefinedWebModel", "falcon"] and not config_dict.get( + "alibi", False + ): + return True + + return False diff --git a/tests/integ/sagemaker/serve/constants.py b/tests/integ/sagemaker/serve/constants.py index cf4c6919aa..701c16d07a 100644 --- a/tests/integ/sagemaker/serve/constants.py +++ b/tests/integ/sagemaker/serve/constants.py @@ -20,6 +20,7 @@ SERVE_IN_PROCESS_TIMEOUT = 5 SERVE_MODEL_PACKAGE_TIMEOUT = 10 SERVE_LOCAL_CONTAINER_TIMEOUT = 10 +SERVE_LOCAL_CONTAINER_TUNE_TIMEOUT = 15 SERVE_SAGEMAKER_ENDPOINT_TIMEOUT = 15 SERVE_SAVE_TIMEOUT = 2 diff --git a/tests/integ/sagemaker/serve/test_serve_js_happy.py b/tests/integ/sagemaker/serve/test_serve_js_happy.py index 77a1e34eec..6145fc21a8 100644 --- a/tests/integ/sagemaker/serve/test_serve_js_happy.py +++ b/tests/integ/sagemaker/serve/test_serve_js_happy.py @@ -13,10 +13,13 @@ from __future__ import absolute_import import pytest + +from sagemaker.serve import Mode from sagemaker.serve.builder.model_builder import ModelBuilder from sagemaker.serve.builder.schema_builder import SchemaBuilder from tests.integ.sagemaker.serve.constants import ( SERVE_SAGEMAKER_ENDPOINT_TIMEOUT, + SERVE_LOCAL_CONTAINER_TUNE_TIMEOUT, PYTHON_VERSION_IS_NOT_310, ) @@ -74,3 +77,32 @@ def test_happy_tgi_sagemaker_endpoint(happy_model_builder, gpu_instance_type): ) if caught_ex: raise caught_ex + + +@pytest.mark.skipif( + PYTHON_VERSION_IS_NOT_310, + reason="The goal of these tests are to test the serving components of our feature", +) +@pytest.mark.local_mode +def test_happy_tune_tgi_local_mode(sagemaker_local_session): + logger.info("Running in LOCAL_CONTAINER mode...") + caught_ex = None + + model_builder = ModelBuilder( + model="huggingface-llm-bilingual-rinna-4b-instruction-ppo-bf16", + schema_builder=SchemaBuilder(SAMPLE_PROMPT, SAMPLE_RESPONSE), + mode=Mode.LOCAL_CONTAINER, + sagemaker_session=sagemaker_local_session, + ) + + model = model_builder.build() + + with timeout(minutes=SERVE_LOCAL_CONTAINER_TUNE_TIMEOUT): + try: + tuned_model = model.tune() + assert tuned_model.env is not None + except Exception as e: + caught_ex = e + finally: + if caught_ex: + raise caught_ex diff --git a/tests/unit/sagemaker/serve/builder/test_js_builder.py b/tests/unit/sagemaker/serve/builder/test_js_builder.py new file mode 100644 index 0000000000..743c02719d --- /dev/null +++ b/tests/unit/sagemaker/serve/builder/test_js_builder.py @@ -0,0 +1,529 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +from unittest.mock import MagicMock, patch + +import unittest + +from sagemaker.serve.builder.model_builder import ModelBuilder +from sagemaker.serve.mode.function_pointers import Mode +from sagemaker.serve.utils.exceptions import ( + LocalDeepPingException, + LocalModelLoadException, + LocalModelOutOfMemoryException, + LocalModelInvocationException, +) + +mock_model_id = "huggingface-llm-amazon-falconlite" +mock_t5_model_id = "google/flan-t5-xxl" +mock_prompt = "Hello, I'm a language model," +mock_response = "Hello, I'm a language model, and I'm here to help you with your English." +mock_sample_input = {"inputs": mock_prompt, "parameters": {}} +mock_sample_output = [{"generated_text": mock_response}] + +mock_set_serving_properties = (4, "fp16", 1, 256, 256) + +mock_tgi_most_performant_model_serving_properties = { + "SAGEMAKER_PROGRAM": "inference.py", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + "SM_NUM_GPUS": "2", +} +mock_tgi_model_serving_properties = { + "SAGEMAKER_PROGRAM": "inference.py", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + "SM_NUM_GPUS": "2", +} + +mock_djl_most_performant_model_serving_properties = { + "SAGEMAKER_PROGRAM": "inference.py", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + "OPTION_TENSOR_PARALLEL_DEGREE": "4", +} +mock_djl_model_serving_properties = { + "SAGEMAKER_PROGRAM": "inference.py", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + "OPTION_TENSOR_PARALLEL_DEGREE": "4", +} + +mock_schema_builder = MagicMock() +mock_schema_builder.sample_input = mock_sample_input +mock_schema_builder.sample_output = mock_sample_output + +mock_tgi_image_uri = ( + "123456789712.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi" + "-inference:2.1.1-tgi1.4.0-gpu-py310-cu121-ubuntu20.04" +) +mock_djl_image_uri = ( + "123456789712.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.24.0-neuronx-sdk2.14.1" +) + + +class TestJumpStartBuilder(unittest.TestCase): + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", + return_value=True, + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model", + return_value=MagicMock(), + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources", + return_value=({"model_type": "t5", "n_head": 71}, True), + ) + @patch("sagemaker.serve.builder.jumpstart_builder._get_ram_usage_mb", return_value=1024) + @patch( + "sagemaker.serve.builder.jumpstart_builder._get_nb_instance", return_value="ml.g5.24xlarge" + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder._get_admissible_tensor_parallel_degrees", + return_value=[4, 2, 1], + ) + @patch( + "sagemaker.serve.utils.tuning._serial_benchmark", + side_effect=[(5, 5, 25), (5.4, 5.4, 20), (5.2, 5.2, 15)], + ) + @patch( + "sagemaker.serve.utils.tuning._concurrent_benchmark", + side_effect=[(0.9, 1), (0.10, 4), (0.13, 2)], + ) + def test_tune_for_tgi_js_local_container( + self, + mock_concurrent_benchmarks, + mock_serial_benchmarks, + mock_admissible_tensor_parallel_degrees, + mock_get_nb_instance, + mock_get_ram_usage_mb, + mock_prepare_for_tgi, + mock_pre_trained_model, + mock_is_jumpstart_model, + mock_telemetry, + ): + builder = ModelBuilder( + model="facebook/galactica-mock-model-id", + schema_builder=mock_schema_builder, + mode=Mode.LOCAL_CONTAINER, + ) + + mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri + + model = builder.build() + builder.serve_settings.telemetry_opt_out = True + + mock_pre_trained_model.return_value.env = mock_tgi_model_serving_properties + + tuned_model = model.tune() + assert tuned_model.env == mock_tgi_most_performant_model_serving_properties + + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", + return_value=True, + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model", + return_value=MagicMock(), + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources", + return_value=({"model_type": "sharding_not_supported", "n_head": 71}, True), + ) + @patch("sagemaker.serve.builder.jumpstart_builder._get_ram_usage_mb", return_value=1024) + @patch( + "sagemaker.serve.builder.jumpstart_builder._get_nb_instance", return_value="ml.g5.24xlarge" + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder._get_admissible_tensor_parallel_degrees", + return_value=[4, 2, 1], + ) + @patch( + "sagemaker.serve.utils.tuning._serial_benchmark", + side_effect=[(5, 5, 25), (5.4, 5.4, 20), (5.2, 5.2, 15)], + ) + @patch( + "sagemaker.serve.utils.tuning._concurrent_benchmark", + side_effect=[(0.9, 1), (0.10, 4), (0.13, 2)], + ) + def test_tune_for_tgi_js_local_container_sharding_not_supported( + self, + mock_concurrent_benchmarks, + mock_serial_benchmarks, + mock_admissible_tensor_parallel_degrees, + mock_get_nb_instance, + mock_get_ram_usage_mb, + mock_prepare_for_tgi, + mock_pre_trained_model, + mock_is_jumpstart_model, + mock_telemetry, + ): + builder = ModelBuilder( + model=mock_model_id, schema_builder=mock_schema_builder, mode=Mode.LOCAL_CONTAINER + ) + + mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri + + model = builder.build() + builder.serve_settings.telemetry_opt_out = True + + mock_pre_trained_model.return_value.env = mock_tgi_model_serving_properties + + tuned_model = model.tune() + assert tuned_model.env == mock_tgi_most_performant_model_serving_properties + + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", + return_value=True, + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model", + return_value=MagicMock(), + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources", + return_value=({"model_type": "t5", "n_head": 71}, True), + ) + @patch("sagemaker.serve.builder.jumpstart_builder._get_ram_usage_mb", return_value=1024) + @patch( + "sagemaker.serve.builder.jumpstart_builder._get_nb_instance", return_value="ml.g5.24xlarge" + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder._get_admissible_tensor_parallel_degrees", + return_value=[4, 2, 1], + ) + @patch( + "sagemaker.serve.builder.djl_builder._serial_benchmark", + **{"return_value.raiseError.side_effect": LocalDeepPingException("mock_exception")} + ) + def test_tune_for_tgi_js_local_container_deep_ping_ex( + self, + mock_serial_benchmarks, + mock_admissible_tensor_parallel_degrees, + mock_get_nb_instance, + mock_get_ram_usage_mb, + mock_prepare_for_tgi, + mock_pre_trained_model, + mock_is_jumpstart_model, + mock_telemetry, + ): + builder = ModelBuilder( + model=mock_model_id, schema_builder=mock_schema_builder, mode=Mode.LOCAL_CONTAINER + ) + + mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri + + model = builder.build() + builder.serve_settings.telemetry_opt_out = True + + mock_pre_trained_model.return_value.env = mock_tgi_model_serving_properties + + tuned_model = model.tune() + assert tuned_model.env == mock_tgi_model_serving_properties + + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", + return_value=True, + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model", + return_value=MagicMock(), + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources", + return_value=({"model_type": "RefinedWebModel", "n_head": 71}, True), + ) + @patch("sagemaker.serve.builder.jumpstart_builder._get_ram_usage_mb", return_value=1024) + @patch( + "sagemaker.serve.builder.jumpstart_builder._get_nb_instance", return_value="ml.g5.24xlarge" + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder._get_admissible_tensor_parallel_degrees", + return_value=[4, 2, 1], + ) + @patch( + "sagemaker.serve.builder.djl_builder._serial_benchmark", + **{"return_value.raiseError.side_effect": LocalModelLoadException("mock_exception")} + ) + def test_tune_for_tgi_js_local_container_load_ex( + self, + mock_serial_benchmarks, + mock_admissible_tensor_parallel_degrees, + mock_get_nb_instance, + mock_get_ram_usage_mb, + mock_prepare_for_tgi, + mock_pre_trained_model, + mock_is_jumpstart_model, + mock_telemetry, + ): + builder = ModelBuilder( + model=mock_model_id, schema_builder=mock_schema_builder, mode=Mode.LOCAL_CONTAINER + ) + + mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri + + model = builder.build() + builder.serve_settings.telemetry_opt_out = True + + mock_pre_trained_model.return_value.env = mock_tgi_model_serving_properties + + tuned_model = model.tune() + assert tuned_model.env == mock_tgi_model_serving_properties + + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", + return_value=True, + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model", + return_value=MagicMock(), + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources", + return_value=({"n_head": 71}, True), + ) + @patch("sagemaker.serve.builder.jumpstart_builder._get_ram_usage_mb", return_value=1024) + @patch( + "sagemaker.serve.builder.jumpstart_builder._get_nb_instance", return_value="ml.g5.24xlarge" + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder._get_admissible_tensor_parallel_degrees", + return_value=[4, 2, 1], + ) + @patch( + "sagemaker.serve.builder.djl_builder._serial_benchmark", + **{"return_value.raiseError.side_effect": LocalModelOutOfMemoryException("mock_exception")} + ) + def test_tune_for_tgi_js_local_container_oom_ex( + self, + mock_serial_benchmarks, + mock_admissible_tensor_parallel_degrees, + mock_get_nb_instance, + mock_get_ram_usage_mb, + mock_prepare_for_tgi, + mock_pre_trained_model, + mock_is_jumpstart_model, + mock_telemetry, + ): + builder = ModelBuilder( + model=mock_model_id, schema_builder=mock_schema_builder, mode=Mode.LOCAL_CONTAINER + ) + + mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri + + model = builder.build() + builder.serve_settings.telemetry_opt_out = True + + mock_pre_trained_model.return_value.env = mock_tgi_model_serving_properties + + tuned_model = model.tune() + assert tuned_model.env == mock_tgi_model_serving_properties + + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", + return_value=True, + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model", + return_value=MagicMock(), + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources", + return_value=({"model_type": "t5", "n_head": 71}, True), + ) + @patch("sagemaker.serve.builder.jumpstart_builder._get_ram_usage_mb", return_value=1024) + @patch( + "sagemaker.serve.builder.jumpstart_builder._get_nb_instance", return_value="ml.g5.24xlarge" + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder._get_admissible_tensor_parallel_degrees", + return_value=[4, 2, 1], + ) + @patch( + "sagemaker.serve.builder.djl_builder._serial_benchmark", + **{"return_value.raiseError.side_effect": LocalModelInvocationException("mock_exception")} + ) + def test_tune_for_tgi_js_local_container_invoke_ex( + self, + mock_serial_benchmarks, + mock_admissible_tensor_parallel_degrees, + mock_get_nb_instance, + mock_get_ram_usage_mb, + mock_prepare_for_tgi, + mock_pre_trained_model, + mock_is_jumpstart_model, + mock_telemetry, + ): + builder = ModelBuilder( + model=mock_model_id, schema_builder=mock_schema_builder, mode=Mode.LOCAL_CONTAINER + ) + + mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri + + model = builder.build() + builder.serve_settings.telemetry_opt_out = True + + mock_pre_trained_model.return_value.env = mock_tgi_model_serving_properties + + tuned_model = model.tune() + assert tuned_model.env == mock_tgi_model_serving_properties + + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", + return_value=True, + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model", + return_value=MagicMock(), + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.prepare_djl_js_resources", + return_value=( + mock_set_serving_properties, + {"model_type": "t5", "n_head": 71}, + True, + ), + ) + @patch("sagemaker.serve.builder.jumpstart_builder._get_ram_usage_mb", return_value=1024) + @patch( + "sagemaker.serve.builder.jumpstart_builder._get_nb_instance", return_value="ml.g5.24xlarge" + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder._get_admissible_tensor_parallel_degrees", + return_value=[4, 2, 1], + ) + @patch( + "sagemaker.serve.utils.tuning._serial_benchmark", + side_effect=[(5, 5, 25), (5.4, 5.4, 20), (5.2, 5.2, 15)], + ) + @patch( + "sagemaker.serve.utils.tuning._concurrent_benchmark", + side_effect=[(0.9, 1), (0.10, 4), (0.13, 2)], + ) + def test_tune_for_djl_js_local_container( + self, + mock_concurrent_benchmarks, + mock_serial_benchmarks, + mock_admissible_tensor_parallel_degrees, + mock_get_nb_instance, + mock_get_ram_usage_mb, + mock_prepare_for_tgi, + mock_pre_trained_model, + mock_is_jumpstart_model, + mock_telemetry, + ): + builder = ModelBuilder( + model="facebook/galactica-mock", + schema_builder=mock_schema_builder, + mode=Mode.LOCAL_CONTAINER, + ) + + mock_pre_trained_model.return_value.image_uri = mock_djl_image_uri + + model = builder.build() + builder.serve_settings.telemetry_opt_out = True + + mock_pre_trained_model.return_value.env = mock_djl_model_serving_properties + + tuned_model = model.tune() + assert tuned_model.env == mock_djl_most_performant_model_serving_properties + + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", + return_value=True, + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model", + return_value=MagicMock(), + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.prepare_djl_js_resources", + return_value=( + mock_set_serving_properties, + {"model_type": "RefinedWebModel", "n_head": 71}, + True, + ), + ) + @patch("sagemaker.serve.builder.jumpstart_builder._get_ram_usage_mb", return_value=1024) + @patch( + "sagemaker.serve.builder.jumpstart_builder._get_nb_instance", return_value="ml.g5.24xlarge" + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder._get_admissible_tensor_parallel_degrees", + return_value=[1], + ) + @patch( + "sagemaker.serve.builder.djl_builder._serial_benchmark", + **{"return_value.raiseError.side_effect": LocalModelInvocationException("mock_exception")} + ) + def test_tune_for_djl_js_local_container_invoke_ex( + self, + mock_serial_benchmarks, + mock_admissible_tensor_parallel_degrees, + mock_get_nb_instance, + mock_get_ram_usage_mb, + mock_prepare_for_tgi, + mock_pre_trained_model, + mock_is_jumpstart_model, + mock_telemetry, + ): + builder = ModelBuilder( + model=mock_model_id, schema_builder=mock_schema_builder, mode=Mode.LOCAL_CONTAINER + ) + + mock_pre_trained_model.return_value.image_uri = mock_djl_image_uri + + model = builder.build() + builder.serve_settings.telemetry_opt_out = True + + mock_pre_trained_model.return_value.env = mock_djl_model_serving_properties + + tuned_model = model.tune() + assert tuned_model.env == mock_djl_model_serving_properties + + @patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id", + return_value=True, + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model", + return_value=MagicMock(), + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources", + return_value=({"model_type": "t5", "n_head": 71}, True), + ) + def test_tune_for_djl_js_endpoint_mode_ex( + self, + mock_prepare_for_tgi, + mock_pre_trained_model, + mock_is_jumpstart_model, + mock_telemetry, + ): + builder = ModelBuilder( + model=mock_model_id, schema_builder=mock_schema_builder, mode=Mode.SAGEMAKER_ENDPOINT + ) + + mock_pre_trained_model.return_value.image_uri = mock_djl_image_uri + + model = builder.build() + builder.serve_settings.telemetry_opt_out = True + + tuned_model = model.tune() + assert tuned_model == model diff --git a/tests/unit/sagemaker/serve/model_server/tgi/test_tgi_prepare.py b/tests/unit/sagemaker/serve/model_server/tgi/test_tgi_prepare.py index c072f3cb99..88d109831d 100644 --- a/tests/unit/sagemaker/serve/model_server/tgi/test_tgi_prepare.py +++ b/tests/unit/sagemaker/serve/model_server/tgi/test_tgi_prepare.py @@ -66,7 +66,11 @@ def test_create_dir_structure_invalid_path(self, mock_path): self.assertEquals("model_dir is not a valid directory", str(context.exception)) @patch("sagemaker.serve.model_server.tgi.prepare.S3Downloader") - def test_prepare_tgi_js_resources_for_jumpstart_uncompressed_str(self, mock_s3_downloader): + @patch("builtins.open", read_data="data") + @patch("json.load", return_value={}) + def test_prepare_tgi_js_resources_for_jumpstart_uncompressed_str( + self, mock_json_load, mock_open, mock_s3_downloader + ): mock_code_dir = Mock() mock_s3_downloader_obj = Mock() mock_s3_downloader.return_value = mock_s3_downloader_obj @@ -80,7 +84,11 @@ def test_prepare_tgi_js_resources_for_jumpstart_uncompressed_str(self, mock_s3_d ) @patch("sagemaker.serve.model_server.tgi.prepare.S3Downloader") - def test_prepare_tgi_js_resources_for_jumpstart_invalid_model_data(self, mock_s3_downloader): + @patch("builtins.open", read_data="data") + @patch("json.load", return_value={}) + def test_prepare_tgi_js_resources_for_jumpstart_invalid_model_data( + self, mock_json_load, mock_open, mock_s3_downloader + ): mock_code_dir = Mock() mock_s3_downloader_obj = Mock() mock_s3_downloader.return_value = mock_s3_downloader_obj @@ -108,8 +116,12 @@ def test_prepare_tgi_js_resources_for_jumpstart_invalid_format(self): @patch("sagemaker.serve.model_server.tgi.prepare.S3Downloader") @patch("sagemaker.serve.model_server.tgi.prepare._tmpdir") @patch("sagemaker.serve.model_server.tgi.prepare._extract_js_resource") + @patch("builtins.open", read_data="data") + @patch("json.load", return_value={}) def test_prepare_tgi_js_resources_for_jumpstart_compressed_str( self, + mock_open, + mock_json_load, mock_extract_js_resource, mock_tmpdir, mock_s3_downloader,