# Deploy LLaVA-v1.5-13B model on Amazon SageMaker

***This notebook works best with the `conda_python3` kernel on a `ml.t3.large` machine***.

---

In this notebook we download the [LLaVA-v1.5-13B]() and deploy it on SageMaker. We use the `huggingface-pytorch-inference` container and deploy this model on a `ml.g5.xlarge` instance type. 

The downloaded model files are archived into a `model.tar.gz` file that is uploaded to the default SageMaker S3 bucket.

In [None]:
import sys
!{sys.executable} -m pip install -r requirements.txt

In [10]:
import os
import time
import boto3
import sagemaker
import globals as g
from pathlib import Path
from sagemaker import image_uris
from utils import get_bucket_name
from sagemaker.s3 import S3Uploader
from sagemaker.utils import name_from_base
from huggingface_hub import snapshot_download

In [None]:
# global constants
!pygmentize globals.py

In [4]:
sagemaker_session = sagemaker.Session()
role = sagemaker.get_execution_role()
sm_client = sagemaker_session.sagemaker_client
s3_client = boto3.client('s3')


In [5]:
bucket_name: str = get_bucket_name(g.CFN_STACK_NAME)
s3_model_uri: str = os.path.join("s3://", bucket_name, g.BUCKET_PREFIX, os.path.basename(g.HF_MODEL_ID), g.S3_MODEL_PREFIX)
s3_model_code_uri: str = os.path.join("s3://", bucket_name, g.BUCKET_PREFIX, os.path.basename(g.HF_MODEL_ID), g.S3_MODEL_CODE_PREFIX, "llava-src.tar.gz")

In [6]:
print(bucket_name)
print(s3_model_uri)
print(s3_model_code_uri)

multimodal-bucket-563851014557
s3://multimodal-bucket-563851014557/multimodal/llava-v1.5-13b/model
s3://multimodal-bucket-563851014557/multimodal/llava-v1.5-13b/code/llava-src.tar.gz


In [8]:
local_model_path: str = os.path.join(os.path.dirname(os.getcwd()), os.path.basename(g.HF_MODEL_ID))
Path(local_model_path).mkdir(exist_ok=True)
# model_name = "liuhaotian/llava-v1.5-13b"
allow_patterns = ["*.json", "*.pt", "*.bin", "*.txt", "*.model"]

model_download_path = snapshot_download(
    repo_id=g.HF_MODEL_ID,
    cache_dir=local_model_path,
    allow_patterns=allow_patterns,
)

Fetching 10 files:   0%|          | 0/10 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/154 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.16k [00:00<?, ?B/s]

pytorch_model.bin.index.json:   0%|          | 0.00/33.7k [00:00<?, ?B/s]

pytorch_model-00002-of-00003.bin:   0%|          | 0.00/9.90G [00:00<?, ?B/s]

mm_projector.bin:   0%|          | 0.00/62.9M [00:00<?, ?B/s]

pytorch_model-00003-of-00003.bin:   0%|          | 0.00/6.24G [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/438 [00:00<?, ?B/s]

pytorch_model-00001-of-00003.bin:   0%|          | 0.00/9.95G [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/749 [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

In [None]:
%%time
# upload model to s3
S3Uploader.upload(local_path=model_download_path, desired_s3_uri=s3_model_uri)


In [19]:
!rm llava-src.tar.gz
!tar zcvf llava-src.tar.gz ../llava-src --exclude ".ipynb_checkpoints" --exclude "__pycache__"
!aws s3 cp llava-src.tar.gz {s3_model_code_uri}

tar: Removing leading `../' from member names
../llava-src/
../llava-src/model.py
../llava-src/requirements.txt
../llava-src/serving.properties
../llava-src/run_llava_local.py
upload: ./llava-src.tar.gz to s3://multimodal-bucket-563851014557/multimodal/llava-v1.5-13b/code/llava-src.tar.gz


In [20]:
# framework_name = f"djl-{g.LLM_ENGINE}"
framework_name = "djl-deepspeed"

inference_image_uri = image_uris.retrieve(
    framework=framework_name, region=g.AWS_REGION, version="0.23.0"
)

print(f"Inference container uri: {inference_image_uri}")

Inference container uri: 763104351884.dkr.ecr.us-east-1.amazonaws.com/djl-inference:0.23.0-deepspeed0.9.5-cu118


### SageMaker endpoint

- Async Endpoint

In [21]:
model_name = name_from_base(f"llava-djl")
print(model_name)

create_model_response = sm_client.create_model(
    ModelName=model_name,
    ExecutionRoleArn=role,
    PrimaryContainer={"Image": inference_image_uri, "ModelDataUrl": s3_model_code_uri},
)
model_arn = create_model_response["ModelArn"]

print(f"Created Model: {model_arn}")

llava-djl-2024-02-03-00-34-24-295
Created Model: arn:aws:sagemaker:us-east-1:563851014557:model/llava-djl-2024-02-03-00-34-24-295


In [22]:
# async_output_uri = f"s3://{bucket_name}/{g.BUCKET_PREFIX}/outputs/{model_name}/"
async_output_uri = os.path.join("s3://", bucket_name, g.BUCKET_PREFIX, "outputs", model_name)
print(async_output_uri)

s3://multimodal-bucket-563851014557/multimodal/outputs/llava-djl-2024-02-03-00-34-24-295


In [23]:
instance_type = "ml.g5.2xlarge"
# instance_type = "ml.g5.xlarge"
# instance_type = "ml.g4dn.xlarge"

endpoint_config_name = f"{model_name}-async-config"
endpoint_name = f"{model_name}-async-endpoint"

endpoint_config_response = sm_client.create_endpoint_config(
    EndpointConfigName=endpoint_config_name,
    ProductionVariants=[
        {
            "VariantName": "variant1",
            "ModelName": model_name,
            "InstanceType": instance_type,
            "InitialInstanceCount": 1,
            "ContainerStartupHealthCheckTimeoutInSeconds": 600,
        },
    ],
    AsyncInferenceConfig={
        "OutputConfig": {
            "S3OutputPath": async_output_uri,
        },
        "ClientConfig": {
            "MaxConcurrentInvocationsPerInstance": 1
        }
    }
)
print(endpoint_config_response)

{'EndpointConfigArn': 'arn:aws:sagemaker:us-east-1:563851014557:endpoint-config/llava-djl-2024-02-03-00-34-24-295-async-config', 'ResponseMetadata': {'RequestId': '092be8fe-610c-4922-aebe-1454edad08b9', 'HTTPStatusCode': 200, 'HTTPHeaders': {'x-amzn-requestid': '092be8fe-610c-4922-aebe-1454edad08b9', 'content-type': 'application/x-amz-json-1.1', 'content-length': '127', 'date': 'Sat, 03 Feb 2024 00:34:30 GMT'}, 'RetryAttempts': 0}}


In [24]:
create_endpoint_response = sm_client.create_endpoint(
    EndpointName=f"{endpoint_name}", EndpointConfigName=endpoint_config_name
)
print(f"Created Endpoint: {create_endpoint_response['EndpointArn']}")

Created Endpoint: arn:aws:sagemaker:us-east-1:563851014557:endpoint/llava-djl-2024-02-03-00-34-24-295-async-endpoint
