# Triton on SageMaker with Pruna


## Set up the environment

Installs the dependencies required to package the model and run inferences using Triton server.

Also define the IAM role that will give SageMaker access to the model artifacts and the NVIDIA Triton ECR image.

In [1]:
!pip install -qU pip awscli boto3 sagemaker transformers
!pip install nvidia-pyindex
!pip install tritonclient[http]

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
pruna 0.2.4 requires opentelemetry-api>=1.30.0, but you have opentelemetry-api 1.26.0 which is incompatible.
pruna 0.2.4 requires opentelemetry-exporter-otlp>=1.29.0, but you have opentelemetry-exporter-otlp 1.26.0 which is incompatible.
pruna 0.2.4 requires opentelemetry-sdk>=1.30.0, but you have opentelemetry-sdk 1.26.0 which is incompatible.
pruna 0.2.4 requires torch==2.7.0, but you have torch 2.6.0 which is incompatible.
pruna 0.2.4 requires torchvision==0.22.0, but you have torchvision 0.21.0 which is incompatible.[0m[31m
[0mLooking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com


In [1]:
#import os
#os.environ['AWS_DEFAULT_REGION'] = "us-east-1"
#os.environ['AWS_ACCESS_KEY_ID'] = ""
#os.environ['AWS_SECRET_ACCESS_KEY'] = ""
#os.environ['AWS_SESSION_TOKEN'] = ""

In [2]:
import boto3, json, sagemaker, time
from sagemaker import get_execution_role

sess = boto3.Session()

sm = sess.client("sagemaker")
sagemaker_session = sagemaker.Session(boto_session=sess, sagemaker_client=sm)
role = get_execution_role()
client = boto3.client("sagemaker-runtime")


sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /home/ubuntu/.config/sagemaker/config.yaml


In [3]:
account_id_map = {
    'us-east-1': '785573368785',
    'us-east-2': '007439368137',
    'us-west-1': '710691900526',
    'us-west-2': '301217895009',
    'eu-west-1': '802834080501',
    'eu-west-2': '205493899709',
    'eu-west-3': '254080097072',
    'eu-north-1': '601324751636',
    'eu-south-1': '966458181534',
    'eu-central-1': '746233611703',
    'ap-east-1': '110948597952',
    'ap-south-1': '763008648453',
    'ap-northeast-1': '941853720454',
    'ap-northeast-2': '151534178276',
    'ap-southeast-1': '324986816169',
    'ap-southeast-2': '355873309152',
    'cn-northwest-1': '474822919863',
    'cn-north-1': '472730292857',
    'sa-east-1': '756306329178',
    'ca-central-1': '464438896020',
    'me-south-1': '836785723513',
    'af-south-1': '774647643957'
}
account_id_map['us-east-1'] = "763104351884"
account_id_map['eu-west-1'] = "763104351884"

In [4]:
region = boto3.Session().region_name
if region not in account_id_map.keys():
    raise("UNSUPPORTED REGION")

In [5]:
base = "amazonaws.com.cn" if region.startswith("cn-") else "amazonaws.com"
triton_image_uri = "{account_id}.dkr.ecr.{region}.{base}/sagemaker-tritonserver:24.09-py3".format(
    account_id=account_id_map[region], region=region, base=base
)

## PyTorch NLP-Llama-Instruct

For a simple use case we will take the pre-trained NLP llama-8B model from [Hugging Face](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct), compress it with `pruna` and deploy it on SageMaker with Triton as the model server. We used the pre-configured `config.pbtxt` file provided with this repo [here](./triton-serve-pt/bert/config.pbtxt) to specify model [configuration](https://github.com/triton-inference-server/server/blob/main/docs/model_configuration.md) which Triton uses to load the model. We tar the model directory and upload it to s3 to later create a [SageMaker Model](https://sagemaker.readthedocs.io/en/stable/api/inference/model.html).

**Note**: SageMaker expects the model tarball file to have a top level directory with the same name as the model defined in the `config.pbtxt`.

```
llama
├── 1
│   └── model.py
└── config.pbtxt
```

### PyTorch: Packaging model files and uploading to s3

Copy model into triton structure

In [6]:
!mkdir -p triton-serve-pt/llama/1/
!cp workspace/model.py triton-serve-pt/llama/1/

Create dedicated conda env, and move it to triton structure

In [7]:
!bash workspace/create_hf_env.sh
!mv hf_env.tar.gz triton-serve-pt/llama/

Channels:
 - defaults
Platform: linux-64
Collecting package metadata (repodata.json): done
Solving environment: done


    current version: 25.1.1
    latest version: 25.5.1

Please update conda by running

    $ conda update -n base -c defaults conda



## Package Plan ##

  environment location: /home/ubuntu/miniconda3/envs/hf_env

  added / updated specs:
    - python=3.10


The following NEW packages will be INSTALLED:

  _libgcc_mutex      pkgs/main/linux-64::_libgcc_mutex-0.1-main 
  _openmp_mutex      pkgs/main/linux-64::_openmp_mutex-5.1-1_gnu 
  bzip2              pkgs/main/linux-64::bzip2-1.0.8-h5eee18b_6 
  ca-certificates    pkgs/main/linux-64::ca-certificates-2025.2.25-h06a4308_0 
  expat              pkgs/main/linux-64::expat-2.7.1-h6a678d5_0 
  ld_impl_linux-64   pkgs/main/linux-64::ld_impl_linux-64-2.40-h12ee557_0 
  libffi             pkgs/main/linux-64::libffi-3.4.4-h6a678d5_1 
  libgcc-ng          pkgs/main/linux-64::libgcc-ng-11.2.0-h1234567_1 
  libgomp            

Locally save the model

In [14]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from pathlib import Path

model_id = "NousResearch/Llama-3.2-1B"
output_dir = Path("triton-serve-pt/llama/1")  # points to the version folder

# Download and save locally
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)

tokenizer.save_pretrained(output_dir / "tokenizer")
model.save_pretrained(output_dir / "model")


tokenizer_config.json:   0%|          | 0.00/50.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/301 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/843 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.47G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/186 [00:00<?, ?B/s]

Package the model and env

In [6]:
!tar -C triton-serve-pt/ -czf llama.tar.gz llama
model_uri = sagemaker_session.upload_data(path="llama.tar.gz", key_prefix="triton-serve-pt")

### PyTorch: Create SageMaker Endpoint

We start off by creating a [sagemaker model](https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateModel.html) from the model files we uploaded to s3 in the previous step.

In this step we also provide an additional Environment Variable i.e. `SAGEMAKER_TRITON_DEFAULT_MODEL_NAME` which specifies the name of the model to be loaded by Triton. **The value of this key should match the folder name in the model package uploaded to s3**. This variable is optional in case of a single model. In case of ensemble models, this key **has to be** specified for Triton to startup in SageMaker.

Additionally, customers can set `SAGEMAKER_TRITON_BUFFER_MANAGER_THREAD_COUNT` and `SAGEMAKER_TRITON_THREAD_COUNT` for optimizing the thread counts.

In [7]:
print(triton_image_uri)
print(model_uri)

763104351884.dkr.ecr.us-east-1.amazonaws.com/sagemaker-tritonserver:24.09-py3
s3://sagemaker-us-east-1-992382637587/triton-serve-pt/llama.tar.gz


In [8]:
sm_model_name = "triton-nlp-llama-pt-" + time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())

container = {
    "Image": triton_image_uri,
    "ModelDataUrl": model_uri,
    "Environment": {"SAGEMAKER_TRITON_DEFAULT_MODEL_NAME": "llama"},
}

create_model_response = sm.create_model(
    ModelName=sm_model_name, ExecutionRoleArn="arn:aws:iam::992382637587:role/sharedservices-sagemaker-role", PrimaryContainer=container
)

print("Model Arn: " + create_model_response["ModelArn"])

Model Arn: arn:aws:sagemaker:us-east-1:992382637587:model/triton-nlp-llama-pt-2025-06-11-12-38-59


Using the model above, we create an [endpoint configuration](https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateEndpointConfig.html) where we can specify the type and number of instances we want in the endpoint.

In [9]:
endpoint_config_name = "triton-nlp-llama-pt-" + time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())

create_endpoint_config_response = sm.create_endpoint_config(
    EndpointConfigName=endpoint_config_name,
    ProductionVariants=[
        {
            "InstanceType": "ml.g4dn.2xlarge", #"ml.g6e.2xlarge",
            "InitialVariantWeight": 1,
            "InitialInstanceCount": 1,
            "ModelName": sm_model_name,
            "VariantName": "AllTraffic",
        }
    ],
)

print("Endpoint Config Arn: " + create_endpoint_config_response["EndpointConfigArn"])

Endpoint Config Arn: arn:aws:sagemaker:us-east-1:992382637587:endpoint-config/triton-nlp-llama-pt-2025-06-11-12-39-02


Using the above endpoint configuration we create a new sagemaker endpoint and wait for the deployment to finish. The status will change to **InService** once the deployment is successful.

In [11]:
endpoint_name = "triton-nlp-llama-pt-" + time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())

create_endpoint_response = sm.create_endpoint(
    EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name
)

print("Endpoint Arn: " + create_endpoint_response["EndpointArn"])

Endpoint Arn: arn:aws:sagemaker:us-east-1:992382637587:endpoint/triton-nlp-llama-pt-2025-06-11-12-50-50


In [13]:
import boto3
import time
import os

# --- Configuration ---
endpoint_name = ""  # REQUIRED: Replace with your endpoint name
region_name = region   # REQUIRED: Replace with your AWS region
poll_interval_seconds = 30        # How often to check status and logs

# --- Initialization ---
# Ensure AWS credentials are configured (e.g., via environment variables, AWS CLI, or IAM role)
# If running this snippet outside the notebook where 'sess' was defined,
# you'll need to create a new boto3 session and clients.
try:
    # Attempt to use existing session if available (e.g. if run in same kernel after notebook cells)
    # This is a common pattern but might not always be the case depending on execution environment.
    if 'sess' not in locals() or sess is None:
        print("Creating new Boto3 session.")
        sess = boto3.Session(region_name=region_name)
except NameError:
    print("Creating new Boto3 session (sess not defined).")
    sess = boto3.Session(region_name=region_name)

sm_client = sess.client("sagemaker")
logs_client = sess.client("logs")

last_event_timestamps = {}  # To store {stream_name: last_timestamp_processed}
log_group_name = f"/aws/sagemaker/Endpoints/{endpoint_name}"

print(f"Monitoring endpoint: {endpoint_name} in region {region_name}")
print(f"Log group: {log_group_name}")
print(f"Polling every {poll_interval_seconds} seconds.\n")

# --- Monitoring Loop ---
while True:
    try:
        resp = sm_client.describe_endpoint(EndpointName=endpoint_name)
        status = resp["EndpointStatus"]
    except Exception as e:
        print(f"Error describing endpoint {endpoint_name}: {e}")
        print("Stopping monitoring.")
        break

    current_time_utc_str = time.strftime('%Y-%m-%d %H:%M:%S UTC', time.gmtime())
    print(f"\n--- {current_time_utc_str} ---")
    print(f"Endpoint Status: {status}")

    if status == "Failed":
        print(f"FailureReason: {resp.get('FailureReason', 'N/A')}")
        break
    if status == "InService":
        print("Endpoint is InService.")
        break
    if status not in ["Creating", "Updating"]: # Other terminal states
        print(f"Endpoint is in a terminal state: {status}. Stopping monitoring.")
        break

    # Fetch and print logs
    try:
        streams_response = logs_client.describe_log_streams(
            logGroupName=log_group_name,
            orderBy='LastEventTime',
            descending=True,
            limit=5  # Check a few most recently active streams
        )

        active_streams_found_this_poll = False
        for stream in streams_response.get('logStreams', []):
            active_streams_found_this_poll = True
            stream_name = stream['logStreamName']
            
            start_time_ms = last_event_timestamps.get(stream_name, 0) + 1
            
            next_token_for_stream_poll = None
            stream_had_new_events_this_poll = False

            # Paginate through events for this stream in this poll interval
            while True:
                event_fetch_args = {
                    'logGroupName': log_group_name,
                    'logStreamName': stream_name,
                    'startTime': start_time_ms,
                    'limit': 100, # Fetch up to 100 events per call
                    'startFromHead': True
                }
                if next_token_for_stream_poll:
                    event_fetch_args['nextToken'] = next_token_for_stream_poll

                try:
                    events_response = logs_client.get_log_events(**event_fetch_args)
                except logs_client.exceptions.ResourceNotFoundException:
                    # Stream might have just been created and not yet available for get_log_events
                    # Or a transient issue.
                    print(f"    Log stream {stream_name} not found during get_log_events. Will retry.")
                    break 
                except Exception as e_get:
                    print(f"    Error getting events for stream {stream_name}: {e_get}")
                    break 

                fetched_events_batch = events_response.get('events', [])
                
                if fetched_events_batch:
                    if not stream_had_new_events_this_poll:
                        print(f"  Logs from stream: {stream_name}")
                        stream_had_new_events_this_poll = True

                    for event in fetched_events_batch:
                        event_ts_ms = event['timestamp']
                        # Only print if newer than last seen for this stream
                        if event_ts_ms >= start_time_ms:
                             event_time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.gmtime(event_ts_ms / 1000))
                             print(f"    {event_time_str} UTC: {event['message'].strip()}")
                    
                    last_event_timestamps[stream_name] = fetched_events_batch[-1]['timestamp']
                
                next_token_for_stream_poll = events_response.get('nextForwardToken')
                if not next_token_for_stream_poll or not fetched_events_batch:
                    break
        
        if not active_streams_found_this_poll and status == "Creating":
             print(f"  No active log streams found yet for {log_group_name}.")

    except logs_client.exceptions.ResourceNotFoundException:
        print(f"  Log group {log_group_name} not found yet. This is normal during initial endpoint creation.")
    except Exception as e_logs:
        print(f"  An error occurred while fetching logs: {str(e_logs)}")

    time.sleep(poll_interval_seconds)

print("\n--- Monitoring finished ---")

Monitoring endpoint: triton-nlp-llama-pt-2025-06-11-12-50-50 in region us-east-1
Log group: /aws/sagemaker/Endpoints/triton-nlp-llama-pt-2025-06-11-12-50-50
Polling every 30 seconds.


--- 2025-06-11 12:51:18 UTC ---
Endpoint Status: Creating
  Log group /aws/sagemaker/Endpoints/triton-nlp-llama-pt-2025-06-11-12-50-50 not found yet. This is normal during initial endpoint creation.

--- 2025-06-11 12:51:48 UTC ---
Endpoint Status: Creating
  Log group /aws/sagemaker/Endpoints/triton-nlp-llama-pt-2025-06-11-12-50-50 not found yet. This is normal during initial endpoint creation.

--- 2025-06-11 12:52:19 UTC ---
Endpoint Status: Creating
  Log group /aws/sagemaker/Endpoints/triton-nlp-llama-pt-2025-06-11-12-50-50 not found yet. This is normal during initial endpoint creation.

--- 2025-06-11 12:52:49 UTC ---
Endpoint Status: Creating
  Log group /aws/sagemaker/Endpoints/triton-nlp-llama-pt-2025-06-11-12-50-50 not found yet. This is normal during initial endpoint creation.

--- 2025-06-11 

In [14]:
resp = sm.describe_endpoint(EndpointName=endpoint_name)
status = resp["EndpointStatus"]
print("Status: " + status)

while status == "Creating":
    time.sleep(60)
    resp = sm.describe_endpoint(EndpointName=endpoint_name)
    status = resp["EndpointStatus"]
    print("Status: " + status)
    print(resp)

print("Arn: " + resp["EndpointArn"])
print("Status: " + status)

Status: InService
Arn: arn:aws:sagemaker:us-east-1:992382637587:endpoint/triton-nlp-llama-pt-2025-06-11-12-50-50
Status: InService


### PyTorch: Run inference

Once we have the endpoint running we can use a sample text to do an inference using json as the payload format. For inference request format, Triton uses the KFServing community standard [inference protocols](https://github.com/triton-inference-server/server/blob/main/docs/protocol/README.md).

In [15]:
def get_text_payload(text, max_tokens=100):
    payload = {}
    payload["inputs"] = []
    payload["inputs"].append(
        {
            "name": "INPUT_TEXT",
            "shape": [1, 1],
            "datatype": "BYTES",
            "data": [text],
        }
    )
    payload["inputs"].append(
        {
            "name": "MAX_TOKENS",
            "shape": [1, 1],
            "datatype": "INT32",
            "data": [[max_tokens]],
        }
    )
    return payload

In [17]:
text_triton = ["Triton Inference Server provides a cloud and edge inferencing solution optimized for both CPUs and GPUs."]

payload = get_text_payload(text_triton, 100)

response = client.invoke_endpoint(
    EndpointName=endpoint_name, ContentType="application/octet-stream", Body=json.dumps(payload)
)

print(json.loads(response["Body"].read().decode("utf8")))

{'model_name': 'llama', 'model_version': '1', 'outputs': [{'name': 'OUTPUT_TEXT', 'datatype': 'BYTES', 'shape': [1, 1], 'data': ['It delivers high-end inference capabilities on-demand. The Inference Engine (Triton-IE) is a general-purpose, open-source framework that offers various features for developers to accelerate AI applications. The Core Library (Triton-Core) provides a set of machine learning primitives implemented in C/C++ to accelerate inference processing in the cloud or edge. The Web UI is a graphical interface for deploying and running inference tasks.\nTriton Inference Platform is the underlying framework to build AI applications. It']}]}


In [19]:
text_triton = ["Pruna AI offers the best compresion methods for LLMs."]

payload = get_text_payload(text_triton, 100)

response = client.invoke_endpoint(
    EndpointName=endpoint_name, ContentType="application/octet-stream", Body=json.dumps(payload)
)

print(json.loads(response["Body"].read().decode("utf8")))

{'model_name': 'llama', 'model_version': '1', 'outputs': [{'name': 'OUTPUT_TEXT', 'datatype': 'BYTES', 'shape': [1, 1], 'data': ['We apply a suite of quality checking methods, and we ensure the quality of the result. Once the quality is verified, your LLM can be deployed.\nWe are currently offering the highest quality service in the industry. We are looking for potential partners to develop additional models and improve the quality.\nFor questions or further information, please contact us at contact@PrunAI.org\nIn case the model does not have enough capacity, we will train the model further with higher quality questions. The questions may come']}]}


### PyTorch: Terminate endpoint and clean up artifacts

In [20]:
sm.delete_model(ModelName=sm_model_name)
sm.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
sm.delete_endpoint(EndpointName=endpoint_name)

{'ResponseMetadata': {'RequestId': 'd7a79939-1178-42a8-ba9a-0e4e44970332',
  'HTTPStatusCode': 200,
  'HTTPHeaders': {'x-amzn-requestid': 'd7a79939-1178-42a8-ba9a-0e4e44970332',
   'content-type': 'application/x-amz-json-1.1',
   'date': 'Wed, 11 Jun 2025 13:19:05 GMT',
   'content-length': '0'},
  'RetryAttempts': 0}}