# Using Long Context MistralLite on SageMaker Endpoints

This notebook provides a step-by-step walkthrough of deploying the open source MistralLite model for **long context** natural language generation with SageMaker. We will build the custom container for long-context inference, deploy the LLM as a SageMaker Endpoint, and invoke the deployed endpoint with example prompts. 

## Install Dependencies

For this example we will use the [SageMaker SDK](https://sagemaker.readthedocs.io/en/stable/) to create and deploy the model. 

In [None]:
!pip install -U sagemaker==2.192.1

## Build Custom Container
First we augment the Dockerfile for SageMaker. Execute the following cells to add additional commands to the Dockerfile, and build the image.

In [None]:
sm_entry_stmt = """
# Text Generation Inference base env
ENV HUGGINGFACE_HUB_CACHE=/tmp \
    HF_HUB_ENABLE_HF_TRANSFER=1 \
    PORT=80
COPY sagemaker-entrypoint.sh entrypoint.sh
RUN chmod +x entrypoint.sh

ENTRYPOINT ["./entrypoint.sh"]
CMD [ "" ]
"""

In [None]:
with open("../tgi-custom/Dockerfile", "r") as fin:
    docker_content = fin.read()

sm_docker_cotent = docker_content + sm_entry_stmt

with open("Dockerfile_sm", "w") as fout:
    fout.write(sm_docker_cotent)

Then we build the image. This could take 10 minutes, so feel free to run it directly in the terminal in case the notebook cell times out.  

**Important Note** - Please ensure the `ROLE` has sufficient permission to push Docker images to Elastic Container Registry.

In [None]:
!cp -r ../tgi-custom/vllm ./vllm

In [None]:
REPO_NAME = "mistrallite-tgi110-ecr"

In [None]:
!bash sm_build.sh {REPO_NAME}

## Deploy SageMaker Endpoint
The SageMaker SDK provides support to deploy open-source LLMs with just a few lines of code, powered by [HuggingFace's Text Generation Inference container](https://github.com/huggingface/text-generation-inference). Execute the next set of cells to deploy a long-context-enabled MistralLite on a `ml.g5.2xlarge` real-time inference endpoint on SageMaker.

In [None]:
import boto3
import json
import sagemaker
import time

from sagemaker.huggingface import HuggingFaceModel, get_huggingface_llm_image_uri

def get_aws_region():
    # Get the current AWS region from the default session
    session = boto3.session.Session()
    return session.region_name

def get_aws_account_id():
    # Get the current AWS account ID from the default session
    sts_client = boto3.client("sts")
    response = sts_client.get_caller_identity()
    return response["Account"]

REGION = get_aws_region()
ACCOUNT_ID = get_aws_account_id()

role = sagemaker.get_execution_role()

In [None]:
image_uri = f"{ACCOUNT_ID}.dkr.ecr.{REGION}.amazonaws.com/{REPO_NAME}"
image_uri

In [None]:
model_name = "MistralLite-" + time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())

instance_type = "ml.g5.2xlarge"
num_gpu = 1
max_input_length = 24000
max_total_tokens = 24576

hub = {
    'HF_MODEL_ID':'amazon/MistralLite',
    'HF_TASK':'text-generation',
    'SM_NUM_GPUS': json.dumps(num_gpu),
    "MAX_INPUT_LENGTH": json.dumps(max_input_length),
    "MAX_TOTAL_TOKENS": json.dumps(max_total_tokens),
    "MAX_BATCH_PREFILL_TOKENS": json.dumps(max_total_tokens),
    "MAX_BATCH_TOTAL_TOKENS":  json.dumps(max_total_tokens),
    "DTYPE": 'bfloat16',
}

model = HuggingFaceModel(
    name=model_name,
    env=hub,
    role=role,
    image_uri=image_uri
)

print("☕ Spinning up the endpoint. This will take a little while ☕")

predictor = model.deploy(
  initial_instance_count=1,
  instance_type=instance_type,
  endpoint_name=model_name,   
)

## Perform Inference

There are a couple ways you can invoke the deployed model, either through the SageMaker SDK or through the AWS SDK for Python, [boto3](https://boto3.amazonaws.com/v1/documentation/api/latest/index.html). Both methods are provided as examples in the following cells.

### LLM Inference via Sagemaker SDK
Execute the following cell to invoke the deployed LLM endpoint on a sample prompt using the SageMaker SDK.

In [None]:
input_data = {
  "inputs": "<|prompter|>What are the main challenges to support a long context for LLM?</s><|assistant|>",
  "parameters": {
    "do_sample": False,
    "max_new_tokens": 400,
    "return_full_text": False,
    #"typical_p": 0.2,
    #"temperature":None,
    #"truncate":None,
    #"seed": 1,
  }
}
result = predictor.predict(input_data)[0]["generated_text"]
print(result)

### LLM Inference via boto3

In [None]:
import boto3
import json

In [None]:
def call_endpoint(client, prompt, endpoint_name, paramters):
    client = boto3.client("sagemaker-runtime")
    payload = {"inputs": prompt,
               "parameters": parameters}
    response = client.invoke_endpoint(EndpointName=endpoint_name,
                                      Body=json.dumps(payload), 
                                      ContentType="application/json")
    output = json.loads(response["Body"].read().decode())
    result = output[0]["generated_text"]
    return result

client = boto3.client("sagemaker-runtime")
parameters = {
    "do_sample": False,
    "max_new_tokens": 400,
    "return_full_text": False,
    #"typical_p": 0.2,
    #"temperature":None,
    #"truncate":None,
    #"seed": 10,
}
endpoint_name = predictor.endpoint_name
prompt = "<|prompter|>What are the main challenges to support a long context for LLM?</s><|assistant|>"
result = call_endpoint(client, prompt, endpoint_name, parameters)
print(result)

### Long-Context Inference with Boto3

With boto3, try the long context of over 13,400 tokens, which are copied from [Amazon Aurora FAQs](https://aws.amazon.com/rds/aurora/faqs/)

In [None]:
with open("../example_long_ctx.txt", "r") as fin:
    task_instruction = fin.read()
    task_instruction = task_instruction.format(
        my_question="please tell me how does pgvector help with Generative AI and give me some examples."
    )
prompt = f"<|prompter|>{task_instruction}</s><|assistant|>"
result = call_endpoint(client, prompt, endpoint_name, parameters)
print(result)

### Streaming Responses

To stream the response, execute the following cells to create a LineIterator class and a streaming response invocation function that will provide a seamless streamed response from the LLM.

In [None]:
import io

class LineIterator:
    """
    A helper class for parsing the byte stream input. 
    
    The output of the model will be in the following format:
    ```
    b'{"outputs": [" a"]}\n'
    b'{"outputs": [" challenging"]}\n'
    b'{"outputs": [" problem"]}\n'
    ...
    ```
    
    While usually each PayloadPart event from the event stream will contain a byte array 
    with a full json, this is not guaranteed and some of the json objects may be split across
    PayloadPart events. For example:
    ```
    {'PayloadPart': {'Bytes': b'{"outputs": '}}
    {'PayloadPart': {'Bytes': b'[" problem"]}\n'}}
    ```
    
    This class accounts for this by concatenating bytes written via the 'write' function
    and then exposing a method which will return lines (ending with a '\n' character) within
    the buffer via the 'scan_lines' function. It maintains the position of the last read 
    position to ensure that previous bytes are not exposed again. 
    """

    def __init__(self, stream):
        self.byte_iterator = iter(stream)
        self.buffer = io.BytesIO()
        self.read_pos = 0

    def __iter__(self):
        return self

    def __next__(self):
        while True:
            self.buffer.seek(self.read_pos)
            line = self.buffer.readline()
            if line and line[-1] == ord('\n'):
                self.read_pos += len(line)
                return line[:-1]
            try:
                chunk = next(self.byte_iterator)
            except StopIteration:
                if self.read_pos < self.buffer.getbuffer().nbytes:
                    continue
                raise
            if 'PayloadPart' not in chunk:
                print('Unknown event type:' + chunk)
                continue
            self.buffer.seek(0, io.SEEK_END)
            self.buffer.write(chunk['PayloadPart']['Bytes'])

In [None]:
def call_endpoint_streaming(client, prompt, endpoint_name, paramters):
    client = boto3.client("sagemaker-runtime")
    payload = {"inputs": prompt,
               "parameters": parameters,
               "stream": True}
    response = client.invoke_endpoint_with_response_stream(EndpointName=endpoint_name,
                                                           Body=json.dumps(payload),
                                                           ContentType='application/json')
    output = ""
    event_stream = response['Body']
    start_json = b'{'
    for line in LineIterator(event_stream):
        if line != b'' and start_json in line:
            data = json.loads(line[line.find(start_json):].decode('utf-8'))
            if not data['token']["special"]:
                print(data['token']['text'],end='')
                output += data['token']['text']
    return output

In [None]:
client = boto3.client("sagemaker-runtime")
parameters = {
    "do_sample": False,
    "max_new_tokens": 400,
    "return_full_text": False,
    #"typical_p": 0.2,
    #"temperature":None,
    #"truncate":None,
    #"seed": 10,
}
endpoint_name = predictor.endpoint_name
prompt = "<|prompter|>What are the main challenges to support a long context for LLM?</s><|assistant|>"
result = call_endpoint_streaming(client, prompt, endpoint_name, parameters)

### Long-Context Streaming
With the new class and function, try the long context of over 13,400 tokens, which are copied from [Amazon Aurora FAQs](https://aws.amazon.com/rds/aurora/faqs/)

In [None]:
with open("../example_long_ctx.txt", "r") as fin:
    task_instruction = fin.read()
    task_instruction = task_instruction.format(
        my_question="please tell me how does pgvector help with Generative AI and give me some examples."
    )
prompt = f"<|prompter|>{task_instruction}</s><|assistant|>"
result = call_endpoint_streaming(client, prompt, endpoint_name, parameters)

## Cleanup

After you've finished using the endpoint, it's important to delete it to avoid incurring unnecessary costs.

In [None]:
predictor.delete_endpoint()