# Reducing Inference Costs on DeepSeek-R1-Distill-Llama-8B with SageMaker Inference's Scale to Zero Capability

This demo notebook demonstrate how you can scale in your SageMaker endpoint to zero instances during idle periods, eliminating the previous requirement of maintaining at least one running instance.

❗This notebook works well on `ml.t3.medium` instance with `PyTorch 2.2.0 Python 3.10 CPU optimized` kernel from **SageMaker Studio Classic** or `Python3` kernel from **JupyterLab**.

## Set up Environment

In [None]:
%%capture --no-stderr

!pip install -U pip
!pip install -U "sagemaker>=2.239.0"
!pip install -U "transformers>=4.47.0"

In [None]:
import boto3
import sagemaker


role = sagemaker.get_execution_role()
boto_region = boto3.Session().region_name
sagemaker_session = sagemaker.session.Session(boto_session=boto3.Session(region_name=boto_region))

## Setup your SageMaker Real-time Endpoint

### Create the SageMaker endpoint

#### Deploy using DJL-Inference Container

The [Deep Java Library (DJL) Large Model Inference (LMI)](https://docs.aws.amazon.com/sagemaker/latest/dg/large-model-inference-container-docs.html) containers are specialized Docker containers designed to facilitate the deployment of large language models (LLMs) on Amazon SageMaker. These containers integrate a model server with optimized inference libraries, providing a comprehensive solution for serving LLMs.

In [None]:
## You can get inference image uri programmatically using sagemaker.image_uris.retrieve
# deepspeed_image_uri = sagemaker.image_uris.retrieve(
#     framework="djl-inference",
#     region=boto_region,
#     version="0.31.0-lmi13.0.0-cu124"
# )

djllmi_inference_image_uri = "763104351884.dkr.ecr.us-east-1.amazonaws.com/djl-inference:0.31.0-lmi13.0.0-cu124"

In [None]:
from sagemaker.utils import name_from_base

model_name = name_from_base("deepseek-r1-distill-llama3-8b", short=True)

deepseek_lmi_model = sagemaker.Model(
    image_uri=djllmi_inference_image_uri,
    env={
        "HF_MODEL_ID": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
        "OPTION_MAX_MODEL_LEN": "10000",
        "OPTION_GPU_MEMORY_UTILIZATION": "0.95",
        "OPTION_ENABLE_STREAMING": "false",
        "OPTION_ROLLING_BATCH": "auto",
        "OPTION_MODEL_LOADING_TIMEOUT": "3600",
        "OPTION_PAGED_ATTENTION": "false",
        "OPTION_DTYPE": "fp16",
    },
    role=role,
    name=model_name
)

In [None]:
from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements

resources_required = ResourceRequirements(
    requests={
        "num_cpus": 2,
        "memory": 1024,
        "num_accelerators": 1,
        "copies": 1, # specify the number of initial copies (default is 1)
    }
)

We begin by creating an endpoint with setting **MinInstanceCount** to **0**. This allows the endpoint to scale in all the way down to zero instances when not in use.

In [None]:
from sagemaker.enums import EndpointType

endpoint_name = name_from_base("deepseek-r1-distill-llama3-8b-scale-to-zero-aas-ep", short=True)

instance_type = "ml.g5.2xlarge"
model_data_download_timeout_in_seconds = 3600
container_startup_health_check_timeout_in_seconds = 3600

min_instance_count = 0 # Minimum instance must be set to 0
max_instance_count = 3

deepseek_lmi_model.deploy(
    instance_type=instance_type,
    initial_instance_count=1,
    accept_eula=True,
    endpoint_name=endpoint_name,
    model_data_download_timeout=model_data_download_timeout_in_seconds,
    container_startup_health_check_timeout=container_startup_health_check_timeout_in_seconds,
    resources=resources_required,
    managed_instance_scaling={
        "Status": "ENABLED",
        "MinInstanceCount": min_instance_count,
        "MaxInstanceCount": max_instance_count,
    },
    endpoint_type=EndpointType.INFERENCE_COMPONENT_BASED,
    routing_config={"RoutingStrategy": "LEAST_OUTSTANDING_REQUESTS"},
)

print(f"Your DJL-LMI Model Endpoint: {endpoint_name} is now deployed! 🚀")

### Create a Predictor with SageMaker Endpoint name

In [None]:
from sagemaker import Predictor
from sagemaker.serializers import JSONSerializer
from sagemaker.deserializers import JSONDeserializer


predictor = Predictor(
    endpoint_name=endpoint_name,
    serializer=JSONSerializer(),
    deserializer=JSONDeserializer()
)

In [None]:
sagemaker_client = boto3.client("sagemaker", region_name=boto_region)

response = sagemaker_client.list_inference_components(EndpointNameEquals=predictor.endpoint_name)
inference_component_name = response['InferenceComponents'][0]['InferenceComponentName']

### Inference with SageMaker SDK

SageMaker python sdk simplifies the inference construct using `sagemaker.Predictor` method.

`DeepSeek Llama8b` variant is based on 3.1 Llama8b prompt format which is as shown below,

```json
<|begin_of_text|>
<|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2024
Today Date: 29 Jan 2025

You are a helpful assistant that thinks and reasons before answering.

<|eot_id|>
<|start_header_id|>user<|end_header_id|>
How many R are in STRAWBERRY? Keep your answer and explanation short!
<|eot_id|>

<|start_header_id|>assistant<|end_header_id|>

In [None]:
from typing import List, Dict
from datetime import datetime


def format_messages(messages: List[Dict[str, str]]) -> List[str]:
    """
    Format messages for Llama 3+ chat models.

    The model only supports 'system', 'user' and 'assistant' roles, starting with 'system', then 'user' and
    alternating (u/a/u/a/u...). The last message must be from 'user'.
    """
    # auto assistant suffix
    # messages.append({"role": "assistant"})

    output = "<|begin_of_text|>"
    # Adding an inferred prefix
    system_prefix = f"\n\nCutting Knowledge Date: December 2024\nToday Date: {datetime.now().strftime('%d %b %Y')}\n\n"
    for i, entry in enumerate(messages):
        output += f"<|start_header_id|>{entry['role']}<|end_header_id|>"
        if entry['role'] == 'system':
            output += f"{system_prefix}{entry['content']}<|eot_id|>"
        elif entry['role'] != 'system' and 'content' in entry:
            output += f"\n\n{entry['content']}<|eot_id|>"
    output += "<|start_header_id|>assistant<|end_header_id|>\n"
    return output

def send_prompt(predictor, initial_args, messages, parameters):
    # convert u/a format
    frmt_input = format_messages(messages)
    payload = {
        "inputs": frmt_input,
        "parameters": parameters
    }
    response = predictor.predict(
        initial_args=initial_args,
        data=payload)
    return response

### Test the endpoint with a sample prompt

Now we can invoke our endpoint with sample text to test its functionality and see the model's output.

In [None]:
%%time

messages = [
    {
        "role": "system",
        "content": "You are a helpful assistant that thinks and reasons before answering."
    },
    {
        "role": "user",
        "content": "How many R are in STRAWBERRY? Keep your answer and explanation short!"
    }
]

response = send_prompt(
    predictor=predictor,
    initial_args={
        'InferenceComponentName': inference_component_name
    },
    messages=messages,
    parameters={
        "temperature": 0.6,
        "max_new_tokens": 512
    }
)

print(response['generated_text'])

Okay, so I need to figure out how many times the letter R appears in the word "STRAWBERRY." Let me start by writing out the word and looking at each letter one by one. S, T, R, A, W, B, E, R, R, Y. Hmm, I see an R right there in the third position. Then, after that, I see two more Rs at the end: R and R. So that's three Rs in total. Wait, let me count again to make sure I didn't miss any. S, T, R, A, W, B, E, R, R, Y. Yep, that's three Rs. I don't think I missed any other letters. So the answer should be three Rs.
</think>

The letter R appears three times in STRAWBERRY.

Step-by-step explanation:
1. Write out the word: S, T, R, A, W, B, E, R, R, Y.
2. Identify each R: The third letter is R, and the eighth and ninth letters are also R.
3. Count the Rs: There are three Rs in total.

Answer: 3
CPU times: user 12.7 ms, sys: 3.78 ms, total: 16.4 ms
Wall time: 9.04 s


## Automatically Scale To Zero

### Scaling policies

Once the endpoint is deployed and InService, you can then add the necessary scaling policies:

- A [target tracking policy](https://docs.aws.amazon.com/autoscaling/application/userguide/application-auto-scaling-target-tracking.html) that can scale in the copy count for our inference component model copies to zero, and from 1 to n.
- A [step scaling policy](https://docs.aws.amazon.com/autoscaling/application/userguide/application-auto-scaling-step-scaling-policies.html) that will allow the endpoint to scale out from zero.

These policies work together to provide cost-effective scaling - the endpoint can scale to zero when idle and automatically scale out as needed to handle incoming requests.

### Scaling policy for inference components copies (target tracking)

We start with creating our target tracking policies for scaling the CopyCount of our inference component

#### Register a new autoscaling target

After you create your SageMaker endpoint and inference components, you register a new auto scaling target for Application Auto Scaling.
In the following code block, you set **MinCapacity** to **0**, which is required for your endpoint to scale down to zero

In [None]:
aas_client = sagemaker_session.boto_session.client("application-autoscaling", region_name=boto_region)
cloudwatch_client = sagemaker_session.boto_session.client("cloudwatch", region_name=boto_region)

In [None]:
# Autoscaling parameters
resource_id = f"inference-component/{inference_component_name}"
service_namespace = "sagemaker"
scalable_dimension = "sagemaker:inference-component:DesiredCopyCount"

min_copy_count = min_instance_count
max_copy_count = max_instance_count

aas_client.register_scalable_target(
    ServiceNamespace=service_namespace,
    ResourceId=resource_id,
    ScalableDimension=scalable_dimension,
    MinCapacity=min_copy_count,
    MaxCapacity=max_copy_count,
)

#### Configure Target Tracking Scaling Policy

Once you have registered your new scalable target, the next step is to define your target tracking policy.
In the code example that follows, we set the TargetValue to 5.
This setting instructs the auto-scaling system to increase capacity when the number of concurrent requests per model reaches or exceeds 5.
Here we are taking advantage of the more granular auto scaling metric `PredefinedMetricType`: `SageMakerInferenceComponentConcurrentRequestsPerCopyHighResolution` to more accurately monitor and react to changes in inference traffic. Take a look this [blog](https://aws.amazon.com/blogs/machine-learning/amazon-sagemaker-inference-launches-faster-auto-scaling-for-generative-ai-models/) for more information.

In [None]:
aas_client.describe_scalable_targets(
    ServiceNamespace=service_namespace,
    ResourceIds=[resource_id],
    ScalableDimension=scalable_dimension,
)

# The policy name for the target traking policy
target_tracking_policy_name = f"Target-tracking-policy-deepseek-r1-distill-llama3-8b-scale-to-zero-aas-{inference_component_name}"

aas_client.put_scaling_policy(
    PolicyName=target_tracking_policy_name,
    PolicyType="TargetTrackingScaling",
    ServiceNamespace=service_namespace,
    ResourceId=resource_id,
    ScalableDimension=scalable_dimension,
    TargetTrackingScalingPolicyConfiguration={
        "PredefinedMetricSpecification": {
            "PredefinedMetricType": "SageMakerInferenceComponentConcurrentRequestsPerCopyHighResolution",
        },
        # Low TPS + load TPS
        "TargetValue": 5,  # you need to adjust this value based on your use case
        "ScaleInCooldown": 300,  # default
        "ScaleOutCooldown": 300,  # default
    }
)

Application Auto Scaling creates two CloudWatch alarms per scaling target. The first triggers scale-out actions after 30 seconds (using 3 sub-minute data point), while the second triggers scale-in after 15 minutes (using 90 sub-minute data points). The time to trigger the scaling action is usually 1–2 minutes longer than those minutes because it takes time for the endpoint to publish metrics to CloudWatch, and it also takes time for AutoScaling to react.

### Scale out from zero policy (step scaling policy )

To enable your endpoint to scale out from zero instances, do the following:

#### Configure Step Scaling Policy

Create a step scaling policy that defines when and how to scale out from zero. This policy will add 1 model copy when triggered, enabling SageMaker to provision the instances required to handle incoming requests after being idle. The following shows you how to define a step scaling policy. Here we have configured to scale out from 0 to 1 model copy ("ScalingAdjustment": 1), depending on your use case you can adjust ScalingAdjustment as required.

In [None]:
# The policy name for the step scaling policy
step_scaling_policy_name = f"Step-scaling-policy-{inference_component_name}"

aas_client.put_scaling_policy(
    PolicyName=step_scaling_policy_name,
    PolicyType="StepScaling",
    ServiceNamespace=service_namespace,
    ResourceId=resource_id,
    ScalableDimension=scalable_dimension,
    StepScalingPolicyConfiguration={
        "AdjustmentType": "ChangeInCapacity",
        "MetricAggregationType": "Maximum",
        "Cooldown": 60,
        "StepAdjustments":
          [
             {
               "MetricIntervalLowerBound": 0,
               "ScalingAdjustment": 1
             }
          ]
    },
)

In [None]:
resp = aas_client.describe_scaling_policies(
    PolicyNames=[step_scaling_policy_name],
    ServiceNamespace=service_namespace,
    ResourceId=resource_id,
    ScalableDimension=scalable_dimension,
)

step_scaling_policy_arn = resp['ScalingPolicies'][0]['PolicyARN']
print(f"step_scaling_policy_arn: {step_scaling_policy_arn}")

#### Create the CloudWatch alarm that will trigger our policy

Finally, create a CloudWatch alarm with the metric **NoCapacityInvocationFailures**. When triggered, the alarm initiates the previously defined scaling policy. For more information about the **NoCapacityInvocationFailures** metric, see [documentation](https://docs.aws.amazon.com/sagemaker/latest/dg/monitoring-cloudwatch.html#cloudwatch-metrics-inference-component).

We have also set the following:

- EvaluationPeriods to 1
- DatapointsToAlarm to 1
- ComparisonOperator to GreaterThanOrEqualToThreshold
 
This results in 1 min waiting for the step scaling policy to trigger

In [None]:
# The alarm name for the step scaling alarm
step_scaling_alarm_name = f"step-scaling-alarm-{inference_component_name}"

cloudwatch_client.put_metric_alarm(
    AlarmName=step_scaling_alarm_name,
    AlarmActions=[step_scaling_policy_arn],  # Replace with your actual ARN
    MetricName='NoCapacityInvocationFailures',
    Namespace='AWS/SageMaker',
    Statistic='Maximum',
    Dimensions=[
        {
            'Name': 'InferenceComponentName',
            'Value': inference_component_name  # Replace with actual InferenceComponentName
        }
    ],
    Period=30, # Set a lower period
    EvaluationPeriods=1,
    DatapointsToAlarm=1,
    Threshold=1,
    ComparisonOperator='GreaterThanOrEqualToThreshold',
    TreatMissingData='missing'
)

### Testing the behaviour

Notice the `MinInstanceCount: 0` setting in the Endpoint configuration, which allows the endpoint to scale down to zero instances. With the scaling policy, CloudWatch alarm, and minimum instances set to zero, your SageMaker Inference Endpoint will now be able to automatically scale down to zero instances when not in use, helping you optimize your costs and resource utilization.

### Inference Component (IC) copy count scales in to zero

We'll pause for a few minutes without making any invocations to our model. Based on our target tracking policy, when our SageMaker endpoint doesn't receive requests for about 10 to 15 minutes, it will automatically scale down to zero the number of model copies.

In [None]:
import sys
import time

time.sleep(600)
start_time = time.time()
while True:
    desc = sagemaker_client.describe_inference_component(InferenceComponentName=inference_component_name)
    status = desc["InferenceComponentStatus"]
    print(status)
    sys.stdout.flush()
    if status in ["InService", "Failed"]:
        break
    time.sleep(10)

total_time = time.time() - start_time
print(f"\nTotal time taken: {total_time:.2f} seconds ({total_time/60:.2f} minutes)")

desc = sagemaker_client.describe_inference_component(InferenceComponentName=inference_component_name)
print(desc)

### Endpoint's instances scale in to zero

After a few additional minutes of inactivity, SageMaker automatically terminates all underlying instances of the endpoint, eliminating all associated costs.

In [None]:
# after 1 mins instances will scale down to 0
time.sleep(60)

# verify whether CurrentInstanceCount is zero
sagemaker_session.wait_for_endpoint(endpoint_name)

### Invoke the endpoint with a sample prompt

If we try to invoke our endpoint while instances are scaled down to zero, we get a validation error: `An error occurred (ValidationError) when calling the InvokeEndpoint operation: Inference Component has no capacity to process this request. ApplicationAutoScaling may be in-progress (if configured) or try to increase the capacity by invoking UpdateInferenceComponentRuntimeConfig API`.

In [None]:
print(time.strftime("%H:%M:%S"))

messages = [
    {
        "role": "system",
        "content": "You are a helpful assistant that thinks and reasons before answering."
    },
    {
        "role": "user",
        "content": "How many R are in STRAWBERRY? Keep your answer and explanation short!"
    }
]

response = send_prompt(
    predictor=predictor,
    initial_args={
        'InferenceComponentName': inference_component_name
    },
    messages=messages,
    parameters={
        "temperature": 0.6,
        "max_new_tokens": 512
    }
)

print(response['generated_text'])

### Scale out from zero kicks in

However, after 1 minutes our step scaling policy should kick in. SageMaker will then start provisioning a new instance and deploy our inference component model copy to handle requests. This demonstrates the endpoint's ability to automatically scale out from zero when needed.

In [None]:
# after 1 min instances will scale out from zero to one
time.sleep(60)

# verify whether CurrentInstanceCount is zero
sagemaker_session.wait_for_endpoint(endpoint_name)

In [None]:
import sys
import time


start_time = time.time()
while True:
    desc = sagemaker_client.describe_inference_component(InferenceComponentName=inference_component_name)
    status = desc["InferenceComponentStatus"]
    print(status)
    sys.stdout.flush()
    if status in ["InService", "Failed"]:
        break
    time.sleep(30)

total_time = time.time() - start_time
print(f"\nTotal time taken: {total_time:.2f} seconds ({total_time/60:.2f} minutes)")

desc = sagemaker_client.describe_inference_component(InferenceComponentName=inference_component_name)
print(desc)

### verify that our endpoint has succesfully scaled out from zero

In [None]:
messages = [
    {
        "role": "system",
        "content": "You are a helpful assistant that thinks and reasons before answering."
    },
    {
        "role": "user",
        "content": "How many R are in STRAWBERRY? Keep your answer and explanation short!"
    }
]

response = send_prompt(
    predictor=predictor,
    initial_args={
        'InferenceComponentName': inference_component_name
    },
    messages=messages,
    parameters={
        "temperature": 0.6,
        "max_new_tokens": 512
    }
)

print(response['generated_text'])

Okay, so I need to figure out how many times the letter R appears in the word "STRAWBERRY." Let me start by writing out the word and looking at each letter one by one. S, T, R, A, W, B, E, R, R, Y. Hmm, I see an R right there in the third position. Then, after that, I see two more Rs at the end: R and R. So that's three Rs in total. Wait, let me count again to make sure I didn't miss any. S, T, R, A, W, B, E, R, R, Y. Yep, that's three Rs. I don't think I missed any other letters. So the answer should be three Rs.
</think>

The letter R appears three times in STRAWBERRY.

Step-by-step explanation:
1. Write out the word: S, T, R, A, W, B, E, R, R, Y.
2. Identify each R: The third letter is R, and the eighth and ninth letters are also R.
3. Count the Rs: There are three Rs in total.

Answer: 3


### Optionally clean up the environment

- Deregister scalable target
- Delete cloudwatch alarms
- Delete scaling policies

In [None]:
try:
    # Deregister the scalable target for AAS
    aas_client.deregister_scalable_target(
        ServiceNamespace="sagemaker",
        ResourceId=resource_id,
        ScalableDimension=scalable_dimension,
    )
    print(f"Scalable target for [b]{resource_id}[/b] deregistered. ✅")
except aas_client.exceptions.ObjectNotFoundException:
    print(f"Scalable target for [b]{resource_id}[/b] not found!.")

print("---" * 10)

# Delete CloudWatch alarms created for Step scaling policy
try:
    cloudwatch_client.delete_alarms(AlarmNames=[step_scaling_alarm_name])
    print(f"Deleted CloudWatch step scaling scale-out alarm [b]{step_scaling_alarm_name} ✅")
except cloudwatch_client.exceptions.ResourceNotFoundException:
    print(f"CloudWatch scale-out alarm [b]{step_scaling_alarm_name}[/b] not found.")


# Delete step scaling policies
print("---" * 10)

try:
    aas_client.delete_scaling_policy(
        PolicyName=step_scaling_policy_name,
        ServiceNamespace="sagemaker",
        ResourceId=resource_id,
        ScalableDimension="sagemaker:variant:DesiredInstanceCount",
    )
    print(f"Deleted scaling policy [i green]{step_scaling_policy_name} ✅")
except aas_client.exceptions.ObjectNotFoundException:
    print(f"Scaling policy [i]{step_scaling_policy_name}[/i] not found.")

- Delete inference component
- Delete endpoint
- delete endpoint-config

In [None]:
sagemaker_client.delete_inference_component(InferenceComponentName=inference_component_name)
predictor.delete_model()
predictor.delete_endpoint()

## References

- [✍🏻 (AWS Machine Learning Blog) Unlock cost savings with the new scale down to zero feature in SageMaker Inference (2024-12-02)](https://aws.amazon.com/blogs/machine-learning/unlock-cost-savings-with-the-new-scale-down-to-zero-feature-in-amazon-sagemaker-inference/)
- [💻 Unlock Cost Savings with New Scale-to-Zero Feature in SageMaker Inference](https://github.com/aws-samples/sagemaker-genai-hosting-examples/blob/main/scale-to-zero-endpoint/llama3-8b-scale-to-zero-autoscaling.ipynb)
- [💻 Deploy DeepSeek R1 Large Language Model from HuggingFace Hub on Amazon SageMaker](https://github.com/aws-samples/sagemaker-genai-hosting-examples/blob/main/Deepseek/DeepSeek-R1-Llama8B-LMI-TGI-Deploy.ipynb)
- [Available AWS Deep Learning Containers (DLC) images](https://github.com/aws/deep-learning-containers/blob/master/available_images.md)