#  Using TorchServe on SageMaker Inf2.24xlarge with LLAMAv2-13B


## Step 1: Let's bump up SageMaker and import stuff

The wheel installed here is a private preview wheel, you need to add into allowlist to run this function

In [None]:
!python --version

In [None]:
# Install the latest aws cli v2 if it is not installed
!curl "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awscliv2.zip"
!unzip awscliv2.zip
!ln -s /usr/local/aws-cli/v2/2.13.6/bin/aws /home/ec2-user/anaconda3/envs/python3/bin/aws 
!aws --version

In [None]:
#%pip install sagemaker pip --upgrade  --quiet
!pip install numpy
!pip install pillow
!pip install -U sagemaker
!pip install -U boto 
!pip install -U botocore
!pip install -U boto3

In [None]:
# Note the following may error depending on which awscli is installed in your jupyter kernel, but that is ok 
#%pip install botocore-*-py3-none-any.whl boto3-*-py3-none-any.whl --force

In [1]:
!aws configure add-model --service-model file://runtime.sagemaker-2017-05-13.normal.json --service-name sagemaker-runtime-demo

In [2]:
import boto3
import sagemaker
from sagemaker import Model, image_uris, serializers, deserializers

boto3_session=boto3.session.Session(region_name="us-west-2")
smr = boto3.client('sagemaker-runtime-demo')
sm = boto3.client('sagemaker')
role = sagemaker.get_execution_role()  # execution role for the endpoint
sess= sagemaker.session.Session(boto3_session, sagemaker_client=sm, sagemaker_runtime_client=smr)  # sagemaker session for interacting with different AWS APIs
region = sess._region_name  # region name of the current SageMaker Studio environment
account_id = sess.account_id()  # account_id of the current SageMaker Studio environment
print(f'account={account_id}, region={region}, role={role}')

account=084495728311, region=us-west-2, role=arn:aws:iam::084495728311:role/service-role/AmazonSageMaker-ExecutionRole-20230505T104760


## Step 2: Build a BYOD TorchServe Docker container and push it to Amazon ECR

In [None]:
# Install our own dependencies
!cat workspace/docker/Dockerfile

In [None]:
%%capture build_output

baseimage = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference-neuronx:1.13.1-neuronx-py310-sdk2.12.0-ubuntu20.04"
reponame = "neuronx212"
versiontag = "ts"

# Build our own docker image
!cd workspace/docker && ./build_and_push.sh {reponame} {versiontag} {baseimage} {region} {account}

## [WIP]Step 3: AOT Pre-Compile Model

### [WIP] Precompile the model at local EC2
Follow [Instruction](https://github.com/pytorch/serve/pull/2458/files#diff-bc416c811f749ead11ab8f100d5a4198fa453adc995b0760272563971638307d) at local EC2 to precompile the model and save it dir neuron_cache

#### main steps
* download llama-2-13b from HF 
* Save the model split checkpoints compatible with `transformers-neuronx`
* Create EC2 model artifacts at local
* [Start torchserve at local EC2](https://docs.google.com/document/d/1mfHTvc65bD9rbx0TBdYhxM5DjzQIZkWAt8v7y1pBCKc/edit?usp=sharing) to load llama-2-13b to generate dir neuron_cache

#### Note: Turn on neuron_cache in function initialize of inf2_handler.py
```
os.environ["NEURONX_CACHE"] = "on"
os.environ["NEURONX_DUMP_TO"] = f"{model_dir}/neuron_cache"
```

### Upload model artifacts to S3

The model artifacts is available in s3://torchserve/mar_files/llama-2-13b/ which supports llama2-13b on neuronx batchSize = 1. You can copy it to your SM S3 model repo. 

## Step 4: Start building SageMaker endpoint
In this step, we will build SageMaker endpoint from scratch

### Create SageMaker endpoint

You need to specify the instance to use and endpoint names

In [57]:
from datetime import datetime

instance_type = "ml.inf2.24xlarge"
endpoint_name = sagemaker.utils.name_from_base("ts-inf2-llama2-13b")

model = Model(
    name="torchserve-inf2-llama2-13b" + datetime.now().strftime("%Y-%m-%d-%H-%M-%S"),
    model_data={
        "S3DataSource": {
                # In this example, I copied s3://torchserve/mar_files/llama-2-13b/ to my SM S3.
                # s3://sagemaker-us-west-2-084495728311/torchserve/llama-2-13b/
                "S3Uri": f"s3://sagemaker-us-west-2-084495728311/torchserve/llama-2-13b/", 
                "S3DataType": "S3Prefix",
                "CompressionType": "None",
        }
    },
    #image_uri=f"084495728311.dkr.ecr.us-west-2.amazonaws.com/neuronx212:latest",
    image_uri=container,
    role=role,
    sagemaker_session=sess,
)
print(model)

<sagemaker.model.Model object at 0x7f61cc7dae30>


In [58]:
model.deploy(
    initial_instance_count=1,
    instance_type=instance_type,
    endpoint_name=endpoint_name,
    volume_size=512, # increase the size to store large model
    model_data_download_timeout=3600, # increase the timeout to download large model
    container_startup_health_check_timeout=3600, # ncrease the timeout to load large model
)

Your model is not compiled. Please compile your model before using Inferentia.


--------------!

## Run the inference

In [59]:
import io


class Parser:
    """
    A helper class for parsing the byte stream input. 
    
    The output of the model will be in the following format:
    ```
    b'{"outputs": [" a"]}\n'
    b'{"outputs": [" challenging"]}\n'
    b'{"outputs": [" problem"]}\n'
    ...
    ```
    
    While usually each PayloadPart event from the event stream will contain a byte array 
    with a full json, this is not guaranteed and some of the json objects may be split across
    PayloadPart events. For example:
    ```
    {'PayloadPart': {'Bytes': b'{"outputs": '}}
    {'PayloadPart': {'Bytes': b'[" problem"]}\n'}}
    ```
    
    This class accounts for this by concatenating bytes written via the 'write' function
    and then exposing a method which will return lines (ending with a '\n' character) within
    the buffer via the 'scan_lines' function. It maintains the position of the last read 
    position to ensure that previous bytes are not exposed again. 
    """
    
    def __init__(self):
        self.buff = io.BytesIO()
        self.read_pos = 0
        
    def write(self, content):
        self.buff.seek(0, io.SEEK_END)
        self.buff.write(content)
        data = self.buff.getvalue()
        
    def scan_lines(self):
        self.buff.seek(self.read_pos)
        for line in self.buff.readlines():
            if line[-1] != b'\n':
                self.read_pos += len(line)
                yield line[:-1]
                
    def reset(self):
        self.read_pos = 0

In [63]:
import json

body = "Today the weather is really nice and I am planning on".encode('utf-8')
resp = smr.invoke_endpoint_with_response_stream(EndpointName=endpoint_name, Body=body, ContentType="application/json")
event_stream = resp['Body']
parser = Parser()
for event in event_stream:
    parser.write(event['PayloadPart']['Bytes'])
    for line in parser.scan_lines():
        print(line.decode("utf-8"), end=' ')

Today the weather is really nice and I am planning on going to the beach. I am going to take my camera and take some pictures of the beach. I am going to take pictures of the sand, the water, and the people. I am also going to take pictures of the sunset. I am really excited to go to the beach and take pictures. The beach is a great place to take pictures. The sand, the water, and the people are all great subjects for pictures. The sunset is also a great subject for pictures 

## Clean up the environment

In [None]:
sess.delete_endpoint(endpoint_name)
sess.delete_endpoint_config(endpoint_name)
model.delete_model()