# **Deploy Llama 3.x through vLLM on SageMaker Endpoint using LMI container from DJL.**

## Use DJL with the SageMaker Python SDK
- SageMaker Python SDK를 사용하면 Deep Java Library를 이용하여 Amazon SageMaker에서 모델을 호스팅할 수 있습니다. <BR>
- Deep Java Library (DJL) Serving은 DJL이 제공하는 고성능 범용 독립형 모델 서빙 솔루션입니다. DJL Serving은 다양한 프레임워크로 학습된 모델을 로드하는 것을 지원합니다. <BR>
- SageMaker Python SDK를 사용하면 DeepSpeed와 HuggingFace Accelerate와 같은 백엔드를 활용하여 DJL Serving으로 대규모 모델을 호스팅할 수 있습니다. <BR>
- DJL Serving의 지원 버전에 대한 정보는 [AWS 문서](https://docs.aws.amazon.com/deep-learning-containers/latest/devguide/deep-learning-containers-images.html)를 참조하십시오. <BR>
- 최신 지원 버전을 사용하는 것을 권장합니다. 왜냐하면 그곳에 우리의 개발 노력이 집중되어 있기 때문입니다. <BR>
- SageMaker Python SDK 사용에 대한 일반적인 정보는 [SageMaker Python SDK 사용하기](https://sagemaker.readthedocs.io/en/v2.139.0/overview.html#using-the-sagemaker-python-sdk)를 참조하십시오.
    
REF: [BLOG] [Deploy LLM with vLLM on SageMaker in only 13 lines of code](https://mrmaheshrajput.medium.com/deploy-llm-with-vllm-on-sagemaker-in-only-13-lines-of-code-1601f780c0cf)

## 1. Depoly model on SageMaker

In [1]:
import boto3
import sagemaker
from sagemaker import get_execution_role

sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /home/ec2-user/.config/sagemaker/config.yaml


- [Avalable DLC (Deep Learning Containers)](https://github.com/aws/deep-learning-containers/blob/master/available_images.md)

In [2]:
role = get_execution_role()
region = boto3.Session().region_name
sagemaker_session = sagemaker.session.Session()
sm_client = boto3.client("sagemaker", region_name=region)
sm_runtime_client = boto3.client("sagemaker-runtime")
sm_autoscaling_client = boto3.client("application-autoscaling")

### Setup Configuration


 - [[DOC] DJL for serving](https://docs.djl.ai/master/docs/serving/serving/docs/lmi/index.html)

In [5]:
%store -r model_s3_path
model_id = model_s3_path["S3DataSource"]["S3Uri"]
print("model_id: ", model_id)

model_id:  s3://sagemaker-us-west-2-419974056037/llama3-2-8b-naver-news-2024-10-01-03-30-2024-10-01-03-30-02-918/output/model/


In [6]:
container_uri = sagemaker.image_uris.retrieve(
    framework="djl-lmi", version="0.29.0", region=region
)
container_uri

'763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi11.0.0-cu124'

In [8]:
instance_type = "ml.g5.12xlarge"
container_startup_health_check_timeout = 900

endpoint_name = sagemaker.utils.name_from_base("Meta-Llama-3-2-8B-Instruct")

print (f'container_uri: {container_uri}')
print (f'container_startup_health_check_timeout: {container_startup_health_check_timeout}')
print (f'instance_type: {instance_type}')
print (f'endpoint_name: {endpoint_name}')

container_uri: 763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.29.0-lmi11.0.0-cu124
container_startup_health_check_timeout: 900
instance_type: ml.g5.12xlarge
endpoint_name: Meta-Llama-3-2-8B-Instruct-2024-10-01-03-34-09-302


### Creat model with env variables


- Target model: [DeepSeek-Coder-V2-Light-Instruct](https://huggingface.co/deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct)

- **[Backend for attention computation in vLLM](https://docs.vllm.ai/en/latest/serving/env_vars.html)**
    - Available options:
        - "TORCH_SDPA": use torch.nn.MultiheadAttention
        - "FLASH_ATTN": use FlashAttention
        - "XFORMERS": use XFormers
        - "ROCM_FLASH": use ROCmFlashAttention
        - "FLASHINFER": use flashinfer

- **'"OPTION_DISABLE_FLASH_ATTN": "false"'** is for HF Accelerate with Seq-Scheduler
    - It will be ignored when using vLLM beckend

- [[DOC] DJL-Container and Model Configurations (info. about properties)](https://docs.djl.ai/master/docs/serving/serving/docs/lmi/deployment_guide/configurations.html)

In [9]:
deploy_env = {
    "HF_MODEL_ID": model_id,
    "OPTION_ROLLING_BATCH": "vllm",
    "OPTION_TENSOR_PARALLEL_DEGREE": "max",
    "OPTION_MAX_ROLLING_BATCH_SIZE": "2",
    "OPTION_DTYPE":"fp16",
    "OPTION_TRUST_REMOTE_CODE": "true",
    "OPTION_MAX_MODEL_LEN": "8192",
    "VLLM_ATTENTION_BACKEND": "XFORMERS",
    #"OPTION_DISABLE_FLASH_ATTN": "false", ## HF Accelerate with Seq-Scheduler
    "HF_TOKEN": "<your token>"
}

In [10]:
model = sagemaker.Model(
    image_uri=container_uri, 
    role=role,
    env=deploy_env
)

### Deploy model

In [11]:
model.deploy(
    instance_type=instance_type,
    initial_instance_count=1,
    endpoint_name=endpoint_name,
    container_startup_health_check_timeout=container_startup_health_check_timeout
)

-------------!

## 2. Invocation (Generate Text using the endpoint)

### Get a predictor for your endpoint

In [12]:
predictor = sagemaker.Predictor(
    endpoint_name=endpoint_name,
    sagemaker_session=sagemaker_session,
    serializer=sagemaker.serializers.JSONSerializer(),
    deserializer=sagemaker.deserializers.JSONDeserializer(),
)

### Make a prediction with your endpoint

- **question candidates**
    - write a quick sort algorithm in python.
    - Write a piece of quicksort code in C++.

In [13]:
outputs = predictor.predict(
    {
        "inputs": "write a quick sort algorithm in python and description",
        "parameters": {"do_sample": True, "max_new_tokens": 2048},
    }
)

print(outputs["generated_text"])

 of each section

```python
def quicksort(arr):
```
**Section 1:**  `def quicksort(arr):` 
This is the function definition for quicksort. `arr` is the input list that we want to sort.

```python
if len(arr) <= 1:
```
*   This section handles the base case for the recursion. If the length of the array is 1 or less, it's already sorted, so we can simply return it.

```python
return arr
```
*   This line passes the test when the array has just one element, as it will return the same array and is considered sorted according to the problem definition.

```python
else:
```
**Section 2:** `else:`
This section executes if the length of the array is more than 1. It starts to explore the possible partitioning.

```python
pivot = arr[len(arr) // 2]
```
```

*   We select the third element (technically the (n // 2) th one for even number lengths) as the pivot element. This pivot is used to split the array around it. The choice of this pivot can impact performance.

```python
left = [x for x in arr

- **With chat template**
    - [DJL Chat Completions API Schema](https://docs.djl.ai/master/docs/serving/serving/docs/lmi/user_guides/chat_input_output_schema.html)

In [14]:
chat = [
    {"role": "user", "content": "Hello, how are you?"},
    {"role": "assistant", "content": "I'm doing great. How can I help you today?"},
    {"role": "user", "content": "I'd like to show off how chat templating works! anyway, write a quick sort algorithm in python and description"},
]

result = predictor.predict(
    {"messages": chat, "max_tokens": 1024}
)
result

{'id': 'chatcmpl-140116466701296',
 'object': 'chat.completion',
 'created': 1727754102,
 'choices': [{'index': 0,
   'message': {'role': 'assistant',
    'content': ' Sounds like a fun challenge!\n\n**Quick Sort Algorithm in Python**\n\nHere\'s a quick sort algorithm implemented in Python:\n\n```python\ndef quick_sort(arr):\n    """\n    Sorts an array using the quick sort algorithm.\n\n    Args:\n        arr (list): The input array to be sorted.\n\n    Returns:\n        list: A new sorted array.\n    """\n    if len(arr) <= 1:\n        return arr\n    pivot = arr[len(arr) // 2]\n    left = [x for x in arr if x < pivot]\n    middle = [x for x in arr if x == pivot]\n    right = [x for x in arr if x > pivot]\n    return quick_sort(left) + middle + quick_sort(right)\n```\n\n**Description**\n\nThe quick sort algorithm is a divide-and-conquer algorithm that works by selecting a pivot element from the input array and partitioning the other elements into two sub-arrays, according to whether 

## 3. Streaming output from the endpoint


In [15]:
import json
import random 

In [16]:
# 다양한 코딩 태스크를 위한 프롬프트 리스트
prompts = [
    "write a quick sort algorithm in python.",
    "Write a Python function to implement a binary search algorithm.",
    "Create a JavaScript function to flatten a nested array.",
    "Implement a simple REST API using Flask in Python.",
    "Write a SQL query to find the top 5 customers by total purchase amount.",
    "Create a React component for a todo list with basic CRUD operations.",
    "Implement a depth-first search algorithm for a graph in C++.",
    "Write a bash script to find and delete files older than 30 days.",
    "Create a Python class to represent a deck of cards with shuffle and deal methods.",
    "Write a regular expression to validate email addresses.",
    "Implement a basic CI/CD pipeline using GitHub Actions."
]

def generate_payload():
    # 랜덤하게 프롬프트 선택
    prompt = random.choice(prompts)
    
    # JSON 페이로드 생성
    body = {
        "inputs": prompt,
        "parameters": {
            "max_new_tokens": 400,
            # "return_full_text": False  # This does not work with Phi3
        },
        "stream": True,
    }
    
    # JSON을 문자열로 변환하고 bytes로 인코딩
    return json.dumps(body).encode('utf-8')

In [17]:
%%time
# Invoke the endpoint
resp = sm_runtime_client.invoke_endpoint_with_response_stream(
    EndpointName=endpoint_name, 
    # Body=json.dumps(body), 
    Body=generate_payload(), 
    
    ContentType="application/json"
)
print("Generated response:")
print("-" * 40)

buffer = ""
for event in resp['Body']:
    if 'PayloadPart' in event:
        chunk = event['PayloadPart']['Bytes'].decode()
        buffer += chunk
        try:
            # Try to parse the buffer as JSON
            data = json.loads(buffer)
            if 'token' in data:
                print(data['token']['text'], end='', flush=True)
            buffer = ""  # Clear the buffer after successful parsing
        except json.JSONDecodeError:
            # If parsing fails, keep the buffer for the next iteration
            pass

print("\n" + "-" * 40)

Generated response:
----------------------------------------
 The API will have two endpoints: one to create a new user and another to retrieve all users.

### Requirements

* Python 3.8+
* Flask 2.0+

### API Endpoints

#### 1. Create a new user

*   **Endpoint:** `/users`
*   **Method:** `POST`
*   **Request Body:** `{"name": "string", "email": "string"}`
*   **Response:** `{"id": int, "name": "string", "email": "string"}`

#### 2. Retrieve all users

*   **Endpoint:** `/users`
*   **Method:** `GET`
*   **Response:** `[{"id": int, "name": "string", "email": "string"}, {"id": int, "name": "string", "email": "string"}]`

### Code

```python
from flask import Flask, request, jsonify

app = Flask(__name__)

# In-memory data store for demonstration purposes
users = [
    {"id": 1, "name": "John Doe", "email": "john@example.com"},
    {"id": 2, "name": "Jane Doe", "email": "jane@example.com"},
]

# Create a new user
@app.route('/users', methods=['POST'])
def create_user():
    """Create a 

- **With chat template**
    - [DJL Chat Completions API Schema](https://docs.djl.ai/master/docs/serving/serving/docs/lmi/user_guides/chat_input_output_schema.html)

In [18]:
# 다양한 코딩 태스크를 위한 프롬프트 리스트
chat = [
    {"role": "user", "content": "Hello, how are you?"},
    {"role": "assistant", "content": "I'm doing great. How can I help you today?"},
    {"role": "user", "content": "I'd like to show off how chat templating works! anyway, write a quick sort algorithm in python and description"},
]

result = predictor.predict(
    {"messages": chat, "max_tokens": 1024}
)

def generate_payload():
    # 랜덤하게 프롬프트 선택
    prompt = random.choice(prompts)
    
    # JSON 페이로드 생성
    body = {
        "messages": chat,
        "max_tokens": 1024,
        "stream": True,
    }
    
    # JSON을 문자열로 변환하고 bytes로 인코딩
    return json.dumps(body).encode('utf-8')

In [19]:
%%time
# Invoke the endpoint
resp = sm_runtime_client.invoke_endpoint_with_response_stream(
    EndpointName=endpoint_name, 
    # Body=json.dumps(body), 
    Body=generate_payload(), 
    
    ContentType="application/json"
)
print("Generated response:")
print("-" * 40)

buffer = ""
for event in resp['Body']:
    if 'PayloadPart' in event:
        chunk = event['PayloadPart']['Bytes'].decode()
        buffer += chunk
        try:
            # Try to parse the buffer as JSON
            data = json.loads(buffer)
            if 'choices' in data:
                print(data['choices'][0]['delta']['content'], end='', flush=True)
            buffer = ""  # Clear the buffer after successful parsing
        except json.JSONDecodeError:
            # If parsing fails, keep the buffer for the next iteration
            pass

print("\n" + "-" * 40)

Generated response:
----------------------------------------
 Ah, excellent choice!

**Quick Sort Algorithm in Python**

**Implementation**
----------------
```python
def quick_sort(arr):
    """
    Recursively sorts an array using the quick sort algorithm.

    Args:
        arr (list): The array to be sorted.

    Returns:
        list: The sorted array.
    """
    if len(arr) <= 1:
        # Base case: If the array has one or zero elements, it's already sorted.
        return arr
    pivot = arr[len(arr) // 2]
    # Divide the array into three lists: elements less than the pivot,
    # elements equal to the pivot, and elements greater than the pivot.
    left = [x for x in arr if x < pivot]
    middle = [x for x in arr if x == pivot]
    right = [x for x in arr if x > pivot]
    # Recursively sort the left and right lists, and concatenate the results.
    return quick_sort(left) + middle + quick_sort(right)

# Example usage
arr = [5, 2, 9, 1, 7, 3]
sorted_arr = quick_sort(arr)
pri

## 4. Real-time Inference Autoscaling


In [None]:
import pprint

In [None]:
resp = sm_client.describe_endpoint(EndpointName=endpoint_name)

# SageMaker expects resource id to be provided with the following structure
resource_id = f"endpoint/{endpoint_name}/variant/{resp['ProductionVariants'][0]['VariantName']}"

# Scaling configuration
scaling_config_response = sm_autoscaling_client.register_scalable_target(
    ServiceNamespace="sagemaker",
    ResourceId=resource_id,
    ScalableDimension="sagemaker:variant:DesiredInstanceCount", 
    MinCapacity=1,
    MaxCapacity=2
)

In [None]:
# Create Scaling Policy
policy_name = f"scaling-policy-{endpoint_name}"
scaling_policy_response = sm_autoscaling_client.put_scaling_policy(
    PolicyName=policy_name,
    ServiceNamespace="sagemaker",
    ResourceId=resource_id,
    ScalableDimension="sagemaker:variant:DesiredInstanceCount",
    PolicyType="TargetTrackingScaling",
    TargetTrackingScalingPolicyConfiguration={
        "TargetValue": 5.0, # Target for avg invocations per minutes
        "PredefinedMetricSpecification": {
            "PredefinedMetricType": "SageMakerVariantInvocationsPerInstance",
        },
        "ScaleInCooldown": 600, # Duration in seconds until scale in
        "ScaleOutCooldown": 60 # Duration in seconds between scale out
    }
)

In [None]:
response = sm_autoscaling_client.describe_scaling_policies(ServiceNamespace="sagemaker")

pp = pprint.PrettyPrinter(indent=4, depth=4)
for i in response["ScalingPolicies"]:
    pp.pprint(i["PolicyName"])
    print("")
    if("TargetTrackingScalingPolicyConfiguration" in i):
        pp.pprint(i["TargetTrackingScalingPolicyConfiguration"])

In [None]:
# 다양한 코딩 태스크를 위한 프롬프트 리스트
prompts = [
    "write a quick sort algorithm in python.",
    "Write a Python function to implement a binary search algorithm.",
    "Create a JavaScript function to flatten a nested array.",
    "Implement a simple REST API using Flask in Python.",
    "Write a SQL query to find the top 5 customers by total purchase amount.",
    "Create a React component for a todo list with basic CRUD operations.",
    "Implement a depth-first search algorithm for a graph in C++.",
    "Write a bash script to find and delete files older than 30 days.",
    "Create a Python class to represent a deck of cards with shuffle and deal methods.",
    "Write a regular expression to validate email addresses.",
    "Implement a basic CI/CD pipeline using GitHub Actions."
]

def generate_payload():
    # 랜덤하게 프롬프트 선택
    prompt = random.choice(prompts)
    
    # JSON 페이로드 생성
    body = {
        "inputs": prompt,
        "parameters": {
            "max_new_tokens": 400,
            # "return_full_text": False  # This does not work with Phi3
        },
        "stream": True,
    }
    
    # JSON을 문자열로 변환하고 bytes로 인코딩
    return json.dumps(body).encode('utf-8')

In [None]:
%%time
import time

request_duration = 250
end_time = time.time() + request_duration
print(f"Endpoint will be tested for {request_duration} seconds")

while time.time() < end_time:
    payload = generate_payload()
    # Invoke the endpoint
    response = sm_runtime_client.invoke_endpoint_with_response_stream(
        EndpointName=endpoint_name, 
        # Body=json.dumps(body), 
        Body = payload,
        ContentType="application/json"
    )

In [None]:
# Check the instance counts after the endpoint gets more load
response = sm_client.describe_endpoint(EndpointName=endpoint_name)
endpoint_status = response["EndpointStatus"]
request_duration = 250
end_time = time.time() + request_duration
print(f"Waiting for Instance count increase for a max of {request_duration} seconds. Please re run this cell in case the count does not change")
while time.time() < end_time:
    response = sm_client.describe_endpoint(EndpointName=endpoint_name)
    endpoint_status = response["EndpointStatus"]
    instance_count = response["ProductionVariants"][0]["CurrentInstanceCount"]
    print(f"Status: {endpoint_status}")
    print(f"Current Instance count: {instance_count}")
    if (endpoint_status=="InService") and (instance_count>1):
        break
    else:
        time.sleep(15)

In [None]:
# Delete model
sm_client.delete_model(ModelName=model_name)

# Delete endpoint configuration
sm_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)

# Delete endpoint
sm_client.delete_endpoint(EndpointName=endpoint_name)