# Triton on Sagemaker - NLP Bert

### TODO: Add Sagemaker, Triton and NGC introduction

## Contents
1. [Basic: PyTorch NLP-Bert](#PyTorch-NLP-Bert)
2. [Advanced: TensorRT NLP-Bert](#TensorRT-NLP-Bert)
3. [Conclusion](#Conclusion)

In [None]:
!pip install -qU pip awscli boto3 sagemaker transformers==4.9.1
!pip install nvidia-pyindex
!pip install tritonclient[http]

In [None]:
import boto3, json, sagemaker, time
from sagemaker import get_execution_role

sess = boto3.Session()
sm = sess.client("sagemaker")
sagemaker_session = sagemaker.Session(boto_session=sess)
role = get_execution_role()
client = boto3.client("sagemaker-runtime")

In [None]:
triton_image_uri = "195202947636.dkr.ecr.us-west-2.amazonaws.com/tritonserver:21.07-py3"

The `tritonclient` package provides utility methods to generate the payload without having to know the details of the specification.

In [None]:
import tritonclient.http as httpclient
from transformers import BertTokenizer
import numpy as np


def tokenize_text(text):
    enc = BertTokenizer.from_pretrained("bert-base-uncased")
    encoded_text = enc(text, padding="max_length", max_length=128)
    return encoded_text["input_ids"], encoded_text["attention_mask"]


def _get_sample_tokenized_text_binary(text, input_names, output_names):
    inputs = []
    outputs = []
    inputs.append(httpclient.InferInput(input_names[0], [1, 128], "INT32"))
    inputs.append(httpclient.InferInput(input_names[1], [1, 128], "INT32"))
    indexed_tokens, attention_mask = tokenize_text(text)

    indexed_tokens = np.array(indexed_tokens, dtype=np.int32)
    indexed_tokens = np.expand_dims(indexed_tokens, axis=0)
    inputs[0].set_data_from_numpy(indexed_tokens, binary_data=True)

    attention_mask = np.array(attention_mask, dtype=np.int32)
    attention_mask = np.expand_dims(attention_mask, axis=0)
    inputs[1].set_data_from_numpy(attention_mask, binary_data=True)

    outputs.append(httpclient.InferRequestedOutput(output_names[0], binary_data=True))
    outputs.append(httpclient.InferRequestedOutput(output_names[1], binary_data=True))
    request_body, header_length = httpclient.InferenceServerClient.generate_request_body(
        inputs, outputs=outputs
    )
    return request_body, header_length


def get_sample_tokenized_text_binary_pt(text):
    return _get_sample_tokenized_text_binary(
        text, ["INPUT__0", "INPUT__1"], ["OUTPUT__0", "1634__1"]
    )


def get_sample_tokenized_text_binary_trt(text):
    return _get_sample_tokenized_text_binary(text, ["token_ids", "attn_mask"], ["output", "1634"])

In [None]:
!docker run --gpus=all --rm -it \
            -v `pwd`/workspace:/workspace nvcr.io/nvidia/pytorch:21.07-py3 \
            /bin/bash generate_models.sh

## PyTorch NLP-Bert

For a simple use case we will take the pre-trained NLP Bert model from [Hugging Face](https://huggingface.co/transformers/model_doc/bert.html) and deploy it on Sagemaker with Triton as the model server. The script for exporting this model can be found [here](./workspace/pt_exporter.py). This is run as part of the `generate_models.sh` script. After the model is serialized we package it into the format that Triton and Sagemaker expect it to be. We used the pre-configured `config.pbtxt` file provided with this repo [here](./triton-serve-pt/bert/config.pbtxt) to specify model [configuration](https://github.com/triton-inference-server/server/blob/main/docs/model_configuration.md) which Triton uses to load the model. We tar the model directory and upload it to s3 to later create a [Sagemaker Model](https://sagemaker.readthedocs.io/en/stable/api/inference/model.html).

### Packaging model files and uploading to s3

In [None]:
!mkdir -p triton-serve-pt/bert/1/
!mv -f workspace/model.pt triton-serve-pt/bert/1/
!tar -C triton-serve-pt/ -czf model.tar.gz bert
model_uri = sagemaker_session.upload_data(path="model.tar.gz", key_prefix="triton-serve-pt")

### Create Sagemaker Enpoint

We start off by creating a [sagemaker model](https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateModel.html) from the model files we uploaded to s3 in the previous step.

In this step we also provide an additional Environment Variable i.e. `SAGEMAKER_TRITON_DEFAULT_MODEL_NAME` which specifies the name of the model to be loaded by Triton. **The value of this key should match the folder name in the model package uploaded to s3**. This variable is optional in case of a single model. In case of ensemble models, this key **has to be** specified for Triton to startup in Sagemaker.

*Note*: The current release of Triton (21.06-py3) on Sagemaker doesn't support running instances of different models on the same server, except in case of [ensembles](https://github.com/triton-inference-server/server/blob/main/docs/architecture.md#ensemble-models). Only multiple model instances of the same model are supported, which can be specified under the [instance-groups](https://github.com/triton-inference-server/server/blob/main/docs/model_configuration.md#instance-groups) section of the config.pbtxt file.

In [None]:
sm_model_name = "triton-nlp-bert-pt-" + time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())

container = {
    "Image": triton_image_uri,
    "ModelDataUrl": model_uri,
    "Environment": {"SAGEMAKER_TRITON_DEFAULT_MODEL_NAME": "bert"},
}

create_model_response = sm.create_model(
    ModelName=sm_model_name, ExecutionRoleArn=role, PrimaryContainer=container
)

print("Model Arn: " + create_model_response["ModelArn"])

Using the model above, we create an [endpoint configuration](https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateEndpointConfig.html) where we can specify the type and number of instances we want in the endpoint.

In [None]:
endpoint_config_name = "triton-nlp-bert-pt-" + time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())

create_endpoint_config_response = sm.create_endpoint_config(
    EndpointConfigName=endpoint_config_name,
    ProductionVariants=[
        {
            "InstanceType": "ml.g4dn.4xlarge",
            "InitialVariantWeight": 1,
            "InitialInstanceCount": 1,
            "ModelName": sm_model_name,
            "VariantName": "AllTraffic",
        }
    ],
)

print("Endpoint Config Arn: " + create_endpoint_config_response["EndpointConfigArn"])

Using the above endpoint configuration we create a new sagemaker endpoint and wait for the deployment to finish. The status will change to **InService** once the deployment is successful.

In [None]:
endpoint_name = "triton-nlp-bert-pt-" + time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())

create_endpoint_response = sm.create_endpoint(
    EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name
)

print("Endpoint Arn: " + create_endpoint_response["EndpointArn"])

In [None]:
resp = sm.describe_endpoint(EndpointName=endpoint_name)
status = resp["EndpointStatus"]
print("Status: " + status)

while status == "Creating":
    time.sleep(60)
    resp = sm.describe_endpoint(EndpointName=endpoint_name)
    status = resp["EndpointStatus"]
    print("Status: " + status)

print("Arn: " + resp["EndpointArn"])
print("Status: " + status)

### Run inference

Once we have the endpoint running we can use the [sample image](./kitten.jpg) provided to do an inference using json as the payload format. For inference request format, Triton uses the KFServing community standard [inference protocols](https://github.com/triton-inference-server/server/blob/main/docs/protocol/README.md).

In [None]:
text = "Triton Inference Server provides a cloud and edge inferencing solution optimized for both CPUs and GPUs."
input_ids, attention_mask = tokenize_text(text)

payload = {
    "inputs": [
        {"name": "INPUT__0", "shape": [1, 128], "datatype": "INT32", "data": input_ids},
        {"name": "INPUT__1", "shape": [1, 128], "datatype": "INT32", "data": attention_mask},
    ]
}

response = client.invoke_endpoint(
    EndpointName=endpoint_name, ContentType="application/octet-stream", Body=json.dumps(payload)
)

print(json.loads(response["Body"].read().decode("utf8")))

We can also use binary+json as the payload format to get better performance for the inference call. The specification of this format is provided [here](https://github.com/triton-inference-server/server/blob/main/docs/protocol/extension_binary_data.md).

**Note:** With the `binary+json` format, we have to specify the length of the request metadata in the header to allow Triton to correctly parse the binary payload. This is done using a custom Content-Type header `application/vnd.sagemaker-triton.binary+json;json-header-size={}`.

Please not, this is different than using `Inference-Header-Content-Length` header on a stand-alone Triton server since custom headers are not allowed in Sagemaker.

In [None]:
text = "Amazon SageMaker helps data scientists and developers to prepare, build, train, and deploy high-quality machine learning (ML) models quickly by bringing together a broad set of capabilities purpose-built for ML."
request_body, header_length = get_sample_tokenized_text_binary_pt(text)

response = client.invoke_endpoint(
    EndpointName=endpoint_name,
    ContentType="application/vnd.sagemaker-triton.binary+json;json-header-size={}".format(
        header_length
    ),
    Body=request_body,
)

# Parse json header size length from the response
header_length_prefix = "application/vnd.sagemaker-triton.binary+json;json-header-size="
header_length_str = response["ContentType"][len(header_length_prefix) :]

# Read response body
result = httpclient.InferenceServerClient.parse_response_body(
    response["Body"].read(), header_length=int(header_length_str)
)
output0_data = result.as_numpy("OUTPUT__0")
output1_data = result.as_numpy("1442__1")
print(output0_data)
print(output1_data)

### Terminate endpoint and clean up artifacts

In [None]:
sm.delete_model(ModelName=sm_model_name)
sm.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
sm.delete_endpoint(EndpointName=endpoint_name)

## TensorRT NLP-Bert

Another way to improve performance is to convert the PyTorch Resnet50 model to a TensorRT plan and use it natively to run inferneces on Triton. By using the [onnx_exporter.py](./workspace/onnx_exporter.py) script and `trtexec` we create a TensorRT plan from the pre-trained PyTorch ResNet50 model. This is already done as part of the `generate_models.sh` script that we ran earlier in this notebook. We'll package the model and the provided `config.pbtxt` according the Triton model specification and upload to s3 for creating a SageMaker model and endpoint.

### Packaging model files and uploading to s3

In [None]:
!mkdir -p triton-serve-trt/bert/1/
!mv -f workspace/model_bs16.plan triton-serve-trt/bert/1/model.plan
!tar -C triton-serve-trt/ -czf model.tar.gz bert
model_uri = sagemaker_session.upload_data(path="model.tar.gz", key_prefix="triton-serve-trt")

### Create Sagemaker Enpoint

In [None]:
sm_model_name = "triton-nlp-bert-trt-" + time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())

container = {
    "Image": triton_image_uri,
    "ModelDataUrl": model_uri,
    "Environment": {"SAGEMAKER_TRITON_DEFAULT_MODEL_NAME": "bert"},
}

create_model_response = sm.create_model(
    ModelName=sm_model_name, ExecutionRoleArn=role, PrimaryContainer=container
)

print("Model Arn: " + create_model_response["ModelArn"])

In [None]:
endpoint_config_name = "triton-nlp-bert-trt-" + time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())

create_endpoint_config_response = sm.create_endpoint_config(
    EndpointConfigName=endpoint_config_name,
    ProductionVariants=[
        {
            "InstanceType": "ml.g4dn.4xlarge",
            "InitialVariantWeight": 1,
            "InitialInstanceCount": 1,
            "ModelName": sm_model_name,
            "VariantName": "AllTraffic",
        }
    ],
)

print("Endpoint Config Arn: " + create_endpoint_config_response["EndpointConfigArn"])

In [None]:
endpoint_name = "triton-nlp-bert-trt-" + time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())

create_endpoint_response = sm.create_endpoint(
    EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name
)

print("Endpoint Arn: " + create_endpoint_response["EndpointArn"])

In [None]:
resp = sm.describe_endpoint(EndpointName=endpoint_name)
status = resp["EndpointStatus"]
print("Status: " + status)

while status == "Creating":
    time.sleep(60)
    resp = sm.describe_endpoint(EndpointName=endpoint_name)
    status = resp["EndpointStatus"]
    print("Status: " + status)

print("Arn: " + resp["EndpointArn"])
print("Status: " + status)

### Run inference

Once we have the endpoint running we can run the inference both using a json payload as well as binary+json payload as described in the standart PyTorch deployment section.

In [None]:
text = "Triton Inference Server provides a cloud and edge inferencing solution optimized for both CPUs and GPUs."
input_ids, attention_mask = tokenize_text(text)

payload = {
    "inputs": [
        {"name": "token_ids", "shape": [1, 128], "datatype": "INT32", "data": input_ids},
        {"name": "attn_mask", "shape": [1, 128], "datatype": "INT32", "data": attention_mask},
    ]
}
response = client.invoke_endpoint(
    EndpointName=endpoint_name, ContentType="application/octet-stream", Body=json.dumps(payload)
)

print(json.loads(response["Body"].read().decode("utf8")))

In [None]:
text = "Amazon SageMaker helps data scientists and developers to prepare, build, train, and deploy high-quality machine learning (ML) models quickly by bringing together a broad set of capabilities purpose-built for ML."
request_body, header_length = get_sample_tokenized_text_binary_trt(text)

response = client.invoke_endpoint(
    EndpointName=endpoint_name,
    ContentType="application/vnd.sagemaker-triton.binary+json;json-header-size={}".format(
        header_length
    ),
    Body=request_body,
)

# Parse json header size length from the response
header_length_prefix = "application/vnd.sagemaker-triton.binary+json;json-header-size="
header_length_str = response["ContentType"][len(header_length_prefix) :]

# Read response body
result = httpclient.InferenceServerClient.parse_response_body(
    response["Body"].read(), header_length=int(header_length_str)
)
output0_data = result.as_numpy("output")
output1_data = result.as_numpy("1442")
print(output0_data)
print(output1_data)

### Terminate endpoint and clean up artifacts

In [None]:
sm.delete_endpoint(EndpointName=endpoint_name)
sm.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
sm.delete_model(ModelName=sm_model_name)

### TODO: Add comparision charts for base PT case vs TRT compiled model