In [None]:
!pip install -qU pip awscli boto3 sagemaker

In [None]:
import boto3, json, sagemaker, time 
from sagemaker import get_execution_role

sess              = boto3.Session()
sm                = sess.client('sagemaker')
sagemaker_session = sagemaker.Session(boto_session=sess)
role              = get_execution_role()

In [None]:
triton_image_uri = '195202947636.dkr.ecr.us-west-2.amazonaws.com/tritonserver:21.06-py3'

In [None]:
import numpy as np
from PIL import Image

def load_sample_image():
    image_path = './kitten.jpg'
    img = Image.open(image_path).convert("RGB")
    img = img.resize((224, 224))
    img = (np.array(img).astype(np.float32) / 255) - np.array([0.485, 0.456, 0.406], dtype=np.float32).reshape(1, 1, 3)
    img = img / np.array([0.229, 0.224, 0.225], dtype=np.float32).reshape(1, 1, 3)
    return img.tolist()

In [None]:
!docker run --gpus=all --rm -it \
            -v `pwd`/workspace:/workspace nvcr.io/nvidia/pytorch:21.06-py3 \
            /bin/bash generate_models.sh

In [None]:
!mkdir -p triton-serve-pt/resnet/1/
!mv -f workspace/model.pt triton-serve-pt/resnet/1/
!tar -C triton-serve-pt/ -czf model.tar.gz resnet
model_uri = sagemaker_session.upload_data(path="model.tar.gz", key_prefix="triton-serve-pt")

In [None]:
sm_model_name = 'triton-resnet-pt-' + time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())

container = {
    'Image': triton_image_uri,
    'ModelDataUrl': model_uri,
    'Environment': {
        'SAGEMAKER_TRITON_DEFAULT_MODEL_NAME': 'resnet'
    }
}

create_model_response = sm.create_model(
    ModelName         = sm_model_name,
    ExecutionRoleArn  = role,
    PrimaryContainer  = container)

print("Model Arn: " + create_model_response['ModelArn'])

In [None]:
endpoint_config_name = 'triton-resnet-pt-' + time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())

create_endpoint_config_response = sm.create_endpoint_config(
    EndpointConfigName = endpoint_config_name,
    ProductionVariants = [{
        'InstanceType'        : 'ml.g4dn.4xlarge',
        'InitialVariantWeight': 1,
        'InitialInstanceCount': 1,
        'ModelName'           : sm_model_name,
        'VariantName'         : 'AllTraffic'}])

print("Endpoint Config Arn: " + create_endpoint_config_response['EndpointConfigArn'])

In [None]:
endpoint_name = 'triton-resnet-pt-' + time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())

create_endpoint_response = sm.create_endpoint(
    EndpointName         = endpoint_name,
    EndpointConfigName   = endpoint_config_name)

print("Endpoint Arn: " + create_endpoint_response['EndpointArn'])

In [None]:
resp = sm.describe_endpoint(EndpointName=endpoint_name)
status = resp['EndpointStatus']
print("Status: " + status)

while status=='Creating':
    time.sleep(60)
    resp = sm.describe_endpoint(EndpointName=endpoint_name)
    status = resp['EndpointStatus']
    print("Status: " + status)

print("Arn: " + resp['EndpointArn'])
print("Status: " + status)

In [None]:
client = boto3.client('sagemaker-runtime')

payload = {
    "inputs": [{
        "name": "INPUT__0",
        "shape": [1, 3, 224, 224],
        "datatype": "FP32",
        "data": load_sample_image()
    }]
}

response = client.invoke_endpoint(EndpointName=endpoint_name,
                                  ContentType='application/octet-stream',
                                  Body=json.dumps(payload))

print(json.loads(response['Body'].read().decode('utf8')))

In [None]:
client.delete_model(sm_model_name)
client.delete_endpoint_config(endpoint_config_name)
client.delete_endpoint(endpoint_name)

In [None]:
!mkdir -p triton-serve-trt/resnet/1/
!mv -f workspace/model.plan triton-serve-trt/resnet/1/model.plan
!tar -C triton-serve-trt/ -czf model.tar.gz resnet
model_uri = sagemaker_session.upload_data(path="model.tar.gz", key_prefix="triton-serve-trt")

In [None]:
sm_model_name = 'triton-resnet-trt-' + time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())

container = {
    'Image': triton_image_uri,
    'ModelDataUrl': model_uri,
    'Environment': {
        'SAGEMAKER_TRITON_DEFAULT_MODEL_NAME': 'resnet'
    }
}

create_model_response = sm.create_model(
    ModelName         = sm_model_name,
    ExecutionRoleArn  = role,
    PrimaryContainer  = container)

print("Model Arn: " + create_model_response['ModelArn'])

In [None]:
endpoint_config_name = 'triton-resnet-trt-' + time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())

create_endpoint_config_response = sm.create_endpoint_config(
    EndpointConfigName = endpoint_config_name,
    ProductionVariants = [{
        'InstanceType'        : 'ml.g4dn.4xlarge',
        'InitialVariantWeight': 1,
        'InitialInstanceCount': 1,
        'ModelName'           : sm_model_name,
        'VariantName'         : 'AllTraffic'}])

print("Endpoint Config Arn: " + create_endpoint_config_response['EndpointConfigArn'])

In [None]:
endpoint_name = 'triton-resnet-trt-' + time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())

create_endpoint_response = sm.create_endpoint(
    EndpointName         = endpoint_name,
    EndpointConfigName   = endpoint_config_name)

print("Endpoint Arn: " + create_endpoint_response['EndpointArn'])

In [None]:
resp = sm.describe_endpoint(EndpointName=endpoint_name)
status = resp['EndpointStatus']
print("Status: " + status)

while status=='Creating':
    time.sleep(60)
    resp = sm.describe_endpoint(EndpointName=endpoint_name)
    status = resp['EndpointStatus']
    print("Status: " + status)

print("Arn: " + resp['EndpointArn'])
print("Status: " + status)

In [None]:
payload = {
    "inputs": [{
        "name": "input",
        "shape": [1, 3, 224, 224],
        "datatype": "FP32",
        "data": load_sample_image()
    }]
}

response = client.invoke_endpoint(EndpointName=endpoint_name,
                                  ContentType='application/octet-stream',
                                  Body=json.dumps(payload))

print(json.loads(response['Body'].read().decode('utf8')))

In [None]:
sm.delete_endpoint(EndpointName=endpoint_name)
sm.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
sm.delete_model(ModelName=sm_model_name)