#### The purpose of this notebook is to deploy a PyTorch model using SageMaker

In [22]:
# initial set up
import sagemaker
# SageMaker session
sagemaker_session = sagemaker.Session()
# default S3 bucket
bucket = sagemaker_session.default_bucket()
prefix = 'sagemaker/DEMO-pytorch-mnist'
# IAM role
role = sagemaker.get_execution_role()

In [None]:
# install torchvision
!yes | pip uninstall torchvison
!pip install -qU torchvision

In [None]:
# download MNIST image dataset for the PyTorch training purpose
from torchvision.datasets import MNIST
from torchvision import transforms

MNIST.mirrors = ["https://sagemaker-sample-files.s3.amazonaws.com/datasets/image/MNIST/"]

MNIST(
    'data',
    download=True,
    transform=transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
    )
)

In [24]:
# prepare input path for training with estimator.fit()
inputs = sagemaker_session.upload_data(path='data', bucket=bucket, key_prefix=prefix)
print('input spec (in this case, just an S3 path): {}'.format(inputs))

input spec (in this case, just an S3 path): s3://sagemaker-us-east-1-058199717680/sagemaker/DEMO-pytorch-mnist


In [None]:
################################################
# Instructions to get training python file
################################################
# 1) go to https://github.com/aws/amazon-sagemaker-examples/tree/main/sagemaker-python-sdk/pytorch_mnist
# 2) Download mnist.py file from the above path.
# 3) Open the "mnist.py" file and in the below method "_get_train_data_loader", add download=True
# 4) Copy the "mnist.py" into the  root folder on the left side bar of the sagemaker studio notebook.
#    Refer https://docs.aws.amazon.com/sagemaker/latest/dg/studio-ui.html#studio-ui-nav-bar for left side bar information.
# 5) Prettify python file "mnist.py" using "pygmentize" library (see next cell)
################################################################################################
# def _get_train_data_loader(batch_size, training_dir, is_distributed, **kwargs):
#     logger.info("Get train data loader")
#     dataset = datasets.MNIST(
#         training_dir,
#         download=True,
#         train=True,
#         transform=transforms.Compose(
#             [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
#         ),
#     )
#     train_sampler = (
#         torch.utils.data.distributed.DistributedSampler(dataset) if is_distributed else None
#     )
#     return torch.utils.data.DataLoader(
#         dataset,
#         batch_size=batch_size,
#         shuffle=train_sampler is None,
#         sampler=train_sampler,
#         **kwargs
#     )

In [None]:
# prettify python file
!pygmentize mnist.py

In [31]:
# PyTorch config with environment details to deploy (entry point of python file, py version, ec2 instance count and type etc).
from sagemaker.pytorch import PyTorch
# PyTorch configuration
estimator = PyTorch(entry_point='mnist.py',
                    role=role,
                    py_version='py3',
                    framework_version='1.8.0',
                    instance_count=2,
                    instance_type='ml.c5.2xlarge',
                    hyperparameters={
                        'epochs': 1,
                        'backend': 'gloo'
                    })

In [None]:
# training
estimator.fit({'training': inputs})

In [None]:
# deploy
predictor = estimator.deploy(initial_instance_count=1, instance_type='ml.m4.xlarge')

In [None]:
# check the image files
!ls data/MNIST/raw

In [None]:
# create a sample data for inference
import gzip
import numpy as np
import random
import os

data_dir = 'data/MNIST/raw'
with gzip.open(os.path.join(data_dir, "t10k-images-idx3-ubyte.gz"), "rb") as f:
    images = np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1, 28, 28).astype(np.float32)

mask = random.sample(range(len(images)), 16) # randomly select some of the test images
mask = np.array(mask, dtype=np.int)
# input data
data = images[mask]

In [None]:
# inference

# input data
response = predictor.predict(np.expand_dims(data, axis=1))
print("Raw prediction result:")
print(response)
print()

labeled_predictions = list(zip(range(10), response[0]))
print("Labeled predictions: ")
print(labeled_predictions)
print()

labeled_predictions.sort(key=lambda label_and_prob: 1.0 - label_and_prob[1])
print("Most likely answer: {}".format(labeled_predictions[0]))

#### Reference 

https://sagemaker-examples.readthedocs.io/en/latest/sagemaker-python-sdk/pytorch_mnist/pytorch_mnist.html