# LLava stateful inference with SageMaker

## Contents

This notebook uses SageMaker notebook instance `conda_pytorch_p310` kernel, demonstrates how to use TorchServe to deploy Meta Fair Dummy Model on SageMaker. 

## Step 0: Let's bump up SageMaker and import stuff

In [129]:
!python --version && aws --version

Python 3.10.14
aws-cli/1.34.4 Python/3.10.14 Linux/5.10.223-212.873.amzn2.x86_64 botocore/1.35.13


In [130]:
!pip install -Uq pip
!pip install -Uq sagemaker
!pip install torch-model-archiver
!pip install -Uq botocore
!pip install -Uq boto3



In [131]:
# Patch boto3 to add support for Session ID
import os
import shutil
import importlib
import botocore

patch_path = os.path.join(botocore.__path__[0], "data/sagemaker-runtime/2017-05-13/")
shutil.copy("./boto/service-2.json.gz", patch_path)
importlib.reload(botocore)

<module 'botocore' from '/home/ec2-user/anaconda3/envs/python3/lib/python3.10/site-packages/botocore/__init__.py'>

In [132]:
import sagemaker
from sagemaker import image_uris
import boto3
import os
import time
import json
from pathlib import Path

In [150]:
import boto3
import sagemaker
from sagemaker import Model, image_uris, serializers, deserializers

boto3_session=boto3.session.Session(region_name="us-west-2")
# Create a SageMaker runtime client object using your IAM role ARN
smr = boto3.client('sagemaker-runtime', region_name="us-west-2")
# Create a SageMaker client object
sm = boto3.client('sagemaker', region_name="us-west-2")
# execution role for the endpoint
role = sagemaker.get_execution_role()  
# sagemaker session for interacting with different AWS APIs
sess= sagemaker.session.Session(boto3_session, sagemaker_client=sm, sagemaker_runtime_client=smr)  
# region name of the current SageMaker Studio environment
region = sess._region_name 
# account_id of the current SageMaker Studio environment
account = sess.account_id()  

# Configuration:
bucket_name = sess.default_bucket()
prefix = "torchserve"
output_path = f"s3://{bucket_name}/{prefix}"
model_name = "llava-sm"
print(f'account={account}, region={region}, role={role}, output_path={output_path}')

account=043632497353, region=us-west-2, role=arn:aws:iam::043632497353:role/service-role/SageMaker-stateful-inference-testing, output_path=s3://sagemaker-us-west-2-043632497353/torchserve


## Step 1: Build a BYOD TorchServe Docker container and push it to Amazon ECR

1. Create an ECR repo: https://docs.aws.amazon.com/AmazonECR/latest/userguide/repository-create.html
2. Get Base Image: https://github.com/aws/deep-learning-containers/blob/master/available_images.md

In [151]:
baseimage = f"763104351884.dkr.ecr.{region}.amazonaws.com/pytorch-inference:2.3.0-gpu-py311-cu121-ubuntu20.04-sagemaker"
reponame = "llava-stateful"
versiontag = "1.0"
print("use the output from the print below to run ./build_and_push.sh in a termianl. You get better feedback in terminal.")
print (f"cd docker && ./build_and_push.sh {reponame} {versiontag} {baseimage} {region} {account}")

use the output from the print below to run ./build_and_push.sh in a termianl. You get better feedback in terminal.
cd docker && ./build_and_push.sh llava-stateful 1.0 763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:2.3.0-gpu-py311-cu121-ubuntu20.04-sagemaker us-west-2 043632497353


In [None]:
%%capture build_output

# Build our own docker image
!cd docker && ./build_and_push.sh {reponame} {versiontag} {baseimage} {region} {account}

In [152]:
# Update container
container = f"{account}.dkr.ecr.{region}.amazonaws.com/{reponame}:{versiontag}"
container
print(baseimage)


763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:2.3.0-gpu-py311-cu121-ubuntu20.04-sagemaker


## Step2: Build TorchServe Model Artifacts and Upload to S3

In [153]:
rm -rf code/{model_name}

In [154]:
!cd code && torch-model-archiver --model-name {model_name} --version 1.0 --handler handler/custom_handler.py --config-file handler/model-config.yaml --archive-format no-archive --extra-files handler/ -f

In [155]:
!cd code && aws s3 cp {model_name} {output_path}/{model_name} --recursive

upload: llava-sm/custom_handler.py to s3://sagemaker-us-west-2-043632497353/torchserve/llava-sm/custom_handler.py
upload: llava-sm/__init__.py to s3://sagemaker-us-west-2-043632497353/torchserve/llava-sm/__init__.py
upload: llava-sm/MAR-INF/MANIFEST.json to s3://sagemaker-us-west-2-043632497353/torchserve/llava-sm/MAR-INF/MANIFEST.json
upload: llava-sm/llava/model/language_model/llava_mistral.py to s3://sagemaker-us-west-2-043632497353/torchserve/llava-sm/llava/model/language_model/llava_mistral.py
upload: llava-sm/llava/mm_utils.py to s3://sagemaker-us-west-2-043632497353/torchserve/llava-sm/llava/mm_utils.py
upload: llava-sm/llava/model/llava_arch.py to s3://sagemaker-us-west-2-043632497353/torchserve/llava-sm/llava/model/llava_arch.py
upload: llava-sm/llava/__init__.py to s3://sagemaker-us-west-2-043632497353/torchserve/llava-sm/llava/__init__.py
upload: llava-sm/data_types.py to s3://sagemaker-us-west-2-043632497353/torchserve/llava-sm/data_types.py
upload: llava-sm/llava/model/con

In [156]:
s3_uri = f"{output_path}/{model_name}/"
print(s3_uri)

s3://sagemaker-us-west-2-043632497353/torchserve/llava-sm/


## Step3: Create SageMaker Endpont

### 3.1 Create Model

In [157]:
from datetime import datetime

instance_type = "ml.g6.48xlarge"
endpoint_name = sagemaker.utils.name_from_base(model_name)

model = Model(
    name=model_name + datetime.now().strftime("%Y-%m-%d-%H-%M-%S"),
    # Enable SageMaker uncompressed model artifacts via "S3DataType": "S3Prefix"
    model_data={
        "S3DataSource": {
                "S3Uri": s3_uri,
                "S3DataType": "S3Prefix",
                "CompressionType": "None",
        }
    },
    image_uri=container,
    role=role,
    sagemaker_session=sess,
    env={
        # TorchServe configuration file
        "TS_CONFIG_FILE": "/home/model-server/config.properties",
        # Disable token authorization for REST APIs
        "TS_DISABLE_TOKEN_AUTHORIZATION": "true", 
        # Headers to indicate Session ID
        "TS_HEADER_KEY_SEQUENCE_ID": "X-Amzn-SageMaker-Session-Id",
        "TS_REQUEST_SEQUENCE_ID": "X-Amzn-SageMaker-Session-Id",
        # Headers to indicate closed session
        "TS_HEADER_KEY_SEQUENCE_END": "X-Amzn-SageMaker-Closed-Session-Id",
        "TS_REQUEST_SEQUENCE_END": "X-Amzn-SageMaker-Closed-Session-Id",
        # Enable system metrics aggregation
        "TS_DISABLE_SYSTEM_METRICS": "false" 
    },
)
print(model)

<sagemaker.model.Model object at 0x7f1c6df7bcd0>


### 3.2 Deploy Model and Create Endpoint

In [158]:
model.deploy(
    initial_instance_count=1, # increase the number of instances based on your load
    instance_type=instance_type,
    endpoint_name=endpoint_name,
    #volume_size=512, # increase the size to store large model
    model_data_download_timeout=3600, 
    container_startup_health_check_timeout=3600, 
)

--------------------------!

### 3.3 Create a Predictor

In [159]:
predictor = sagemaker.predictor.Predictor(
    endpoint_name=model.endpoint_name,
    sagemaker_session=sess
)
print(predictor)

Predictor: {'endpoint_name': 'llava-sm-2024-09-06-04-12-18-525', 'sagemaker_session': <sagemaker.session.Session object at 0x7f1c6df7a8c0>, 'serializer': <sagemaker.base_serializers.IdentitySerializer object at 0x7f1d22ce7ca0>, 'deserializer': <sagemaker.base_deserializers.BytesDeserializer object at 0x7f1d22d08400>}


In [160]:
# predictor = sagemaker.predictor.Predictor(
#     endpoint_name='llava-sm-2024-09-04-06-35-10-354',
#     sagemaker_session=sess
# )
# print(predictor)

## Step4: Run Inference

In [161]:
#Add necessary modules path to sys.path
import os, sys

demo_data_path = os.path.join(os.getcwd(), "code/handler")
if demo_data_path not in sys.path:
    sys.path.append(demo_data_path)

In [162]:
#Install dependencies
!pip install torch dataclasses_json



### 4.1 Open Session 1

In [163]:
image_url="https://images.pexels.com/photos/1519753/pexels-photo-1519753.jpeg"

In [164]:
%%time
from data_types import (
    BaseRequest,
    CloseSessionRequest,
    StartSessionRequest,
    TextPromptRequest,
    OpenSessionResponse,
    TextPromptResponse,
    CloseSessionResponse
)

ts_request_sequence_id = "SessionId"


def send_and_check_request(r, seq_id):
    response = smr.invoke_endpoint(
        EndpointName=endpoint_name,
        Body=r.to_json(),
        ContentType="application/json",
        SessionId=seq_id
    )
    assert response["ResponseMetadata"]["HTTPStatusCode"] == 200, f"Sending request failed: {r}"
    return response['Body'].readlines()[0]

open_request = StartSessionRequest(
    type="start_session",
    path=image_url,
)

open_response = send_and_check_request(open_request, "NEW_SESSION")
open_response = OpenSessionResponse.from_json(open_response)
print(open_response)
assert open_response.session_id.startswith("ts-seq-")

ReadTimeoutError: Read timeout on endpoint URL: "https://runtime.sagemaker.us-west-2.amazonaws.com/endpoints/llava-sm-2024-09-06-04-12-18-525/invocations"

In [None]:
open_response.session_id

### 4.2 Send Text Promt 1

In [None]:
%%time
text_prompt_request1 = TextPromptRequest(
    type="send_text_prompt",
    session_id=open_response.session_id,
    prompt_text="describe the picture"
)

text_prompt_response1 = send_and_check_request(text_prompt_request1, open_response.session_id)
text_prompt_response1 = TextPromptResponse.from_json(text_prompt_response1)
print(text_prompt_response1.response_text)
assert text_prompt_response1.response_text

### 4.3 Send Text Promt 2

In [None]:
%%time
text_prompt_request2 = TextPromptRequest(
    type="send_text_prompt",
    session_id=open_response.session_id,
    prompt_text="is there a mountain in the picture, describe it"
)

text_prompt_response2 = send_and_check_request(text_prompt_request2, open_response.session_id)
text_prompt_response2 = TextPromptResponse.from_json(text_prompt_response2)
print(text_prompt_response2.response_text)
assert text_prompt_response2.response_text

### 4.4 Close session

In [None]:
# close session
close_request = CloseSessionRequest(
    type="close_session",
    session_id=open_response.session_id,
)
    
close_response = send_and_check_request(
    close_request, open_response.session_id
)

close_response = CloseSessionResponse.from_json(close_response)
assert close_response.success

In [149]:
sess.delete_endpoint(endpoint_name)
sess.delete_endpoint_config(endpoint_name)
model.delete_model()