# Deploying pre-trained PyTorch vision models with Amazon SageMaker Neo

---

This notebook's CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook. 

![This us-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-west-2/sagemaker_neo_compilation_jobs|pytorch_torchvision|pytorch_torchvision_neo.ipynb)

---

Amazon SageMaker Neo is an API to compile machine learning models to optimize them for our choice of hardware targets. Currently, Neo supports pre-trained PyTorch models from [TorchVision](https://pytorch.org/docs/stable/torchvision/models.html). General support for other PyTorch models is forthcoming.

### Runtime

This notebook takes approximately 8 minutes to run.

### Contents

1. [Import ResNet18 from TorchVision](#Import-ResNet18-from-TorchVision)
1. [Invoke Neo Compilation API](#Invoke-Neo-Compilation-API)
1. [Deploy the model](#Deploy-the-model)
1. [Send requests](#Send-requests)
1. [Delete the Endpoint](#Delete-the-Endpoint)

## Import ResNet18 from TorchVision

We import the [ResNet18](https://arxiv.org/abs/1512.03385) model from TorchVision and create a model artifact `model.tar.gz`.

In [None]:
import sys

!{sys.executable} -m pip install torch==1.13.0 torchvision==0.14.0
!{sys.executable} -m pip install s3transfer==0.5.0
!{sys.executable} -m pip install --upgrade sagemaker

Specify the input data shape. For more information, see [Prepare Model for Compilation](https://docs.aws.amazon.com/sagemaker/latest/dg/neo-compilation-preparing-model.html).

In [None]:
import sagemaker
import torch
import torchvision.models as models
import tarfile

resnet18 = models.resnet18(pretrained=True)
input_shape = [1, 3, 224, 224]
trace = torch.jit.trace(resnet18.float().eval(), torch.zeros(input_shape).float())
trace.save("model.pth")

with tarfile.open("model.tar.gz", "w:gz") as f:
    f.add("model.pth")

### Upload the model archive to S3

Specify parameters for the compilation job and upload the `model.tar.gz` archive file.

In [None]:
import boto3
import sagemaker
import time
from sagemaker.utils import name_from_base

role = sagemaker.get_execution_role()
sess = sagemaker.Session()
region = sess.boto_region_name
bucket = sess.default_bucket()

compilation_job_name = name_from_base("TorchVision-ResNet18-Neo")
prefix = compilation_job_name + "/model"

model_path = sess.upload_data(path="model.tar.gz", key_prefix=prefix)

data_shape = '{"input0":[1,3,224,224]}'
target_device = "ml_c5"
framework = "PYTORCH"
framework_version = "1.13"
compiled_model_path = "s3://{}/{}/output".format(bucket, compilation_job_name)

## Invoke Neo Compilation API

### Create a PyTorch SageMaker model

Use the `PyTorchModel` and define parameters including the path to the model, the `entry_point` script that is used to perform inference, and other version and environment variables.

In [None]:
from sagemaker.pytorch.model import PyTorchModel
from sagemaker.predictor import Predictor

sagemaker_model = PyTorchModel(
    model_data=model_path,
    predictor_cls=Predictor,
    framework_version=framework_version,
    role=role,
    sagemaker_session=sess,
    entry_point="resnet18.py",
    source_dir="code",
    py_version="py3",
    env={"MMS_DEFAULT_RESPONSE_TIMEOUT": "500"},
)

### Use Neo compiler to compile the model

Run the compilation job, which is saved in S3 at the specified `compiled_model_path` location.

In [None]:
compiled_model = sagemaker_model.compile(
    target_instance_family=target_device,
    input_shape=data_shape,
    job_name=compilation_job_name,
    role=role,
    framework=framework.lower(),
    framework_version=framework_version,
    output_path=compiled_model_path,
)

## Deploy the model

Deploy the compiled model to an endpoint so it can be used for inference.

In [None]:
predictor = compiled_model.deploy(initial_instance_count=1, instance_type="ml.c5.9xlarge")

## Send requests

Let's send a picture to the endpoint to predict the image subject.

![title](cat.jpg)

Open the image and pass the payload as a bytearray to the predictor, receiving a response.

In [None]:
import numpy as np
import json

with open("cat.jpg", "rb") as f:
    payload = f.read()
    payload = bytearray(payload)

response = predictor.predict(payload)
result = json.loads(response.decode())
print("Most likely class: {}".format(np.argmax(result)))

Use the ImageNet class ID response to look up which subject the image contains, and with what probability.

In [None]:
# Load names for ImageNet classes
object_categories = {}
with open("imagenet1000_clsidx_to_labels.txt", "r") as f:
    for line in f:
        key, val = line.strip().split(":")
        object_categories[key] = val.strip(" ").strip(",")
print(
    "The label is",
    object_categories[str(np.argmax(result))],
    "with probability",
    str(np.amax(result))[:5],
)

## Delete the Endpoint
Delete the endpoint to avoid incurring costs now that it is no longer needed.

In [None]:
predictor.delete_model()
sess.delete_endpoint(predictor.endpoint_name)

## Notebook CI Test Results

This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.

![This us-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-east-1/sagemaker_neo_compilation_jobs|pytorch_torchvision|pytorch_torchvision_neo.ipynb)

![This us-east-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-east-2/sagemaker_neo_compilation_jobs|pytorch_torchvision|pytorch_torchvision_neo.ipynb)

![This us-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-west-1/sagemaker_neo_compilation_jobs|pytorch_torchvision|pytorch_torchvision_neo.ipynb)

![This ca-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ca-central-1/sagemaker_neo_compilation_jobs|pytorch_torchvision|pytorch_torchvision_neo.ipynb)

![This sa-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/sa-east-1/sagemaker_neo_compilation_jobs|pytorch_torchvision|pytorch_torchvision_neo.ipynb)

![This eu-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-1/sagemaker_neo_compilation_jobs|pytorch_torchvision|pytorch_torchvision_neo.ipynb)

![This eu-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-2/sagemaker_neo_compilation_jobs|pytorch_torchvision|pytorch_torchvision_neo.ipynb)

![This eu-west-3 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-3/sagemaker_neo_compilation_jobs|pytorch_torchvision|pytorch_torchvision_neo.ipynb)

![This eu-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-central-1/sagemaker_neo_compilation_jobs|pytorch_torchvision|pytorch_torchvision_neo.ipynb)

![This eu-north-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-north-1/sagemaker_neo_compilation_jobs|pytorch_torchvision|pytorch_torchvision_neo.ipynb)

![This ap-southeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-southeast-1/sagemaker_neo_compilation_jobs|pytorch_torchvision|pytorch_torchvision_neo.ipynb)

![This ap-southeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-southeast-2/sagemaker_neo_compilation_jobs|pytorch_torchvision|pytorch_torchvision_neo.ipynb)

![This ap-northeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-northeast-1/sagemaker_neo_compilation_jobs|pytorch_torchvision|pytorch_torchvision_neo.ipynb)

![This ap-northeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-northeast-2/sagemaker_neo_compilation_jobs|pytorch_torchvision|pytorch_torchvision_neo.ipynb)

![This ap-south-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-south-1/sagemaker_neo_compilation_jobs|pytorch_torchvision|pytorch_torchvision_neo.ipynb)
