# **Deploy DeepSeek-Coder-V2 with vLLM on SageMaker Endpoint using LMI container from DJL.**

## Use DJL with the SageMaker Python SDK
- SageMaker Python SDK를 사용하면 Deep Java Library를 이용하여 Amazon SageMaker에서 모델을 호스팅할 수 있습니다. <BR>
- Deep Java Library (DJL) Serving은 DJL이 제공하는 고성능 범용 독립형 모델 서빙 솔루션입니다. DJL Serving은 다양한 프레임워크로 학습된 모델을 로드하는 것을 지원합니다. <BR>
- SageMaker Python SDK를 사용하면 DeepSpeed와 HuggingFace Accelerate와 같은 백엔드를 활용하여 DJL Serving으로 대규모 모델을 호스팅할 수 있습니다. <BR>
- DJL Serving의 지원 버전에 대한 정보는 [AWS 문서](https://docs.aws.amazon.com/deep-learning-containers/latest/devguide/deep-learning-containers-images.html)를 참조하십시오. <BR>
- 최신 지원 버전을 사용하는 것을 권장합니다. 왜냐하면 그곳에 우리의 개발 노력이 집중되어 있기 때문입니다. <BR>
- SageMaker Python SDK 사용에 대한 일반적인 정보는 [SageMaker Python SDK 사용하기](https://sagemaker.readthedocs.io/en/v2.139.0/overview.html#using-the-sagemaker-python-sdk)를 참조하십시오.
    
REF: [BLOG] [Deploy LLM with vLLM on SageMaker in only 13 lines of code](https://mrmaheshrajput.medium.com/deploy-llm-with-vllm-on-sagemaker-in-only-13-lines-of-code-1601f780c0cf)

In [17]:
# ! pip install sagemaker -U


## 1. Depoly model on SageMaker

In [1]:
import boto3
import sagemaker
from sagemaker import get_execution_role



sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /home/ec2-user/.config/sagemaker/config.yaml


- [Avalable DLC (Deep Learning Containers)](https://github.com/aws/deep-learning-containers/blob/master/available_images.md)

In [2]:
role = get_execution_role()
region = boto3.Session().region_name
sagemaker_session = sagemaker.session.Session()
sm_client = boto3.client("sagemaker", region_name=region)
sm_runtime_client = boto3.client("sagemaker-runtime")
sm_autoscaling_client = boto3.client("application-autoscaling")

### Setup Configuration


 - [[DOC] DJL for serving](https://docs.djl.ai/master/docs/serving/serving/docs/lmi/index.html)

In [3]:
model_id = "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"
#model_id = "deepseek-ai/DeepSeek-Coder-V2-Instruct"

### SageMaker SDK 에서 LML 컨텐이너 이미지 찾기

In [19]:
container_uri = sagemaker.image_uris.retrieve(
    framework="djl-lmi", version="0.30.0", region=region
)
container_uri

'763104351884.dkr.ecr.us-east-1.amazonaws.com/djl-inference:0.30.0-lmi12.0.0-cu124'

### LMI container Image:  v1.0-djl-0.32.0-inf-lmi-14.0.0
* Release date (Feb 8, 25) 
    * https://github.com/aws/deep-learning-containers/releases/tag/v1.0-djl-0.32.0-inf-lmi-14.0.0
* Docker Image
    * 763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.32.0-lmi14.0.0-cu124
    * 763104351884.dkr.ecr.us-east-1.amazonaws.com/djl-inference:0.32.0-lmi14.0.0-cu124

위의 버전은 2025.2.12 현재 SageMaker SDK 에 업데이트 되지 않은 컨테이너 임.

In [None]:
container_uri ="763104351884.dkr.ecr.us-east-1.amazonaws.com/djl-inference:0.32.0-lmi14.0.0-cu124"

In [35]:


if model_id == "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct":
    instance_type = "ml.g5.12xlarge"
    container_startup_health_check_timeout = 900
elif model_id == "deepseek-ai/DeepSeek-Coder-V2-Instruct":
    instance_type = "ml.p4de.24xlarge"
    container_startup_health_check_timeout = 1800
    
endpoint_name = sagemaker.utils.name_from_base("DeepSeek-Coder-V2-Instruct")

print (f'container_uri: {container_uri}')
print (f'container_startup_health_check_timeout: {container_startup_health_check_timeout}')
print (f'instance_type: {instance_type}')
print (f'endpoint_name: {endpoint_name}')

container_uri: 763104351884.dkr.ecr.us-east-1.amazonaws.com/djl-inference:0.32.0-lmi14.0.0-cu124
container_startup_health_check_timeout: 900
instance_type: ml.g5.12xlarge
endpoint_name: DeepSeek-Coder-V2-Instruct-2025-02-12-12-07-00-775


### Creat model with env variables


- Target model: [DeepSeek-Coder-V2-Light-Instruct](https://huggingface.co/deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct)

- **[Backend for attention computation in vLLM](https://docs.vllm.ai/en/latest/serving/env_vars.html)**
    - Available options:
        - "TORCH_SDPA": use torch.nn.MultiheadAttention
        - "FLASH_ATTN": use FlashAttention
        - "XFORMERS": use XFormers
        - "ROCM_FLASH": use ROCmFlashAttention
        - "FLASHINFER": use flashinfer

- **'"OPTION_DISABLE_FLASH_ATTN": "false"'** is for HF Accelerate with Seq-Scheduler
    - It will be ignored when using vLLM beckend

- [[DOC] DJL-Container and Model Configurations (info. about properties)](https://docs.djl.ai/master/docs/serving/serving/docs/lmi/deployment_guide/configurations.html)

In [36]:
deploy_env = {
    "HF_MODEL_ID": model_id,
    "OPTION_ROLLING_BATCH": "vllm",
    "OPTION_TENSOR_PARALLEL_DEGREE": "max",
    "OPTION_MAX_ROLLING_BATCH_SIZE": "2",
    "OPTION_DTYPE":"fp16",
    "OPTION_TRUST_REMOTE_CODE": "true",
    "OPTION_MAX_MODEL_LEN": "8192",
    "VLLM_ATTENTION_BACKEND": "XFORMERS",
    #"OPTION_DISABLE_FLASH_ATTN": "false", ## HF Accelerate with Seq-Scheduler, 
}

In [37]:
model = sagemaker.Model(
    image_uri=container_uri, 
    role=role,
    env=deploy_env
)

### Deploy model

In [38]:
model.deploy(
    instance_type=instance_type,
    initial_instance_count=1,
    endpoint_name=endpoint_name,
    container_startup_health_check_timeout=container_startup_health_check_timeout
)

-----------------!

## 2. Invocation (Generate Text using the endpoint)

### Get a predictor for your endpoint

In [39]:
predictor = sagemaker.Predictor(
    endpoint_name=endpoint_name,
    sagemaker_session=sagemaker_session,
    serializer=sagemaker.serializers.JSONSerializer(),
    deserializer=sagemaker.deserializers.JSONDeserializer(),
)

### Make a prediction with your endpoint

- **question candidates**
    - write a quick sort algorithm in python.
    - Write a piece of quicksort code in C++.

In [40]:
outputs = predictor.predict(
    {
        "inputs": "write a quick sort algorithm in python and description",
        "parameters": {"do_sample": True, "max_new_tokens": 2048},
    }
)

print(outputs["generated_text"])


Headers: Quick Sort Algorithm in Python

Description: 

Quick Sort is a divide and conquer algorithm. It works by selecting a 'pivot' element from the array and partitioning the other elements into two sub-arrays, according to whether they are less than or greater than the pivot. The sub-arrays are then sorted recursively. The whole array is sorted as a result.

The key process in quickSort is partition(). Target of partitions is, given an array and an element x of array as pivot, put x at correct position in sorted array and put all smaller elements before x, and all greater after x. All this should be done in linear time.


Here is the Python code:

```python
def partition(arr, low, high):
    i = (low - 1)         # index of smaller element
    pivot = arr[high]     # pivot

    for j in range(low, high):
        if arr[j] <= pivot:
            i += 1
            arr[i], arr[j] = arr[j], arr[i]

    arr[i + 1], arr[high] = arr[high], arr[i + 1]
    return (i + 1)


def quickSort(ar

- **With chat template**
    - [DJL Chat Completions API Schema](https://docs.djl.ai/master/docs/serving/serving/docs/lmi/user_guides/chat_input_output_schema.html)

In [41]:
chat = [
    {"role": "user", "content": "Hello, how are you?"},
    {"role": "assistant", "content": "I'm doing great. How can I help you today?"},
    {"role": "user", "content": "I'd like to show off how chat templating works! anyway, write a quick sort algorithm in python and description"},
]

result = predictor.predict(
    {"messages": chat, "max_tokens": 1024}
)
result

{'id': 'chatcmpl-139779376032848',
 'object': 'chat.completion',
 'created': 1739366007,
 'choices': [{'index': 0,
   'message': {'role': 'assistant',
    'content': ' Sure! Quick Sort is a widely used and efficient sorting algorithm that uses the divide-and-conquer strategy to sort elements. Here\'s a Python implementation of the Quick Sort algorithm along with a brief description:\n\n```python\ndef quick_sort(arr):\n    if len(arr) <= 1:\n        return arr\n    else:\n        pivot = arr[len(arr) // 2]\n        left = [x for x in arr if x < pivot]\n        middle = [x for x in arr if x == pivot]\n        right = [x for x in arr if x > pivot]\n        return quick_sort(left) + middle + quick_sort(right)\n\n# Example usage:\narr = [3, 6, 8, 10, 1, 2, 1]\nprint("Original array:", arr)\nsorted_arr = quick_sort(arr)\nprint("Sorted array:", sorted_arr)\n```\n\n### Description of Quick Sort:\n\n1. **Base Case**: If the array has 0 or 1 elements, it is already sorted.\n2. **Pivot Selection*

## 3. Streaming output from the endpoint


In [42]:
import json
import random 

In [43]:
# 다양한 코딩 태스크를 위한 프롬프트 리스트
prompts = [
    "write a quick sort algorithm in python.",
    "Write a Python function to implement a binary search algorithm.",
    "Create a JavaScript function to flatten a nested array.",
    "Implement a simple REST API using Flask in Python.",
    "Write a SQL query to find the top 5 customers by total purchase amount.",
    "Create a React component for a todo list with basic CRUD operations.",
    "Implement a depth-first search algorithm for a graph in C++.",
    "Write a bash script to find and delete files older than 30 days.",
    "Create a Python class to represent a deck of cards with shuffle and deal methods.",
    "Write a regular expression to validate email addresses.",
    "Implement a basic CI/CD pipeline using GitHub Actions."
]

def generate_payload():
    # 랜덤하게 프롬프트 선택
    prompt = random.choice(prompts)
    
    # JSON 페이로드 생성
    body = {
        "inputs": prompt,
        "parameters": {
            "max_new_tokens": 400,
            # "return_full_text": False  # This does not work with Phi3
        },
        "stream": True,
    }
    
    # JSON을 문자열로 변환하고 bytes로 인코딩
    return json.dumps(body).encode('utf-8')

In [44]:
%%time
# Invoke the endpoint
resp = sm_runtime_client.invoke_endpoint_with_response_stream(
    EndpointName=endpoint_name, 
    # Body=json.dumps(body), 
    Body=generate_payload(), 
    
    ContentType="application/json"
)
print("Generated response:")
print("-" * 40)

buffer = ""
for event in resp['Body']:
    if 'PayloadPart' in event:
        chunk = event['PayloadPart']['Bytes'].decode()
        buffer += chunk
        try:
            # Try to parse the buffer as JSON
            data = json.loads(buffer)
            if 'token' in data:
                print(data['token']['text'], end='', flush=True)
            buffer = ""  # Clear the buffer after successful parsing
        except json.JSONDecodeError:
            # If parsing fails, keep the buffer for the next iteration
            pass

print("\n" + "-" * 40)

Generated response:
----------------------------------------


Here's a simple implementation of a depth-first search (DFS) algorithm for a graph in C++. This implementation uses an adjacency list to represent the graph and a vector to keep track of visited nodes.

```cpp
#include <iostream>
#include <vector>
#include <stack>

using namespace std;

// Function to add an edge to the graph
void addEdge(vector<int> adj[], int u, int v) {
    adj[u].push_back(v);
}

// DFS function
void DFS(vector<int> adj[], int start, vector<bool>& visited) {
    stack<int> s;
    s.push(start);

    while (!s.empty()) {
        int node = s.top();
        s.pop();

        if (!visited[node]) {
            cout << node << " ";
            visited[node] = true;
        }

        for (auto it = adj[node].rbegin(); it != adj[node].rend(); ++it) {
            if (!visited[*it]) {
                s.push(*it);
            }
        }
    }
}

// Driver code
int main() {
    int V = 5; // Number of vertices
 

- **With chat template**
    - [DJL Chat Completions API Schema](https://docs.djl.ai/master/docs/serving/serving/docs/lmi/user_guides/chat_input_output_schema.html)

In [45]:
# 다양한 코딩 태스크를 위한 프롬프트 리스트
chat = [
    {"role": "user", "content": "Hello, how are you?"},
    {"role": "assistant", "content": "I'm doing great. How can I help you today?"},
    {"role": "user", "content": "I'd like to show off how chat templating works! anyway, write a quick sort algorithm in python and description"},
]

result = predictor.predict(
    {"messages": chat, "max_tokens": 1024}
)

def generate_payload():
    # 랜덤하게 프롬프트 선택
    prompt = random.choice(prompts)
    
    # JSON 페이로드 생성
    body = {
        "messages": chat,
        "max_tokens": 1024,
        "stream": True,
    }
    
    # JSON을 문자열로 변환하고 bytes로 인코딩
    return json.dumps(body).encode('utf-8')

In [46]:
%%time
# Invoke the endpoint
resp = sm_runtime_client.invoke_endpoint_with_response_stream(
    EndpointName=endpoint_name, 
    # Body=json.dumps(body), 
    Body=generate_payload(), 
    
    ContentType="application/json"
)
print("Generated response:")
print("-" * 40)

buffer = ""
for event in resp['Body']:
    if 'PayloadPart' in event:
        chunk = event['PayloadPart']['Bytes'].decode()
        buffer += chunk
        try:
            # Try to parse the buffer as JSON
            data = json.loads(buffer)
            if 'choices' in data:
                print(data['choices'][0]['delta']['content'], end='', flush=True)
            buffer = ""  # Clear the buffer after successful parsing
        except json.JSONDecodeError:
            # If parsing fails, keep the buffer for the next iteration
            pass

print("\n" + "-" * 40)

Generated response:
----------------------------------------
 Sure! Below is a Python implementation of the Quick Sort algorithm, followed by a brief description of how it works.

### Quick Sort Algorithm in Python

```python
def quick_sort(arr):
    if len(arr) <= 1:
        return arr
    else:
        pivot = arr[len(arr) // 2]
        left = [x for x in arr if x < pivot]
        middle = [x for x in arr if x == pivot]
        right = [x for x in arr if x > pivot]
        return quick_sort(left) + middle + quick_sort(right)

# Example usage:
arr = [3, 6, 8, 10, 1, 2, 1]
print("Original array:", arr)
sorted_arr = quick_sort(arr)
print("Sorted array:", sorted_arr)
```

### Description of Quick Sort

Quick Sort is a highly efficient, recursive sorting algorithm that follows the divide-and-conquer strategy. Here’s how it works:

1. **Choose a Pivot**: Select an element from the array as the pivot. The choice of pivot can vary; common strategies include picking the first element, the las

## 4. Real-time Inference Autoscaling


In [None]:
import pprint

In [None]:
resp = sm_client.describe_endpoint(EndpointName=endpoint_name)

# SageMaker expects resource id to be provided with the following structure
resource_id = f"endpoint/{endpoint_name}/variant/{resp['ProductionVariants'][0]['VariantName']}"

# Scaling configuration
scaling_config_response = sm_autoscaling_client.register_scalable_target(
    ServiceNamespace="sagemaker",
    ResourceId=resource_id,
    ScalableDimension="sagemaker:variant:DesiredInstanceCount", 
    MinCapacity=1,
    MaxCapacity=2
)

In [None]:
# Create Scaling Policy
policy_name = f"scaling-policy-{endpoint_name}"
scaling_policy_response = sm_autoscaling_client.put_scaling_policy(
    PolicyName=policy_name,
    ServiceNamespace="sagemaker",
    ResourceId=resource_id,
    ScalableDimension="sagemaker:variant:DesiredInstanceCount",
    PolicyType="TargetTrackingScaling",
    TargetTrackingScalingPolicyConfiguration={
        "TargetValue": 5.0, # Target for avg invocations per minutes
        "PredefinedMetricSpecification": {
            "PredefinedMetricType": "SageMakerVariantInvocationsPerInstance",
        },
        "ScaleInCooldown": 600, # Duration in seconds until scale in
        "ScaleOutCooldown": 60 # Duration in seconds between scale out
    }
)

In [None]:
response = sm_autoscaling_client.describe_scaling_policies(ServiceNamespace="sagemaker")

pp = pprint.PrettyPrinter(indent=4, depth=4)
for i in response["ScalingPolicies"]:
    pp.pprint(i["PolicyName"])
    print("")
    if("TargetTrackingScalingPolicyConfiguration" in i):
        pp.pprint(i["TargetTrackingScalingPolicyConfiguration"])

In [None]:
# 다양한 코딩 태스크를 위한 프롬프트 리스트
prompts = [
    "write a quick sort algorithm in python.",
    "Write a Python function to implement a binary search algorithm.",
    "Create a JavaScript function to flatten a nested array.",
    "Implement a simple REST API using Flask in Python.",
    "Write a SQL query to find the top 5 customers by total purchase amount.",
    "Create a React component for a todo list with basic CRUD operations.",
    "Implement a depth-first search algorithm for a graph in C++.",
    "Write a bash script to find and delete files older than 30 days.",
    "Create a Python class to represent a deck of cards with shuffle and deal methods.",
    "Write a regular expression to validate email addresses.",
    "Implement a basic CI/CD pipeline using GitHub Actions."
]

def generate_payload():
    # 랜덤하게 프롬프트 선택
    prompt = random.choice(prompts)
    
    # JSON 페이로드 생성
    body = {
        "inputs": prompt,
        "parameters": {
            "max_new_tokens": 400,
            # "return_full_text": False  # This does not work with Phi3
        },
        "stream": True,
    }
    
    # JSON을 문자열로 변환하고 bytes로 인코딩
    return json.dumps(body).encode('utf-8')

In [None]:
%%time
import time

request_duration = 250
end_time = time.time() + request_duration
print(f"Endpoint will be tested for {request_duration} seconds")

while time.time() < end_time:
    payload = generate_payload()
    # Invoke the endpoint
    response = sm_runtime_client.invoke_endpoint_with_response_stream(
        EndpointName=endpoint_name, 
        # Body=json.dumps(body), 
        Body = payload,
        ContentType="application/json"
    )

In [None]:
# Check the instance counts after the endpoint gets more load
response = sm_client.describe_endpoint(EndpointName=endpoint_name)
endpoint_status = response["EndpointStatus"]
request_duration = 250
end_time = time.time() + request_duration
print(f"Waiting for Instance count increase for a max of {request_duration} seconds. Please re run this cell in case the count does not change")
while time.time() < end_time:
    response = sm_client.describe_endpoint(EndpointName=endpoint_name)
    endpoint_status = response["EndpointStatus"]
    instance_count = response["ProductionVariants"][0]["CurrentInstanceCount"]
    print(f"Status: {endpoint_status}")
    print(f"Current Instance count: {instance_count}")
    if (endpoint_status=="InService") and (instance_count>1):
        break
    else:
        time.sleep(15)

In [None]:
# Delete model
sm_client.delete_model(ModelName=model_name)

# Delete endpoint configuration
sm_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)

# Delete endpoint
sm_client.delete_endpoint(EndpointName=endpoint_name)