# SageMaker RealTime Inference Streaming feature using HuggingFace TGI container with Falcon-7B model
In this tutorial, you will use the [Hugging Face LLM Inference Container](https://huggingface.co/blog/sagemaker-huggingface-llm) on SageMaker and run inference with it. The DLC is powered by [Text Generation Inference (TGI)](https://github.com/huggingface/text-generation-inference), an open-source, purpose-built solution for deploying and serving Large Language Models (LLMs). TGI enables high-performance text generation using Tensor Parallelism and dynamic batching for the most popular open-source LLMs, including StarCoder, BLOOM, GPT-NeoX, Llama, and T5.

Please make sure the following permission granted before running the notebook:

- S3 bucket push access
- SageMaker access

## Step 1: Let's bump up SageMaker and import stuff

The wheel installed here is a private preview wheel, you need to add into allowlist to run this function

In [None]:
%pip install sagemaker pip boto3 botocore --upgrade  --quiet

In [None]:
hf_model_id = "tiiuae/falcon-7b-instruct" # model id from huggingface.co/models
number_of_gpu = 1 # number of gpus to use for inference and tensor parallelism
health_check_timeout = 300 # Increase the timeout for the health check to 5 minutes for downloading the model
instance_type = "ml.g5.2xlarge" # instance type to use for deployment

Compared to deploying regular Hugging Face models, we first need to retrieve the container uri and provide it to our HuggingFaceModel model class with a image_uri pointing to the image. To retrieve the new Hugging Face LLM DLC in Amazon SageMaker, we can use the `get_huggingface_llm_image_uri` method provided by the SageMaker SDK. This method allows us to retrieve the URI for the desired Hugging Face LLM DLC based on the specified backend, session, region, and version. You can find the available versions [here](https://github.com/aws/deep-learning-containers/blob/master/available_images.md#huggingface-text-generation-inference-containers).

In [None]:
from sagemaker.huggingface import get_huggingface_llm_image_uri
from sagemaker.huggingface import HuggingFaceModel
from sagemaker.session import Session
import sagemaker

sagemaker_session = Session()
role = sagemaker_session.get_caller_identity_arn()

# retrieve the llm image uri
llm_image = get_huggingface_llm_image_uri(
  "huggingface",
  version="0.8.2"
)

# print ecr image uri
print(f"llm image uri: {llm_image}")


## Step 2: Deploy the Hugging Face model using the TGI image
We create the `HuggingFaceModel` and deploy it to Amazon SageMaker using the `deploy` method. We will deploy the model with the ml.g5.2xlarge instance type as defined earlier. 

In [None]:
endpoint_name = sagemaker.utils.name_from_base("tgi-model-falcon-7b")
llm_model = HuggingFaceModel(
      role=role,
      image_uri=llm_image,
      env={
        'HF_MODEL_ID': hf_model_id,
        # 'HF_MODEL_QUANTIZE': "bitsandbytes", # comment in to quantize
        'SM_NUM_GPUS': "1",
        'MAX_INPUT_LENGTH': "1900",  # Max length of input text
        'MAX_TOTAL_TOKENS': "2048",  # Max length of the generation (including input text)
      }
    )
llm = llm_model.deploy(
  initial_instance_count=1,
  instance_type=instance_type,
  container_startup_health_check_timeout=health_check_timeout,
  endpoint_name=endpoint_name,
)

In [None]:
import json
import boto3
import logging

from sagemaker.base_deserializers import *

boto3.set_stream_logger("",logging.INFO)

llm.deserializer=StreamDeserializer()

smr = boto3.client('sagemaker-runtime-demo')

body = {
    "inputs":"tell me one sentence",
    "parameters":{
        "max_new_tokens":400,
        "return_full_text": False
    },
    "stream": True
}


In [None]:
class Parser:
    """
    A helper class for parsing the byte stream input from TGI container. 
    
    The output of the model will be in the following format:
    ```
    b'data:{"token": {"text": " a"}}\n\n'
    b'data:{"token": {"text": " challenging"}}\n\n'
    b'data:{"token": {"text": " problem"
    b'}}'
    ...
    ```
    
    While usually each PayloadPart event from the event stream will contain a byte array 
    with a full json, this is not guaranteed and some of the json objects may be split across
    PayloadPart events. For example:
    ```
    {'PayloadPart': {'Bytes': b'{"outputs": '}}
    {'PayloadPart': {'Bytes': b'[" problem"]}\n'}}
    ```
    
    This class accounts for this by concatenating bytes written via the 'write' function
    and then exposing a method which will return lines (ending with a '\n' character) within
    the buffer via the 'scan_lines' function. It maintains the position of the last read 
    position to ensure that previous bytes are not exposed again. It will also save any pending 
    lines that doe not end with a '\n' to make sure truncations are concatinated
    """
    
    def __init__(self):
        self.buff = io.BytesIO()
        self.read_pos = 0
        
    def write(self, content):
        self.buff.seek(0, io.SEEK_END)
        self.buff.write(content)
        
    def scan_lines(self, pending=b''):
        self.buff.seek(self.read_pos)
        lines = self.buff.read().splitlines(True)
        for line in lines[:-1]:
            self.read_pos += len(line)
            yield line.splitlines(False)[0]
        line = lines[-1]
        if line[-1:]==b"\n":
            self.read_pos += len(line)
            yield line.splitlines(False)[0]
                
    def reset(self):
        self.read_pos = 0

In [None]:
resp = smr.invoke_endpoint_with_response_stream(EndpointName=llm.endpoint_name, Body=json.dumps(body), ContentType="application/json")
print(resp)
event_stream = resp['Body']
parser = Parser()
for event in event_stream:
    parser.write(event['PayloadPart']['Bytes'])
    for line in parser.scan_lines():
        out = line.decode("utf-8")
        if out !="":
            data = json.loads(out[5:])
            if data["token"]["text"] != '<|endoftext|>':
                print(data["token"]["text"],end="")

## Clean up

In [None]:
# llm.delete_model()
# llm.delete_endpoint()