# Deploy Llama2-7b on Amazon SageMaker using LMI container

## Resources
- [Deep Learning Containers](https://docs.aws.amazon.com/sagemaker/latest/dg/large-model-inference-dlc.html)
- [Deep Java Library - Large Model Inference](https://docs.djl.ai/docs/serving/serving/docs/large_model_inference.html)

## Step 1: Setup

In [None]:
%pip install --upgrade --quiet sagemaker

In [None]:
import json
import boto3
import sagemaker

In [None]:
role = sagemaker.get_execution_role()  # execution role for the endpoint
sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs
bucket = sess.default_bucket()  # bucket to house artifacts
region = sess._region_name  # region name of the current SageMaker Studio environment

sm_client = boto3.client("sagemaker")  # client to intreract with SageMaker
smr_client = boto3.client("sagemaker-runtime")  # client to intreract with SageMaker Endpoints

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sess.default_bucket()}")
print(f"sagemaker session region: {sess.boto_region_name}")
print(f"boto3 version: {boto3.__version__}")
print(f"sagemaker version: {sagemaker.__version__}")

## Step 2: Endpoint Deployment (LMI - vLLM)

In [None]:
version = "0.27.0"
deepspeed_image = sagemaker.image_uris.retrieve(
    "djl-deepspeed", region = region, version = version
)
print(f"DeepSpeed image for vLLM is ----> {deepspeed_image}")

In [None]:
#
# vLLM with DeepSpeed 
#

instance_type = "ml.g5.2xlarge"
model_name = "Llama-2-7b-chat-hf-vLLM"

# vLLM config
vllm_config = {
    "SERVING_LOAD_MODELS": "test::Python=/opt/ml/model",
    "OPTION_MODEL_ID": "TheBloke/Llama-2-7B-Chat-fp16",
    "OPTION_ROLLING_BATCH": "vllm",
    "OPTION_TENSOR_PARALLEL_DEGREE": "max",
    "OPTION_MAX_ROLLING_BATCH_SIZE": "32",
    "OPTION_MAX_INPUT_LEN": "1024",
    "OPTION_MAX_OUTPUT_LEN": "2048",
    "OPTION_MAX_MODEL_LEN": "2048",
    "OPTION_DTYPE": "fp16",
}

image_uri = deepspeed_image
env = vllm_config

create_model_response = sm_client.create_model(
    ModelName = model_name,
    ExecutionRoleArn = role,
    PrimaryContainer = {
        "Image": image_uri,
        "Environment": env,
    }
)
model_arn = create_model_response["ModelArn"]
print(f"Created Model: {model_arn}")

In [None]:
endpoint_config_name = f"{model_name}-EP-config"
health_check_timeout = 600

endpoint_config_response = sm_client.create_endpoint_config(
    EndpointConfigName = endpoint_config_name,
    ProductionVariants = [
        {
            "VariantName": "variant1",
            "ModelName": model_name,
            "InstanceType": instance_type,
            "InitialInstanceCount": 1,
            "ContainerStartupHealthCheckTimeoutInSeconds": health_check_timeout,
            "RoutingConfig": {
                'RoutingStrategy': 'LEAST_OUTSTANDING_REQUESTS'
            },
        },
    ],
)
endpoint_config_response

In [None]:
#
# Create endpoint config
#
endpoint_name = f"{model_name}-EP"

create_endpoint_response = sm_client.create_endpoint(
    EndpointName = endpoint_name, EndpointConfigName = endpoint_config_name
)
print(f"Created Endpoint: {create_endpoint_response['EndpointArn']}")

In [None]:
#
# Using helper function to wait for the endpoint to be ready
#
sess.wait_for_endpoint(endpoint_name)

## Step 3: Run Inference

In [None]:
#
# define payload
#
# define payload
prompt = """You are an helpful Assistant, called Jarvis. Knowing everyting about AWS.

User: Can you tell me something about Amazon SageMaker?
Jarvis:"""

params = { "max_new_tokens": 256, "temperature": 0.1}

# hyperparameters for llm (remove "\nUser:" from stop conditions)
payload = {
  "inputs": prompt,
  "parameters": params
}

response_model = smr_client.invoke_endpoint(
    EndpointName = endpoint_name,
    Body = json.dumps(payload),
    ContentType = "application/json",
)

assistant = json.loads(response_model["Body"].read().decode("utf8"))["generated_text"]
print(assistant)

## Step 3.2: Test inference

In [None]:
# 
# Calculate runtime performance
# 
import time
import numpy as np

# define payload
prompt = """You are an helpful Assistant, called Jarvis. Knowing everyting about AWS.

User: Can you tell me something about Amazon SageMaker?
Jarvis:"""

params = { "max_new_tokens": 256, "temperature": 0.1}

# hyperparameters for llm (remove "\nUser:" from stop conditions)
payload = {
  "inputs": prompt,
  "parameters": params
}

results = []
for i in range(0, 10):
    start = time.time()
    response_model = smr_client.invoke_endpoint(
        EndpointName=endpoint_name,
        Body=json.dumps(payload),
        ContentType="application/json",
    )
    results.append((time.time() - start) * 1000)

print("\nPredictions for model latency: \n")
print("P95: " + str(np.percentile(results, 95)) + " ms")
print("P90: " + str(np.percentile(results, 90)) + " ms")
print("Average: " + str(np.average(results)) + " ms")

## Step 4: Cleanup

In [None]:
sess.delete_endpoint(endpoint_name)
sess.delete_endpoint_config(endpoint_config_name)
sess.delete_model(model_name)

## Step 5: Endpoint Deployment (LMI - TensorRT-LLM)

In [None]:
version = "0.27.0"
trtllm_image = sagemaker.image_uris.retrieve(
    "djl-tensorrtllm", region=region, version=version
)
print(f"TensorRT-LLM image is ----> {trtllm_image}")

In [None]:
#
instance_type = "ml.g5.16xlarge" # required for TensorRT-LLM Just In Time Compilation
model_name = "Llama-2-7b-chat-hf-TRTLLM"

trtllm_config = {
    "SERVING_LOAD_MODELS": "test::MPI=/opt/ml/model",
    "OPTION_MODEL_ID": "TheBloke/Llama-2-7B-Chat-fp16",
    "OPTION_TENSOR_PARALLEL_DEGREE": "max",
    "OPTION_ROLLING_BATCH": "trtllm",
    "OPTION_MAX_INPUT_LEN": "1024",
    "OPTION_MAX_OUTPUT_LEN": "2048",
    "OPTION_MAX_ROLLING_BATCH_SIZE": "2"
}

image_uri = trtllm_image
env = trtllm_config

create_model_response = sm_client.create_model(
    ModelName = model_name,
    ExecutionRoleArn = role,
    PrimaryContainer = {
        "Image": image_uri,
        "Environment": env,
    }
)
model_arn = create_model_response["ModelArn"]

print(f"Created Model: {model_arn}")

In [None]:
endpoint_config_name = f"{model_name}-EP-config"
health_check_timeout = 1200

endpoint_config_response = sm_client.create_endpoint_config(
    EndpointConfigName = endpoint_config_name,
    ProductionVariants = [
        {
            "VariantName": "variant1",
            "ModelName": model_name,
            "InstanceType": instance_type,
            "InitialInstanceCount": 1,
            "ContainerStartupHealthCheckTimeoutInSeconds": health_check_timeout,
            "RoutingConfig": {
                'RoutingStrategy': 'LEAST_OUTSTANDING_REQUESTS'
            },
        },
    ],
)
endpoint_config_response

In [None]:
#
# Create endpoint config
#
endpoint_name = f"{model_name}-EP"

create_endpoint_response = sm_client.create_endpoint(
    EndpointName = endpoint_name, EndpointConfigName = endpoint_config_name
)
print(f"Created Endpoint: {create_endpoint_response['EndpointArn']}")

In [None]:
#
# Using helper function to wait for the endpoint to be ready
#
sess.wait_for_endpoint(endpoint_name)

## Step 6: Run Inference (TensorRT-LLM)

In [None]:
#
# define payload
#
prompt = """You are an helpful Assistant, called Jarvis. Knowing everyting about AWS.

User: Can you tell me something about Amazon SageMaker?
Jarvis:"""

params = { "max_new_tokens": 256, "temperature": 0.1}

# hyperparameters for llm (remove "\nUser:" from stop conditions)
payload = {
  "inputs": prompt,
  "parameters": params
}

response_model = smr_client.invoke_endpoint(
    EndpointName = endpoint_name,
    Body = json.dumps(payload),
    ContentType = "application/json",
)

assistant = json.loads(response_model["Body"].read().decode("utf8"))["generated_text"]
print(assistant)

## Step 6.2: Test inference performance (TensorRT-LLM)

In [None]:
# 
# Calculate runtime performance
# 
import time
import numpy as np

# define payload
prompt = """You are an helpful Assistant, called Jarvis. Knowing everyting about AWS.

User: Can you tell me something about Amazon SageMaker?
Jarvis:"""

params = { "max_new_tokens": 256, "temperature": 0.1}

# hyperparameters for llm (remove "\nUser:" from stop conditions)
payload = {
  "inputs": prompt,
  "parameters": params
}

results = []
for i in range(0, 10):
    start = time.time()
    response_model = smr_client.invoke_endpoint(
        EndpointName=endpoint_name,
        Body=json.dumps(payload),
        ContentType="application/json",
    )
    results.append((time.time() - start) * 1000)

print("\nPredictions for model latency: \n")
print("P95: " + str(np.percentile(results, 95)) + " ms")
print("P90: " + str(np.percentile(results, 90)) + " ms")
print("Average: " + str(np.average(results)) + " ms")

## Step 7: Cleanup

In [None]:
sess.delete_endpoint(endpoint_name)
sess.delete_endpoint_config(endpoint_config_name)
sess.delete_model(model_name)

## Step 8. Configure quantized Llama2-7b-chat-hf model using vLLM

In [None]:
version = "0.27.0"
deepspeed_image = sagemaker.image_uris.retrieve(
    "djl-deepspeed", region=region, version=version
)
print(f"DeepSpeed image with vLLM is ----> {deepspeed_image}")

In [None]:
#
# vLLM with DeepSpeed 
#

instance_type = "ml.g5.2xlarge"
model_name = "Llama-2-7b-chat-hf-AWQ"

# vLLM config
vllm_config = {
    "SERVING_LOAD_MODELS": "test::Python=/opt/ml/model",
    "OPTION_MODEL_ID": "TheBloke/Llama-2-7B-Chat-AWQ",
    "OPTION_ROLLING_BATCH": "vllm",
    "OPTION_TENSOR_PARALLEL_DEGREE": "max",
    "OPTION_MAX_ROLLING_BATCH_SIZE": "2",
    "OPTION_MAX_INPUT_LEN": "1024",
    "OPTION_MAX_OUTPUT_LEN": "2048",
    "OPTION_MAX_MODEL_LEN": "2048",
    "OPTION_QUANTIZE": "awq",
    "OPTION_DTYPE": "auto",
}

image_uri = deepspeed_image
env = vllm_config

create_model_response = sm_client.create_model(
    ModelName = model_name,
    ExecutionRoleArn = role,
    PrimaryContainer = {
        "Image": image_uri,
        "Environment": env,
    }
)
model_arn = create_model_response["ModelArn"]
print(f"Created Model: {model_arn}")

In [None]:
endpoint_config_name = f"{model_name}-EP-config"
health_check_timeout = 1200

endpoint_config_response = sm_client.create_endpoint_config(
    EndpointConfigName = endpoint_config_name,
    ProductionVariants = [
        {
            "VariantName": "variant1",
            "ModelName": model_name,
            "InstanceType": instance_type,
            "InitialInstanceCount": 1,
            "ContainerStartupHealthCheckTimeoutInSeconds": health_check_timeout,
            "RoutingConfig": {
                'RoutingStrategy': 'LEAST_OUTSTANDING_REQUESTS'
            },
        },
    ],
)
endpoint_config_response

In [None]:
#
# Create endpoint config
#
endpoint_name = f"{model_name}-EP"

create_endpoint_response = sm_client.create_endpoint(
    EndpointName = endpoint_name, EndpointConfigName = endpoint_config_name
)
print(f"Created Endpoint: {create_endpoint_response['EndpointArn']}")

In [None]:
#
# Using helper function to wait for the endpoint to be ready
#
sess.wait_for_endpoint(endpoint_name)

## Step 9: Inference (vLLM - AWQ)

In [None]:
#
# define payload
#
prompt = """You are an helpful Assistant, called Jarvis. Knowing everyting about AWS.

User: Can you tell me something about Amazon SageMaker?
Jarvis:"""

params = { "max_new_tokens": 256, "temperature": 0.1}

# hyperparameters for llm (remove "\nUser:" from stop conditions)
payload = {
  "inputs": prompt,
  "parameters": params
}

response_model = smr_client.invoke_endpoint(
    EndpointName = endpoint_name,
    Body = json.dumps(payload),
    ContentType = "application/json",
)

assistant = json.loads(response_model["Body"].read().decode("utf8"))["generated_text"]
print(assistant)

In [None]:
# 
# Calculate runtime performance
# 
import time
import numpy as np

# define payload
prompt = """You are an helpful Assistant, called Jarvis. Knowing everyting about AWS.

User: Can you tell me something about Amazon SageMaker?
Jarvis:"""

params = { "max_new_tokens": 256, "temperature": 0.1}

# hyperparameters for llm (remove "\nUser:" from stop conditions)
payload = {
  "inputs": prompt,
  "parameters": params
}

results = []
for i in range(0, 10):
    start = time.time()
    response_model = smr_client.invoke_endpoint(
        EndpointName=endpoint_name,
        Body=json.dumps(payload),
        ContentType="application/json",
    )
    results.append((time.time() - start) * 1000)

print("\nPredictions for model latency: \n")
print("P95: " + str(np.percentile(results, 95)) + " ms")
print("P90: " + str(np.percentile(results, 90)) + " ms")
print("Average: " + str(np.average(results)) + " ms")

## Step 10. Cleanup

In [None]:
sess.delete_endpoint(endpoint_name)
sess.delete_endpoint_config(endpoint_config_name)
sess.delete_model(model_name)