# Deploy Falcon 40B on Amazon SageMaker using LMI and vLLM

## Resources
- [Falcon-40B model card](https://huggingface.co/tiiuae/falcon-40b)
- [LMI Configuration Documentation](https://docs.djl.ai/docs/serving/serving/docs/lmi/configurations_large_model_inference_containers.html)
- [DJL-Demo Samples](https://github.com/deepjavalibrary/djl-demo/tree/2a5152f578f5954b8b68acdee18eed4e2a75c81f/aws/sagemaker/large-model-inference/sample-llm)
- [vLLM documentation](https://docs.vllm.ai/en/latest/)

## Step 1: Setup

In [27]:
# %pip install --upgrade --quiet sagemaker

In [28]:
import sagemaker
import boto3
import json
print(f"boto3 version: {boto3.__version__}")
print(f"sagemaker version: {sagemaker.__version__}")

boto3 version: 1.34.39
sagemaker version: 2.209.0


In [29]:
role = sagemaker.get_execution_role()  # execution role for the endpoint
sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs
bucket = sess.default_bucket()  # bucket to house artifacts
region = sess._region_name

sm_client = boto3.client("sagemaker")
smr_client = boto3.client("sagemaker-runtime")

## Step 2: Create a model, endpoint configuration and endpoint

Retrieve the ECR image URI for the DJL TensorRT accelerated large language model framework. The image URI is looked up based on the framework name, AWS region, and framework version. This allows us to dynamically select the right Docker image for our environment.

Functions for generating ECR image URIs for pre-built SageMaker Docker images. See available Large Model Inference DLC's [here](https://github.com/aws/deep-learning-containers/blob/master/available_images.md#large-model-inference-containers)

In [30]:
version = "0.26.0"
inference_image_uri = sagemaker.image_uris.retrieve(
    "djl-deepspeed", region=region, version=version
)
print(f"Image going to be used is ----> {inference_image_uri}")

Image going to be used is ----> 763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.26.0-deepspeed0.12.6-cu121


In [31]:
model_name = sagemaker.utils.name_from_base("falcon40b-lmi-vllm")
print(model_name)

env = {
    "SERVING_LOAD_MODELS": "test::Python=/opt/ml/model",
    "OPTION_MODEL_ID": "tiiuae/falcon-40b",
    "OPTION_ROLLING_BATCH": "vllm",
    "OPTION_TENSOR_PARALLEL_DEGREE": "8",
}

create_model_response = sm_client.create_model(
    ModelName = model_name,
    ExecutionRoleArn = role,
    PrimaryContainer = {
        "Image": inference_image_uri, 
        "Environment": env,
    },
)
model_arn = create_model_response["ModelArn"]

print(f"Created Model: {model_arn}")

falcon40b-lmi-vllm-2024-02-26-22-26-24-860
Created Model: arn:aws:sagemaker:us-west-2:461312420708:model/falcon40b-lmi-vllm-2024-02-26-22-26-24-860


These two cells below deploy the model to a SageMaker endpoint for real-time inference. The instance_type defines the machine instance for the endpoint. The endpoint name is programmatically generated based on the base name. The model is deployed with a large container startup timeout specified, as the TensorRT model takes time to initialize on the GPU instance.

In [32]:
endpoint_config_name = f"{model_name}-config"

In [33]:
# Set varient name and instance type for hosting
variant_name = "AllTraffic"
instance_type = "ml.g5.48xlarge"
model_data_download_timeout_in_seconds = 1200
container_startup_health_check_timeout_in_seconds = 1200

initial_instance_count = 1
max_instance_count = 2 # will use for managed instance scaling later
print(f"Initial instance count: {initial_instance_count}")
print(f"Max instance count: {max_instance_count}")

sm_client.create_endpoint_config(
    EndpointConfigName = endpoint_config_name,
    ExecutionRoleArn = role,
    ProductionVariants = [
        {
            "VariantName": variant_name,
            "InstanceType": instance_type,
            "InitialInstanceCount": initial_instance_count,
            "ModelDataDownloadTimeoutInSeconds": model_data_download_timeout_in_seconds,
            "ContainerStartupHealthCheckTimeoutInSeconds": container_startup_health_check_timeout_in_seconds,
            "ManagedInstanceScaling": {
                "Status": "ENABLED",
                "MinInstanceCount": initial_instance_count,
                "MaxInstanceCount": max_instance_count,
            },
            "RoutingConfig": {"RoutingStrategy": "LEAST_OUTSTANDING_REQUESTS"},
        }
    ]
)

Initial instance count: 1
Max instance count: 2


{'EndpointConfigArn': 'arn:aws:sagemaker:us-west-2:461312420708:endpoint-config/falcon40b-lmi-vllm-2024-02-26-22-26-24-860-config',
 'ResponseMetadata': {'RequestId': 'a7f162c3-144d-4815-b643-72015b12d05c',
  'HTTPStatusCode': 200,
  'HTTPHeaders': {'x-amzn-requestid': 'a7f162c3-144d-4815-b643-72015b12d05c',
   'content-type': 'application/x-amz-json-1.1',
   'content-length': '130',
   'date': 'Mon, 26 Feb 2024 22:26:25 GMT'},
  'RetryAttempts': 0}}

In [34]:
endpoint_name = f"{model_name}-endpoint"

In [35]:
create_endpoint_response = sm_client.create_endpoint(
    EndpointName = endpoint_name, EndpointConfigName = endpoint_config_name
)
print(f"Created Endpoint: {create_endpoint_response['EndpointArn']}")

Created Endpoint: arn:aws:sagemaker:us-west-2:461312420708:endpoint/falcon40b-lmi-vllm-2024-02-26-22-26-24-860-endpoint


### This step can take ~ 10 min or longer so please be patient

In [36]:
#
# Using helper function to wait for the endpoint to be ready
#
sess.wait_for_endpoint(endpoint_name)

-----!

{'EndpointName': 'falcon40b-lmi-vllm-2024-02-26-22-26-24-860-endpoint',
 'EndpointArn': 'arn:aws:sagemaker:us-west-2:461312420708:endpoint/falcon40b-lmi-vllm-2024-02-26-22-26-24-860-endpoint',
 'EndpointConfigName': 'falcon40b-lmi-vllm-2024-02-26-22-26-24-860-config',
 'ProductionVariants': [{'VariantName': 'AllTraffic',
   'CurrentInstanceCount': 1,
   'DesiredInstanceCount': 1,
   'ManagedInstanceScaling': {'Status': 'ENABLED',
    'MinInstanceCount': 1,
    'MaxInstanceCount': 2},
   'RoutingConfig': {'RoutingStrategy': 'LEAST_OUTSTANDING_REQUESTS'}}],
 'EndpointStatus': 'InService',
 'CreationTime': datetime.datetime(2024, 2, 26, 22, 26, 26, 189000, tzinfo=tzlocal()),
 'LastModifiedTime': datetime.datetime(2024, 2, 26, 22, 29, 17, 267000, tzinfo=tzlocal()),
 'ResponseMetadata': {'RequestId': 'c15a5b08-a007-438a-bf31-4b80e5a7cecd',
  'HTTPStatusCode': 200,
  'HTTPHeaders': {'x-amzn-requestid': 'c15a5b08-a007-438a-bf31-4b80e5a7cecd',
   'content-type': 'application/x-amz-json-1.1',
 

In [42]:
inference_component_name = f"{model_name}-ic"

In [45]:
print(f"Test inference component name: {inference_component_name}")

initial_copy_count = 1
max_copy_count_per_instance = 4  # will use later for autoscaling

variant_name = "AllTraffic"

min_memory_required_in_mb = 1024 
number_of_accelerator_devices_required = 8

sm_client.create_inference_component(
    InferenceComponentName = inference_component_name,
    EndpointName = endpoint_name,
    VariantName = variant_name,
    Specification={
        "ModelName": model_name,
        "StartupParameters": {
            "ModelDataDownloadTimeoutInSeconds": model_data_download_timeout_in_seconds,
            "ContainerStartupHealthCheckTimeoutInSeconds": container_startup_health_check_timeout_in_seconds,
        },
        "ComputeResourceRequirements": {
            "MinMemoryRequiredInMb": min_memory_required_in_mb,
            "NumberOfAcceleratorDevicesRequired": number_of_accelerator_devices_required,
        },
    },
    RuntimeConfig={
        "CopyCount": initial_copy_count,
    },
)

Test inference component name: falcon40b-lmi-vllm-2024-02-26-22-26-24-860-ic


{'InferenceComponentArn': 'arn:aws:sagemaker:us-west-2:461312420708:inference-component/falcon40b-lmi-vllm-2024-02-26-22-26-24-860-ic',
 'ResponseMetadata': {'RequestId': '6c9e67b4-fa2c-4ad8-b016-ef48bd64b028',
  'HTTPStatusCode': 200,
  'HTTPHeaders': {'x-amzn-requestid': '6c9e67b4-fa2c-4ad8-b016-ef48bd64b028',
   'content-type': 'application/x-amz-json-1.1',
   'content-length': '134',
   'date': 'Mon, 26 Feb 2024 23:15:35 GMT'},
  'RetryAttempts': 0}}

### This step can take ~ 10 min or longer so please be patient

In [46]:
sess.wait_for_inference_component(inference_component_name)

-----------------------------!

{'InferenceComponentName': 'falcon40b-lmi-vllm-2024-02-26-22-26-24-860-ic',
 'InferenceComponentArn': 'arn:aws:sagemaker:us-west-2:461312420708:inference-component/falcon40b-lmi-vllm-2024-02-26-22-26-24-860-ic',
 'EndpointName': 'falcon40b-lmi-vllm-2024-02-26-22-26-24-860-endpoint',
 'EndpointArn': 'arn:aws:sagemaker:us-west-2:461312420708:endpoint/falcon40b-lmi-vllm-2024-02-26-22-26-24-860-endpoint',
 'VariantName': 'AllTraffic',
 'Specification': {'ModelName': 'falcon40b-lmi-vllm-2024-02-26-22-26-24-860',
  'StartupParameters': {'ModelDataDownloadTimeoutInSeconds': 1200,
   'ContainerStartupHealthCheckTimeoutInSeconds': 1200},
  'ComputeResourceRequirements': {'NumberOfAcceleratorDevicesRequired': 8.0,
   'MinMemoryRequiredInMb': 1024}},
 'RuntimeConfig': {'DesiredCopyCount': 1, 'CurrentCopyCount': 1},
 'CreationTime': datetime.datetime(2024, 2, 26, 23, 15, 35, 574000, tzinfo=tzlocal()),
 'LastModifiedTime': datetime.datetime(2024, 2, 26, 23, 25, 25, 982000, tzinfo=tzlocal()),
 'Infe

## Step 3: Invoke the Endpoint

In [47]:
%%time

response_model = smr_client.invoke_endpoint(
    EndpointName = endpoint_name,
    InferenceComponentName = inference_component_name,
    Body = json.dumps(
        {
            "inputs": "What is AWS re:invent? Where does it happen every year?", 
            "parameters": {"max_new_tokens": 256, "do_sample": True}
        }
    ),
    ContentType = "application/json",
)

response_model["Body"].read().decode("utf8")

CPU times: user 15.4 ms, sys: 57 µs, total: 15.4 ms
Wall time: 8.6 s


'{"generated_text": " How to get registered for the same.\\nAWS re:Invent is a four-day conference in Las Vegas, where Amazon web services will announce new features, evidence for new technologies, products and solutions to solve the challenges and prepare their users for future demands.\\nIf anyone is willing to attend and doesn\'t know the process to get registered, if you can take a look at, once you land on this website follow the steps mentioned there. If you have any queries related to the AWS solution then you can also visit Amoli website here you will find solutions related to AWS. Amoli can also help you in implementing Digital Transformation in modern era with the help of AWS solutions.\\n4 Answers\\nAWS re:Invent is a four-day conference in Las Vegas, where Amazon web services will announce new features (or showcase new... The event acquires millions of dollars worth of annual attendee business from restaurants, hotels, and other tourism services benefiting the local economy

## (Optional) Step 4: Define and test autoscaling policy

We define the scaling policy for desired copy count of inference component instances.

**Please note:**
- SageMaker endpoint will have to perform JIT compilation for every IC we start
- We created our endpoint with managed instance scaling thus SageMaker endpoint will start additional instances automatically to satisfy the requested number of inference component instances

In [None]:
aas_client = sess.boto_session.client("application-autoscaling")

In [None]:
max_copy_count = max_copy_count_per_instance * max_instance_count
print(f"Initial copy count: {initial_copy_count}")
print(f"Max copy county: {max_copy_count}")

In [None]:
# Autoscaling parameters
resource_id = f"inference-component/{inference_component_name}"
service_namespace = "sagemaker"
scalable_dimension = "sagemaker:inference-component:DesiredCopyCount"

In [None]:
aas_client.register_scalable_target(
    ServiceNamespace=service_namespace,
    ResourceId=resource_id,
    ScalableDimension=scalable_dimension,
    MinCapacity=initial_copy_count,
    MaxCapacity=max_copy_count,
)

In [None]:
# Sanity check
#aas_client.describe_scalable_targets(
#    ServiceNamespace=service_namespace,
#    ResourceIds=[resource_id],
#    ScalableDimension=scalable_dimension,
#)

In [None]:
#
# Scalable policy
#
aas_client.put_scaling_policy(
    PolicyName=endpoint_name,
    PolicyType="TargetTrackingScaling",
    ServiceNamespace=service_namespace,
    ResourceId=resource_id,
    ScalableDimension=scalable_dimension,
    TargetTrackingScalingPolicyConfiguration={
        "PredefinedMetricSpecification": {
            "PredefinedMetricType": "SageMakerInferenceComponentInvocationsPerCopy",
        },
        "TargetValue": 1,  # you need to adjust this value based on your use case
        "ScaleInCooldown": 60,
        "ScaleOutCooldown": 300,
        "DisableScaleIn": False
    },
)

In [None]:
# Sanity check
#aas_client.describe_scaling_policies(
#    PolicyNames=[endpoint_name],
#    ServiceNamespace=service_namespace,
#    ResourceId=resource_id,
#    ScalableDimension=scalable_dimension,
#)

In [None]:
#
# Initial state
#
endpoint_desc = sm_client.describe_endpoint(EndpointName=endpoint_name)
print(f"EndpointStatus: {endpoint_desc['EndpointStatus']}")
print(f"\tCurrentInstanceCount: {endpoint_desc['ProductionVariants'][0]['CurrentInstanceCount']}")
print(f"\tDesiredInstanceCount: {endpoint_desc['ProductionVariants'][0]['DesiredInstanceCount']}")

ic_desc = sm_client.describe_inference_component(InferenceComponentName=inference_component_name)
print(f"InferenceComponentStatus: {ic_desc['InferenceComponentStatus']}")
print(f"\tCurrentCopyCount: {ic_desc['RuntimeConfig']['CurrentCopyCount']}")
print(f"\tDesiredCopyCount: {ic_desc['RuntimeConfig']['DesiredCopyCount']}")

In [None]:
#
# Test the timing only
#
#sm_client.update_inference_component(
#    InferenceComponentName = inference_component_name,
#    RuntimeConfig = {
#        'CopyCount': 6
#    }
#)

In [None]:
!pip install --quiet locust

In [None]:
!cat distributed.sh #adjust users and workers to increase traffic, users are a multiple of the workers in locust

In [None]:
#
# We recommend you run this command in a terminal (it generates a lot of output)
#
#%%bash -s "$endpoint_name/$inference_component_name"
#./distributed.sh $1

In [None]:
# Test
# define some helper functions
import time
from dataclasses import dataclass
from datetime import datetime

@dataclass
class AutoscalingStatus:
    status_name: str  # endpoint status or inference component status
    start_time: datetime  # when was the status changed
    current_instance_count: int
    desired_instance_count: int
    current_copy_count: int
    desired_copy_count: int

Helper code to illustrate scaling out and scaling in timings.
Stop the cell execution when done.

In [None]:
statuses = []

while True:
    endpoint_desc = sm_client.describe_endpoint(EndpointName=endpoint_name)
    status = endpoint_desc['EndpointStatus']
    current_instance_count = endpoint_desc['ProductionVariants'][0]['CurrentInstanceCount']
    desired_instance_count = endpoint_desc['ProductionVariants'][0]['DesiredInstanceCount']
    ic_desc = sm_client.describe_inference_component(InferenceComponentName=inference_component_name)
    ic_status = ic_desc['InferenceComponentStatus']
    current_copy_count = ic_desc['RuntimeConfig']['CurrentCopyCount']
    desired_copy_count = ic_desc['RuntimeConfig']['DesiredCopyCount']
    status_name = f"{status}_{ic_status}"
    if not statuses or statuses[-1].status_name != status_name:
        statuses.append(AutoscalingStatus(
            status_name=status_name,
            start_time=datetime.utcnow(),
            current_instance_count=current_instance_count,
            desired_instance_count=desired_instance_count,
            current_copy_count=current_copy_count,
            desired_copy_count=desired_copy_count,
        ))
        print(statuses[-1])
    time.sleep(1)

## Step 5: Autoscaling cleanup

In [None]:
aas_client.delete_scaling_policy(
    PolicyName=endpoint_name,
    ServiceNamespace=service_namespace,
    ResourceId=resource_id,
    ScalableDimension=scalable_dimension,
)

In [None]:
aas_client.deregister_scalable_target(
    ServiceNamespace=service_namespace,
    ResourceId=resource_id,
    ScalableDimension=scalable_dimension,
)

## Step 6: Clean up the environment

In [48]:
sess.delete_inference_component(inference_component_name, wait = True)

In [49]:
sess.delete_endpoint(endpoint_name)
sess.delete_endpoint_config(endpoint_config_name)
sess.delete_model(model_name)

In [None]:
#
# Helper code - find my IP to use in locust_script.py (localhost does not work)
#

#import socket
#s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
#s.connect(("8.8.8.8", 80))
#print(s.getsockname()[0])
#s.close()