## Utilizing SageMaker Inference Components to Host Multiple LLMs on a Single Endpoint

In this example we utilize SageMaker Inference Components to host both a Falcon and Flan Model on a singular endpoint. Unlike traditional SageMaker Real-Time Endpoints we follow the flow of Endpoint Config -> Endpoint -> IC (1...n). In this case we create the endpoint and then add both a Falcon and Flan Model as their own ICs. ICs are similar to SageMaker Model objects we can define the model data and container information, the difference is we can enable AutoScaling as well at the IC level.

### Setup

In [None]:
#!pip install sagemaker --upgrade

In [None]:
import boto3
import sagemaker
import time
from time import gmtime, strftime

#Setup
client = boto3.client(service_name="sagemaker")
runtime = boto3.client(service_name="sagemaker-runtime")
boto_session = boto3.session.Session()
s3 = boto_session.resource('s3')
region = boto_session.region_name
sagemaker_session = sagemaker.Session()
role = sagemaker.get_execution_role()
print(f"Role ARN: {role}")
print(f"Region: {region}")

### Create Endpoint Config and Endpoint

In [None]:
# endpoint config name
epc_name = "ic-epc" + strftime("%Y-%m-%d-%H-%M-%S", gmtime())
print(f"Endpoint Config Name: {epc_name}")

# Container Parameters, increase health check for LLMs: 
variant_name = "AllTraffic"
instance_type = "ml.g5.24xlarge"
model_data_download_timeout_in_seconds = 3600
container_startup_health_check_timeout_in_seconds = 3600

# Setting up managed AutoScaling
initial_instance_count = 1
max_instance_count = 2
print(f"Initial instance count: {initial_instance_count}")
print(f"Max instance count: {max_instance_count}")

# Endpoint Config Creation
endpoint_config_response = client.create_endpoint_config(
    EndpointConfigName=epc_name,
    ExecutionRoleArn=role,
    ProductionVariants=[
        {
            "VariantName": variant_name,
            "InstanceType": instance_type,
            "InitialInstanceCount": 1,
            "ModelDataDownloadTimeoutInSeconds": model_data_download_timeout_in_seconds,
            "ContainerStartupHealthCheckTimeoutInSeconds": container_startup_health_check_timeout_in_seconds,
            "ManagedInstanceScaling": {
                "Status": "ENABLED",
                "MinInstanceCount": initial_instance_count,
                "MaxInstanceCount": max_instance_count,
            },
            # can set to least outstanding or random: https://aws.amazon.com/blogs/machine-learning/minimize-real-time-inference-latency-by-using-amazon-sagemaker-routing-strategies/
            "RoutingConfig": {"RoutingStrategy": "LEAST_OUTSTANDING_REQUESTS"},
        }
    ],
)

print("Endpoint Configuration Arn: " + endpoint_config_response["EndpointConfigArn"])

In [None]:
#Endpoint Creation
endpoint_name = "ic-ep" + strftime("%Y-%m-%d-%H-%M-%S", gmtime())
create_endpoint_response = client.create_endpoint(
    EndpointName=endpoint_name,
    EndpointConfigName=epc_name,
)
print("Endpoint Arn: " + create_endpoint_response["EndpointArn"])

In [None]:
#Monitor creation
describe_endpoint_response = client.describe_endpoint(EndpointName=endpoint_name)
while describe_endpoint_response["EndpointStatus"] == "Creating":
    describe_endpoint_response = client.describe_endpoint(EndpointName=endpoint_name)
    print(describe_endpoint_response["EndpointStatus"])
    time.sleep(15)
print(describe_endpoint_response)

### Inference Component Creation

First we define the SageMaker model objects which have our container and model data info, the Inference Component directly takes this metadata from the SageMaker Model Object.

In [None]:
from sagemaker.huggingface import HuggingFaceModel, get_huggingface_llm_image_uri
import json

# utilizing huggingface TGI container
image_uri = get_huggingface_llm_image_uri("huggingface",version="1.1.0")
print(f"TGI Image: {image_uri}")

# Flan T5 TGI Model
flant5_model = {"Image": image_uri, "Environment": {"HF_MODEL_ID": "google/flan-t5-xxl"}}
flant5_model_name = "flant5-model" + strftime("%Y-%m-%d-%H-%M-%S", gmtime())
print(f"Flan Model Name: {flant5_model_name}")

#note: falcon 7b takes just one GPU, sharding is not supported
falcon7b_model = {"Image": image_uri, "Environment": {'HF_MODEL_ID':'tiiuae/falcon-7b'}}
falcon7b_model_name = "falcon7b-model" + strftime("%Y-%m-%d-%H-%M-%S", gmtime())
print(f"Falcon Model Name: {falcon7b_model_name}")

In [None]:
# create model object for flan t5
create_flan_model_response = client.create_model(
    ModelName=flant5_model_name,
    ExecutionRoleArn=role,
    Containers=[flant5_model],
)
print("Flan Model Arn: " + create_flan_model_response["ModelArn"])

# create falcon model object
create_falcon_model_response = client.create_model(
    ModelName=falcon7b_model_name,
    ExecutionRoleArn=role,
    Containers=[falcon7b_model],
)
print("Falcon Model Arn: " + create_falcon_model_response["ModelArn"])

In [None]:
flant5_ic_name = "flant5-ic" + strftime("%Y-%m-%d-%H-%M-%S", gmtime())
variant_name = "AllTraffic"

# flan inference component reaction
create_flan_ic_response = client.create_inference_component(
    InferenceComponentName=flant5_ic_name,
    EndpointName=endpoint_name,
    VariantName=variant_name,
    Specification={
        "ModelName": flant5_model_name,
        "ComputeResourceRequirements": {
            # enables tensor parallel via TGI, reserving 2 GPUs (g5.24xlarge has 4 GPUs)
            "NumberOfAcceleratorDevicesRequired": 2,
            "NumberOfCpuCoresRequired": 1,
            "MinMemoryRequiredInMb": 1024,
        },
    },
    # can setup autoscaling for copies
    RuntimeConfig={"CopyCount": 1},
)

print("IC Flan Arn: " + create_flan_ic_response["InferenceComponentArn"])

In [None]:
describe_ic_flan_response = client.describe_inference_component(
    InferenceComponentName=flant5_ic_name)

while describe_ic_flan_response["InferenceComponentStatus"] == "Creating":
    describe_ic_flan_response = client.describe_inference_component(InferenceComponentName=flant5_ic_name)
    print(describe_ic_flan_response["InferenceComponentStatus"])
    time.sleep(30)
print(describe_ic_flan_response)

In [None]:
import json

payload = "What is the capitol of the United States?"
response = runtime.invoke_endpoint(
    EndpointName=endpoint_name,
    InferenceComponentName=flant5_ic_name, #specify IC name
    ContentType="application/json",
    Accept="application/json",
    Body=json.dumps(
        {
            "inputs": payload,
            "parameters": {
                "early_stopping": True,
                "length_penalty": 2.0,
                "max_new_tokens": 50,
                "temperature": 1,
                "min_length": 10,
                "no_repeat_ngram_size": 3,
                },
        }
    ),
)
result = json.loads(response["Body"].read().decode())
result

In [None]:
falcon_ic_name = "falcon-ic" + strftime("%Y-%m-%d-%H-%M-%S", gmtime())
variant_name = "AllTraffic"

create_falcon_ic_response = client.create_inference_component(
    InferenceComponentName=falcon_ic_name,
    EndpointName=endpoint_name,
    VariantName=variant_name,
    Specification={
        "ModelName": falcon7b_model_name,
        "ComputeResourceRequirements": {
            # For falcon 7b only one GPU is needed: https://github.com/huggingface/text-generation-inference/issues/418#issuecomment-1579186709
            "NumberOfAcceleratorDevicesRequired": 1,
            "NumberOfCpuCoresRequired": 1,
            "MinMemoryRequiredInMb": 1024,
        },
    },
    # can setup autoscaling for copies
    RuntimeConfig={"CopyCount": 1},
)

print("IC Falcon Arn: " + create_falcon_ic_response["InferenceComponentArn"])

In [None]:
describe_ic_falcon_response = client.describe_inference_component(
    InferenceComponentName=falcon_ic_name)

while describe_ic_falcon_response["InferenceComponentStatus"] == "Creating":
    describe_ic_falcon_response = client.describe_inference_component(InferenceComponentName=falcon_ic_name)
    print(describe_ic_falcon_response["InferenceComponentStatus"])
    time.sleep(60)
print(describe_ic_falcon_response)

In [None]:
import json

payload = "What is the capitol of the United States?"
response = runtime.invoke_endpoint(
    EndpointName=endpoint_name,
    InferenceComponentName=falcon_ic_name, #specify IC name
    ContentType="application/json",
    Accept="application/json",
    Body=json.dumps(
        {
            "inputs": payload,
            "parameters": {
                "early_stopping": True,
                "length_penalty": 2.0,
                "max_new_tokens": 50,
                "temperature": 1,
                "min_length": 10,
                "no_repeat_ngram_size": 3,
                },
        }
    ),
)
result = json.loads(response["Body"].read().decode())
result