## Triton Embeddings Onnx Model MME Example

In this example we take the following sample embeddings model and make a 30 copies of it and load it into SageMaker MME for GPU inference with Triton Inference Server. Note that you want to adjust this depending on a realistic use-case with different models (thereby model sizes).

<b>Sample Model</b>: https://huggingface.co/sentence-transformers/msmarco-bert-base-dot-v5

### Setup & Local Inference

In [None]:
!pip install -qU pip awscli boto3 sagemaker
!pip install nvidia-pyindex --quiet
!pip install tritonclient[http] --quiet

In [None]:
# imports
import boto3, json, sagemaker, time
from sagemaker import get_execution_role
import numpy as np
from PIL import Image
import tritonclient.http as httpclient

# variables
s3_client = boto3.client("s3")
ts = time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())

# sagemaker variables
role = get_execution_role()
sm_client = boto3.client(service_name="sagemaker")
runtime_sm_client = boto3.client("sagemaker-runtime")
sagemaker_session = sagemaker.Session(boto_session=boto3.Session())
bucket = sagemaker_session.default_bucket()
prefix = "onnx-mme-embed"
# endpoint variables
sm_model_name = f"{prefix}-mdl-{ts}"
endpoint_config_name = f"{prefix}-epc-{ts}"
endpoint_name = f"{prefix}-ep-{ts}"
model_data_url = f"s3://{bucket}/{prefix}/"

# account mapping for SageMaker MME Triton Image
account_id_map = {
    "us-east-1": "785573368785",
    "us-east-2": "007439368137",
    "us-west-1": "710691900526",
    "us-west-2": "301217895009",
    "eu-west-1": "802834080501",
    "eu-west-2": "205493899709",
    "eu-west-3": "254080097072",
    "eu-north-1": "601324751636",
    "eu-south-1": "966458181534",
    "eu-central-1": "746233611703",
    "ap-east-1": "110948597952",
    "ap-south-1": "763008648453",
    "ap-northeast-1": "941853720454",
    "ap-northeast-2": "151534178276",
    "ap-southeast-1": "324986816169",
    "ap-southeast-2": "355873309152",
    "cn-northwest-1": "474822919863",
    "cn-north-1": "472730292857",
    "sa-east-1": "756306329178",
    "ca-central-1": "464438896020",
    "me-south-1": "836785723513",
    "af-south-1": "774647643957",
}

region = boto3.Session().region_name
if region not in account_id_map.keys():
    raise ("UNSUPPORTED REGION")

base = "amazonaws.com.cn" if region.startswith("cn-") else "amazonaws.com"

# triton image being utilized, latest images available are here: https://github.com/aws/deep-learning-containers/blob/master/available_images.md
mme_triton_image_uri = (
    "{account_id}.dkr.ecr.{region}.{base}/sagemaker-tritonserver:22.12-py3".format(
        account_id=account_id_map[region], region=region, base=base
    )
)

## Local Inference

In [None]:
from transformers import AutoTokenizer, AutoModel
import torch

# Load model from HuggingFace Hub
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/msmarco-bert-base-dot-v5")
model = AutoModel.from_pretrained("sentence-transformers/msmarco-bert-base-dot-v5")

#Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output.last_hidden_state
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)


# Sentences we want sentence embeddings for
query = "How many people live in London?"
encoded_input = tokenizer(query, padding=True, truncation=True, return_tensors='pt')
#print(encoded_input)
# Compute token embeddings
with torch.no_grad():
    model_output = model(**encoded_input, return_dict=True)
    #print(model_output)
# Perform pooling
embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
#embeddings.numpy()

### Export Onnx

Reference: https://huggingface.co/docs/transformers/v4.29.1/serialization

In [None]:
from pathlib import Path
import transformers
from transformers.onnx import FeaturesManager
from transformers import AutoConfig, AutoTokenizer, AutoModelForSequenceClassification
from transformers import AutoTokenizer, AutoModel
import torch

# load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/msmarco-bert-base-dot-v5")
model = AutoModel.from_pretrained("sentence-transformers/msmarco-bert-base-dot-v5")

# load config
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(model)
onnx_config = model_onnx_config(model.config)

# export
onnx_inputs, onnx_outputs = transformers.onnx.export(
        preprocessor=tokenizer,
        model=model,
        config=onnx_config,
        opset=13,
        output=Path("model.onnx")
)

In [None]:
%%sh
mkdir workspace
mv model.onnx workspace

### Triton Setup

Adjust the config.pbtxt for the inputs/outputs your onnx model is expecting. Recommend testing the config with Docker before SageMaker Deployment.

In [None]:
!mkdir -p triton-serve-onnx/sentence/

In [None]:
%%writefile triton-serve-onnx/sentence/config.pbtxt
name: "sentence_onnx"
platform: "onnxruntime_onnx"
input: [
    {
        name: "input_ids"
        data_type: TYPE_INT64
        dims: [ -1, -1 ]
    },
    {
        name: "token_type_ids"
        data_type: TYPE_INT64
        dims: [ -1, -1 ]
    },
    {
        name: "attention_mask"
        data_type: TYPE_INT64
        dims: [ -1, -1 ]
    }
]
output [
  {
    name: "last_hidden_state"
    data_type: TYPE_FP32
    dims: [ -1, -1, 768 ]
  }
]
instance_group {
  count: 1
  kind: KIND_GPU
}
dynamic_batching {
}

### SageMaker Endpoint Creation

SageMaker expects the model artifacts in a model.tar.gz format. There are three steps in a SageMaker Endpoint Config:

- <b>SageMaker Model</b>: Points towards the model data and any inference scripts.
- <b>SageMaker Endpoint Config</b>: Any variants and defines hardware for endpoint.
- <b>SageMaker Endpoint</b>: REST Endpoint that you can invoke and specify TargetModel as a header.

In [None]:
!mkdir -p triton-serve-onnx/sentence/1/
!cp -f workspace/model.onnx triton-serve-onnx/sentence/1/
!tar -C triton-serve-onnx/ -czf model.tar.gz sentence

In [None]:
s3 = boto3.client('s3')
for i in range(0,30):
    with open("model.tar.gz", "rb") as f:
        s3.upload_fileobj(f, bucket, "mme-onnx/onnx-{}.tar.gz".format(i))

In [None]:
model_data_url = f"s3://{bucket}/mme-onnx/"
model_data_url

In [None]:
!aws s3 ls {model_data_url} # can see 30 copies of model, replace with actual tarballs

In [None]:
container = {"Image": mme_triton_image_uri, "ModelDataUrl": model_data_url, "Mode": "MultiModel"}

In [None]:
create_model_response = sm_client.create_model(
    ModelName=sm_model_name, ExecutionRoleArn=role, PrimaryContainer=container
)

print("Model Arn: " + create_model_response["ModelArn"])

In [None]:
create_endpoint_config_response = sm_client.create_endpoint_config(
    EndpointConfigName=endpoint_config_name,
    ProductionVariants=[
        {
            "InstanceType": "ml.g5.2xlarge",
            "InitialVariantWeight": 1,
            "InitialInstanceCount": 1,
            "ModelName": sm_model_name,
            "VariantName": "AllTraffic",
        }
    ],
)

print("Endpoint Config Arn: " + create_endpoint_config_response["EndpointConfigArn"])

In [None]:
create_endpoint_response = sm_client.create_endpoint(
    EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name
)

print("Endpoint Arn: " + create_endpoint_response["EndpointArn"])

In [None]:
resp = sm_client.describe_endpoint(EndpointName=endpoint_name)
status = resp["EndpointStatus"]
print("Status: " + status)

while status == "Creating":
    time.sleep(90)
    resp = sm_client.describe_endpoint(EndpointName=endpoint_name)
    status = resp["EndpointStatus"]
    print("Status: " + status)

print("Arn: " + resp["EndpointArn"])
print("Status: " + status)

### Sample Inference

In [None]:
# prepare client payload
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/msmarco-bert-base-dot-v5")

def tokenize_text(text):
    tokenized_text = tokenizer(text, padding=True, truncation=True, return_tensors='pt')
    payload = {}
    payload["inputs"] = []
    payload["inputs"].append(
        {
            "name": "input_ids",
            "shape": tokenized_text.input_ids.shape,
            "datatype": "INT64",
            "data": tokenized_text.input_ids.tolist(),
        }
    )
    payload["inputs"].append(
        {
            "name": "token_type_ids",
            "shape": tokenized_text.token_type_ids.shape,
            "datatype": "INT64",
            "data": tokenized_text.token_type_ids.tolist(),
        }
    )
    payload["inputs"].append(
        {
            "name": "attention_mask",
            "shape": tokenized_text.attention_mask.shape,
            "datatype": "INT64",
            "data": tokenized_text.attention_mask.tolist(),
        }
    )
    
    return payload
sampPayload = tokenize_text(["This is a test"])
sampPayload

In [None]:
response = runtime_sm_client.invoke_endpoint(
    EndpointName=endpoint_name,
    ContentType="application/octet-stream",
    Body=json.dumps(sampPayload),
    TargetModel="onnx-5.tar.gz", #replace with s3 target model
)

response = json.loads(response["Body"].read().decode("utf8"))
output = response["outputs"][0]["data"]

#print(output)

In [None]:
%%time
# randomly invoke models behind MME
for i in range (0,30):
    target_model = "onnx-{}.tar.gz".format(i)
    response = runtime_sm_client.invoke_endpoint(
            EndpointName=endpoint_name,
            ContentType="application/octet-stream",
            TargetModel=target_model,
            Body=json.dumps(sampPayload))
    print(f"Target Model Invoked: {target_model}")
    response = json.loads(response["Body"].read().decode("utf8"))
    output = response["outputs"][0]["data"]