# Part 2: Deploy a model trained using SageMaker distributed data parallel

Use this notebook after you have completed **Part 1: Distributed data parallel MNIST training with PyTorch and SageMaker's distributed data parallel library** in the notebook pytorch_smdataparallel_mnist_demo.ipynb. To deploy the model you previously trained, you need to create a Sagemaker Endpoint. This is a hosted prediction service that you can use to perform inference.

## Finding the model

This notebook uses a stored model if it exists. If you recently ran a training example that use the `%store%` magic, it will be restored in the next cell.

Otherwise, you can pass the URI to the model file (a .tar.gz file) in the `model_data` variable.

To find the location of model files in the [SageMaker console](https://console.aws.amazon.com/sagemaker/home), do the following: 

1. Go to the SageMaker console: https://console.aws.amazon.com/sagemaker/home.
1. Select **Training** in the left navigation pane and then Select **Training jobs**. 
1. Find your recent training job and choose it.
1. In the **Output** section, you should see an S3 URI under **S3 model artifact**. Copy this S3 URI.
1. Uncomment the `model_data` line in the next cell that manually sets the model's URI and replace the placeholder value with that S3 URI.

In [None]:
%store -r model_data
try:
    model_data
except NameError:
    model_data = 's3://sagemaker-sample-files/models/pytorch-smdataparallel-mnist-2021-06-14-23-25-29-876/output/model.tar.gz'
    
print("Using this model: {}".format(model_data))

## Create a model object

You define the model object by using the SageMaker Python SDK's `PyTorchModel` and pass in the model from the `estimator` and the `entry_point`. The endpoint's entry point for inference is defined by `model_fn` as seen in the following code block that prints out `inference.py`. The function loads the model and sets it to use a GPU, if available.

In [None]:
!aws s3 cp s3://sagemaker-sample-files/datasets/image/MNIST/model/pytorch-training-2020-11-21-22-02-56-203/model.tar.gz .

In [None]:
!tar -xvzf model.tar.gz

In [None]:
!pygmentize code/inference.py

In [None]:
import sagemaker

role = sagemaker.get_execution_role()

from sagemaker.pytorch import PyTorchModel

model = PyTorchModel(
    model_data=model_data,
    source_dir="code",
    entry_point="inference.py",
    role=role,
    framework_version="1.6.0",
    py_version="py3",
)

### Deploy the model on an endpoint

You create a `predictor` by using the `model.deploy` function. You can optionally change both the instance count and instance type.

In [None]:
predictor = model.deploy(
    initial_instance_count=1, 
    instance_type="ml.m4.xlarge")

## Test the model
You can test the depolyed model using samples from the test set.


In [None]:
!aws s3 cp s3://sagemaker-sample-files/datasets/image/MNIST/pytorch/ data/ --recursive

In [None]:
# Download the test set
from torchvision import datasets, transforms
from torch.utils.data import DataLoader


test_set = datasets.MNIST(
    "data",
    download=False,
    train=False,
    transform=transforms.Compose(
        [transforms.ToTensor(), 
         transforms.Normalize((0.1307,), (0.3081,))]
    ),
)


# Randomly sample 16 images from the test set
test_loader = DataLoader(test_set, shuffle=True, batch_size=16)
test_images, _ = iter(test_loader).next()

# inspect the images
import torchvision
import numpy as np
import matplotlib.pyplot as plt

%matplotlib inline


def imshow(img):
    img = img.numpy()
    img = np.transpose(img, (1, 2, 0))
    plt.imshow(img)
    return


# unnormalize the test images for displaying
unnorm_images = (test_images * 0.3081) + 0.1307

print("Sampled test images: ")
imshow(torchvision.utils.make_grid(unnorm_images))

In [None]:
import json
x = json.dumps(
    {'inputs': test_images.numpy().tolist()}
)

In [None]:
# invoke endpoint
import boto3
import json

sm_runtime = boto3.client("sagemaker-runtime")

body = json.dumps(
    {'inputs': test_images.numpy().tolist()}
)
content_type = "application/json"

# respnse type
accept = "application/json"

res = sm_runtime.invoke_endpoint(
    EndpointName=predictor.endpoint_name,
    Body=body,  # encoded input data
    ContentType=content_type,  # I told the endpoint what's the encode
    Accept=accept,  # I told the endpoint how I want to decode its response
)

# decode the response body
res_body = res["Body"]
pred = res_body.read().decode("utf-8")

print("Type of the response: ", type(pred))
print()
print(pred)
print()

# string -> list
pred = json.loads(pred)

# list -> numpy
pred = np.array(pred, dtype=np.float32)

# predicted class
predicted = np.argmax(pred, axis=1)
print("Predictions: ", predicted)

## Cleanup

If you don't intend on trying out inference or to do anything else with the endpoint, you should delete it.

In [None]:
predictor.delete_endpoint()