# Deploy HuggingFaceH4/zephyr-7b-beta on Amazon SageMaker using Hugging Face Text Generation Inference (TGI) container

## Resources
- [Zephyr-7B-beta model card](https://huggingface.co/HuggingFaceH4/zephyr-7b-beta)
- [TGI documentation](https://huggingface.co/docs/text-generation-inference/en/index)

## Step 1: Setup

In [None]:
%pip install --upgrade --quiet sagemaker

In [None]:
import sagemaker
import json
import boto3

In [None]:
role = sagemaker.get_execution_role()  # execution role for the endpoint
sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs
bucket = sess.default_bucket()  # bucket to house artifacts
region = sess._region_name  # region name of the current SageMaker Studio environment
account_id = sess.account_id()  # account_id of the current SageMaker Studio environment

sm_client = boto3.client("sagemaker")  # client to intreract with SageMaker
smr_client = boto3.client("sagemaker-runtime")  # client to intreract with SageMaker Endpoints

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sess.default_bucket()}")
print(f"sagemaker session region: {sess.boto_region_name}")
print(f"boto3 version: {boto3.__version__}")
print(f"sagemaker version: {sagemaker.__version__}")

## Step 2: Endpoint Deployment

In [None]:
from sagemaker.huggingface import get_huggingface_llm_image_uri

# retrieve the llm image uri
latest_version = "1.4.2" 
llm_image = get_huggingface_llm_image_uri("huggingface", version=latest_version)
print(f"llm image uri: {llm_image}")

In [None]:
from sagemaker.huggingface import HuggingFaceModel

# sagemaker config
instance_type = "ml.g5.2xlarge"
number_of_gpu = 1 # deploying on g5.2xlarge that has only one GPU
model_name = "Zephyr-7b-beta"

# TGI config
config = {
    'HF_MODEL_ID': "HuggingFaceH4/zephyr-7b-beta", # model_id from hf.co/models
    'SM_NUM_GPUS': json.dumps(number_of_gpu), # Number of GPU used per replica
    'MAX_INPUT_LENGTH': json.dumps(1024),  # Max length of input text
    'MAX_TOTAL_TOKENS': json.dumps(2048),  # Max length of the generation (including input text)
}

# create HuggingFaceModel
llm_model = HuggingFaceModel(
    role = role,
    image_uri = llm_image,
    env = config,
    name = model_name
)

In [None]:
# Deploy model to an endpoint
# https://sagemaker.readthedocs.io/en/stable/api/inference/model.html#sagemaker.model.Model.deploy
endpoint_name = f"{model_name}-TGI-EP"
health_check_timeout = 600

llm = llm_model.deploy(
    initial_instance_count = 1,
    instance_type = instance_type,
    endpoint_name=endpoint_name,
    container_startup_health_check_timeout = health_check_timeout, # timeout for loading the model
)

## Step 3: Run Inference

In [None]:
# define payload
prompt = """You are an helpful Assistant, called Zephyr. Knowing everyting about AWS.

User: Can you tell me something about Amazon SageMaker?
Zephyr:"""

# hyperparameters for llm (remove "\nUser:" from stop conditions)
payload = {
  "inputs": prompt,
  "parameters": {
    "do_sample": True,
    "top_p": 0.9,
    "temperature": 0.8,
    "max_new_tokens": 512,
    "repetition_penalty": 1.03,
    "stop": ["<|endoftext|>","</s>"]
  }
}

# send request to endpoint
response = llm.predict(payload)

# print assistant respond
assistant = response[0]["generated_text"][len(prompt):]
print(assistant)

## Step 3.1: Run inference (streaming)

In [None]:
#
# Streaming
#
import io

class LineIterator:
    """
    A helper class for parsing the byte stream input. 
    
    The output of the model will be in the following format:
    ```
    b'{"outputs": [" a"]}\n'
    b'{"outputs": [" challenging"]}\n'
    b'{"outputs": [" problem"]}\n'
    ...
    ```
    
    While usually each PayloadPart event from the event stream will contain a byte array 
    with a full json, this is not guaranteed and some of the json objects may be split across
    PayloadPart events. For example:
    ```
    {'PayloadPart': {'Bytes': b'{"outputs": '}}
    {'PayloadPart': {'Bytes': b'[" problem"]}\n'}}
    ```
    
    This class accounts for this by concatenating bytes written via the 'write' function
    and then exposing a method which will return lines (ending with a '\n' character) within
    the buffer via the 'scan_lines' function. It maintains the position of the last read 
    position to ensure that previous bytes are not exposed again. 
    """
    
    def __init__(self, stream):
        self.byte_iterator = iter(stream)
        self.buffer = io.BytesIO()
        self.read_pos = 0

    def __iter__(self):
        return self

    def __next__(self):
        while True:
            self.buffer.seek(self.read_pos)
            line = self.buffer.readline()
            if line and line[-1] == ord('\n'):
                self.read_pos += len(line)
                return line[:-1]
            try:
                chunk = next(self.byte_iterator)
            except StopIteration:
                if self.read_pos < self.buffer.getbuffer().nbytes:
                    continue
                raise
            if 'PayloadPart' not in chunk:
                print('Unknown event type:' + chunk)
                continue
            self.buffer.seek(0, io.SEEK_END)
            self.buffer.write(chunk['PayloadPart']['Bytes'])


body = {
    "inputs":"tWhat is Amazon SageMaker",
    "parameters":{
        "max_new_tokens": 400,
        "return_full_text": False
    },
    "stream": True
}

stop_token = '<|endoftext|>'


In [None]:
from sagemaker.base_deserializers import StreamDeserializer

llm.deserializer=StreamDeserializer()
resp = smr_client.invoke_endpoint_with_response_stream(
    EndpointName = llm.endpoint_name, 
    Body = json.dumps(body), 
    ContentType = 'application/json'
)

In [None]:
event_stream = resp['Body']
start_json = b'{'
for line in LineIterator(event_stream):
    if line != b'' and start_json in line:
        data = json.loads(line[line.find(start_json):].decode('utf-8'))
        if data['token']['text'] != stop_token:
            print(data['token']['text'],end='')

## Step 3.2: Test inference performance

In [None]:
# 
# Calculate runtime performance
# 
import time
import numpy as np

# define payload
prompt = """You are an helpful Assistant, called Zephyr. Knowing everyting about AWS.

User: Can you tell me something about Amazon SageMaker?
Zephyr:"""

# hyperparameters for llm (remove "\nUser:" from stop conditions)
payload = {
  "inputs": prompt,
  "parameters": {
    "do_sample": True,
    "top_p": 0.9,
    "temperature": 0.8,
    "max_new_tokens": 512,
    "repetition_penalty": 1.03,
    "stop": ["<|endoftext|>","</s>"]
  }
}

results = []
for i in range(0, 10):
    start = time.time()
    response_model = smr_client.invoke_endpoint(
        EndpointName=endpoint_name,
        Body=json.dumps(payload),
        ContentType="application/json",
    )
    results.append((time.time() - start) * 1000)

print("\nPredictions for model latency: \n")
print("\nP95: " + str(np.percentile(results, 95)) + " ms\n")
print("P90: " + str(np.percentile(results, 90)) + " ms\n")
print("Average: " + str(np.average(results)) + " ms\n")

## Step 4: Cleanup

In [None]:
llm.delete_model()
llm.delete_endpoint(delete_endpoint_config=True)