# SageMaker RealTime Inference Streaming feature using HuggingFace TGI container with Falcon-40B model
In this tutorial, you will use the [Hugging Face LLM Inference Container](https://huggingface.co/blog/sagemaker-huggingface-llm) on SageMaker and run inference with it. The DLC is powered by [Text Generation Inference (TGI)](https://github.com/huggingface/text-generation-inference), an open-source, purpose-built solution for deploying and serving Large Language Models (LLMs). TGI enables high-performance text generation using Tensor Parallelism and dynamic batching for the most popular open-source LLMs, including StarCoder, BLOOM, GPT-NeoX, Llama, and T5.

Please make sure the following permission granted before running the notebook:

- S3 bucket push access
- SageMaker access

## Step 1: Let's bump up SageMaker and import stuff

The wheel installed here is a private preview wheel, you need to add into allowlist to run this function

In [None]:
%pip install sagemaker pip boto3 botocore --upgrade  --quiet

In [2]:
hf_model_id = "tiiuae/falcon-40b-instruct" # model id from huggingface.co/models
number_of_gpus = 4 # number of gpus to use for inference and tensor parallelism
health_check_timeout = 900 # Increase the timeout for the health check for downloading the model
instance_type = "ml.g5.12xlarge" # instance type to use for deployment

Compared to deploying regular Hugging Face models, we first need to retrieve the container uri and provide it to our HuggingFaceModel model class with a image_uri pointing to the image. To retrieve the new Hugging Face LLM DLC in Amazon SageMaker, we can use the `get_huggingface_llm_image_uri` method provided by the SageMaker SDK. This method allows us to retrieve the URI for the desired Hugging Face LLM DLC based on the specified backend, session, region, and version. You can find the available versions [here](https://github.com/aws/deep-learning-containers/blob/master/available_images.md#huggingface-text-generation-inference-containers).

In [3]:
from sagemaker.huggingface import get_huggingface_llm_image_uri
from sagemaker.huggingface import HuggingFaceModel
from sagemaker.session import Session
import sagemaker

sagemaker_session = Session()
role = sagemaker_session.get_caller_identity_arn()

# retrieve the llm image uri
llm_image = get_huggingface_llm_image_uri(
  "huggingface",
  version="0.8.2"
)

# print ecr image uri
print(f"llm image uri: {llm_image}")




llm image uri: 763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.0.0-tgi0.8.2-gpu-py39-cu118-ubuntu20.04


## Step 2: Deploy the Hugging Face model using the TGI image
We create the `HuggingFaceModel` and deploy it to Amazon SageMaker using the `deploy` method. We will deploy the model with the `ml.g5.12xlarge` instance type as defined earlier. TGI will automatically distribute and shard the model across all GPUs.

In [4]:
endpoint_name = sagemaker.utils.name_from_base("tgi-model-falcon-40b")
llm_model = HuggingFaceModel(
      role=role,
      image_uri=llm_image,
      env={
        'HF_MODEL_ID': hf_model_id,
        'HF_MODEL_QUANTIZE': "bitsandbytes", # comment in to quantize
        'SM_NUM_GPUS': str(number_of_gpus),
        'MAX_INPUT_LENGTH': "1900",  # Max length of input text
        'MAX_TOTAL_TOKENS': "2048",  # Max length of the generation (including input text)
      }
    )
llm = llm_model.deploy(
  initial_instance_count=1,
  instance_type=instance_type,
  container_startup_health_check_timeout=health_check_timeout,
  endpoint_name=endpoint_name,
)

----------------------!

In [5]:
import json
import boto3
import logging
import io

from sagemaker.base_deserializers import StreamDeserializer

boto3.set_stream_logger("",logging.INFO)

llm.deserializer=StreamDeserializer()

smr = boto3.client('sagemaker-runtime-demo')

stop_token = '<|endoftext|>'
body = {
    "inputs":"what is life?",
    "parameters":{
        "max_new_tokens":400,
        "return_full_text": False
    },
    "stream": True
}


2023-08-09 00:09:49,129 botocore.client [INFO] No endpoints ruleset found for service sagemaker-runtime-demo, falling back to legacy endpoint routing.


In [6]:
class LineIterator:
    """
    A helper class for parsing the byte stream input from TGI container. 
    
    The output of the model will be in the following format:
    ```
    b'data:{"token": {"text": " a"}}\n\n'
    b'data:{"token": {"text": " challenging"}}\n\n'
    b'data:{"token": {"text": " problem"
    b'}}'
    ...
    ```
    
    While usually each PayloadPart event from the event stream will contain a byte array 
    with a full json, this is not guaranteed and some of the json objects may be split across
    PayloadPart events. For example:
    ```
    {'PayloadPart': {'Bytes': b'{"outputs": '}}
    {'PayloadPart': {'Bytes': b'[" problem"]}\n'}}
    ```
    
    This class accounts for this by concatenating bytes written via the 'write' function
    and then exposing a method which will return lines (ending with a '\n' character) within
    the buffer via the 'scan_lines' function. It maintains the position of the last read 
    position to ensure that previous bytes are not exposed again. It will also save any pending 
    lines that doe not end with a '\n' to make sure truncations are concatinated
    """
    
    def __init__(self, stream):
        self.byte_iterator = iter(stream)
        self.buffer = io.BytesIO()
        self.read_pos = 0

    def __iter__(self):
        return self

    def __next__(self):
        while True:
            self.buffer.seek(self.read_pos)
            line = self.buffer.readline()
            if line and line[-1] == ord('\n'):
                self.read_pos += len(line)
                return line[:-1]
            try:
                chunk = next(self.byte_iterator)
            except StopIteration:
                if self.read_pos < self.buffer.getbuffer().nbytes:
                    continue
                raise
            if 'PayloadPart' not in chunk:
                print('Unknown event type:' + chunk)
                continue
            self.buffer.seek(0, io.SEEK_END)
            self.buffer.write(chunk['PayloadPart']['Bytes'])

In [7]:
resp = smr.invoke_endpoint_with_response_stream(EndpointName=llm.endpoint_name, Body=json.dumps(body), ContentType='application/json')
print(resp)
event_stream = resp['Body']
for line in LineIterator(event_stream):
    if line != b'':
        data = json.loads(line[5:].decode('utf-8'))
        if data['token']['text'] != stop_token:
            print(data['token']['text'],end='')

{'ResponseMetadata': {'RequestId': '7b92da4c-e993-40d4-886a-c7d9680ddba9', 'HTTPStatusCode': 200, 'HTTPHeaders': {'x-amzn-requestid': '7b92da4c-e993-40d4-886a-c7d9680ddba9', 'x-amzn-invoked-production-variant': 'AllTraffic', 'x-amzn-sagemaker-content-type': 'text/event-stream', 'date': 'Wed, 09 Aug 2023 00:09:55 GMT', 'content-type': 'application/vnd.amazon.eventstream', 'transfer-encoding': 'chunked', 'connection': 'keep-alive'}, 'RetryAttempts': 0}, 'ContentType': 'text/event-stream', 'InvokedProductionVariant': 'AllTraffic', 'Body': <botocore.eventstream.EventStream object at 0x7f5eddaa9c00>}

Life is a complex and dynamic process that involves the existence of organisms, their interaction with their environment, and the ability to adapt and evolve over time.

## Clean up

In [None]:
# llm.delete_model()
# llm.delete_endpoint()