# SageMaker Ollama endpoint example

## 1. Prerequisites

The byoc will build and store a endpoint docker image in you ECR private repo (for example `sagemaker_endpoint/ollama`), you need to define the following variables.

In [None]:
import os

MODEL_ID = "qwen3:0.6b"
INSTANCE_TYPE = "ml.c7g.4xlarge"
ARCHITECTURE = "arm64"

INSTANCE_TYPE = "ml.c7i.4xlarge"
INSTANCE_TYPE = "ml.g5.2xlarge"

ARCHITECTURE = "x86_64"

REPO_TAG = "0.12.5"
REPO_NAMESPACE = "sagemaker_endpoint/ollama-" + ARCHITECTURE

ACCOUNT = !aws sts get-caller-identity --query Account --output text
REGION = !aws configure get region
REGION = REGION[0] if REGION else os.environ.get("AWS_REGION")

ACCOUNT = ACCOUNT[0]


ARTIFACTS_PREFIX = "ollama"

CONTAINER = f"{ACCOUNT}.dkr.ecr.{REGION}.amazonaws.com/{REPO_NAMESPACE}:{REPO_TAG}"
print(CONTAINER)

In [None]:
%pip install -U sagemaker

In [None]:
import os
import re
import json
from datetime import datetime
import time

import boto3
import sagemaker

role = sagemaker.get_execution_role()

sagemaker_client = boto3.client("sagemaker")

print("Your execution role details page:")
print(f"https://us-east-1.console.aws.amazon.com/iam/home?#/roles/details/{role.split('/')[-1]}")

This build needs CodeBuild service, you need to make sure your ExecutionRole contains sufficient permissions.

Here is an example policy you can put into your role in the above role details page:
##### role policy:

``` json
{
    "Version": "2012-10-17",
    "Statement": [
        {
            "Effect": "Allow",
            "Action": [
                "iam:CreateRole",
                "iam:GetRole",
                "iam:PutRolePolicy",
                "iam:PassRole"
            ],
            "Resource": [
                "arn:aws:iam::*:role/CodeBuildServiceRole-*"
            ]
        },
        {
            "Effect": "Allow",
            "Action": [
                "codebuild:CreateProject",
                "codebuild:UpdateProject",
                "codebuild:BatchGetProjects",
                "codebuild:StartBuild"
            ],
            "Resource": "*"
        },
        {
            "Effect": "Allow",
            "Action": [
                "s3:CreateBucket",
                "s3:ListBucket",
                "s3:PutObject",
                "s3:GetObject"
            ],
            "Resource": [
                "arn:aws:s3:::sagemaker-*",
                "arn:aws:s3:::sagemaker-*/*"
            ]
        },
        {
            "Effect": "Allow",
            "Action": [
                "sts:GetCallerIdentity"
            ],
            "Resource": "*"
        }
    ]
}
```

## 2. Build the container

Endpoint starting codes are in `app/`. The script will build and push to ecr. 

**The docker only need to be built once**, and after that, when deploying other endpoints, the same docker image can be shared.

In [None]:
cmd = f"REPO_NAMESPACE={REPO_NAMESPACE} REPO_TAG={REPO_TAG} ARCHITECTURE={ARCHITECTURE} ./build-deploy.sh"
print("Runging:\n", cmd)
!{cmd}

## 3. Deploy on SageMaker

Define the model and deploy on SageMaker


### 3.1. create model

In [None]:
model_name = f"{MODEL_ID}--{INSTANCE_TYPE}".replace("/", "-").replace(".", "-").replace(":", "-")

endpoint_model_name = sagemaker.utils.name_from_base(ARTIFACTS_PREFIX + "-" + model_name, short=True)

create_model_response = sagemaker_client.create_model(
    ModelName=endpoint_model_name,
    ExecutionRoleArn=role,
    PrimaryContainer={
        "Image": CONTAINER,
        "Environment": {
            "OLLAMA_MODEL_ID": MODEL_ID,
            "OLLAMA_NUM_PARALLEL": "4",
        },
    },
)
print(create_model_response)
print("endpoint_model_name:", endpoint_model_name)

### 3.2. create endpoint config

In [None]:
endpoint_config_name = sagemaker.utils.name_from_base(ARTIFACTS_PREFIX + "-" + model_name, short=True)

endpoint_config_response = sagemaker_client.create_endpoint_config(
    EndpointConfigName=endpoint_config_name,
    ProductionVariants=[
        {
            "VariantName": "Variant1",
            "ModelName": endpoint_model_name,
            "InstanceType": INSTANCE_TYPE,
            "InitialInstanceCount": 1,
            "ContainerStartupHealthCheckTimeoutInSeconds": 1200,
        },
    ],
)
print(endpoint_config_response)
print("endpoint_config_name:", endpoint_config_name)

### 3.3. create endpoint

In [None]:
endpoint_name = sagemaker.utils.name_from_base(ARTIFACTS_PREFIX + "-" + model_name, short=True)

create_endpoint_response = sagemaker_client.create_endpoint(
    EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name
)
print(create_endpoint_response)
print("endpoint_config_name:", endpoint_name)
while 1:
    status = sagemaker_client.describe_endpoint(EndpointName=endpoint_name)["EndpointStatus"]
    if status != "Creating":
        break
    print(datetime.now().strftime('%Y%m%d-%H:%M:%S') + " status: " + status)
    time.sleep(60)
print("Endpoint:", endpoint_name, status)

## 4. Test

You can invoke your model with SageMaker runtime.

In [None]:
messages = [{
    "role": "user",
    "content": "Hi, who are you!"
}]

max_tokens = 4096

### 4.1 Message api non-stream mode

In [None]:
sagemaker_runtime = boto3.client('runtime.sagemaker')

payload = {
    "model": MODEL_ID,
    "messages": messages,
    "max_tokens": max_tokens,
    "stream": False
}
response = sagemaker_runtime.invoke_endpoint(
    EndpointName=endpoint_name,
    ContentType='application/json',
    Body=json.dumps(payload)
)

print(json.loads(response['Body'].read())["choices"][0]["message"]["content"])

### 4.2 Message api stream mode

In [None]:
payload = {
    "model": MODEL_ID,
    "messages": messages,
    "max_tokens": max_tokens,
    "stream": True
}

response = sagemaker_runtime.invoke_endpoint_with_response_stream(
    EndpointName=endpoint_name,
    ContentType='application/json',
    Body=json.dumps(payload)
)

buffer = ""
for t in response['Body']:
    buffer += t["PayloadPart"]["Bytes"].decode()
    last_idx = 0
    for match in re.finditer(r'(^|\n)data:\s*(\{.+?\})\n', buffer):
        try:
            data = json.loads(match.group(2).strip())
            last_idx = match.span()[1]
            if data["choices"][0]["delta"].get("reasoning", None):
                print(data["choices"][0]["delta"]["reasoning"], end="", flush=True)
            if data["choices"][0]["delta"].get("reasoning_content", None):
                print(data["choices"][0]["delta"]["reasoning_content"], end="", flush=True)
            if data["choices"][0]["delta"].get("content", None):
                print(data["choices"][0]["delta"]["content"], end="", flush=True)
        except (json.JSONDecodeError, KeyError, IndexError) as e:
            pass
    buffer = buffer[last_idx:]
print()

### 4.3 Speed test

In [None]:
sagemaker_runtime = boto3.client('runtime.sagemaker')

messages = [{
    "role": "user",
    "content": "write a poem about Shanghai"
}]

payload = {
    "model": MODEL_ID,
    "messages": messages,
    "max_tokens": 4096,
    "temperature": 0.0,
    "stream": True,
    "stream_options": {"include_usage": True},
}

time_start = time.time()

response = sagemaker_runtime.invoke_endpoint_with_response_stream(
    EndpointName=endpoint_name,
    ContentType='application/json',
    Body=json.dumps(payload)
)
buffer = ""
bytes_buffer = b""
first_token_latency = 0
for t in response['Body']:
    try:
        buffer += (bytes_buffer + t["PayloadPart"]["Bytes"]).decode()
        bytes_buffer = b""
    except UnicodeDecodeError:
        bytes_buffer += t["PayloadPart"]["Bytes"]
    last_idx = 0
    for match in re.finditer(r'(^|\n)data:\s*(\{.+?\})\n', buffer):
        try:
            data = json.loads(match.group(2).strip())
            last_idx = match.span()[1]
            if "usage" in data and data["usage"] is not None:
                input_tokens = data["usage"]["prompt_tokens"]
                output_tokens = data["usage"]["completion_tokens"]
            elif data["choices"][0]["delta"].get("reasoning", None):
                if first_token_latency == 0:
                    first_token_latency = time.time() - time_start
                print(data["choices"][0]["delta"]["reasoning"], end="", flush=True)
            elif data["choices"][0]["delta"].get("reasoning_content", None):
                if first_token_latency == 0:
                    first_token_latency = time.time() - time_start
                print(data["choices"][0]["delta"]["reasoning_content"], end="", flush=True)
            elif data["choices"][0]["delta"].get("content", None):
                if first_token_latency == 0:
                    first_token_latency = time.time() - time_start
                print(data["choices"][0]["delta"]["content"], end="", flush=True)
            else:
                pass
        except (json.JSONDecodeError, KeyError, IndexError) as e:
            # print(data)
            pass
    buffer = buffer[last_idx:]
print()

total_time = time.time() - time_start

print("\n" + "=" * 50)
print("Input_tokens", input_tokens, "Output_tokens", output_tokens)
print(f"First token latency {first_token_latency:.3} seconds")
print(f"Total latency {total_time:.3} seconds")
print(f"Output speed {output_tokens/(total_time-first_token_latency):.3} tokens/seconds")
print("=" * 50)

## 5. Clean

You could delete files using these functions.

In [None]:
def delete_endpoint(endpoint_name):
    try:
        sagemaker_client.delete_endpoint(EndpointName=endpoint_name)
        print(f"Endpoint '{endpoint_name}' deletion initiated.")

        # Wait for the endpoint to be deleted
        while True:
            try:
                sagemaker_client.describe_endpoint(EndpointName=endpoint_name)
                print("Waiting for endpoint to be deleted...")
                time.sleep(30)
            except sagemaker_client.exceptions.ClientError:
                print(f"Endpoint '{endpoint_name}' has been deleted.")
                break
    except sagemaker_client.exceptions.ClientError as e:
        print(f"Error deleting endpoint: {e}")

def delete_endpoint_config(endpoint_config_name):
    try:
        sagemaker_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
        print(f"Endpoint configuration '{endpoint_config_name}' has been deleted.")
    except sagemaker_client.exceptions.ClientError as e:
        print(f"Error deleting endpoint configuration: {e}")

def delete_model(model_name):
    try:
        sagemaker_client.delete_model(ModelName=model_name)
        print(f"Model '{model_name}' has been deleted.")
    except sagemaker_client.exceptions.ClientError as e:
        print(f"Error deleting model: {e}")

        
# delete_endpoint(endpoint_name)

# delete_endpoint_config(endpoint_config_name)

# delete_model(endpoint_model_name)