Skip to content

Commit

Permalink
Feat: Pull latest tei container for sentence similiarity models on Hu…
Browse files Browse the repository at this point in the history
…ggingFace hub (#4686)

* Update: Pull latest tei container for sentence similiarity models

* Fix formatting

* Address PR comments

* Fix formatting

* Fix check

* Switch sentence similarity to be deployed on tgi

* Fix formatting

* Fix formatting

* Fix formatting

* Fix formatting

* Introduce TEI builder with TGI server

* Fix formmatting

* Add integ test

* Fix formatting

* Add integ test

* Add integ test

* Add integ test

* Add integ test

* Add integ test

* Fix formatting

* Move to G5 for integ test

* Fix formatting

* Integ test updates

* Integ test updates

* Integ test updates

* Fix formatting

* Integ test updates

* Move back to generate for ping

* Integ test updates

* Integ test updates
  • Loading branch information
samruds committed May 17, 2024
1 parent 06e6f9d commit c9b55a4
Show file tree
Hide file tree
Showing 6 changed files with 543 additions and 5 deletions.
11 changes: 7 additions & 4 deletions src/sagemaker/serve/builder/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from sagemaker.serve.detector.pickler import save_pkl, save_xgboost
from sagemaker.serve.builder.serve_settings import _ServeSettings
from sagemaker.serve.builder.djl_builder import DJL
from sagemaker.serve.builder.tei_builder import TEI
from sagemaker.serve.builder.tgi_builder import TGI
from sagemaker.serve.builder.jumpstart_builder import JumpStart
from sagemaker.serve.builder.transformers_builder import Transformers
Expand Down Expand Up @@ -95,9 +96,9 @@
}


# pylint: disable=attribute-defined-outside-init, disable=E1101, disable=R0901
# pylint: disable=attribute-defined-outside-init, disable=E1101, disable=R0901, disable=R1705
@dataclass
class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing):
class ModelBuilder(Triton, DJL, JumpStart, TGI, Transformers, TensorflowServing, TEI):
"""Class that builds a deployable model.
Args:
Expand Down Expand Up @@ -753,7 +754,7 @@ def build( # pylint: disable=R0911
model_task = self.model_metadata.get("HF_TASK")
if self._is_jumpstart_model_id():
return self._build_for_jumpstart()
if self._is_djl(): # pylint: disable=R1705
if self._is_djl():
return self._build_for_djl()
else:
hf_model_md = get_huggingface_model_metadata(
Expand All @@ -764,8 +765,10 @@ def build( # pylint: disable=R0911
model_task = hf_model_md.get("pipeline_tag")
if self.schema_builder is None and model_task is not None:
self._hf_schema_builder_init(model_task)
if model_task == "text-generation": # pylint: disable=R1705
if model_task == "text-generation":
return self._build_for_tgi()
if model_task == "sentence-similarity":
return self._build_for_tei()
elif self._can_fit_on_single_gpu():
return self._build_for_transformers()
elif (
Expand Down
222 changes: 222 additions & 0 deletions src/sagemaker/serve/builder/tei_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Holds mixin logic to support deployment of Model ID"""
from __future__ import absolute_import
import logging
from typing import Type
from abc import ABC, abstractmethod

from sagemaker import image_uris
from sagemaker.model import Model
from sagemaker.djl_inference.model import _get_model_config_properties_from_hf

from sagemaker.huggingface import HuggingFaceModel
from sagemaker.serve.utils.local_hardware import (
_get_nb_instance,
)
from sagemaker.serve.model_server.tgi.prepare import _create_dir_structure
from sagemaker.serve.utils.predictors import TgiLocalModePredictor
from sagemaker.serve.utils.types import ModelServer
from sagemaker.serve.mode.function_pointers import Mode
from sagemaker.serve.utils.telemetry_logger import _capture_telemetry
from sagemaker.base_predictor import PredictorBase

logger = logging.getLogger(__name__)

_CODE_FOLDER = "code"


class TEI(ABC):
"""TEI build logic for ModelBuilder()"""

def __init__(self):
self.model = None
self.serve_settings = None
self.sagemaker_session = None
self.model_path = None
self.dependencies = None
self.modes = None
self.mode = None
self.model_server = None
self.image_uri = None
self._is_custom_image_uri = False
self.image_config = None
self.vpc_config = None
self._original_deploy = None
self.hf_model_config = None
self._default_tensor_parallel_degree = None
self._default_data_type = None
self._default_max_tokens = None
self.pysdk_model = None
self.schema_builder = None
self.env_vars = None
self.nb_instance_type = None
self.ram_usage_model_load = None
self.secret_key = None
self.jumpstart = None
self.role_arn = None

@abstractmethod
def _prepare_for_mode(self):
"""Placeholder docstring"""

@abstractmethod
def _get_client_translators(self):
"""Placeholder docstring"""

def _set_to_tgi(self):
"""Placeholder docstring"""
if self.model_server != ModelServer.TGI:
messaging = (
"HuggingFace Model ID support on model server: "
f"{self.model_server} is not currently supported. "
f"Defaulting to {ModelServer.TGI}"
)
logger.warning(messaging)
self.model_server = ModelServer.TGI

def _create_tei_model(self, **kwargs) -> Type[Model]:
"""Placeholder docstring"""
if self.nb_instance_type and "instance_type" not in kwargs:
kwargs.update({"instance_type": self.nb_instance_type})

if not self.image_uri:
self.image_uri = image_uris.retrieve(
"huggingface-tei",
image_scope="inference",
instance_type=kwargs.get("instance_type"),
region=self.sagemaker_session.boto_region_name,
)

pysdk_model = HuggingFaceModel(
image_uri=self.image_uri,
image_config=self.image_config,
vpc_config=self.vpc_config,
env=self.env_vars,
role=self.role_arn,
sagemaker_session=self.sagemaker_session,
)

logger.info("Detected %s. Proceeding with the the deployment.", self.image_uri)

self._original_deploy = pysdk_model.deploy
pysdk_model.deploy = self._tei_model_builder_deploy_wrapper
return pysdk_model

@_capture_telemetry("tei.deploy")
def _tei_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]:
"""Placeholder docstring"""
timeout = kwargs.get("model_data_download_timeout")
if timeout:
self.pysdk_model.env.update({"MODEL_LOADING_TIMEOUT": str(timeout)})

if "mode" in kwargs and kwargs.get("mode") != self.mode:
overwrite_mode = kwargs.get("mode")
# mode overwritten by customer during model.deploy()
logger.warning(
"Deploying in %s Mode, overriding existing configurations set for %s mode",
overwrite_mode,
self.mode,
)

if overwrite_mode == Mode.SAGEMAKER_ENDPOINT:
self.mode = self.pysdk_model.mode = Mode.SAGEMAKER_ENDPOINT
elif overwrite_mode == Mode.LOCAL_CONTAINER:
self._prepare_for_mode()
self.mode = self.pysdk_model.mode = Mode.LOCAL_CONTAINER
else:
raise ValueError("Mode %s is not supported!" % overwrite_mode)

serializer = self.schema_builder.input_serializer
deserializer = self.schema_builder._output_deserializer
if self.mode == Mode.LOCAL_CONTAINER:
timeout = kwargs.get("model_data_download_timeout")

predictor = TgiLocalModePredictor(
self.modes[str(Mode.LOCAL_CONTAINER)], serializer, deserializer
)

self.modes[str(Mode.LOCAL_CONTAINER)].create_server(
self.image_uri,
timeout if timeout else 1800,
None,
predictor,
self.pysdk_model.env,
jumpstart=False,
)

return predictor

if "mode" in kwargs:
del kwargs["mode"]
if "role" in kwargs:
self.pysdk_model.role = kwargs.get("role")
del kwargs["role"]

# set model_data to uncompressed s3 dict
self.pysdk_model.model_data, env_vars = self._prepare_for_mode()
self.env_vars.update(env_vars)
self.pysdk_model.env.update(self.env_vars)

# if the weights have been cached via local container mode -> set to offline
if str(Mode.LOCAL_CONTAINER) in self.modes:
self.pysdk_model.env.update({"TRANSFORMERS_OFFLINE": "1"})
else:
# if has not been built for local container we must use cache
# that hosting has write access to.
self.pysdk_model.env["TRANSFORMERS_CACHE"] = "/tmp"
self.pysdk_model.env["HUGGINGFACE_HUB_CACHE"] = "/tmp"

if "endpoint_logging" not in kwargs:
kwargs["endpoint_logging"] = True

if not self.nb_instance_type and "instance_type" not in kwargs:
raise ValueError(
"Instance type must be provided when deploying " "to SageMaker Endpoint mode."
)

if "initial_instance_count" not in kwargs:
kwargs.update({"initial_instance_count": 1})

predictor = self._original_deploy(*args, **kwargs)

predictor.serializer = serializer
predictor.deserializer = deserializer
return predictor

def _build_for_hf_tei(self):
"""Placeholder docstring"""
self.nb_instance_type = _get_nb_instance()

_create_dir_structure(self.model_path)
if not hasattr(self, "pysdk_model"):
self.env_vars.update({"HF_MODEL_ID": self.model})
self.hf_model_config = _get_model_config_properties_from_hf(
self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
)

self.pysdk_model = self._create_tei_model()

if self.mode == Mode.LOCAL_CONTAINER:
self._prepare_for_mode()

return self.pysdk_model

def _build_for_tei(self):
"""Placeholder docstring"""
self.secret_key = None

self._set_to_tgi()

self.pysdk_model = self._build_for_hf_tei()
return self.pysdk_model
123 changes: 123 additions & 0 deletions tests/integ/sagemaker/serve/test_serve_tei.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import pytest
from sagemaker.serve.builder.schema_builder import SchemaBuilder
from sagemaker.serve.builder.model_builder import ModelBuilder, Mode

from tests.integ.sagemaker.serve.constants import (
HF_DIR,
PYTHON_VERSION_IS_NOT_310,
SERVE_SAGEMAKER_ENDPOINT_TIMEOUT,
)

from tests.integ.timeout import timeout
from tests.integ.utils import cleanup_model_resources
import logging

logger = logging.getLogger(__name__)

sample_input = {
"inputs": "The man worked as a [MASK].",
}

loaded_response = [
{
"score": 0.0974755585193634,
"token": 10533,
"token_str": "carpenter",
"sequence": "the man worked as a carpenter.",
},
{
"score": 0.052383411675691605,
"token": 15610,
"token_str": "waiter",
"sequence": "the man worked as a waiter.",
},
{
"score": 0.04962712526321411,
"token": 13362,
"token_str": "barber",
"sequence": "the man worked as a barber.",
},
{
"score": 0.0378861166536808,
"token": 15893,
"token_str": "mechanic",
"sequence": "the man worked as a mechanic.",
},
{
"score": 0.037680838257074356,
"token": 18968,
"token_str": "salesman",
"sequence": "the man worked as a salesman.",
},
]


@pytest.fixture
def model_input():
return {"inputs": "The man worked as a [MASK]."}


@pytest.fixture
def model_builder_model_schema_builder():
return ModelBuilder(
model_path=HF_DIR,
model="BAAI/bge-m3",
schema_builder=SchemaBuilder(sample_input, loaded_response),
model_metadata={
"HF_TASK": "sentence-similarity",
},
)


@pytest.fixture
def model_builder(request):
return request.getfixturevalue(request.param)


@pytest.mark.skipif(
PYTHON_VERSION_IS_NOT_310,
reason="Testing feature needs latest metadata",
)
@pytest.mark.parametrize("model_builder", ["model_builder_model_schema_builder"], indirect=True)
def test_tei_sagemaker_endpoint(sagemaker_session, model_builder, model_input):
logger.info("Running in SAGEMAKER_ENDPOINT mode...")
caught_ex = None

iam_client = sagemaker_session.boto_session.client("iam")
role_arn = iam_client.get_role(RoleName="SageMakerRole")["Role"]["Arn"]

model = model_builder.build(
mode=Mode.SAGEMAKER_ENDPOINT, role_arn=role_arn, sagemaker_session=sagemaker_session
)

with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT):
try:
logger.info("Deploying and predicting in SAGEMAKER_ENDPOINT mode...")
predictor = model.deploy(instance_type="ml.g5.2xlarge", initial_instance_count=1)
predictor.predict(model_input)
assert predictor is not None
except Exception as e:
caught_ex = e
finally:
cleanup_model_resources(
sagemaker_session=model_builder.sagemaker_session,
model_name=model.name,
endpoint_name=model.endpoint_name,
)
if caught_ex:
logger.exception(caught_ex)
assert False, f"{caught_ex} was thrown when running tei sagemaker endpoint test"
Loading

0 comments on commit c9b55a4

Please sign in to comment.