# Deploy LLama2 7b Model with high performance on SageMaker using Sagemaker LMI and TensorRT-LLM



In this notebook, we explore how to host a LLama2 large language model with FP16 precision on SageMaker using the large model inference container (LMI). We use TRT-LLM as the model serving solution in this example that is bundled in the LMI container. TRT-LLM a high-performance model serving solution that can be used to optize inference performance of many models. 


In this example we also use model parallelism can help deploy large models that would normally be too large for a single GPU. With model parallelism, we partition and distribute a model across multiple GPUs. Each GPU holds a different part of the model, resolving the memory capacity issue for the largest deep learning models with billions of parameters. 

SageMaker has rolled out LMI container which now provides users with the ability to leverage the managed serving capabilities and help to provide the un-differentiated heavy lifting.

In this notebook, we deploy https://huggingface.co/TheBloke/Llama-2-7b-fp16 model on a ml.g5.2xlarge instance. 

# Licence agreement
 - View license information https://huggingface.co/meta-llama before using the model.
 - This notebook is a sample notebook and not intended for production use. Please refer to the licence at https://github.com/aws/mit-0. 

In [2]:
!pip install sagemaker boto3 huggingface_hub --upgrade #--quiet

Collecting sagemaker
  Obtaining dependency information for sagemaker from https://files.pythonhosted.org/packages/74/e9/7c99298b535fc3cb8cd50efe4978c2fd115a9eba22f3cf7bf1921ad6b53d/sagemaker-2.200.1-py2.py3-none-any.whl.metadata
  Downloading sagemaker-2.200.1-py2.py3-none-any.whl.metadata (13 kB)
Collecting boto3
  Obtaining dependency information for boto3 from https://files.pythonhosted.org/packages/d4/e8/3ccebf1cb78702beb8bb2e535525ff350d0f597ff794baeeab5206800a2c/boto3-1.34.1-py3-none-any.whl.metadata
  Downloading boto3-1.34.1-py3-none-any.whl.metadata (6.6 kB)
Collecting huggingface_hub
  Obtaining dependency information for huggingface_hub from https://files.pythonhosted.org/packages/05/09/1945ca6ba3ad8ad6e2872ba682ce8d68c5e63c8e55458ed8ab4885709f1d/huggingface_hub-0.19.4-py3-none-any.whl.metadata
  Downloading huggingface_hub-0.19.4-py3-none-any.whl.metadata (14 kB)
Collecting urllib3<1.27 (from sagemaker)
  Obtaining dependency information for urllib3<1.27 from https://files

In [2]:
import sagemaker
import jinja2
from sagemaker import image_uris
import boto3
import os
import time
import json
from pathlib import Path

sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /root/.config/sagemaker/config.yaml


In [3]:
role = sagemaker.get_execution_role()  # execution role for the endpoint
sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs
bucket = sess.default_bucket()  # bucket to house artifacts

In [4]:
model_bucket = sess.default_bucket()  # bucket to house model artifacts
s3_code_prefix = "hf-large-model-djl/meta-llama/Llama-2-7b-fp16/code"  # folder within bucket where code artifact will go

s3_model_prefix = "hf-large-model-djl/meta-llama/Llama-2-7b-fp16/model"  # folder within bucket where model artifact will go
region = sess._region_name
account_id = sess.account_id()

s3_client = boto3.client("s3")
sm_client = boto3.client("sagemaker")
smr_client = boto3.client("sagemaker-runtime")

jinja_env = jinja2.Environment()

### Define a variable to contain the s3url of the location that has the model

In [5]:
# Define a variable to contain the s3url of the location that has the model. For demo purpose, we use Llama-2-13b-fp16 model artifacts from our S3 bucket
pretrained_model_location = f"s3://sagemaker-example-files-prod-{region}/models/llama-2/fp16/7B/"

## Create SageMaker compatible Model artifact,  upload Model to S3 and bring your own inference script.

SageMaker Large Model Inference containers can be used to host models without providing your own inference code. This is extremely useful when there is no custom pre-processing of the input data or postprocessing of the model's predictions.

SageMaker needs the model artifacts to be in a Tarball format. In this example, we provide the following files - serving.properties.

The tarball is in the following format:

```
code
├──── 
│   └── serving.properties
```

    serving.properties is the configuration file that can be used to configure the model server.


#### Create serving.properties 
This is a configuration file to indicate to DJL Serving which model parallelization and inference optimization libraries you would like to use. Depending on your need, you can set the appropriate configuration.

Here is a list of settings that we use in this configuration file -

    engine: The engine for DJL to use. In this case, we have set it to MPI.
    option.model_id: The model id of a pretrained model hosted inside a model repository on huggingface.co (https://huggingface.co/models) or S3 path to the model artefacts. 
    option.tensor_parallel_degree: Set to the number of GPU devices over which Accelerate needs to partition the model. This parameter also controls the no of workers per model which will be started up when DJL serving runs. As an example if we have a 4 GPU machine and we are creating 4 partitions then we will have 1 worker per model to serve the requests.

For more details on the configuration options and an exhaustive list, you can refer the documentation - https://docs.aws.amazon.com/sagemaker/latest/dg/realtime-endpoints-large-model-configuration.html.



In [6]:
!rm -rf code_llama2_7b_fp16
!mkdir -p code_llama2_7b_fp16

In [7]:
%%writefile code_llama2_7b_fp16/serving.properties
engine=MPI
#option.model_id=TheBloke/Llama-2-13B-fp16
option.model_id = {{model_id}}
option.tensor_parallel_degree=4
option.max_rolling_batch_size=128
option.rolling_batch=trtllm

Writing code_llama2_7b_fp16/serving.properties


In [8]:
# we plug in the appropriate model location into our `serving.properties`
template = jinja_env.from_string(Path("code_llama2_7b_fp16/serving.properties").open().read())
Path("code_llama2_7b_fp16/serving.properties").open("w").write(
    template.render(model_id=pretrained_model_location)
)
!pygmentize code_llama2_7b_fp16/serving.properties | cat -n

     1	[36mengine[39;49;00m=[33mMPI[39;49;00m[37m[39;49;00m
     2	[37m#option.model_id=TheBloke/Llama-2-13B-fp16[39;49;00m[37m[39;49;00m
     3	[36moption.model_id[39;49;00m[37m [39;49;00m=[37m [39;49;00m[33ms3://sagemaker-example-files-prod-us-east-1/models/llama-2/fp16/7B/[39;49;00m[37m[39;49;00m
     4	[36moption.tensor_parallel_degree[39;49;00m=[33m4[39;49;00m[37m[39;49;00m
     5	[36moption.max_rolling_batch_size[39;49;00m=[33m128[39;49;00m[37m[39;49;00m
     6	[36moption.rolling_batch[39;49;00m=[33mtrtllm[39;49;00m[37m[39;49;00m


**Image URI for the DJL container is being used here**

In [9]:
try:
    inference_image_uri = image_uris.retrieve(
        framework="djl-tensorrtllm", region=region, version="0.25.0"
    )
except FileNotFoundError:
    inference_image_uri = f"763104351884.dkr.ecr.{region}.amazonaws.com/djl-inference:0.25.0-tensorrtllm0.5.0-cu122"
print(f"Image going to be used is ---- > {inference_image_uri}")

Image going to be used is ---- > 763104351884.dkr.ecr.us-east-1.amazonaws.com/djl-inference:0.25.0-tensorrtllm0.5.0-cu122


**Create the Tarball and then upload to S3 location**

In [10]:
!rm model.tar.gz
!tar czvf model.tar.gz code_llama2_13b_fp16

tar: code_llama2_13b_fp16: Cannot stat: No such file or directory
tar: Exiting with failure status due to previous errors


In [11]:
s3_code_artifact = sess.upload_data("model.tar.gz", bucket, s3_code_prefix)

### To create the end point the steps are:

1. Create the Model using the Image container and the Model Tarball uploaded earlier
2. Create the endpoint config using the following key parameters

    a) Instance Type is ml.g5.12xlarge 
    
    b) ContainerStartupHealthCheckTimeoutInSeconds is 3600 to ensure health check starts after the model is ready    
3. Create the end point using the endpoint config created    


#### Create the Model
Use the image URI for the DJL container and the s3 location to which the tarball was uploaded.

The container downloads the model into the `/tmp` space on the instance because SageMaker maps the `/tmp` to the Amazon Elastic Block Store (Amazon EBS) volume that is mounted when we specify the endpoint creation parameter VolumeSizeInGB. 
It leverages `s5cmd`(https://github.com/peak/s5cmd) which offers a very fast download speed and hence extremely useful when downloading large models.

For instances like p4dn, which come pre-built with the volume instance, we can continue to leverage the `/tmp` on the container. The size of this mount is large enough to hold the model.


In [19]:
from sagemaker.utils import name_from_base

model_name = name_from_base(f"Llama-2-7b-fp16-mpi")
print(model_name)

create_model_response = sm_client.create_model(
    ModelName=model_name,
    ExecutionRoleArn=role,
    PrimaryContainer={
        "Image": inference_image_uri,
        "ModelDataUrl": s3_code_artifact,
        "Environment": {"MODEL_LOADING_TIMEOUT": "3600"},
    },
)
model_arn = create_model_response["ModelArn"]

print(f"Created Model: {model_arn}")

Llama-2-7b-fp16-mpi-2023-12-14-22-59-02-878
Created Model: arn:aws:sagemaker:us-east-1:717117019124:model/llama-2-7b-fp16-mpi-2023-12-14-22-59-02-878


In [20]:
endpoint_config_name = f"{model_name}-config"
endpoint_name = f"{model_name}-endpoint"

endpoint_config_response = sm_client.create_endpoint_config(
    EndpointConfigName=endpoint_config_name,
    ProductionVariants=[
        {
            "VariantName": "variant1",
            "ModelName": model_name,
            "InstanceType": "ml.g5.2xlarge",
            "InitialInstanceCount": 1,
            "ModelDataDownloadTimeoutInSeconds": 3600,
            "ContainerStartupHealthCheckTimeoutInSeconds": 3600,
        },
    ],
)
endpoint_config_response

{'EndpointConfigArn': 'arn:aws:sagemaker:us-east-1:717117019124:endpoint-config/llama-2-7b-fp16-mpi-2023-12-14-22-59-02-878-config',
 'ResponseMetadata': {'RequestId': '71699353-97af-49ee-a10e-c565bfe63b77',
  'HTTPStatusCode': 200,
  'HTTPHeaders': {'x-amzn-requestid': '71699353-97af-49ee-a10e-c565bfe63b77',
   'content-type': 'application/x-amz-json-1.1',
   'content-length': '131',
   'date': 'Thu, 14 Dec 2023 22:59:03 GMT'},
  'RetryAttempts': 0}}

In [21]:
create_endpoint_response = sm_client.create_endpoint(
    EndpointName=f"{endpoint_name}", EndpointConfigName=endpoint_config_name
)
print(f"Created Endpoint: {create_endpoint_response['EndpointArn']}")

Created Endpoint: arn:aws:sagemaker:us-east-1:717117019124:endpoint/llama-2-7b-fp16-mpi-2023-12-14-22-59-02-878-endpoint


### This step can take ~ 20 min or longer so please be patient

In [22]:
import time

resp = sm_client.describe_endpoint(EndpointName=endpoint_name)
status = resp["EndpointStatus"]
print("Status: " + status)

while status == "Creating":
    time.sleep(60)
    resp = sm_client.describe_endpoint(EndpointName=endpoint_name)
    status = resp["EndpointStatus"]
    print("Status: " + status)

print("Arn: " + resp["EndpointArn"])
print("Status: " + status)

Status: InService
Arn: arn:aws:sagemaker:us-east-1:717117019124:endpoint/llama-2-7b-fp16-mpi-2023-12-14-22-59-02-878-endpoint
Status: InService


#### While you wait for the endpoint to be created, you can read more about:
- [Deep Learning containers for large model inference](https://docs.aws.amazon.com/sagemaker/latest/dg/realtime-endpoints-large-model-dlc.html)

#### Leverage the Boto3 to invoke the endpoint. 

This is a generative model so we pass in a Text as a prompt and Model will complete the sentence and return the results.

You can pass a prompt as input to the model. This done by setting inputs to a prompt. The model then returns a result for each prompt. The text generation can be configured using appropriate parameters.
These parameters need to be passed to the endpoint as a dictionary of kwargs. Refer this documentation - https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationConfig for more details.

The below code sample illustrates the invocation of the endpoint using a text prompt and also sets some parameters

In [23]:
%%time
smr_client.invoke_endpoint(
    EndpointName=endpoint_name,
    Body=json.dumps(
        {
            "inputs": "The diamondback terrapin was the first reptile to be",
            "parameters": {
                "do_sample": True,
                "max_new_tokens": 256,
                "temperature": 0.3,
            },
        }
    ),
    ContentType="application/json",
)["Body"].read().decode("utf8")

ModelError: An error occurred (ModelError) when calling the InvokeEndpoint operation: Received client error (400) from primary with message "{
  "code": 400,
  "type": "BadRequestException",
  "message": "Parameter model_name is required."
}
". See https://us-east-1.console.aws.amazon.com/cloudwatch/home?region=us-east-1#logEventViewer:group=/aws/sagemaker/Endpoints/Llama-2-7b-fp16-mpi-2023-12-14-22-59-02-878-endpoint in account 717117019124 for more information.

## Clean Up

In [22]:
# - Delete the end point
sm_client.delete_endpoint(EndpointName=endpoint_name)

{'ResponseMetadata': {'RequestId': '13ae3821-94bb-440d-adcb-47e088a126f4',
  'HTTPStatusCode': 200,
  'HTTPHeaders': {'x-amzn-requestid': '13ae3821-94bb-440d-adcb-47e088a126f4',
   'content-type': 'application/x-amz-json-1.1',
   'content-length': '0',
   'date': 'Thu, 14 Dec 2023 22:28:44 GMT'},
  'RetryAttempts': 0}}

In [23]:
# - In case the end point failed we still want to delete the model
sm_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
sm_client.delete_model(ModelName=model_name)

{'ResponseMetadata': {'RequestId': '28a37799-0f39-4d4b-b895-f5088889f9eb',
  'HTTPStatusCode': 200,
  'HTTPHeaders': {'x-amzn-requestid': '28a37799-0f39-4d4b-b895-f5088889f9eb',
   'content-type': 'application/x-amz-json-1.1',
   'content-length': '0',
   'date': 'Thu, 14 Dec 2023 22:28:44 GMT'},
  'RetryAttempts': 0}}