# Triton on SageMaker - Deploying on Neuron/Inferentia instance type


---

This notebook's CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook. 

![This us-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-west-2/sagemaker-triton|resnet50|triton_resnet50.ipynb)

---


[Amazon SageMaker](https://aws.amazon.com/sagemaker/) is a fully managed service for data science and machine learning workflows. It helps data scientists and developers to prepare, build, train, and deploy high-quality ML models quickly by bringing together a broad set of capabilities purpose-built for ML.

Now, [NVIDIA Triton Inference Server](https://github.com/triton-inference-server/server/) can be used to serve models for inference in Amazon SageMaker. Thanks to the new NVIDIA Triton container image, you can easily serve ML models and benefit from the performance optimizations, dynamic batching, and multi-framework support provided by NVIDIA Triton. Triton helps maximize the utilization of GPU and CPU, further lowering the cost of inference.

This notebook was tested on an Amazon SageMaker notebook instance of type `inf2.24xlarge`. This notebook uses scripts that help 

## Contents
1. [Introduction to NVIDIA Triton Server](#Introduction-to-NVIDIA-Triton-Server)
1. [Set up the environment](#Set-up-the-environment)
1. [Add utility methods for preparing request payload](#Add-utility-methods-for-preparing-request-payload)
1. [Basic: PyTorch Resnet50](#PyTorch-Resnet50)
  1. [PyTorch: Packaging model files and uploading to s3](#PyTorch:-Packaging-model-files-and-uploading-to-s3)
  1. [PyTorch: Create SageMaker Endpoint](#PyTorch:-Create-SageMaker-Endpoint)
  1. [PyTorch: Run inference](#PyTorch:-Run-inference)
  1. [PyTorch: Terminate endpoint and clean up artifacts](#PyTorch:-Terminate-endpoint-and-clean-up-artifacts)
1. [Advanced: TensorRT Resnet50](#TensorRT-Resnet50)
  1. [TensorRT: Packaging model files and uploading to s3](#TensorRT:-Packaging-model-files-and-uploading-to-s3)
  1. [TensorRT: Create SageMaker Endpoint](#TensorRT:-Create-SageMaker-Endpoint)
  1. [TensorRT: Run inference](#TensorRT:-Run-inference)
  1. [TensorRT: Terminate endpoint and clean up artifacts](#TensorRT:-Terminate-endpoint-and-clean-up-artifacts)

## Introduction to NVIDIA Triton Server

[NVIDIA Triton Inference Server](https://github.com/triton-inference-server/server/) was developed specifically to enable scalable, cost-effective, and easy deployment of models in production. NVIDIA Triton Inference Server is open-source inference serving software that simplifies the inference serving process and provides high inference performance.

Some key features of Triton are:
* **Support for Multiple frameworks**: Triton can be used to deploy models from all major frameworks. Triton supports TensorFlow GraphDef, TensorFlow SavedModel, ONNX, PyTorch TorchScript, TensorRT, RAPIDS FIL for tree based models, and OpenVINO model formats. 
* **Model pipelines**: Triton model ensemble represents a pipeline of one or more models or pre/post processing logic and the connection of input and output tensors between them. A single inference request to an ensemble will trigger the execution of the entire pipeline.
* **Concurrent model execution**: Multiple models (or multiple instances of the same model) can run simultaneously on the same GPU or on multiple GPUs for different model management needs.
* **Dynamic batching**: For models that support batching, Triton has multiple built-in scheduling and batching algorithms that combine individual inference requests together to improve inference throughput. These scheduling and batching decisions are transparent to the client requesting inference.
* **Diverse CPUs and GPUs**: The models can be executed on CPUs or GPUs for maximum flexibility and to support heterogeneous computing requirements.

**Note**: This initial release of NVIDIA Triton on SageMaker will only support a single model. Future releases will have multi-model support. A minimal `config.pbtxt` configuration file is **required** in the model artifacts. This release doesn't support inferring the model config automatically.

## Set up the environment

Installs the dependencies required to package the model and run inferences using Triton server.

Also define the IAM role that will give SageMaker access to the model artifacts and the NVIDIA Triton ECR image.

In [None]:
!pip install -qU pip awscli boto3 sagemaker
!pip install nvidia-pyindex
!pip install tritonclient[http]

In [34]:
import boto3, json, sagemaker, time
from sagemaker import get_execution_role

sm_client = boto3.client(service_name="sagemaker",region_name="us-east-2")
runtime_sm_client = boto3.client("sagemaker-runtime",region_name="us-east-2")
sagemaker_session = sagemaker.Session(boto_session=boto3.Session(region_name="us-east-2"))
role = get_execution_role()

In [2]:
account_id_map = {
    'us-east-1': '785573368785',
    'us-east-2': '007439368137',
    'us-west-1': '710691900526',
    'us-west-2': '301217895009',
    'eu-west-1': '802834080501',
    'eu-west-2': '205493899709',
    'eu-west-3': '254080097072',
    'eu-north-1': '601324751636',
    'eu-south-1': '966458181534',
    'eu-central-1': '746233611703',
    'ap-east-1': '110948597952',
    'ap-south-1': '763008648453',
    'ap-northeast-1': '941853720454',
    'ap-northeast-2': '151534178276',
    'ap-southeast-1': '324986816169',
    'ap-southeast-2': '355873309152',
    'cn-northwest-1': '474822919863',
    'cn-north-1': '472730292857',
    'sa-east-1': '756306329178',
    'ca-central-1': '464438896020',
    'me-south-1': '836785723513',
    'af-south-1': '774647643957'
}

In [31]:
region = boto3.Session().region_name
if region not in account_id_map.keys():
    raise("UNSUPPORTED REGION")

In [32]:
base = "amazonaws.com.cn" if region.startswith("cn-") else "amazonaws.com"
triton_image_uri = "{account_id}.dkr.ecr.{region}.{base}/sagemaker-tritonserver:23.03-py3".format(
    account_id=account_id_map[region], region=region, base=base
)

In [33]:
triton_image_uri

'301217895009.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tritonserver:23.03-py3'

**Note: update the 'FROM' base image in Dockerfile located in docker/ folder as per above**

## Add utility methods for preparing request payload

The following method transforms a sample image we will be using for inference into the payload that can be sent for inference to the Triton server.

In [6]:
import numpy as np
from PIL import Image

s3_client = boto3.client('s3')
s3_client.download_file(
    "sagemaker-sample-files",
    "datasets/image/pets/shiba_inu_dog.jpg",
    "shiba_inu_dog.jpg"
)

def get_sample_image():
    image_path = "./shiba_inu_dog.jpg"
    img = Image.open(image_path).convert("RGB")
    img = img.resize((224, 224))
    img = (np.array(img).astype(np.float32) / 255) - np.array(
        [0.485, 0.456, 0.406], dtype=np.float32
    ).reshape(1, 1, 3)
    img = img / np.array([0.229, 0.224, 0.225], dtype=np.float32).reshape(1, 1, 3)
    img = np.transpose(img, (2, 0, 1))
    return img.tolist()

The `tritonclient` package provides utility methods to generate the payload without having to know the details of the specification. We'll use the following methods to convert our inference request into a binary format which provides lower latencies for inference.

In [7]:
import tritonclient.http as httpclient


def _get_sample_image_binary(input_name, output_name):
    inputs = []
    outputs = []
    inputs.append(httpclient.InferInput(input_name, [1, 3, 224, 224], "FP32"))
    input_data = np.array(get_sample_image(), dtype=np.float32)
    input_data = np.expand_dims(input_data, axis=0)
    inputs[0].set_data_from_numpy(input_data, binary_data=True)
    outputs.append(httpclient.InferRequestedOutput(output_name, binary_data=True))
    request_body, header_length = httpclient.InferenceServerClient.generate_request_body(
        inputs, outputs=outputs
    )
    return request_body, header_length


def get_sample_image_binary_pt():
    return _get_sample_image_binary("INPUT__0", "OUTPUT__0")


def get_sample_image_binary_trt():
    return _get_sample_image_binary("input", "output")

## Set up the Tritonserver Container with neuronx libraries

In [8]:
!aws ecr get-login-password --region {region} | docker login --username AWS --password-stdin {account_id_map[region]}.dkr.ecr.us-west-2.amazonaws.com

https://docs.docker.com/engine/reference/commandline/login/#credentials-store

Login Succeeded


In [2]:
!mkdir -p docker/mylib

Sockets and rules required by the container to interact with neuron hardware

In [3]:
!cp -R /lib/udev/rules.d/* docker/mylib/

In [7]:
!docker build --no-cache -t tritonserver-neuronx docker/ --build-arg SM_TRITON_IMAGE_URI={triton_image_uri} 

[1A[1B[0G[?25l[+] Building 0.0s (0/2)                                                         
[?25h[1A[0G[?25l[+] Building 0.2s (5/8)                                                         
[34m => [internal] load build definition from Dockerfile                       0.0s
[0m[34m => => transferring dockerfile: 557B                                       0.0s
[0m[34m => [internal] load .dockerignore                                          0.0s
[0m[34m => => transferring context: 2B                                            0.0s
[0m[34m => [internal] load metadata for 301217895009.dkr.ecr.us-west-2.amazonaws  0.0s
[0m[34m => CACHED [1/4] FROM 301217895009.dkr.ecr.us-west-2.amazonaws.com/sagema  0.0s
[0m => [2/4] RUN mkdir -p /mylib/udev/rules.d/                                0.1s
[34m => [internal] load build context                                          0.0s
[0m[34m => => transferring context: 6.32kB                                        0.0s
[0m[?25h[

In [9]:
# Upload the container to ECR for SageMaker to consume
# inf2 instances are available in us-east-2
curr_account_id = boto3.client('sts').get_caller_identity().get('Account')
inf2_region = "us-east-2"

triton_neuronx_image_uri = f"{curr_account_id}.dkr.ecr.{inf2_region}.amazonaws.com/sagemaker-tritonserver:23.03-py3"

In [16]:
!docker tag tritonserver-neuronx:latest {triton_neuronx_image_uri}

In [308]:
!aws ecr get-login-password --region {inf2_region} | docker login --username AWS --password-stdin {curr_account_id}.dkr.ecr.{inf2_region}.amazonaws.com

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
https://docs.docker.com/engine/reference/commandline/login/#credentials-store

Login Succeeded


In [20]:
!docker push {triton_neuronx_image_uri}

The push refers to repository [613283112109.dkr.ecr.us-east-2.amazonaws.com/sagemaker-tritonserver]

[1B23d90fbd: Preparing 
[1Bb0c84c5c: Preparing 
[1B16b88c51: Preparing 
[1B2109ad51: Preparing 
[1Be02b07d1: Preparing 
[1Bdd89826b: Preparing 
[1B39da0441: Preparing 
[1Bc24e8a7b: Preparing 
[1B9c0c6805: Preparing 
[1B6e6cebd2: Preparing 
[1Be922ddea: Preparing 
[1B60ef4761: Preparing 
[1B11e02d36: Preparing 
[1B3e32b9e1: Preparing 
[1Bca5d6255: Preparing 
[1B6f8d460f: Preparing 
[1Bbf18a086: Preparing 
[1B5fc56587: Preparing 
[1B474188a6: Preparing 
[1Bdb6c3896: Preparing 
[1Bb7fd341b: Preparing 
[1B232d1291: Preparing 
[1B3a4224a1: Preparing 
[1Bc02687ba: Preparing 
[1Be352f364: Preparing 
[1B7d3bab63: Preparing 
[1Baaf8cc7e: Preparing 
[1B6de4f64c: Preparing 
[1Bb45bef95: Preparing 
[30B3d90fbd: Pushed   11.56GB/11.53GB6A[2K[30A[2K[21A[2K[18A[2K[15A[2K[13A[2K[10A[2K[6A[2K[1A[2K[30A[2K[30A[2K[30A[2K[30A[2K[30A[2K[30A[2K[30A

In [21]:
triton_neuronx_image_uri

'613283112109.dkr.ecr.us-east-2.amazonaws.com/sagemaker-tritonserver:23.03-py3'

## PyTorch-Neuronx 

In [22]:
!pip install torch-neuronx transformers-neuronx

Looking in indexes: https://pypi.org/simple, https://pip.repos.neuron.amazonaws.com, https://pypi.ngc.nvidia.com
Collecting transformers-neuronx
  Downloading https://pip.repos.neuron.amazonaws.com/transformers-neuronx/transformers_neuronx-0.3.32-py3-none-any.whl (71 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m71.7/71.7 kB[0m [31m38.9 MB/s[0m eta [36m0:00:00[0m
Collecting accelerate (from transformers-neuronx)
  Downloading accelerate-0.19.0-py3-none-any.whl (219 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m219.1/219.1 kB[0m [31m34.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting transformers (from transformers-neuronx)
  Downloading transformers-4.29.2-py3-none-any.whl (7.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.1/7.1 MB[0m [31m275.4 MB/s[0m eta [36m0:00:00[0m
Collecting filelock (from transformers->transformers-neuronx)
  Downloading filelock-3.12.0-py3-none-any.whl (10 kB)
Collecting huggingface-

In [348]:
import os
import urllib
from PIL import Image

import torch
import torch_neuronx
from torchvision import models
from torchvision.transforms import functional


def get_image(batch_size=1, image_shape=(224, 224)):
    # Get an example input
    filename = "000000039769.jpg"
    if not os.path.exists(filename):
        url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        urllib.request.urlretrieve(url, filename)
    image = Image.open(filename).convert('RGB')
    image = functional.resize(image, (image_shape))
    image = functional.to_tensor(image)
    image = torch.unsqueeze(image, 0)
    image = torch.repeat_interleave(image, batch_size, 0)
    return (image, )


# Create the model
model = models.resnet50(pretrained=True)
model.eval()

# Get an example input
image = get_image()

# Run inference on CPU
output_cpu = model(*image)

# Compile the model
model_neuron = torch_neuronx.trace(model, image)

# Save the TorchScript for inference deployment
filename = 'model.pt'
torch.jit.save(model_neuron, filename)



huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)




In [359]:
!mkdir -p triton-pt-inf2/resnet/1
!mv model.pt triton-pt-inf2/resnet/

# Generate a pytorch-neuronx triton model.py and config.pbtxt file
!git clone https://github.com/nskool/python_backend && cd python_backend && git checkout add_inf2_support
!cd python_backend && python3 inferentia/scripts/gen_triton_model.py --inf2 --model_type pytorch --triton_input INPUT__0,FP32,3x224x224 --triton_output OUTPUT__0,FP32,1000 --compiled_model model.pt --neuron_core_range 0:1 --triton_model_dir ../triton-pt-inf2/resnet --enable_dynamic_batching --max_batch_size 4

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
mv: cannot stat 'model.pt': No such file or directory
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Cloning into 'python_backend'...
remote: Enumerating objects: 1536, done.[K
remote: Counting obj

Move the generated config.pbtxt and model.py to the model folder

In [360]:
!tar -C triton-pt-inf2/ -czf triton-pt-inf2.tar.gz resnet
model_uri = sagemaker_session.upload_data(path="triton-pt-inf2.tar.gz", key_prefix="triton-inf2-models")

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [361]:
model_uri

's3://sagemaker-us-east-2-613283112109/triton-inf2-models/triton-pt-inf2.tar.gz'

Load the model

In [362]:
sm_model_name = "triton-resnet-inf2-pt-" + time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())

container = {
    "Image": triton_neuronx_image_uri,
    "ModelDataUrl": model_uri,
    "Environment": {"SAGEMAKER_TRITON_DEFAULT_MODEL_NAME": "resnet"},
}

create_model_response = sm_client.create_model(
    ModelName=sm_model_name, ExecutionRoleArn=role, PrimaryContainer=container
)

print("Model Arn: " + create_model_response["ModelArn"])

Model Arn: arn:aws:sagemaker:us-east-2:613283112109:model/triton-resnet-inf2-pt-2023-06-13-17-03-31


Create endpoint configuration

In [363]:
endpoint_config_name = "triton-resnet-inf2-pt-" + time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())

create_endpoint_config_response = sm_client.create_endpoint_config(
    EndpointConfigName=endpoint_config_name,
    ProductionVariants=[
        {
            "InstanceType": "ml.inf2.xlarge",
            "InitialVariantWeight": 1,
            "InitialInstanceCount": 1,
            "ModelName": sm_model_name,
            "VariantName": "AllTraffic",
        }
    ],
)

print("Endpoint Config Arn: " + create_endpoint_config_response["EndpointConfigArn"])

Endpoint Config Arn: arn:aws:sagemaker:us-east-2:613283112109:endpoint-config/triton-resnet-inf2-pt-2023-06-13-17-03-32


In [364]:
endpoint_name = "triton-resnet-inf2-pt-" + time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())

create_endpoint_response = sm_client.create_endpoint(
    EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name
)

print("Endpoint Arn: " + create_endpoint_response["EndpointArn"])

Endpoint Arn: arn:aws:sagemaker:us-east-2:613283112109:endpoint/triton-resnet-inf2-pt-2023-06-13-17-03-32


In [365]:
resp = sm_client.describe_endpoint(EndpointName=endpoint_name)
status = resp["EndpointStatus"]
print("Status: " + status)

while status == "Creating":
    time.sleep(60)
    resp = sm_client.describe_endpoint(EndpointName=endpoint_name)
    status = resp["EndpointStatus"]
    print("Status: " + status)

print("Arn: " + resp["EndpointArn"])
print("Status: " + status)

Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: InService
Arn: arn:aws:sagemaker:us-east-2:613283112109:endpoint/triton-resnet-inf2-pt-2023-06-13-17-03-32
Status: InService


Run regular inference

In [366]:
payload = {
    "inputs": [
        {
            "name": "INPUT__0",
            "shape": [1, 3, 224, 224],
            "datatype": "FP32",
            "data": get_sample_image(),
        }
    ]
}

response = runtime_sm_client.invoke_endpoint(
    EndpointName=endpoint_name, ContentType="application/octet-stream", Body=json.dumps(payload)
)

print(json.loads(response["Body"].read().decode("utf8")))

{'model_name': 'resnet', 'model_version': '1', 'outputs': [{'name': 'OUTPUT__0', 'datatype': 'FP32', 'shape': [1, 1000], 'data': [-1.2339320182800293, 0.26743221282958984, -1.7244561910629272, -2.9731521606445312, -2.9897332191467285, -1.9399539232254028, -2.4892311096191406, -0.3647805154323578, 0.04905950650572777, -1.619198203086853, -3.850961446762085, -1.9598301649093628, -3.635099411010742, -5.3621416091918945, -1.686911940574646, -2.895658493041992, -1.8479063510894775, 0.3577522039413452, -0.6747124791145325, -1.7334718704223633, -4.42026948928833, -2.315218925476074, -3.1044459342956543, -2.2334206104278564, -2.116276741027832, -3.471982717514038, -4.789962291717529, -2.852299928665161, -3.040926694869995, -3.106398105621338, -2.429238796234131, -0.4101293981075287, -2.3326222896575928, -3.5598886013031006, -0.7902666330337524, -4.283840179443359, -1.8385157585144043, -4.4796342849731445, -2.8223791122436523, -2.196667194366455, -0.19892823696136475, -3.900879144668579, -0.906

Run inference with binary payload

In [367]:
request_body, header_length = get_sample_image_binary_pt()

response = runtime_sm_client.invoke_endpoint(
    EndpointName=endpoint_name,
    ContentType="application/vnd.sagemaker-triton.binary+json;json-header-size={}".format(
        header_length
    ),
    Body=request_body,
)

# Parse json header size length from the response
header_length_prefix = "application/vnd.sagemaker-triton.binary+json;json-header-size="
header_length_str = response["ContentType"][len(header_length_prefix) :]

# Read response body
result = httpclient.InferenceServerClient.parse_response_body(
    response["Body"].read(), header_length=int(header_length_str)
)
output0_data = result.as_numpy("OUTPUT__0")
print(output0_data)

[[-1.23393202e+00  2.67432213e-01 -1.72445619e+00 -2.97315216e+00
  -2.98973322e+00 -1.93995392e+00 -2.48923111e+00 -3.64780515e-01
   4.90595065e-02 -1.61919820e+00 -3.85096145e+00 -1.95983016e+00
  -3.63509941e+00 -5.36214161e+00 -1.68691194e+00 -2.89565849e+00
  -1.84790635e+00  3.57752204e-01 -6.74712479e-01 -1.73347187e+00
  -4.42026949e+00 -2.31521893e+00 -3.10444593e+00 -2.23342061e+00
  -2.11627674e+00 -3.47198272e+00 -4.78996229e+00 -2.85229993e+00
  -3.04092669e+00 -3.10639811e+00 -2.42923880e+00 -4.10129398e-01
  -2.33262229e+00 -3.55988860e+00 -7.90266633e-01 -4.28384018e+00
  -1.83851576e+00 -4.47963428e+00 -2.82237911e+00 -2.19666719e+00
  -1.98928237e-01 -3.90087914e+00 -9.06249762e-01 -2.96887517e+00
  -2.12057996e+00 -3.92205238e+00 -1.33538949e+00 -2.36063385e+00
  -5.24708891e+00 -2.65685511e+00 -1.33713758e+00 -1.03983080e+00
  -1.27821445e+00 -1.48820710e+00 -1.74384296e+00 -4.52535003e-01
  -2.02982917e-01 -2.99495125e+00 -2.81305623e+00 -1.37612796e+00
  -9.06762

## Terminate and cleanup

In [None]:
sm_client.delete_endpoint(EndpointName=endpoint_name)
sm_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
sm_client.delete_model(ModelName=sm_model_name)

# Run a Transformers based OPT-125M model
In this section we will run a pre-trained neuron-traced OPT model on an inf2.24xl instance type

Create and upload OPT-125M model on SageMaker

### Note: the following commands to compile/trace the model should be run on an inf2 instance

In [None]:
!pip install transformers-neuronx --extra-index-url=https://pip.repos.neuron.amazonaws.com

In [None]:
import os
import time
import argparse
import torch
from transformers_neuronx.module import save_pretrained_split
from transformers_neuronx.dtypes import to_torch_dtype
from transformers_neuronx.opt.model import OPTForSampling
from transformers import AutoModelForCausalLM, AutoTokenizer

def amp_callback(model, dtype):
    # cast attention and mlp to low precisions only; layernorms stay as f32
    for block in model.model.decoder.layers:
        block.self_attn.to(dtype)
        block.fc1.to(dtype)
        block.fc2.to(dtype)
    model.lm_head.to(dtype)
​
​
def compile(model_name, batch_size, compiler_args, amp='bf16', tp_degree=2, n_positions=2048, unroll=None):
​
    os.environ["NEURON_CC_FLAGS"] = compiler_args
    os.environ["NEURONX_DUMP_TO"] = "opt-12m-tp12"
​
    # Split the GPT-J model for faster loading
    model_dir = f"opt-125m-model"
    if not os.path.exists(model_dir):
        model_cpu = AutoModelForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True)
        dtype = to_torch_dtype(amp)
        amp_callback(model_cpu, dtype=dtype)
        save_pretrained_split(model_cpu, model_dir)
​
compile("facebook/opt-125m", 1, "--model-type=transformer", "bf16", 12)

The `save_pretrained_split()` command will create a folder `opt-125m-tp12` to save the split model 

A `model.py` and `config.pbtxt` file has been made available in the `inf2_llm/opt-125m/opt` folder. Both the files have been generated using the `gen_triton_model.py` script, and modified as per the LLM HF model to be used

In [None]:
!mkdir -p inf2_llm/opt-125m/opt/opt-125m-model
!mkdir -p inf2_llm/opt-125m/opt/opt-125m-tp12
!cp -R opt-125m-model/* inf2_llm/opt-125m/opt/opt-125m-model
!cp -R opt-125m-tp12/* inf2_llm/opt-125m/opt/opt-125m-tp12

!tar -C inf2_llm/opt-125m/ -hczf triton-transformers-opt-inf2.tar.gz opt
model_uri = sagemaker_session.upload_data(path="triton-transformers-opt-inf2.tar.gz", key_prefix="triton-inf2-models")

Load model on SageMaker

In [None]:
sm_model_name = "triton-resnet-inf2-transformers-opt-" + time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())

container = {
    "Image": triton_neuronx_image_uri,
    "ModelDataUrl": model_uri,
    "Environment": {"SAGEMAKER_TRITON_DEFAULT_MODEL_NAME": "opt"},
}

create_model_response = sm_client.create_model(
    ModelName=sm_model_name, ExecutionRoleArn=role, PrimaryContainer=container
)

print("Model Arn: " + create_model_response["ModelArn"])

endpoint_config_name = "triton-resnet-inf2-transformers-opt-" + time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())

create_endpoint_config_response = sm_client.create_endpoint_config(
    EndpointConfigName=endpoint_config_name,
    ProductionVariants=[
        {
            "InstanceType": "ml.inf2.24xlarge",
            "InitialVariantWeight": 1,
            "InitialInstanceCount": 1,
            "ModelName": sm_model_name,
            "VariantName": "AllTraffic",
        }
    ],
)

print("Endpoint Config Arn: " + create_endpoint_config_response["EndpointConfigArn"])

endpoint_name = "triton-resnet-inf2-transformers-opt-" + time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())

create_endpoint_response = sm_client.create_endpoint(
    EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name
)

print("Endpoint Arn: " + create_endpoint_response["EndpointArn"])

resp = sm_client.describe_endpoint(EndpointName=endpoint_name)
status = resp["EndpointStatus"]
print("Status: " + status)

while status == "Creating":
    time.sleep(60)
    resp = sm_client.describe_endpoint(EndpointName=endpoint_name)
    status = resp["EndpointStatus"]
    print("Status: " + status)

print("Arn: " + resp["EndpointArn"])
print("Status: " + status)

Run Inference

In [347]:
from transformers import AutoTokenizer
import torch

prompt = "this summer"

tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
if not tokenizer.pad_token:    tokenizer.pad_token = tokenizer.eos_token
encoded_text = tokenizer.encode(prompt, padding="max_length", max_length=128, truncation=True)
# Run inference here

payload = {
    "inputs": [
        {"name": "INPUT__0", "shape": [1, 128], "datatype": "INT64", "data": encoded_text}
    ]
}

response = runtime_sm_client.invoke_endpoint(
    EndpointName=endpoint_name, Body=json.dumps(payload)
)

generated_sequence = json.loads(response["Body"].read().decode("utf8"))["outputs"][0]["data"]
output = ' '.join(tokenizer.batch_decode(generated_sequence, skip_special_tokens=True)).encode("utf-8")

print(output)

b' this  summer                                                                                                                               aug ,  or  any  means  that  means  some  kind  of  that  will  take  a  certain  percentage  of  the  total  market  that  means  which  is  very  valuable  in  the  past  few  years .  But  what  is  that  makes  them  so  rich  in  the  past ?  We  have  developed  the  idea  of  the  future  of  global ization ,  and  how  it  will  go  in  a  world  of  huge  changes ,  at  the  same  time  that  the  present  economy  has  an  enormous  potential  for  market  that  they  just  want  us  to  be  in  the  know . \n \n R ic her l le \n We  developed  the  idea  of  market  manipulation  and  the  idea  of  changing  the  market .  The  old  way ,  the  old  way  of  market  manipulation ,  and  now  people  are  creating  a  market  of  very  powerful  changes  in  the  global  climate .  In  the  current  global  climate ,  we  will  face  t

In [None]:
sm_client.delete_endpoint(EndpointName=endpoint_name)
sm_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
sm_client.delete_model(ModelName=sm_model_name)

# Run a Transformers based GPTJ-6B model
In this section we will run a pre-trained neuron-traced GPTJ-6B model on an inf2.24xl instance type

### PyTorch: Packaging model files and uploading to s3

Create and upload GPTJ-6B model on SageMaker

### Note: the following commands to compile/trace the model should be run on an inf2 instance

In [None]:
!mkdir -p triton-serve-pt/resnet/1/
!mv -f workspace/model.pt triton-serve-pt/resnet/1/
!tar -C triton-serve-pt/ -czf model.tar.gz resnet
model_uri = sagemaker_session.upload_data(path="model.tar.gz", key_prefix="triton-serve-pt")

In [None]:
!pip install transformers-neuronx --extra-index-url=https://pip.repos.neuron.amazonaws.com

In [None]:
import os
import time
import argparse
import torch
from transformers_neuronx.module import save_pretrained_split
from transformers_neuronx.dtypes import to_torch_dtype
from transformers_neuronx.gptj.model import GPTJForSampling
from transformers import AutoModelForCausalLM, AutoTokenizer

def amp_callback(model, dtype):
    # cast attention and mlp to low precisions only; layernorms stay as f32
    for block in model.model.decoder.layers:
        block.self_attn.to(dtype)
        block.fc1.to(dtype)
        block.fc2.to(dtype)
    model.lm_head.to(dtype)
​
​
def compile(model_name, batch_size, compiler_args, amp='bf16', tp_degree=2, n_positions=2048, unroll=None):
​
    os.environ["NEURON_CC_FLAGS"] = compiler_args
    os.environ["NEURONX_DUMP_TO"] = "gpt-j-6b-artifacts-tp4"
​
    # Split the GPT-J model for faster loading
    model_dir = f"EleutherAI-gpt-j-6B-bf16-local"
    if not os.path.exists(model_dir):
        model_cpu = AutoModelForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True)
        dtype = to_torch_dtype(amp)
        amp_callback(model_cpu, dtype=dtype)
        save_pretrained_split(model_cpu, model_dir)
​
compile("EleutherAI/gpt-j-6B", 1, "--model-type=transformer", "bf16", 4)

The `save_pretrained_split()` command will create a folder `opt-125m-tp12` to save the split model 

A `model.py` and `config.pbtxt` file has been made available in the `inf2_llm/opt-125m/opt` folder. Both the files have been generated using the `gen_triton_model.py` script, and modified as per the LLM HF model to be used

In [None]:
!mkdir -p inf2_llm/gptj-6b/gptj/EleutherAI-gpt-j-6B-bf16-local
!mkdir -p inf2_llm/gptj-6b/gptj/gpt-j-6b-artifacts-tp4
!cp -R EleutherAI-gpt-j-6B-bf16-local/* inf2_llm/gptj-6b/gptj/EleutherAI-gpt-j-6B-bf16-local
!cp -R gpt-j-6b-artifacts-tp4/* inf2_llm/gptj-6b/gptj/gpt-j-6b-artifacts-tp4

!tar -C inf2_llm/gptj-6b/ -hczf triton-transformers-gptj-inf2.tar.gz gptj
model_uri = sagemaker_session.upload_data(path="triton-transformers-gptj-inf2.tar.gz", key_prefix="triton-inf2-models")

Load on SageMaker

In [None]:
sm_model_name = "triton-resnet-inf2-transformers-gptj-" + time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())

container = {
    "Image": triton_neuronx_image_uri,
    "ModelDataUrl": model_uri,
    "Environment": {"SAGEMAKER_TRITON_DEFAULT_MODEL_NAME": "gptj"},
}

create_model_response = sm_client.create_model(
    ModelName=sm_model_name, ExecutionRoleArn=role, PrimaryContainer=container
)

print("Model Arn: " + create_model_response["ModelArn"])

endpoint_config_name = "triton-resnet-inf2-transformers-gptj-" + time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())

create_endpoint_config_response = sm_client.create_endpoint_config(
    EndpointConfigName=endpoint_config_name,
    ProductionVariants=[
        {
            "InstanceType": "ml.inf2.24xlarge",
            "InitialVariantWeight": 1,
            "InitialInstanceCount": 1,
            "ModelName": sm_model_name,
            "VariantName": "AllTraffic",
        }
    ],
)

print("Endpoint Config Arn: " + create_endpoint_config_response["EndpointConfigArn"])

endpoint_name = "triton-resnet-inf2-transformers-gptj-" + time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())

create_endpoint_response = sm_client.create_endpoint(
    EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name
)

print("Endpoint Arn: " + create_endpoint_response["EndpointArn"])

resp = sm_client.describe_endpoint(EndpointName=endpoint_name)
status = resp["EndpointStatus"]
print("Status: " + status)

while status == "Creating":
    time.sleep(60)
    resp = sm_client.describe_endpoint(EndpointName=endpoint_name)
    status = resp["EndpointStatus"]
    print("Status: " + status)

print("Arn: " + resp["EndpointArn"])
print("Status: " + status)

Run Inference

### PyTorch: Create SageMaker Endpoint

We start off by creating a [sagemaker model](https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateModel.html) from the model files we uploaded to s3 in the previous step.

In this step we also provide an additional Environment Variable i.e. `SAGEMAKER_TRITON_DEFAULT_MODEL_NAME` which specifies the name of the model to be loaded by Triton. **The value of this key should match the folder name in the model package uploaded to s3**. This variable is optional in case of a single model. In case of ensemble models, this key **has to be** specified for Triton to startup in SageMaker.

Additionally, customers can set `SAGEMAKER_TRITON_BUFFER_MANAGER_THREAD_COUNT` and `SAGEMAKER_TRITON_THREAD_COUNT` for optimizing the thread counts.

**Note**: The current release of Triton (21.08-py3) on SageMaker doesn't support running instances of different models on the same server, except in case of [ensembles](https://github.com/triton-inference-server/server/blob/main/docs/architecture.md#ensemble-models). Only multiple model instances of the same model are supported, which can be specified under the [instance-groups](https://github.com/triton-inference-server/server/blob/main/docs/model_configuration.md#instance-groups) section of the config.pbtxt file.

In [None]:
from transformers import AutoTokenizer
import torch

prompt = "this summer"

tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
if not tokenizer.pad_token:    tokenizer.pad_token = tokenizer.eos_token
encoded_text = tokenizer.encode(prompt, padding="max_length", max_length=128, truncation=True)
# Run inference here

payload = {
    "inputs": [
        {"name": "INPUT__0", "shape": [1, 128], "datatype": "INT64", "data": encoded_text}
    ]
}

response = runtime_sm_client.invoke_endpoint(
    EndpointName=endpoint_name, Body=json.dumps(payload)
)

generated_sequence = json.loads(response["Body"].read().decode("utf8"))["outputs"][0]["data"]
output = ' '.join(tokenizer.batch_decode(generated_sequence, skip_special_tokens=True)).encode("utf-8")

print(output)

In [None]:
sm_client.delete_endpoint(EndpointName=endpoint_name)
sm_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
sm_client.delete_model(ModelName=sm_model_name)

In [None]:
sm_model_name = "triton-resnet-pt-" + time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())

container = {
    "Image": triton_image_uri,
    "ModelDataUrl": model_uri,
    "Environment": {"SAGEMAKER_TRITON_DEFAULT_MODEL_NAME": "resnet"},
}

create_model_response = sm_client.create_model(
    ModelName=sm_model_name, ExecutionRoleArn=role, PrimaryContainer=container
)

print("Model Arn: " + create_model_response["ModelArn"])

Using the model above, we create an [endpoint configuration](https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateEndpointConfig.html) where we can specify the type and number of instances we want in the endpoint.

In [None]:
endpoint_config_name = "triton-resnet-pt-" + time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())

create_endpoint_config_response = sm_client.create_endpoint_config(
    EndpointConfigName=endpoint_config_name,
    ProductionVariants=[
        {
            "InstanceType": "ml.g4dn.4xlarge",
            "InitialVariantWeight": 1,
            "InitialInstanceCount": 1,
            "ModelName": sm_model_name,
            "VariantName": "AllTraffic",
        }
    ],
)

print("Endpoint Config Arn: " + create_endpoint_config_response["EndpointConfigArn"])

Using the above endpoint configuration we create a new sagemaker endpoint and wait for the deployment to finish. The status will change to **InService** once the deployment is successful.

In [None]:
endpoint_name = "triton-resnet-pt-" + time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())

create_endpoint_response = sm_client.create_endpoint(
    EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name
)

print("Endpoint Arn: " + create_endpoint_response["EndpointArn"])

In [None]:
resp = sm_client.describe_endpoint(EndpointName=endpoint_name)
status = resp["EndpointStatus"]
print("Status: " + status)

while status == "Creating":
    time.sleep(60)
    resp = sm_client.describe_endpoint(EndpointName=endpoint_name)
    status = resp["EndpointStatus"]
    print("Status: " + status)

print("Arn: " + resp["EndpointArn"])
print("Status: " + status)

### PyTorch: Run inference

Once we have the endpoint running we can use the [sample image](./shiba_inu_dog.jpg) provided to do an inference using json as the payload format. For inference request format, Triton uses the KFServing community standard [inference protocols](https://github.com/triton-inference-server/server/blob/main/docs/protocol/README.md).

In [87]:
payload = {
    "inputs": [
        {
            "name": "INPUT__0",
            "shape": [1, 3, 224, 224],
            "datatype": "FP32",
            "data": get_sample_image(),
        }
    ]
}

response = runtime_sm_client.invoke_endpoint(
    EndpointName=endpoint_name, ContentType="application/octet-stream", Body=json.dumps(payload)
)

print(json.loads(response["Body"].read().decode("utf8")))

{'model_name': 'resnet', 'model_version': '1', 'outputs': [{'name': 'OUTPUT__0', 'datatype': 'FP32', 'shape': [1, 1000], 'data': [-1.2339320182800293, 0.26743221282958984, -1.7244561910629272, -2.9731521606445312, -2.9897332191467285, -1.9399539232254028, -2.4892311096191406, -0.3647805154323578, 0.04905950650572777, -1.619198203086853, -3.850961446762085, -1.9598301649093628, -3.635099411010742, -5.3621416091918945, -1.686911940574646, -2.895658493041992, -1.8479063510894775, 0.3577522039413452, -0.6747124791145325, -1.7334718704223633, -4.42026948928833, -2.315218925476074, -3.1044459342956543, -2.2334206104278564, -2.116276741027832, -3.471982717514038, -4.789962291717529, -2.852299928665161, -3.040926694869995, -3.106398105621338, -2.429238796234131, -0.4101293981075287, -2.3326222896575928, -3.5598886013031006, -0.7902666330337524, -4.283840179443359, -1.8385157585144043, -4.4796342849731445, -2.8223791122436523, -2.196667194366455, -0.19892823696136475, -3.900879144668579, -0.906

We can also use binary+json as the payload format to get better performance for the inference call. The specification of this format is provided [here](https://github.com/triton-inference-server/server/blob/main/docs/protocol/extension_binary_data.md).

**Note:** With the `binary+json` format, we have to specify the length of the request metadata in the header to allow Triton to correctly parse the binary payload. This is done using a custom Content-Type header `application/vnd.sagemaker-triton.binary+json;json-header-size={}`.

Please not, this is different from using `Inference-Header-Content-Length` header on a stand-alone Triton server since custom headers are not allowed in SageMaker.

In [88]:
request_body, header_length = get_sample_image_binary_pt()

response = runtime_sm_client.invoke_endpoint(
    EndpointName=endpoint_name,
    ContentType="application/vnd.sagemaker-triton.binary+json;json-header-size={}".format(
        header_length
    ),
    Body=request_body,
)

# Parse json header size length from the response
header_length_prefix = "application/vnd.sagemaker-triton.binary+json;json-header-size="
header_length_str = response["ContentType"][len(header_length_prefix) :]

# Read response body
result = httpclient.InferenceServerClient.parse_response_body(
    response["Body"].read(), header_length=int(header_length_str)
)
output0_data = result.as_numpy("OUTPUT__0")
print(output0_data)

[[-1.23393202e+00  2.67432213e-01 -1.72445619e+00 -2.97315216e+00
  -2.98973322e+00 -1.93995392e+00 -2.48923111e+00 -3.64780515e-01
   4.90595065e-02 -1.61919820e+00 -3.85096145e+00 -1.95983016e+00
  -3.63509941e+00 -5.36214161e+00 -1.68691194e+00 -2.89565849e+00
  -1.84790635e+00  3.57752204e-01 -6.74712479e-01 -1.73347187e+00
  -4.42026949e+00 -2.31521893e+00 -3.10444593e+00 -2.23342061e+00
  -2.11627674e+00 -3.47198272e+00 -4.78996229e+00 -2.85229993e+00
  -3.04092669e+00 -3.10639811e+00 -2.42923880e+00 -4.10129398e-01
  -2.33262229e+00 -3.55988860e+00 -7.90266633e-01 -4.28384018e+00
  -1.83851576e+00 -4.47963428e+00 -2.82237911e+00 -2.19666719e+00
  -1.98928237e-01 -3.90087914e+00 -9.06249762e-01 -2.96887517e+00
  -2.12057996e+00 -3.92205238e+00 -1.33538949e+00 -2.36063385e+00
  -5.24708891e+00 -2.65685511e+00 -1.33713758e+00 -1.03983080e+00
  -1.27821445e+00 -1.48820710e+00 -1.74384296e+00 -4.52535003e-01
  -2.02982917e-01 -2.99495125e+00 -2.81305623e+00 -1.37612796e+00
  -9.06762

### PyTorch: Terminate endpoint and clean up artifacts

In [78]:
sm_client.delete_model(ModelName=sm_model_name)
sm_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
sm_client.delete_endpoint(EndpointName=endpoint_name)

{'ResponseMetadata': {'RequestId': '717ad98d-c79c-46f8-831a-7c54aae7e877',
  'HTTPStatusCode': 200,
  'HTTPHeaders': {'x-amzn-requestid': '717ad98d-c79c-46f8-831a-7c54aae7e877',
   'content-type': 'application/x-amz-json-1.1',
   'content-length': '0',
   'date': 'Mon, 05 Jun 2023 23:18:34 GMT'},
  'RetryAttempts': 0}}