# Stream LLM Response 

## Import Libraries

In [1]:
import os
import json
import boto3
import io
import sagemaker
from sagemaker.session import Session
from sagemaker.base_deserializers import StreamDeserializer
from sagemaker.huggingface import HuggingFaceModel, get_huggingface_llm_image_uri



In [2]:
print(f"boto3 version: {boto3.__version__}")
print(f"sagemaker version: {sagemaker.__version__}")

boto3 version: 1.28.40
sagemaker version: 2.177.0


## Notebook Setup

In [3]:
PROFILE_NAME = "PROFILE_NAME"
REGION = "REGION_NAME"
ENDPOINT_NAME = "falcon-7b-instruct-streaming-endpoint"
ROLE = "ROLE_NAME"

In [4]:
boto_session = boto3.session.Session(profile_name=PROFILE_NAME, region_name=REGION)
sg_session = Session(boto_session=boto_session)
smr = boto_session.client('sagemaker-runtime')

## Deploy Falcon Model on Sagemaker

In [5]:
# get the huggingface llm image
llm_img = get_huggingface_llm_image_uri(
    backend="huggingface",
    session=sg_session,
    version="0.9.3"
)

print(f"TGI Container: {llm_img}")

TGI Container: 763104351884.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.0.1-tgi0.9.3-gpu-py39-cu118-ubuntu20.04


In [6]:
# define the model deployment configuration
deploy_config = {
    'HF_MODEL_ID': "tiiuae/falcon-7b-instruct", # model_id from hf.co/models
    'SM_NUM_GPUS': json.dumps(1), # Number of GPU used per replica
    'MAX_INPUT_LENGTH': json.dumps(3072),  # Max length of input text
    'MAX_TOTAL_TOKENS': json.dumps(4096),  # Max length of the generation (including input text)
    'MAX_BATCH_TOTAL_TOKENS': json.dumps(8192),  # Limits the number of tokens that can be processed in parallel during the generation
    #'HF_MODEL_QUANTIZE': "bitsandbytes", # comment in to quantize
}

In [7]:
# create HuggingFaceModel with the image uri
llm_model = HuggingFaceModel(
    role=ROLE,
    image_uri=llm_img,
    env=deploy_config,
    sagemaker_session=sg_session
)

In [8]:
# Deploy model to an endpoint
instance_type = "ml.g5.2xlarge"
health_check_timeout = 300

llm_endpoint = llm_model.deploy(
    endpoint_name=ENDPOINT_NAME,
    initial_instance_count=1,
    instance_type=instance_type,
    container_startup_health_check_timeout=health_check_timeout, # 10 minutes to be able to load the model
    deserializer=StreamDeserializer()
)

# llm_endpoint.deserializer=StreamDeserializer()

---------------!

## Helper Classes and Function

In [9]:
class LineIterator:
    """
    A helper class for parsing the byte stream input. 
    
    The output of the model will be in the following format:
    ```
    b'{"outputs": [" a"]}\n'
    b'{"outputs": [" challenging"]}\n'
    b'{"outputs": [" problem"]}\n'
    ...
    ```
    
    While usually each PayloadPart event from the event stream will contain a byte array 
    with a full json, this is not guaranteed and some of the json objects may be split across
    PayloadPart events. For example:
    ```
    {'PayloadPart': {'Bytes': b'{"outputs": '}}
    {'PayloadPart': {'Bytes': b'[" problem"]}\n'}}
    ```
    
    This class accounts for this by concatenating bytes written via the 'write' function
    and then exposing a method which will return lines (ending with a '\n' character) within
    the buffer via the 'scan_lines' function. It maintains the position of the last read 
    position to ensure that previous bytes are not exposed again. 
    """
    
    def __init__(self, stream):
        self.byte_iterator = iter(stream)
        self.buffer = io.BytesIO()
        self.read_pos = 0

    def __iter__(self):
        return self

    def __next__(self):
        while True:
            self.buffer.seek(self.read_pos)
            line = self.buffer.readline()
            if line and line[-1] == ord('\n'):
                self.read_pos += len(line)
                return line[:-1]
            try:
                chunk = next(self.byte_iterator)
            except StopIteration:
                if self.read_pos < self.buffer.getbuffer().nbytes:
                    continue
                raise
            if 'PayloadPart' not in chunk:
                print('Unknown event type:' + chunk)
                continue
            self.buffer.seek(0, io.SEEK_END)
            self.buffer.write(chunk['PayloadPart']['Bytes'])
            

def invoke_stream_endpoint(endpoint_name, query, stop_token="<|endoftext|>"):
    body = {
        "inputs": query,
        "parameters":{
            "max_new_tokens":400,
            "return_full_text": False
        },
        "stream": True
    }
    
    
    llm_endpoint.deserializer=StreamDeserializer()
    resp = smr.invoke_endpoint_with_response_stream(
        EndpointName=endpoint_name, 
        Body=json.dumps(body), 
        ContentType='application/json')
    
    event_stream = resp['Body']
    start_json = b'{'
    for line in LineIterator(event_stream):
        if line != b'' and start_json in line:
            data = json.loads(line[line.find(start_json):].decode('utf-8'))
            if data['token']['text'] != stop_token:
                print(data['token']['text'],end='')

## Test the Streaming

In [10]:
invoke_stream_endpoint(
    endpoint_name=ENDPOINT_NAME, 
    query="How to cook good meal for diet? can you give me a list of 10 recipes?")


Sure! Here are 10 recipes that are healthy and delicious:
1. Grilled salmon with lemon and capers
2. Quinoa and vegetable stir-fry
3. Baked sweet potato with black beans and salsa
4. Chicken and vegetable soup
5. Egg and vegetable omelette
6. Greek yogurt and berry parfait
7. Grilled chicken with avocado and tomato
8. Lentil soup with spinach and carrots
9. Roasted vegetables with chickpeas
10. Turkey and vegetable chili

In [11]:
invoke_stream_endpoint(
    endpoint_name=ENDPOINT_NAME, 
    query="Can you give me best 10 places to visit in London please?")


Sure! Here are some of the best places to visit in London:
1. Tower of London
2. Big Ben
3. London Eye
4. Buckingham Palace
5. Trafalgar Square
6. Hyde Park
7. The British Museum
8. National Gallery
9. Tate Modern
10. St. Paul's Cathedral

## Kill Endpoint

In [13]:
llm_endpoint.delete_model()
llm_endpoint.delete_endpoint()