# Deploy HuggingFaceH4/zephyr-7b-beta on Amazon SageMaker using Hugging Face Text Generation Inference (TGI) container

## Resources
- [Zephyr-7B-beta model card](https://huggingface.co/HuggingFaceH4/zephyr-7b-beta)
- [Deep Learning Containers](https://docs.aws.amazon.com/sagemaker/latest/dg/large-model-inference-dlc.html)
- [Deep Java Library - Large Model Inference](https://docs.djl.ai/docs/serving/serving/docs/large_model_inference.html)

## Step 1: Setup

In [None]:
%pip install --upgrade --quiet sagemaker

In [None]:
import io
import json
import boto3
import sagemaker

In [None]:
role = sagemaker.get_execution_role()  # execution role for the endpoint
sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs
bucket = sess.default_bucket()  # bucket to house artifacts
region = sess._region_name  # region name of the current SageMaker Studio environment
account_id = sess.account_id()  # account_id of the current SageMaker Studio environment

sm_client = boto3.client("sagemaker")  # client to intreract with SageMaker
smr_client = boto3.client("sagemaker-runtime")  # client to intreract with SageMaker Endpoints

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sess.default_bucket()}")
print(f"sagemaker session region: {sess.boto_region_name}")
print(f"boto3 version: {boto3.__version__}")
print(f"sagemaker version: {sagemaker.__version__}")

## Step 2: Endpoint Deployment

In [None]:
version = "0.26.0"
deepspeed_image = sagemaker.image_uris.retrieve(
    "djl-deepspeed", region=region, version=version
)
print(f"SeepSpeed image going to be used is ----> {deepspeed_image}")

trtllm_image = sagemaker.image_uris.retrieve(
    "djl-tensorrtllm", region=region, version=version
)
print(f"TensorRT-LLM image going to be used is ----> {trtllm_image}")

### LMI container configuration
The notebook contains configurations for 2 use cases:
1. Open-ended generation (vllm_config and deepspeed_image)
2. Summarization (trtllm_config and trtllm_image)

Please pick ***one*** based on your use case

In [None]:
#
# Please pick _one_ bases on your use case
#

number_of_gpu = 1
model_name = "Zephyr-7b-beta"

# vLLM config
vllm_config = {
    "SERVING_LOAD_MODELS": "test::Python=/opt/ml/model",
    "OPTION_MODEL_ID": "HuggingFaceH4/zephyr-7b-beta",
    "OPTION_ROLLING_BATCH": "vllm",
    "OPTION_TENSOR_PARALLEL_DEGREE": "max",
    "OPTION_MAX_ROLLING_BATCH_SIZE": "32",
    "OPTION_MAX_INPUT_LEN": "1024",
    "OPTION_MAX_OUTPUT_LEN": "2048",
    "OPTION_MAX_MODEL_LEN": "2048",
    "OPTION_DTYPE": "fp16",
    #"OPTION_OUTPUT_FORMATTER": "jsonlines",
    #"OPTION_ENABLE_STREAMING": "True"
}

trtllm_config = {
    "SERVING_LOAD_MODELS": "test::MPI=/opt/ml/model",
    "OPTION_MODEL_ID": "HuggingFaceH4/zephyr-7b-beta",
    "OPTION_TENSOR_PARALLEL_DEGREE": "max",
    "OPTION_ROLLING_BATCH": "trtllm",
    "OPTION_MAX_INPUT_LEN": "1024",
    "OPTION_MAX_OUTPUT_LEN": "2048",
    "OPTION_MAX_ROLLING_BATCH_SIZE": "64"
}

image_uri = deepspeed_image
#image_uri = trtllm_image
model_name = "Zephyr-7b-beta-vLLM"
#model_name = "Zephyr-7b-beta-TRTLLM"
env = vllm_config
#env = trtllm_config

# create Model
print(model_name)

create_model_response = sm_client.create_model(
    ModelName = model_name,
    ExecutionRoleArn = role,
    PrimaryContainer = {
        "Image": image_uri,
        "Environment": env,
    }
)
model_arn = create_model_response["ModelArn"]

print(f"Created Model: {model_arn}")

In [None]:
endpoint_config_name = f"{model_name}-EP-config"
instance_type = "ml.g5.2xlarge"
#
# REQUIRED for TensorRT-LLM Just In Time (JIT) compilation
#instance_type = "ml.g5.16xlarge"
health_check_timeout = 600

endpoint_config_response = sm_client.create_endpoint_config(
    EndpointConfigName = endpoint_config_name,
    ProductionVariants = [
        {
            "VariantName": "variant1",
            "ModelName": model_name,
            "InstanceType": instance_type,
            "InitialInstanceCount": 1,
            "ContainerStartupHealthCheckTimeoutInSeconds": health_check_timeout,
            "RoutingConfig": {
                'RoutingStrategy': 'LEAST_OUTSTANDING_REQUESTS'
            },
        },
    ],
)
endpoint_config_response

In [None]:
#
# Create endpoint config
#
endpoint_name = f"{model_name}-EP"

create_endpoint_response = sm_client.create_endpoint(
    EndpointName = endpoint_name, EndpointConfigName = endpoint_config_name
)
print(f"Created Endpoint: {create_endpoint_response['EndpointArn']}")

In [None]:
#
# Using helper function to wait for the endpoint to be ready
#
sess.wait_for_endpoint(endpoint_name)

## Step 3: Run Inference

In [None]:
# define payload
#
# This will give a weird output for streaming config (RUN 3.1)
#
prompt = """You are an helpful Assistant, called Zephyr. Knowing everyting about AWS.

User: Can you tell me something about Amazon SageMaker?
Zephyr:"""

# hyperparameters for llm (remove "\nUser:" from stop conditions)
payload = {
  "inputs": prompt,
  "parameters": {
    "do_sample": True,
    "top_p": 0.9,
    "temperature": 0.8,
    "max_new_tokens": 512,
    "repetition_penalty": 1.03,
    #"stop": ["<|endoftext|>","</s>"]
  }
}

response_model = smr_client.invoke_endpoint(
    EndpointName = endpoint_name,
    Body = json.dumps(payload),
    ContentType = "application/json",
)

assistant = json.loads(response_model["Body"].read().decode("utf8"))["generated_text"]
#assistant = response_model["Body"].read().decode("utf8")
print(assistant)

## (Optional) Step 3.1: Run Inference (streaming)
Require change to the model config above

In [None]:
class LineIterator:
    """
    A helper class for parsing the byte stream input.

    The output of the model will be in the following format:
    ```
    b'{"outputs": [" a"]}\n'
    b'{"outputs": [" challenging"]}\n'
    b'{"outputs": [" problem"]}\n'
    ...
    ```

    While usually each PayloadPart event from the event stream will contain a byte array
    with a full json, this is not guaranteed and some of the json objects may be split across
    PayloadPart events. For example:
    ```
    {'PayloadPart': {'Bytes': b'{"outputs": '}}
    {'PayloadPart': {'Bytes': b'[" problem"]}\n'}}
    ```

    This class accounts for this by concatenating bytes written via the 'write' function
    and then exposing a method which will return lines (ending with a '\n' character) within
    the buffer via the 'scan_lines' function. It maintains the position of the last read
    position to ensure that previous bytes are not exposed again.
    """

    def __init__(self, stream):
        self.byte_iterator = iter(stream)
        self.buffer = io.BytesIO()
        self.read_pos = 0

    def __iter__(self):
        return self

    def __next__(self):
        while True:
            self.buffer.seek(self.read_pos)
            line = self.buffer.readline()
            if line and line[-1] == ord('\n'):
                self.read_pos += len(line)
                return line[:-1]
            try:
                chunk = next(self.byte_iterator)
            except StopIteration:
                if self.read_pos < self.buffer.getbuffer().nbytes:
                    continue
                raise
            if 'PayloadPart' not in chunk:
                print('Unknown event type:' + chunk)
                continue
            self.buffer.seek(0, io.SEEK_END)
            self.buffer.write(chunk['PayloadPart']['Bytes'])

In [None]:
body = {"inputs": "what is Amazon SageMaker?", "parameters": {"max_new_tokens":400}}
resp = smr_client.invoke_endpoint_with_response_stream(
    EndpointName=endpoint_name,
    Body=json.dumps(body),
    ContentType="application/json"
)
event_stream = resp['Body']

for line in LineIterator(event_stream):
    resp = json.loads(line)
#    #print(resp)
    print(resp["token"]["text"], end='')
    #print(resp.get("outputs")[0], end='')

## Step 3.2: Test inference

In [None]:
#
# Calculate runtime performance
#
import time
import numpy as np

# define payload
prompt = """You are an helpful Assistant, called Zephyr. Knowing everyting about AWS.

User: Can you tell me something about Amazon SageMaker?
Zephyr:"""

# hyperparameters for llm (remove "\nUser:" from stop conditions)
payload = {
  "inputs": prompt,
  "parameters": {
    "do_sample": True,
    "top_p": 0.9,
    "temperature": 0.8,
    "max_new_tokens": 512,
    "repetition_penalty": 1.03,
    #"stop": ["<|endoftext|>","</s>"]
  }
}

results = []
for i in range(0, 10):
    start = time.time()
    response_model = smr_client.invoke_endpoint(
        EndpointName=endpoint_name,
        Body=json.dumps(payload),
        ContentType="application/json",
    )
    results.append((time.time() - start) * 1000)

print("\nPredictions for model latency: \n")
print("\nP95: " + str(np.percentile(results, 95)) + " ms\n")
print("P90: " + str(np.percentile(results, 90)) + " ms\n")
print("Average: " + str(np.average(results)) + " ms\n")

## Step 4: Cleanup

In [None]:
sess.delete_endpoint(endpoint_name)
sess.delete_endpoint_config(endpoint_config_name)
sess.delete_model(model_name)