# MME GPU SageMaker Real-Time Inference

In this example we take a BERT NLP model and make hundreds of copies of it and run inference with SageMaker Multi-Model Endpoints. We run this notebook on a conda_pytorch_p310 kernel on a classic SageMaker Notebook Instance

In [None]:
!pip install transformers

## Local BERT Inference & Model Saving

In [None]:
import torch
from transformers import BertModel, BertTokenizer
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load bert model and tokenizer
model_name = 'bert-base-uncased'
model = BertModel.from_pretrained(model_name, torchscript = True)
tokenizer = BertTokenizer.from_pretrained(model_name)

# Sample Input
text = "I am super happy right now to be trying out BERT."

# Tokenize sample text
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)

In [None]:
# jit trace model
traced_model = torch.jit.trace(model, (inputs["input_ids"], inputs["attention_mask"]))

In [None]:
# Save traced model
torch.jit.save(traced_model, "model.pt")

In [None]:
# sample inference with loaded model
loaded_model = torch.jit.load("model.pt")
res = loaded_model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
res

### Model Configuration

We can understand the input and output shapes by observing the model configuration from the transformers library. This will help us shape our config.pbtxt file for our Triton Inference Server configuration.

In [None]:
from transformers import BertConfig
bert_config = BertConfig.from_pretrained(model_name)
max_sequence_length = bert_config.max_position_embeddings
output_shape = bert_config.hidden_size
print(f"Maximum Input Sequence Length: {max_sequence_length}")
print(f"Output Shape: {output_shape}")

## Local Triton Setup

We want to check Triton Inference Server and ensure we can run local inference with the container beforehand, this will help us quickly debug any issues rather than discovering post SageMaker Endpoint creation.

In [None]:
def tokenize_text(text):
    encoded_text = tokenizer(text, padding="max_length", max_length=512, truncation=True)
    return encoded_text["input_ids"], encoded_text["attention_mask"]

In [None]:
sample_text = """
                We are testing some sample text for BERT.
                This is a test with SageMaker MME GPU.
              """

input_ids, attention_mask = tokenize_text(sample_text)

# for shape refer to configuration code above, our max sequence length for BERT is 512
payload = {
    "inputs": [
        {"name": "input_ids", "shape": [1, 512], "datatype": "INT32", "data": input_ids},
        {"name": "attention_mask", "shape": [1, 512], "datatype": "INT32", "data": attention_mask},
    ]
}

#payload

### Create Proper Directory Structure for Triton

PyTorch models are expected to be in following folder format for Triton:

- bert_model
    - 1 (model_version)
        - model.pt
        - model.py (optionally add)
    - config.pbtxt
    
We can create our config file and move the serialized model artifact to where necessary.

In [None]:
%%writefile config.pbtxt
name: "bert_model"
platform: "pytorch_libtorch"

input [
  {
    name: "input_ids"
    data_type: TYPE_INT32
    dims: [1, 512]
  },
  {
    name: "attention_mask"
    data_type: TYPE_INT32
    dims: [1, 512]
  }
]

output [
  {
    name: "OUTPUT"
    data_type: TYPE_FP32
    dims: [512, 768]
  }
]

In [None]:
%%sh
mkdir bert_model
mv config.pbtxt model.pt bert_model
cd bert_model
mkdir 1
mv model.pt 1/
cd ..

### Start Triton Container

Make sure to start the Server with the following Docker command before running the local Inference cells.

```
docker run --gpus all --rm -p 8000:8000 -p 8001:8001 -p 8002:8002 -v /home/ec2-user/SageMaker:/models nvcr.io/nvidia/tritonserver:23.08-py3 tritonserver --model-repository=/models --exit-on-error=false --log-verbose=1
```

In [None]:
import requests
import json

In [None]:
# Specify the model name and version
model_name = "bert_model" #specified in config.pbtxt
model_version = "1"

# Set the inference URL based on the Triton server's address
url = f"http://localhost:8000/v2/models/{model_name}/versions/{model_version}/infer"

# sample invoke
output = requests.post(url, data=json.dumps(payload))
res = output.json()
#print(res)

## SageMaker MME GPU

First we create our model tarball which we will make copies of to create our MME GPU based endpoint.

In [None]:
import boto3
import sagemaker
import json
sess = boto3.Session()
sm = sess.client("sagemaker")
sagemaker_session = sagemaker.Session(boto_session=sess)
role = sagemaker.get_execution_role()
region = boto3.Session().region_name
bucket = sagemaker.Session().default_bucket()
s3_model_prefix = "triton-bert"

client = boto3.client("sagemaker", region_name=region)
runtime_client = boto3.client("sagemaker-runtime")
s3_client = boto3.client("s3")

account_id_map = {
    "us-east-1": "785573368785",
    "us-east-2": "007439368137",
    "us-west-1": "710691900526",
    "us-west-2": "301217895009",
    "eu-west-1": "802834080501",
    "eu-west-2": "205493899709",
    "eu-west-3": "254080097072",
    "eu-north-1": "601324751636",
    "eu-south-1": "966458181534",
    "eu-central-1": "746233611703",
    "ap-east-1": "110948597952",
    "ap-south-1": "763008648453",
    "ap-northeast-1": "941853720454",
    "ap-northeast-2": "151534178276",
    "ap-southeast-1": "324986816169",
    "ap-southeast-2": "355873309152",
    "cn-northwest-1": "474822919863",
    "cn-north-1": "472730292857",
    "sa-east-1": "756306329178",
    "ca-central-1": "464438896020",
    "me-south-1": "836785723513",
    "af-south-1": "774647643957",
}


if region not in account_id_map.keys():
    raise ("UNSUPPORTED REGION")

print(f"SageMaker Role: {role}")
print(f"Region Name: {region}")

base = "amazonaws.com.cn" if region.startswith("cn-") else "amazonaws.com"
triton_image_uri = "{account_id}.dkr.ecr.{region}.{base}/sagemaker-tritonserver:23.07-py3".format(
    account_id=account_id_map[region], region=region, base=base
)

print(f"Triton Inference server DLC image: {triton_image_uri}")

In [None]:
!tar -cvzf model.tar.gz bert_model/

In [None]:
%%time
# we make a 200 copies of the tarball, this will take about ~6 minutes to finish (can vary depending on model size)
for i in range(200):
    with open("model.tar.gz", "rb") as f:
        s3_client.upload_fileobj(f, bucket, "{}/model-{}.tar.gz".format(s3_model_prefix,i))

In [None]:
mme_artifacts = "s3://{}/{}/".format(bucket, s3_model_prefix) #location of model data
mme_artifacts

### Endpoint Creation

In [None]:
#Step 1: Model Creation
import time
from time import gmtime, strftime

model_name = "triton-bert-mme" + strftime("%Y-%m-%d-%H-%M-%S", gmtime())
container = {
    "Image": triton_image_uri,
    "ModelDataUrl": mme_artifacts,
    "Mode": "MultiModel"
}

create_model_response = client.create_model(
    ModelName=model_name, ExecutionRoleArn=role, PrimaryContainer=container
)

print("Model Arn: " + create_model_response["ModelArn"])

In [None]:
endpoint_config_name = "triton-epc-mme-gpu" + strftime("%Y-%m-%d-%H-%M-%S", gmtime())
endpoint_config_response = client.create_endpoint_config(
    EndpointConfigName=endpoint_config_name,
    ProductionVariants=[
        {
            "VariantName": "tritontraffic",
            "ModelName": model_name,
            "InstanceType": "ml.g4dn.4xlarge",
            "InitialInstanceCount": 1,
            "InitialVariantWeight": 1
        },
    ],
)
print("Endpoint Configuration Arn: " + endpoint_config_response["EndpointConfigArn"])

In [None]:
endpoint_name = "triton-mme-gpu-ep" + time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())

create_endpoint_response = client.create_endpoint(
    EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name
)

print("Endpoint Arn: " + create_endpoint_response["EndpointArn"])

In [None]:
#Monitor creation
describe_endpoint_response = client.describe_endpoint(EndpointName=endpoint_name)
while describe_endpoint_response["EndpointStatus"] == "Creating":
    describe_endpoint_response = client.describe_endpoint(EndpointName=endpoint_name)
    print(describe_endpoint_response["EndpointStatus"])
    time.sleep(60)
print(describe_endpoint_response)

### Sample Inference

In [None]:
response = runtime_client.invoke_endpoint(
    EndpointName=endpoint_name, ContentType="application/octet-stream", 
    Body=json.dumps(payload), TargetModel='model-199.tar.gz'
)
print(json.loads(response["Body"].read().decode("utf8")))