# MNIST distributed training with TensorFlow  

## Contents

1. [Background](#Background)
1. [Setup](#Setup)
1. [Data](#Data)
1. [Train](#Train)
1. [Host](#Host)
1. [Predict](#Predict)


## Background

The **SageMaker Python SDK** helps you deploy your models for training and hosting in optimized, productions ready containers in SageMaker. The SageMaker Python SDK is easy to use, modular, extensible and compatible with TensorFlow and MXNet. This tutorial focuses on how to create a convolutional neural network model to train the [MNIST dataset](http://yann.lecun.com/exdb/mnist/) using **TensorFlow distributed training**.




## Setup

Here we will start by importing the necessary libraries for this notebook.

In [None]:
import os
os.system("aws s3 cp s3://sagemaker-workshop-pdx/mnist/utils.py utils.py")
os.system("aws s3 cp s3://sagemaker-workshop-pdx/mnist/mnist.py mnist.py")
import sagemaker
import utils
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.contrib.learn.python.learn.datasets import mnist
import tensorflow as tf
import boto3


Next we specify the IAM role arn used to give training and hosting access to your data. See the documentation for how to create these.  Note, if more than one role is required for notebook instances, training, and/or hosting, please replace the `get_execution_role()` call with the appropriate full IAM role arn string(s).

In [None]:
role = sagemaker.get_execution_role()
sagemaker_session = sagemaker.Session()


## Data


### Download the MNIST dataset

First we will download the data from the workshop's S3 bucket, then we will extract the images from the compressed files.

In [None]:
os.system("aws s3 cp --recursive s3://sagemaker-workshop-pdx/mnist/data data")

data_sets = mnist.read_data_sets('mnist/data', dtype=tf.uint8, reshape=False, validation_size=5000)

utils.convert_to(data_sets.train, 'train', 'mnist/data')
utils.convert_to(data_sets.validation, 'validation', 'mnist/data')
utils.convert_to(data_sets.test, 'test', 'mnist/data')

### Some sample images from the MNIST data set

Here are some images from the MNIST traing data set, feel free to change the batch number and re-run this cell to see some other images from the collection, or, change the data set it pulls from to the test data set to see images from that collection.

In [None]:
!cat utils.py

In [None]:
batch_xs, batch_ys = data_sets.train.next_batch(5) # Change "train" to "test" or select a different batch.
utils.gen_image(batch_xs[0]).show()
utils.gen_image(batch_xs[1]).show()
utils.gen_image(batch_xs[2]).show()

### Upload the data
We use the ```sagemaker.Session.upload_data``` function to upload our datasets to an S3 bucket. The return value of inputs identifies the location -- we will use this later when we start the training job.

In [None]:
inputs = sagemaker_session.upload_data(path='mnist/data', key_prefix='data/mnist')


## Train

Here is the full code for the network model:

In [None]:
!cat 'mnist.py'

The script here is an adaptation of the [TensorFlow MNIST example](https://github.com/tensorflow/models/tree/master/official/mnist). We have defined ```model_fn(features, labels, mode)```, which includes all the logic to support training, evaluation and inference. 

### A regular ```model_fn```

A regular **```model_fn```** follows the pattern:
1. [defines a neural network](https://github.com/tensorflow/models/blob/master/official/mnist/mnist.py#L96)
- [applies the ```features``` in the neural network](https://github.com/tensorflow/models/blob/master/official/mnist/mnist.py#L178)
- [if the ```mode``` is ```PREDICT```, returns the output from the neural network](https://github.com/tensorflow/models/blob/master/official/mnist/mnist.py#L186)
- [calculates the loss function comparing the output with the ```labels```](https://github.com/tensorflow/models/blob/master/official/mnist/mnist.py#L188)
- [creates an optimizer and minimizes the loss function to improve the neural network](https://github.com/tensorflow/models/blob/master/official/mnist/mnist.py#L193)
- [returns the output, optimizer and loss function](https://github.com/tensorflow/models/blob/master/official/mnist/mnist.py#L205)

### Writing a ```model_fn``` for distributed training
When distributed training happens, the same neural network will be sent to multiple training instances. Each instance will train with a batch of the dataset, calculate loss and minimize the optimizer. One entire loop of this process is called a **training step**.

### Syncronizing training steps
A [global step](https://www.tensorflow.org/api_docs/python/tf/train/global_step) is a global counter shared between the instances. This counter is used by the optimizer to keep track of the number of **training steps** across instances and is necessary for distributed training: 

```python
train_op = optimizer.minimize(loss, tf.train.get_or_create_global_step())
```

That is also the **only** required change for distributed training!

### Create a training job using the sagemaker.TensorFlow estimator

In [None]:
from sagemaker.tensorflow import TensorFlow

mnist_estimator = TensorFlow(entry_point='mnist.py',
                             role=role,
                             training_steps=1000, 
                             evaluation_steps=100,
                             train_instance_count=2,
                             train_instance_type='ml.c4.8xlarge')

mnist_estimator.fit(inputs)

The **```fit```** method will create a training job using two **ml.c4.8xlarge** instances. The output above will show the status of the training jobs on each instance during training and evaluation.

When training is complete, the training job will generate a saved model for serving using a SageMaker endpoint.


## Host


### Deploy the trained model to prepare for predictions

The deploy() method creates a SageMaker endpoint which serves prediction requests in real-time.

In [None]:
mnist_predictor = mnist_estimator.deploy(initial_instance_count=1,
                                             instance_type='ml.m4.xlarge')


## Predict

### Invoking the endpoint

Now we will pass some of the test images to the model endpoint for inference.

In [None]:
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)

for i in range(10):
    data = mnist.test.images[i].tolist()
    tensor_proto = tf.make_tensor_proto(values=np.asarray(data), shape=[1, len(data)], dtype=tf.float32)
    predict_response = mnist_predictor.predict(tensor_proto)
    
    image = mnist.test.images[i]
    image = np.array(image, dtype='float')
    plt.imshow(image.reshape(28, 28))
    plt.show()
    label = np.argmax(mnist.test.labels[i])
    print("Label is: {}".format(label))
    prediction = predict_response['outputs']['classes']['int64Val'][0]
    print("Prediction is: {}".format(prediction))
    print("_________________________________")

# Deleting the endpoint
When you are done with the notebook, delete the endpoint to not incur unneccessary charges by running the following cell.

In [None]:
sagemaker.Session().delete_endpoint(mnist_predictor.endpoint)