# Utilizing SageMaker Inference Components To Host Multiple LLMs

In this example we will take a look at using SageMaker Inference Components to host both the Llama 7B model we deployed in the SME Lab and a Flan T-5 Model. With Inference Components you can bring multiple containers onto a singular endpoint. In this case we have two different models with different containers/model servers implemented. You can also optionally bring your own container.

The flow for creating Inference Components is a little different from creating a traditional SageMaker Endpoint.

![creation-flow](images/ic-arch.png)

Think of an IC Component as a combination of two factors:

- <b>SageMaker Model Object</b>: Model data + container selection
- <b>Hardware Resources</b>: Dedicated Compute you are assigning to that Model (GPUs, Inferentia2, CPU).
    - <b>Copy Count</b>: Number of copies of a model, you can set AutoScaling policy at a per model level based off of copy count.
    

To understand scaling at a per model level please reference this [example](https://github.com/aws/amazon-sagemaker-examples/blob/main/inference/generativeai/llm-workshop/lab-inference-components-with-scaling/2c_meta-llama2-7b-lmi-autoscaling.ipynb).

## Setup & Endpoint Creation

To get started we create a persistent SageMaker Endpoint and enable managed AutoScaling at the endpoint level. Here AutoScaling is taken care for us at the endpoint level and you can enable AutoScaling policies at a model/container level based off of the number of invocations per copy.

In [None]:
#!pip install sagemaker --upgrade

In [None]:
import boto3
import sagemaker
import time
from time import gmtime, strftime

#Setup
client = boto3.client(service_name="sagemaker")
runtime = boto3.client(service_name="sagemaker-runtime")
boto_session = boto3.session.Session()
s3 = boto_session.resource('s3')
region = boto_session.region_name
sagemaker_session = sagemaker.Session()
bucket = sagemaker_session.default_bucket()
role = sagemaker.get_execution_role()
print(f"Role ARN: {role}")
print(f"Region: {region}")

# client setup
s3_client = boto3.client("s3")
sm_client = boto3.client("sagemaker")
smr_client = boto3.client("sagemaker-runtime")

In [None]:
# endpoint config name
epc_name = "ic-epc" + strftime("%Y-%m-%d-%H-%M-%S", gmtime())
print(f"Endpoint Config Name: {epc_name}")

# Container Parameters, increase health check for LLMs: 
variant_name = "AllTraffic"
instance_type = "ml.g5.48xlarge"
model_data_download_timeout_in_seconds = 3600
container_startup_health_check_timeout_in_seconds = 3600

# Setting up managed AutoScaling at endpoint level
initial_instance_count = 1
max_instance_count = 4
print(f"Initial instance count: {initial_instance_count}")
print(f"Max instance count: {max_instance_count}")

# Endpoint Config Creation
endpoint_config_response = client.create_endpoint_config(
    EndpointConfigName=epc_name,
    ExecutionRoleArn=role,
    ProductionVariants=[
        {
            "VariantName": variant_name,
            "InstanceType": instance_type,
            "InitialInstanceCount": 1,
            "ModelDataDownloadTimeoutInSeconds": model_data_download_timeout_in_seconds,
            "ContainerStartupHealthCheckTimeoutInSeconds": container_startup_health_check_timeout_in_seconds,
            "ManagedInstanceScaling": {
                "Status": "ENABLED",
                "MinInstanceCount": initial_instance_count,
                "MaxInstanceCount": max_instance_count,
            },
            # can set to least outstanding or random: https://aws.amazon.com/blogs/machine-learning/minimize-real-time-inference-latency-by-using-amazon-sagemaker-routing-strategies/
            "RoutingConfig": {"RoutingStrategy": "LEAST_OUTSTANDING_REQUESTS"},
        }
    ],
)

print("Endpoint Configuration Arn: " + endpoint_config_response["EndpointConfigArn"])

In [None]:
#Endpoint Creation
endpoint_name = "ic-ep" + strftime("%Y-%m-%d-%H-%M-%S", gmtime())
create_endpoint_response = client.create_endpoint(
    EndpointName=endpoint_name,
    EndpointConfigName=epc_name,
)
print("Endpoint Arn: " + create_endpoint_response["EndpointArn"])

In [None]:
#Monitor creation
describe_endpoint_response = client.describe_endpoint(EndpointName=endpoint_name)
while describe_endpoint_response["EndpointStatus"] == "Creating":
    describe_endpoint_response = client.describe_endpoint(EndpointName=endpoint_name)
    print(describe_endpoint_response["EndpointStatus"])
    time.sleep(15)
print(describe_endpoint_response)

## Inference Components Creation

### Inference Component 1: Llama 7B via LMI Container

Here we'll use our single model Llama 7b optimized example and take the same container to create our Inference Component. We create a SageMaker Model object and the IC inherits the metadata from this object. The new API call we are dealing with is the [create_inference_component API call](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker/client/create_inference_component.html).

In [None]:
!rm -rf code_llama2_7b_fp16
!mkdir -p code_llama2_7b_fp16

In [None]:
%%writefile code_llama2_7b_fp16/serving.properties
engine=MPI
option.tensor_parallel_degree=4
option.rolling_batch=trtllm
option.paged_attention = true
option.max_rolling_batch_prefill_tokens = 16080
option.max_rolling_batch_size=64
option.model_loading_timeout = 900
option.model_id = s3://sagemaker-example-files-prod-us-east-1/models/llama-2/fp16/7B/

In [None]:
image_uri = sagemaker.image_uris.retrieve(
        framework="djl-tensorrtllm",
        region=sagemaker_session.boto_session.region_name,
        version="0.26.0"
    )

In [None]:
!rm model.tar.gz
!tar czvf model.tar.gz code_llama2_7b_fp16

In [None]:
s3_code_prefix = "hf-large-model-djl/meta-llama/Llama-2-7b-fp16/code"
s3_code_artifact = sagemaker_session.upload_data("model.tar.gz", bucket, s3_code_prefix)

In [None]:
print(f"Model data is stored: {s3_code_artifact}")

In [None]:
from sagemaker.utils import name_from_base

llama_model_name = name_from_base(f"Llama-2-7b-fp16-mpi")
print(llama_model_name)

create_model_response = sm_client.create_model(
    ModelName=llama_model_name,
    ExecutionRoleArn=role,
    PrimaryContainer={"Image": image_uri, "ModelDataUrl": s3_code_artifact},
)
model_arn = create_model_response["ModelArn"]

print(f"Created Model: {model_arn}")

In [None]:
llama7b_ic_name = "llama7b-ic" + strftime("%Y-%m-%d-%H-%M-%S", gmtime())
variant_name = "AllTraffic"

# llama inference component reaction
create_llama_ic_response = sm_client.create_inference_component(
    InferenceComponentName=llama7b_ic_name,
    EndpointName=endpoint_name,
    VariantName=variant_name,
    Specification={
        "ModelName": llama_model_name,
        "ComputeResourceRequirements": {
            # enables tensor parallel via TGI, reserving 4 GPUs (g5.48xlarge has 8 GPUs)
            "NumberOfAcceleratorDevicesRequired": 4,
            "NumberOfCpuCoresRequired": 1,
            "MinMemoryRequiredInMb": 1024,
        },
    },
    # can setup autoscaling for copies, each copy will retain the hardware you have allocated
    RuntimeConfig={"CopyCount": 1},
)

print("IC Llama Arn: " + create_llama_ic_response["InferenceComponentArn"])

In [None]:
describe_ic_llama_response = client.describe_inference_component(
    InferenceComponentName=llama7b_ic_name)

while describe_ic_llama_response["InferenceComponentStatus"] == "Creating":
    describe_ic_llama_response = client.describe_inference_component(InferenceComponentName=llama7b_ic_name)
    print(describe_ic_llama_response["InferenceComponentStatus"])
    time.sleep(100)
print(describe_ic_llama_response)

#### Sample Inference

This is the same REST API call, you just specify the necessary inference component name for scaling.

In [None]:
payload = {"inputs": "Who is Roger Federer?", 
           "parameters": {"max_new_tokens":128, "do_sample":True}}

import json

runtime_client = boto3.client('sagemaker-runtime')
content_type = "application/json"

response = runtime_client.invoke_endpoint(
    EndpointName=endpoint_name,
    InferenceComponentName=llama7b_ic_name, #specify IC name
    ContentType=content_type,
    Body=json.dumps(payload))
result = json.loads(response['Body'].read().decode())['generated_text']
print(result)

### Inference Component 2: FlanT5 via TGI Container

In the case of our second Inference Component we use the HuggingFace Text Generation Inference (TGI) container to pull down the Flan T-5 model directly. To understand which model server/container to use for your LLM hosting and the tradeoffs please refer to the following [article](https://aws.plainenglish.io/four-different-ways-to-host-large-language-models-on-amazon-sagemaker-4d1b027812b5).

In [None]:
from sagemaker.huggingface import HuggingFaceModel, get_huggingface_llm_image_uri
import json

# utilizing huggingface TGI container
image_uri = get_huggingface_llm_image_uri("huggingface",version="1.1.0")
print(f"TGI Image: {image_uri}")

# Flan T5 TGI Model
flant5_model = {"Image": image_uri, "Environment": {"HF_MODEL_ID": "google/flan-t5-xxl"}}
flant5_model_name = "flant5-model" + strftime("%Y-%m-%d-%H-%M-%S", gmtime())
print(f"Flan Model Name: {flant5_model_name}")

In [None]:
# create model object for flan t5
create_flan_model_response = sm_client.create_model(
    ModelName=flant5_model_name,
    ExecutionRoleArn=role,
    Containers=[flant5_model],
)
print("Flan Model Arn: " + create_flan_model_response["ModelArn"])

In [None]:
flant5_ic_name = "flant5-ic" + strftime("%Y-%m-%d-%H-%M-%S", gmtime())
variant_name = "AllTraffic"

# flan inference component reaction
create_flan_ic_response = sm_client.create_inference_component(
    InferenceComponentName=flant5_ic_name,
    EndpointName=endpoint_name,
    VariantName=variant_name,
    Specification={
        "ModelName": flant5_model_name,
        "ComputeResourceRequirements": {
            # enables tensor parallel via TGI, reserving 2 GPUs (g5.48xlarge has 8 GPUs)
            "NumberOfAcceleratorDevicesRequired": 2,
            "NumberOfCpuCoresRequired": 1,
            "MinMemoryRequiredInMb": 1024,
        },
    },
    # can setup autoscaling for copies
    RuntimeConfig={"CopyCount": 1},
)

print("IC Flan Arn: " + create_flan_ic_response["InferenceComponentArn"])

In [None]:
describe_ic_flan_response = client.describe_inference_component(
    InferenceComponentName=flant5_ic_name)

while describe_ic_flan_response["InferenceComponentStatus"] == "Creating":
    describe_ic_flan_response = client.describe_inference_component(InferenceComponentName=flant5_ic_name)
    print(describe_ic_flan_response["InferenceComponentStatus"])
    time.sleep(100)
print(describe_ic_flan_response)

#### Sample Inference

In [None]:
import json

payload = "What is the capitol of the United States?"
response = runtime_client.invoke_endpoint(
    EndpointName=endpoint_name,
    InferenceComponentName=flant5_ic_name, #specify IC name
    ContentType="application/json",
    Accept="application/json",
    Body=json.dumps(
        {
            "inputs": payload,
            "parameters": {
                "early_stopping": True,
                "length_penalty": 2.0,
                "max_new_tokens": 50,
                "temperature": 1,
                "min_length": 10,
                "no_repeat_ngram_size": 3,
                },
        }
    ),
)
result = json.loads(response["Body"].read().decode())
result