# Deploy a DeepSeek-V3.1 on SageMaker AI endpoint using LMI container

DeepSeek-V3.1 is a hybrid model that supports both thinking mode and non-thinking mode. Compared to the previous version, this upgrade brings improvements in multiple aspects:

- **Hybrid thinking mode**: One model supports both thinking mode and non-thinking mode by changing the chat template.

- **Smarter tool calling**: Through post-training optimization, the model's performance in tool usage and agent tasks has significantly improved.

- **Higher thinking efficiency**: DeepSeek-V3.1-Think achieves comparable answer quality to DeepSeek-R1-0528, while responding more quickly.

DeepSeek-V3.1 is post-trained on the top of DeepSeek-V3.1-Base, which is built upon the original V3 base checkpoint through a two-phase long context extension approach, following the methodology outlined in the original DeepSeek-V3 report. We have expanded our dataset by collecting additional long documents and substantially extending both training phases. The 32K extension phase has been increased 10-fold to 630B tokens, while the 128K extension phase has been extended by 3.3x to 209B tokens. Additionally, DeepSeek-V3.1 is trained using the UE8M0 FP8 scale data format to ensure compatibility with microscaling data formats.

HuggingFace Repo: [deepseek-ai/DeepSeek-V3.1](https://huggingface.co/deepseek-ai/DeepSeek-V3.1)

## Setup

In [None]:
%pip install sagemaker --upgrade --quiet --no-warn-conflicts

In [None]:
import json
import sagemaker
import boto3

role = sagemaker.get_execution_role()  # execution role for the endpoint
sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs
bucket = sess.default_bucket()  # bucket to house artifacts
region = sess._region_name  # region name of the current SageMaker Studio environment
account_id = sess.account_id()

sm_client = boto3.client("sagemaker")  # client to intreract with SageMaker
smr_client = boto3.client("sagemaker-runtime")  # client to intreract with SageMaker Endpoints

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sess.default_bucket()}")
print(f"sagemaker session region: {sess.boto_region_name}")
print(f"sagemaker version: {sagemaker.__version__}")

## Deployment

We are going to use SageMaker AI Large Model Inference (LMI) container, see [this](https://github.com/aws/deep-learning-containers/blob/master/available_images.md#large-model-inference-containers) for more info.

In this notebook, we are going to deploy the model directly from the HuggingFace hub. You can also can deploy the model from S3 (just point `HF_MODEL_ID` to S3 prefix where model weights are stored.

Please note that your account should have a quota on `ml.p5en.48xlarge` instance.

In [None]:
CONTAINER_VERSION = "0.33.0-lmi15.0.0-cu128"
inference_image = f"763104351884.dkr.ecr.{region}.amazonaws.com/djl-inference:{CONTAINER_VERSION}"
instance_type = "ml.p5en.48xlarge"
model_name = sagemaker.utils.name_from_base("model-ds")
endpoint_name = model_name
timeout = 1800

common_env = {
    "HF_MODEL_ID": "deepseek-ai/DeepSeek-V3.1",
}
lmi_env = {
    "SERVING_FAIL_FAST": "true",
    "OPTION_ASYNC_MODE": "true",
    "OPTION_ROLLING_BATCH": "disable",
    "OPTION_MAX_MODEL_LEN": "8192",
    "OPTION_TENSOR_PARALLEL_DEGREE": "max",
    "OPTION_ENTRYPOINT": "djl_python.lmi_vllm.vllm_async_service",
    "OPTION_TRUST_REMOTE_CODE": "true",
}
env = common_env | lmi_env

In [None]:
lmi_model = sagemaker.Model(
    image_uri=inference_image,
    env=env,
    role=role,
    name=model_name,
)

lmi_model.deploy(
    initial_instance_count=1,
    instance_type=instance_type,
    container_startup_health_check_timeout=timeout,
    endpoint_name=endpoint_name,
)

llm = sagemaker.Predictor(
    endpoint_name=endpoint_name,
    sagemaker_session=sess,
    serializer=sagemaker.serializers.JSONSerializer(),
    deserializer=sagemaker.deserializers.JSONDeserializer(),
)

### Inference test

#### Standard (non-thinking) request

In [27]:
payload={
    "messages": [
        {"role": "user", "content": "Name popular places to visit in London?"}
    ],
}

res = smr_client.invoke_endpoint(EndpointName=endpoint_name,
                                 Body=json.dumps(payload),
                                 ContentType="application/json")

response = json.loads(res["Body"].read().decode("utf8"))
content = response["choices"][0]["message"]["content"]
usage = response["usage"] 

print("###############\n" + content + "\n###############\n")
print(usage)

###############
Of course! London is packed with iconic landmarks, world-class museums, and vibrant neighborhoods. Here is a list of popular places to visit, categorized to help you plan your trip.

### **The Absolute Icons (The "Must-Do" List)**
These are the landmarks that define London's skyline and history.

1.  **The Houses of Parliament & Big Ben:** The stunning Gothic revival building that houses the UK government. The Elizabeth Tower (commonly known as Big Ben) is one of the most recognizable symbols in the world.
2.  **Buckingham Palace:** The official London residence of the King. The **Changing of the Guard** ceremony is a major draw (check schedules online).
3.  **The Tower of London:** A historic castle on the River Thames. Explore nearly 1,000 years of history, see the Crown Jewels, and be entertained by the Yeoman Warders (Beefeaters).
4.  **Tower Bridge:** The famous bascule and suspension bridge next to the Tower of London. You can walk across for free or pay to go ins

#### Switching to "thinking" mode

Note `</think>` token in the output

In [29]:
payload = {
    "messages": [
        {"role": "user", "content": "What is bigger 9.11 or 9.8?"}
    ],
    "chat_template_kwargs": {
        "thinking": True
    }
}

res = smr_client.invoke_endpoint(EndpointName=endpoint_name,
                                 Body=json.dumps(payload),
                                 ContentType="application/json")

response = json.loads(res["Body"].read().decode("utf8"))
content = response["choices"][0]["message"]["content"]
usage = response["usage"] 

print("###############\n" + content + "\n###############\n")
print(usage)

###############
First, the question is: "What is bigger 9.11 or 9.8?" I need to compare these two numbers.

Both numbers are decimals. 9.11 and 9.8. I should think of them in terms of place value to compare them properly.

The number 9.11 has digits: 9 (units), 1 (tenths), and 1 (hundredths). So, it's 9 + 0.1 + 0.01 = 9.11.

The number 9.8 has digits: 9 (units) and 8 (tenths). So, it's 9 + 0.8 = 9.8.

To compare them, I should look at the whole number part first. Both have the same whole number, which is 9. So, I need to compare the decimal parts.

The decimal part of 9.11 is 0.11, and the decimal part of 9.8 is 0.8.

Now, 0.8 is the same as 0.80, which is greater than 0.11 because 80 hundredths is greater than 11 hundredths.

I can also think in terms of tenths: 9.8 has 8 tenths, while 9.11 has 1 tenth and 1 hundredth, so 8 tenths is greater than 1 tenth.

Therefore, 9.8 is greater than 9.11.

So, the bigger number is 9.8.

I should make sure there's no trick here. Sometimes people mi

#### Helper functions

In [30]:
import io
import json
import time
import boto3
from IPython.display import clear_output

class LineIterator:
    def __init__(self, stream):
        self.byte_iterator = iter(stream)
        self.buffer = io.BytesIO()
        self.read_pos = 0

    def __iter__(self):
        return self

    def __next__(self):
        while True:
            self.buffer.seek(self.read_pos)
            line = self.buffer.readline()
            if line and line[-1] == ord("\n"):
                self.read_pos += len(line)
                return line[:-1]
            try:
                chunk = next(self.byte_iterator)
            except StopIteration:
                if self.read_pos < self.buffer.getbuffer().nbytes:
                    continue
                raise
            if "PayloadPart" not in chunk:
                print("Unknown event type:" + chunk)
                continue
            self.buffer.seek(0, io.SEEK_END)
            self.buffer.write(chunk["PayloadPart"]["Bytes"])

def stream_response(endpoint_name, inputs, thinking=False):    
    body = {
        "messages": [
            {"role": "user", "content": inputs}
        ],
        "chat_template_kwargs": {
            "thinking": thinking
        },
        "stream": True
    }

    resp = smr_client.invoke_endpoint_with_response_stream(
        EndpointName=endpoint_name,
        Body=json.dumps(body),
        ContentType="application/json",
    )

    event_stream = resp["Body"]
    start_json = b"{"
    full_response = ""
    start_time = time.time()
    token_count = 0

    for line in LineIterator(event_stream):
        if line != b"" and start_json in line:
            data = json.loads(line[line.find(start_json):].decode("utf-8"))
            token_text = data['choices'][0]['delta'].get('content', '')
            full_response += token_text
            token_count += 1

            # Calculate tokens per second
            elapsed_time = time.time() - start_time
            tps = token_count / elapsed_time if elapsed_time > 0 else 0

            # Clear the output and reprint everything
            clear_output(wait=True)
            print(full_response)
            print(f"\nTokens per Second: {tps:.2f}", end="")

    print("\n") # Add a newline after response is complete
    
    return full_response

#### Streaming invocation in "thinking" mode

In [31]:
inputs = "How many 'r' in word 'strawberries'"
output = stream_response(endpoint_name, inputs, thinking=True)

First, the question is: "How many 'r' in word 'strawberries'?" I need to count the number of times the letter 'r' appears in the word "strawberries".

Let me write down the word: s-t-r-a-w-b-e-r-r-i-e-s.

Now, I'll go through each letter and count the 'r's.

Starting from the beginning:

- s: not r

- t: not r

- r: yes, that's one r

- a: not r

- w: not r

- b: not r

- e: not r

- r: yes, that's the second r

- r: yes, that's the third r

- i: not r

- e: not r

- s: not r

So, I have three 'r's in "strawberries".

I should double-check to make sure I didn't miss any. The word has 12 letters, and the 'r's are at positions 3, 8, and 9. Yes, that seems correct.

Therefore, the answer should be 3.</think>The word "strawberries" contains **3** instances of the letter 'r'. Here's a breakdown for clarity:

- The word is: s-t-r-a-w-b-e-r-r-i-e-s
- The 'r's appear at positions 3, 8, and 9.

Tokens per Second: 33.88



## Cleanup

In [34]:
sess.delete_endpoint(endpoint_name)
sess.delete_endpoint_config(endpoint_name)
sess.delete_model(model_name)