# Mistral 7B TensorRT LLM Deployment
In this notebook we will take a look at deploying Mistral 7B utilizing the Amazon Large Model Inference Container powered by [Nvidia TRT LLM](https://github.com/NVIDIA/TensorRT-LLM). Note that you can also utilize different engines/backends please refer [here](https://docs.djl.ai/docs/serving/serving/docs/lmi/tuning_guides/trtllm_tuning_guide.html) for the different options and how you can tune your configuration. In this case with the TensorRT container, rolling batch is enabled by default. We will use a g5.12xlarge to apply a tensor parallel of 4, you can tune this depending on your hardware, Mistral 7B can also be hosted on a g5.2xlarge if opting for a smaller instance type.

### Table of Contents
- Setup & Endpoint Creation
- Load Testing & AutoScaling
- Cleanup

### Credits/References
- [LMI Configuration Documentation](https://docs.djl.ai/docs/serving/serving/docs/lmi/configurations_large_model_inference_containers.html)
- [DJL-Demo Samples](https://github.com/deepjavalibrary/djl-demo/tree/2a5152f578f5954b8b68acdee18eed4e2a75c81f/aws/sagemaker/large-model-inference/sample-llm)

## Setup & Endpoint Creation

In [None]:
import boto3
import sagemaker
from sagemaker import Model, image_uris, serializers, deserializers

role = sagemaker.get_execution_role()  # execution role for the endpoint
sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs
region = sess._region_name  # region name of the current SageMaker Studio environment
account_id = sess.account_id() 

In [None]:
%%writefile serving.properties
engine=MPI
option.tensor_parallel_degree=4
option.model_id=mistralai/Mistral-7B-v0.1
option.max_rolling_batch_size=16
option.rolling_batch=auto

In [None]:
%%sh
mkdir mymodel
mv serving.properties mymodel/
tar czvf mymodel.tar.gz mymodel/
rm -rf mymodel

In [None]:
# retreive TensorRT image
image_uri = image_uris.retrieve(
        framework="djl-tensorrtllm",
        region=sess.boto_session.region_name,
        version="0.26.0"
    )

In [None]:
s3_code_prefix = "large-model-lmi/code"
bucket = sess.default_bucket()  # bucket to house artifacts
code_artifact = sess.upload_data("mymodel.tar.gz", bucket, s3_code_prefix)
print(f"S3 Code or Model tar ball uploaded to --- > {code_artifact}")

model = Model(image_uri=image_uri, model_data=code_artifact, role=role)

In [None]:
# this step can take around ~ 10 minutes for creation, the model artifacts are being pulled from HF Hub
instance_type = "ml.g5.12xlarge"
endpoint_name = sagemaker.utils.name_from_base("lmi-trt-mistral")

model.deploy(initial_instance_count=1,
             instance_type=instance_type,
             endpoint_name=endpoint_name
            )

# our requests and responses will be in json format so we specify the serializer and the deserializer
predictor = sagemaker.Predictor(
    endpoint_name=endpoint_name,
    sagemaker_session=sess,
    serializer=serializers.JSONSerializer(),
)

In [None]:
# inference via sagemaker python SDK
predictor.predict(
    {"inputs": "Who is Roger Federer?"})

In [None]:
# boto3 inference sample
import json
runtime_client = boto3.client('sagemaker-runtime')
content_type = "application/json"
payload = {"inputs": "Who is Roger Federer?"} #optionally add any parameters for your model

# sample inference
response = runtime_client.invoke_endpoint(
    EndpointName=endpoint_name,
    ContentType=content_type,
    Body=json.dumps(payload))
result = json.loads(response['Body'].read().decode())['generated_text']
print(result)

## Load Testing & Enabling AutoScaling

### Load Testing

For Load Testing we'll use the open source Python framework: Locust. With Locust we can simulate concurrent users to generate traffic, for a deeper guide please refer to this [blog](https://aws.amazon.com/blogs/machine-learning/best-practices-for-load-testing-amazon-sagemaker-real-time-inference-endpoints/). For the test we have will two scripts we provide:

- <b>distributed.sh</b>: Can control users and workers to increase traffic (TPS)
- <b>locust_script.py</b>: Python script that defines task to test on, in this case it is our invoke_endpoint REST API call.

In [None]:
!pip install locust

In [None]:
!which locust

In [None]:
!cat distributed.sh #adjust users and workers to increase traffic, users are a multiple of the workers in locust

In [None]:
%%bash -s "$endpoint_name"
./distributed.sh $1

We can take a look at the Locust generated metrics to understand our end to end latency. We also take a look at the built-in CloudWatch metrics on the SageMaker UI to further understand our hardware utilization and invocation metrics (container latency, etc). We look at the maximum GPU Utilization (400% available with 4 GPUs) and the Invocations Per Minute generated by the Locust test. Understanding these metrics will help us provide a prescriptive AutoScaling policy. To understand further about CW Metrics integrated with SageMaker Real-Time Inference please refer to the following [documentation](https://docs.aws.amazon.com/sagemaker/latest/dg/monitoring-cloudwatch.html).

<div style="display: flex;">
    <img src="images/invocations.png" alt="Invocations" style="width: 50%;">
    <img src="images/hardware-utilization.png" alt="Hardware" style="width: 50%;">
</div>

In [None]:
import pandas as pd
locust_data = pd.read_csv('results_stats.csv')
for index, row in locust_data.head(n=2).iterrows():
     print(index, row)

### AutoScaling

You can also enable AutoScaling at an endpoint level on Amazon SageMaker. Before getting to AutoScaling it is recommended that you load test a single instance behind the endpoint, this will help you determine how much you are getting out of a singular instance. One this has been derived and the appropriate instance is chosen you can determine your scaling policy with Managed AutoScaling. For a deeper dive blog into AutoScaling with SageMaker Inference, refer to this [blog](https://towardsdatascience.com/autoscaling-sagemaker-real-time-endpoints-b1b6e6731c59). <b>Please also ensure that you have the necessary limits request for the scaling you set for your endpoint. In this case 4 g5.12xlarge instances are needed.</b>

We will work with setting up a Managed AutoScaling policy via Application AutoScaling using the Boto3 SDK. We should see this reflected in the SageMaker Endpoint UI as well:

![autoscaling](images/autoscaling-setup.png)

In [None]:
# AutoScaling client
asg = boto3.client('application-autoscaling')

# Resource type is variant and the unique identifier is the resource ID.
# default VariantName is AllTraffic adjust for your use-case
resource_id=f"endpoint/{predictor.endpoint_name}/variant/AllTraffic"

# scaling configuration
response = asg.register_scalable_target(
    ServiceNamespace='sagemaker', #
    ResourceId=resource_id,
    ScalableDimension='sagemaker:variant:DesiredInstanceCount', 
    MinCapacity=1,
    MaxCapacity=4
)

#Target Scaling
response = asg.put_scaling_policy(
    PolicyName=f'Request-ScalingPolicy-{endpoint_name}',
    ServiceNamespace='sagemaker',
    ResourceId=resource_id,
    ScalableDimension='sagemaker:variant:DesiredInstanceCount',
    PolicyType='TargetTrackingScaling',
    TargetTrackingScalingPolicyConfiguration={
        'TargetValue': 5.0, # Threshold, 5 requests in a minute
        'PredefinedMetricSpecification': {
            'PredefinedMetricType': 'SageMakerVariantInvocationsPerInstance',
        },
        'ScaleInCooldown': 300, # duration until scale in
        'ScaleOutCooldown': 60 # duration between scale out
    }
)

Let's send requests for 15 minutes to see our hardware scale up as we defined. We should see our endpoint updating to four instances.

![Updating Endpoint](images/updating-endpoint.png)
![Updated Endpoint](images/updated-endpoint.png)

In [None]:
import time
request_duration = 60 * 15 # 15 minutes
end_time = time.time() + request_duration
print(f"test will run for {request_duration} seconds")
while time.time() < end_time:
    response = runtime_client.invoke_endpoint(
        EndpointName=endpoint_name,
        ContentType=content_type,
        Body=json.dumps(payload))

## Cleanup

In [None]:
predictor.delete_endpoint()