## MNIST Training and Prediction with SageMaker Chainer

[MNIST](http://yann.lecun.com/exdb/mnist/), the "Hello World" of machine learning, is a popular dataset for handwritten digit classification. It consists of 70,000 28x28 grayscale images labeled in 10 digit classes (0 to 9). This tutorial will show how to train on MNIST using SageMaker Chainer in [sagemaker-python-sdk](https://github.com/aws/sagemaker-python-sdk).

In [None]:
import sagemaker
from sagemaker import get_execution_role

sagemaker_session = sagemaker.Session()

# Get a SageMaker-compatible role used by this Notebook Instance.
role = get_execution_role()

## Download MNIST datasets

We can use Chainer's built-in get_mnist() method to download, import and preprocess the MNIST dataset.

In [None]:
import os

import chainer

try:
    os.makedirs('data/train')
    os.makedirs('data/test')
except FileExistsError:
    print('Directories already exist!')

train, test = chainer.datasets.get_mnist()

## Parse, save, and upload the data

We save our data, then use sagemaker_session.upload_data to upload the data to an S3 location used for training. The return value identifies the S3 path to the uploaded data.

In [None]:
import numpy as np

train_images = np.array([data[0] for data in train])
train_labels = np.array([data[1] for data in train])
test_images = np.array([data[0] for data in test])
test_labels = np.array([data[1] for data in test])

np.savez('data/train/train.npz', images=train_images, labels=train_labels)
np.savez('data/test/test.npz', images=test_images, labels=test_labels)

train_input = sagemaker_session.upload_data(path=os.path.join('data', 'train'), key_prefix='notebook/chainer/mnist')
test_input = sagemaker_session.upload_data(path=os.path.join('data', 'test'), key_prefix='notebook/chainer/mnist')

## Prepare user script for training and prediction

We need to provide a user script that can run on the SageMaker platform. The script is essentially the same as one you would write for local training, except that a function `train` that returns a trained `chainer.Chain` model is required. By default, this model is saved to disk as an npz file named `model.npz`.
For prediciton, this script also requires a function `model_fn` that loads the `chainer.Chain` saved from training. When SageMaker calls your `train` and `model_fn` functions, it will pass in arguments that describe the training environment.
You could find a detailed instruction for writing a user script in [sagemaker-python-sdk](https://github.com/aws/sagemaker-python-sdk).

The below user script used in our tutorial is adapted from [chainer mnist example](https://github.com/chainer/chainer/blob/master/examples/mnist/train_mnist_data_parallel.py)

In [None]:
!cat 'chainer_mnist_single_machine.py'

## Create SageMaker chainer estimator

To train MNIST, let's construct a `sagemaker.chainer.estimator.Chainer` estimator. A quick explanation for some configurable arguments here:

`entry_point`: The user script SageMaker runs for training and prediction.

`train_instance_count`: The number of SageMaker instances for training. Since we only do single machine training in this tutorial, it should be 1.

`train_instance_type`: The type of SageMaker instances for training. We pass the string 'local' here to enable the [local mode](https://github.com/aws/sagemaker-python-sdk#local-mode) for training in the local environment. If you want to train on a remote instance, specify a SageMaker ML instance type here accordingly. See [Amazon SageMaker ML Instance Types](https://aws.amazon.com/sagemaker/pricing/instance-types/)

`hyperparameters`: The hyper-parameters defined in the user script. In this tutorial, `epochs`, `batch_size` and `frequency` can be configured and passed.

In [None]:
from sagemaker.chainer.estimator import Chainer

chainer_estimator = Chainer(entry_point='chainer_mnist_single_machine.py', role=role,
                            train_instance_count=1, train_instance_type='local',
                            hyperparameters={'epochs': 16, 'batch_size': 128})

## Train on MNIST data in S3

After we've constructed our Chainer object, we can fit it using the MNIST data we uploaded to S3. SageMaker makes sure our data is available in the local filesystem, so our user script can simply read the data from disk.

In [None]:
chainer_estimator.fit({'train': train_input, 'test': test_input})

Our user script writes various artifacts, such as plots, to a directory `output_data_dir`, the contents of which SageMaker uploads to S3. Now we download and extract these artifacts.

In [None]:
try:
    os.makedirs('output/single_machine_mnist')
except FileExistsError:
    print('Directory already exist!')

chainer_training_job = chainer_estimator.latest_training_job.name

desc = chainer_estimator.sagemaker_session.sagemaker_client.describe_training_job(TrainingJobName=chainer_training_job)
output_data = desc['ModelArtifacts']['S3ModelArtifacts'].replace('s3_model_artifacts', '')
status = os.system('cp {}algo*/output/data/*.png output/single_machine_mnist'.format(output_data))

These plots show the accuracy and loss over epochs:

In [None]:
from IPython.display import Image
from IPython.display import display

accuracy_graph = Image(filename = "output/single_machine_mnist/accuracy.png", width=800, height=800)
loss_graph = Image(filename = "output/single_machine_mnist/loss.png", width=800, height=800)

display(accuracy_graph, loss_graph)

## Deploy model to endpoint

After training, we deploy the model to an endpoint using the Chainer estimator object. Here we also specify instance_type to be `local` to deploy the model to the local environment. If you want to deploy the model to a remote instance, specify a SageMaker ML instance type here accordingly.

In [None]:
predictor = chainer_estimator.deploy(initial_instance_count=1, instance_type='local')

## Predict Hand-Written Digit

We can now use this predictor to classify hand-written digits. Drawing into the image box loads the pixel data into a variable named 'data' in this notebook, which we can then pass to the Chainer predictor.

In [None]:
from IPython.display import HTML
HTML(open("input.html").read())

Now let's see if your writing can be recognized!

In [None]:
image = np.array(data, dtype=np.float32)
prediction = predictor.predict(image)
predicted_label = prediction.argmax(axis=1)[0]
print('What you wrote is: {}'.format(predicted_label))

We can also get some random test images in MNIST.

In [None]:
import random

import matplotlib.pyplot as plt

num_samples = 5
indices = random.sample(range(test_images.shape[0] - 1), num_samples)
images, labels = test_images[indices], test_labels[indices]

for i in range(num_samples):
    plt.subplot(1,num_samples,i+1)
    plt.imshow(images[i].reshape(28, 28), cmap='gray')
    plt.title(labels[i])
    plt.axis('off')

Now let's see if we can make correct predictions.

In [None]:
prediction = predictor.predict(images)
predicted_label = prediction.argmax(axis=1)
print('The predicted labels are: {}'.format(predicted_label))

## Clean resources

After you have finished with this example, remember to delete the prediction endpoint to release the instance associated with it and remove the data and outputs saved locally.

In [None]:
import shutil

chainer_estimator.delete_endpoint()
shutil.rmtree('data')
shutil.rmtree('output')