# Deploying pre-trained PyTorch vision models with Amazon SageMaker Neo

Amazon SageMaker Neo is API to compile machine learning models to optimize them for our choice of hardward targets. Currently, Neo supports pre-trained PyTorch models from [TorchVision](https://pytorch.org/docs/stable/torchvision/models.html). General support for other PyTorch models is forthcoming.

In [1]:
!~/anaconda3/envs/pytorch_p36/bin/pip install torch==1.2.0 torchvision==0.4.0

Collecting torch==1.2.0
[?25l  Downloading https://files.pythonhosted.org/packages/30/57/d5cceb0799c06733eefce80c395459f28970ebb9e896846ce96ab579a3f1/torch-1.2.0-cp36-cp36m-manylinux1_x86_64.whl (748.8MB)
[K    100% |████████████████████████████████| 748.9MB 42kB/s  eta 0:00:01    54% |█████████████████▎              | 405.2MB 53.6MB/s eta 0:00:07
[?25hCollecting torchvision==0.4.0
[?25l  Downloading https://files.pythonhosted.org/packages/06/e6/a564eba563f7ff53aa7318ff6aaa5bd8385cbda39ed55ba471e95af27d19/torchvision-0.4.0-cp36-cp36m-manylinux1_x86_64.whl (8.8MB)
[K    100% |████████████████████████████████| 8.8MB 7.6MB/s eta 0:00:01
Installing collected packages: torch, torchvision
  Found existing installation: torch 1.4.0
    Uninstalling torch-1.4.0:
      Successfully uninstalled torch-1.4.0
  Found existing installation: torchvision 0.5.0
    Uninstalling torchvision-0.5.0:
      Successfully uninstalled torchvision-0.5.0
Successfully installed torch-1.2.0 torchvision-0.4.0


## Import ResNet18 from TorchVision

We'll import [ResNet18](https://arxiv.org/abs/1512.03385) model from TorchVision and create a model artifact `model.tar.gz`:

In [2]:
import torch
import torchvision.models as models
import tarfile

resnet18 = models.resnet18(pretrained=True)
input_shape = [1,3,224,224]
trace = torch.jit.trace(resnet18.float().eval(), torch.zeros(input_shape).float())
trace.save('model.pth')

with tarfile.open('model.tar.gz', 'w:gz') as f:
    f.add('model.pth')

Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to /home/ec2-user/.cache/torch/checkpoints/resnet18-5c106cde.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 104MB/s]


## Invoke Neo Compilation API

We then forward the model artifact to Neo Compilation API:

In [3]:
import boto3
import sagemaker
import time
from sagemaker.utils import name_from_base

role = sagemaker.get_execution_role()
sess = sagemaker.Session()
region = sess.boto_region_name
bucket = sess.default_bucket()

compilation_job_name = name_from_base('TorchVision-ResNet18-Neo')

model_key = '{}/model/model.tar.gz'.format(compilation_job_name)
model_path = 's3://{}/{}'.format(bucket, model_key)
boto3.resource('s3').Bucket(bucket).upload_file('model.tar.gz', model_key)

sm_client = boto3.client('sagemaker')
data_shape = '{"input0":[1,3,224,224]}'
target_device = 'ml_c5'
framework = 'PYTORCH'
framework_version = '1.2.0'
compiled_model_path = 's3://{}/{}/output'.format(bucket, compilation_job_name)

In [4]:
response = sm_client.create_compilation_job(
    CompilationJobName=compilation_job_name,
    RoleArn=role,
    InputConfig={
        'S3Uri': model_path,
        'DataInputConfig': data_shape,
        'Framework': framework
    },
    OutputConfig={
        'S3OutputLocation': compiled_model_path,
        'TargetDevice': target_device
    },
    StoppingCondition={
        'MaxRuntimeInSeconds': 300
    }
)
print(response)

# Poll every 30 sec
while True:
    response = sm_client.describe_compilation_job(CompilationJobName=compilation_job_name)
    if response['CompilationJobStatus'] == 'COMPLETED':
        break
    elif response['CompilationJobStatus'] == 'FAILED':
        raise RuntimeError('Compilation failed')
    print('Compiling ...')
    time.sleep(30)
print('Done!')

# Extract compiled model artifact
compiled_model_path = response['ModelArtifacts']['S3ModelArtifacts']

{'CompilationJobArn': 'arn:aws:sagemaker:us-east-1:111652037296:compilation-job/TorchVision-ResNet18-Neo-2020-04-17-02-20-54-739', 'ResponseMetadata': {'RequestId': '53dd7e26-2121-44b3-8a38-b4ad954cd7a1', 'HTTPStatusCode': 200, 'HTTPHeaders': {'x-amzn-requestid': '53dd7e26-2121-44b3-8a38-b4ad954cd7a1', 'content-type': 'application/x-amz-json-1.1', 'content-length': '129', 'date': 'Fri, 17 Apr 2020 02:20:54 GMT'}, 'RetryAttempts': 0}}
Compiling ...
Compiling ...
Compiling ...
Done!


## Create prediction endpoint

To create a prediction endpoint, we first specify two additional functions, to be used with Neo Deep Learning Runtime:

* `neo_preprocess(payload, content_type)`: Function that takes in the payload and Content-Type of each incoming request and returns a NumPy array. Here, the payload is byte-encoded NumPy array, so the function simply decodes the bytes to obtain the NumPy array.
* `neo_postprocess(result)`: Function that takes the prediction results produced by Deep Learining Runtime and returns the response body

In [5]:
!pygmentize resnet18.py

[34mdef[39;49;00m [32mneo_preprocess[39;49;00m(payload, content_type):
    [34mimport[39;49;00m [04m[36mPIL.Image[39;49;00m   [37m# Training container doesn't have this package[39;49;00m
    [34mimport[39;49;00m [04m[36mlogging[39;49;00m
    [34mimport[39;49;00m [04m[36mnumpy[39;49;00m [34mas[39;49;00m [04m[36mnp[39;49;00m
    [34mimport[39;49;00m [04m[36mio[39;49;00m

    logging.info([33m'[39;49;00m[33mInvoking user-defined pre-processing function[39;49;00m[33m'[39;49;00m)

    [34mif[39;49;00m content_type != [33m'[39;49;00m[33mapplication/x-image[39;49;00m[33m'[39;49;00m:
        [34mraise[39;49;00m [36mRuntimeError[39;49;00m([33m'[39;49;00m[33mContent type must be application/x-image[39;49;00m[33m'[39;49;00m)

    f = io.BytesIO(payload)
    [37m# Load image and convert to RGB space[39;49;00m
    image = PIL.Image.open(f).convert([33m'[39;49;00m[33mRGB[39;49;00m[33m'[39;49;00m)
    [37m# Resize[39;49;00m
    image = 

Upload the Python script containing the two functions to S3:

In [6]:
source_key = '{}/source/sourcedir.tar.gz'.format(compilation_job_name)
source_path = 's3://{}/{}'.format(bucket, source_key)

with tarfile.open('sourcedir.tar.gz', 'w:gz') as f:
    f.add('resnet18.py')

boto3.resource('s3').Bucket(bucket).upload_file('sourcedir.tar.gz', source_key)

We then create a SageMaker model record:

In [7]:
from sagemaker.model import NEO_IMAGE_ACCOUNT
from sagemaker.fw_utils import create_image_uri

model_name = name_from_base('TorchVision-ResNet18-Neo')

image_uri = create_image_uri(region, 'neo-' + framework.lower(), target_device.replace('_', '.'),
                             framework_version, py_version='py3', account=NEO_IMAGE_ACCOUNT[region])

response = sm_client.create_model(
    ModelName=model_name,
    PrimaryContainer={
        'Image': image_uri,
        'ModelDataUrl': compiled_model_path,
        'Environment': { 'SAGEMAKER_SUBMIT_DIRECTORY': source_path }
    },
    ExecutionRoleArn=role
)
print(response)

{'ModelArn': 'arn:aws:sagemaker:us-east-1:111652037296:model/torchvision-resnet18-neo-2020-04-17-02-22-27-353', 'ResponseMetadata': {'RequestId': 'd823f419-8df4-444b-baa4-6a834c8eb0cc', 'HTTPStatusCode': 200, 'HTTPHeaders': {'x-amzn-requestid': 'd823f419-8df4-444b-baa4-6a834c8eb0cc', 'content-type': 'application/x-amz-json-1.1', 'content-length': '110', 'date': 'Fri, 17 Apr 2020 02:22:27 GMT'}, 'RetryAttempts': 0}}


Then we create an Endpoint Configuration:

In [8]:
config_name = model_name

response = sm_client.create_endpoint_config(
    EndpointConfigName=config_name,
    ProductionVariants=[
        {
            'VariantName': 'default-variant-name',
            'ModelName': model_name,
            'InitialInstanceCount': 1,
            'InstanceType': 'ml.c5.xlarge',
            'InitialVariantWeight': 1.0
        },
    ],
)
print(response)

{'EndpointConfigArn': 'arn:aws:sagemaker:us-east-1:111652037296:endpoint-config/torchvision-resnet18-neo-2020-04-17-02-22-27-353', 'ResponseMetadata': {'RequestId': '874ef9ea-8f7b-4cf7-a00f-31cc4de38a30', 'HTTPStatusCode': 200, 'HTTPHeaders': {'x-amzn-requestid': '874ef9ea-8f7b-4cf7-a00f-31cc4de38a30', 'content-type': 'application/x-amz-json-1.1', 'content-length': '129', 'date': 'Fri, 17 Apr 2020 02:22:27 GMT'}, 'RetryAttempts': 0}}


Finally, we create an Endpoint:

In [9]:
endpoint_name = model_name + '-Endpoint'

response = sm_client.create_endpoint(
    EndpointName=endpoint_name,
    EndpointConfigName=config_name,
)
print(response)

print('Creating endpoint ...')
waiter = sm_client.get_waiter('endpoint_in_service')
waiter.wait(EndpointName=endpoint_name)

response = sm_client.describe_endpoint(EndpointName=endpoint_name)
print(response)

{'EndpointArn': 'arn:aws:sagemaker:us-east-1:111652037296:endpoint/torchvision-resnet18-neo-2020-04-17-02-22-27-353-endpoint', 'ResponseMetadata': {'RequestId': '0e7df91f-1300-4458-bc9c-898bb063e288', 'HTTPStatusCode': 200, 'HTTPHeaders': {'x-amzn-requestid': '0e7df91f-1300-4458-bc9c-898bb063e288', 'content-type': 'application/x-amz-json-1.1', 'content-length': '125', 'date': 'Fri, 17 Apr 2020 02:22:28 GMT'}, 'RetryAttempts': 0}}
Creating endpoint ...
{'EndpointName': 'TorchVision-ResNet18-Neo-2020-04-17-02-22-27-353-Endpoint', 'EndpointArn': 'arn:aws:sagemaker:us-east-1:111652037296:endpoint/torchvision-resnet18-neo-2020-04-17-02-22-27-353-endpoint', 'EndpointConfigName': 'TorchVision-ResNet18-Neo-2020-04-17-02-22-27-353', 'ProductionVariants': [{'VariantName': 'default-variant-name', 'DeployedImages': [{'SpecifiedImage': '785573368785.dkr.ecr.us-east-1.amazonaws.com/sagemaker-neo-pytorch:1.2.0-cpu-py3', 'ResolvedImage': '785573368785.dkr.ecr.us-east-1.amazonaws.com/sagemaker-neo-pyto

## Send requests

Let's try to send a cat picture.

![title](cat.jpg)

In [10]:
import json
import numpy as np

sm_runtime = boto3.Session().client('sagemaker-runtime')

with open('cat.jpg', 'rb') as f:
    payload = f.read()

response = sm_runtime.invoke_endpoint(EndpointName=endpoint_name,
                                      ContentType='application/x-image',
                                      Body=payload)
print(response)
result = json.loads(response['Body'].read().decode())
print('Most likely class: {}'.format(np.argmax(result)))

{'ResponseMetadata': {'RequestId': '9300f2c6-954c-4625-8c87-bc2ad2ffdbae', 'HTTPStatusCode': 200, 'HTTPHeaders': {'x-amzn-requestid': '9300f2c6-954c-4625-8c87-bc2ad2ffdbae', 'x-amzn-invoked-production-variant': 'default-variant-name', 'date': 'Fri, 17 Apr 2020 02:28:59 GMT', 'content-type': 'application/json', 'content-length': '23362'}, 'RetryAttempts': 0}, 'ContentType': 'application/json', 'InvokedProductionVariant': 'default-variant-name', 'Body': <botocore.response.StreamingBody object at 0x7fa7a56bb278>}
Most likely class: 282


In [11]:
# Load names for ImageNet classes
object_categories = {}
with open("imagenet1000_clsidx_to_labels.txt", "r") as f:
    for line in f:
        key, val = line.strip().split(':')
        object_categories[key] = val
print("Result: label - " + object_categories[str(np.argmax(result))]+ " probability - " + str(np.amax(result)))

Result: label -  'tiger cat', probability - 0.6977682113647461


## Delete the Endpoint
Having an endpoint running will incur some costs. Therefore as a clean-up job, we should delete the endpoint.

In [12]:
#sess.delete_endpoint(endpoint_name)