# Standard instruction for using LMI container on SageMaker with Falcon-40B
In this tutorial, you will use LMI container from DLC to SageMaker and run inference with it.

Please make sure the following permission granted before running the notebook:

- S3 bucket push access
- SageMaker access

## Step 1: Let's bump up SageMaker and import stuff

The wheel installed here is a private preview wheel, you need to add into allowlist to run this function

In [None]:
%pip install sagemaker pip --upgrade  --quiet

In [None]:
# Note the following may error depending on which awscli is installed in your jupyter kernel, but that is ok 
%pip install ../botocore-*-py3-none-any.whl ../boto3-*-py3-none-any.whl --force

In [None]:
!aws configure add-model --service-model file://../runtime.sagemaker-2017-05-13.normal.json --service-name sagemaker-runtime-demo

In [None]:
import boto3
import sagemaker
from sagemaker import Model, image_uris, serializers, deserializers

boto3_session=boto3.session.Session()
smr = boto3.client('sagemaker-runtime-demo')
sm = boto3.client('sagemaker')
role = sagemaker.get_execution_role()  # execution role for the endpoint
sess = sagemaker.session.Session(boto3_session, sagemaker_client=sm, sagemaker_runtime_client=smr)  # sagemaker session for interacting with different AWS APIs
region = sess._region_name  # region name of the current SageMaker Studio environment
account_id = sess.account_id()  # account_id of the current SageMaker Studio environment

In [None]:
print(f"Role: {role}")

## Step 2: Start preparing model artifacts
In LMI contianer, we expect some artifacts to help setting up the model
- serving.properties (required): Defines the model server settings
- model.py (optional): A python file to define the core inference logic
- requirements.txt (optional): Any additional pip wheel need to install

In [None]:
%%writefile serving.properties
engine=MPI
option.model_id=tiiuae/falcon-40b-instruct
option.trust_remote_code=true
option.tensor_parallel_degree=4
option.max_rolling_batch_size=64
option.rolling_batch=auto
option.dtype=fp16
option.output_formatter=jsonlines
option.paged_attention=false
option.enable_streaming=true

In [None]:
%%sh
mkdir mymodel-40b
mv serving.properties mymodel-40b/
tar czvf mymodel-40b.tar.gz mymodel-40b/
rm -rf mymodel-40b

## Step 3: Start building SageMaker endpoint
In this step, we will build SageMaker endpoint from scratch

### Getting the container image URI


In [None]:
image_uri = image_uris.retrieve(
framework="djl-deepspeed", region=sess.boto_session.region_name, version="0.22.1")
image_uri = image_uri.split(":")[0] + ":" + "0.23.0-deepspeed0.9.5-cu118"

### Upload artifact on S3 and create SageMaker model

In [None]:
s3_code_prefix = "large-model-lmi/code"
bucket = sess.default_bucket()  # bucket to house artifacts
code_artifact = sess.upload_data("mymodel-40b.tar.gz", bucket, s3_code_prefix)
print(f"S3 Code or Model tar ball uploaded to --- > {code_artifact}")

model = Model(sagemaker_session=sess, image_uri=image_uri, model_data=code_artifact, role=role)

### 4.2 Create SageMaker endpoint

You need to specify the instance to use and endpoint names

In [None]:
instance_type = "ml.g5.48xlarge"
endpoint_name = sagemaker.utils.name_from_base("lmi-model-falcon-40b")

model.deploy(
    initial_instance_count=1,
    instance_type=instance_type,
    endpoint_name=endpoint_name,
    container_startup_health_check_timeout=900
)

## Step 5: Test and benchmark the inference

![architecture](../images/arch.png)

InvokeEndpointWithResponseStream will perform the same inference request to the model, however it will provide the model response as a stream of parts of the full response payload. 

This enables models to respond with responses of larger size and enables faster-time-to-first-byte for models where there is a significant difference between the generation of the first and last byte of the response. 

In [None]:
import io


class Parser:
    """
    A helper class for parsing the byte stream input. 
    
    The output of the model will be in the following format:
    ```
    b'{"outputs": [" a"]}\n'
    b'{"outputs": [" challenging"]}\n'
    b'{"outputs": [" problem"]}\n'
    ...
    ```
    
    While usually each PayloadPart event from the event stream will contain a byte array 
    with a full json, this is not guaranteed and some of the json objects may be split across
    PayloadPart events. For example:
    ```
    {'PayloadPart': {'Bytes': b'{"outputs": '}}
    {'PayloadPart': {'Bytes': b'[" problem"]}\n'}}
    ```
    
    This class accounts for this by concatenating bytes written via the 'write' function
    and then exposing a method which will return lines (ending with a '\n' character) within
    the buffer via the 'scan_lines' function. It maintains the position of the last read 
    position to ensure that previous bytes are not exposed again. 
    """
    
    def __init__(self):
        self.buff = io.BytesIO()
        self.read_pos = 0
        
    def write(self, content):
        self.buff.seek(0, io.SEEK_END)
        self.buff.write(content)
        
    def scan_lines(self):
        self.buff.seek(self.read_pos)
        for line in self.buff.readlines():
            if line[-1] != b'\n':
                self.read_pos += len(line)
                yield line[:-1]
                
    def reset(self):
        self.read_pos = 0

In [None]:
import json
         
body = {"inputs": "what is life", "parameters": {"max_new_tokens":400}}
resp = smr.invoke_endpoint_with_response_stream(EndpointName=endpoint_name, Body=json.dumps(body), ContentType="application/json")
event_stream = resp['Body']
parser = Parser()
for event in event_stream:
    parser.write(event['PayloadPart']['Bytes'])
    for line in parser.scan_lines():
        resp = json.loads(line)
        print(resp.get("outputs")[0], end='')

## Clean up the environment

In [None]:
# sess.delete_endpoint(endpoint_name)
# model.delete_model()