# MNIST distributed training and batch transform

The SageMaker Python SDK helps you deploy your models for training and hosting in optimized, production-ready containers in SageMaker. The SageMaker Python SDK is easy to use, modular, extensible and compatible with TensorFlow and MXNet. This tutorial focuses on how to create a convolutional neural network model to train the [MNIST dataset](http://yann.lecun.com/exdb/mnist/) using TensorFlow distributed training.

## Set up the environment

First, we'll just set up a few things needed for this example

In [None]:
import sagemaker
from sagemaker import get_execution_role
from sagemaker.session import Session

sagemaker_session = sagemaker.Session()
region = sagemaker_session.boto_session.region_name

role = get_execution_role()

### Download the MNIST dataset

We'll now need to download the MNIST dataset, and upload it to a location in S3 after preparing for training.

In [None]:
import utils
from tensorflow.contrib.learn.python.learn.datasets import mnist
import tensorflow as tf

data_sets = mnist.read_data_sets('data', dtype=tf.uint8, reshape=False, validation_size=5000)

utils.convert_to(data_sets.train, 'train', 'data')
utils.convert_to(data_sets.validation, 'validation', 'data')
utils.convert_to(data_sets.test, 'test', 'data')

### Upload the data
We use the ```sagemaker.Session.upload_data``` function to upload our datasets to an S3 location. The return value inputs identifies the location -- we will use this later when we start the training job.

In [None]:
inputs = sagemaker_session.upload_data(path='data', key_prefix='data/DEMO-mnist')

# Construct a script for distributed training 
Here is the full code for the network model:

In [None]:
!cat 'mnist.py'

## Create a training job

In [None]:
from sagemaker.tensorflow import TensorFlow

mnist_estimator = TensorFlow(entry_point='mnist.py',
                             role=role,
                             framework_version='1.11.0',
                             training_steps=1000, 
                             evaluation_steps=100,
                             train_instance_count=2,
                             train_instance_type='ml.c4.xlarge')

mnist_estimator.fit(inputs)

The `fit()` method will create a training job in two ml.c4.xlarge instances. The logs above will show the instances doing training, evaluation, and incrementing the number of training steps. 

In the end of the training, the training job will generate a saved model for TF serving.

## SageMaker's transformer class

After training, we use our TensorFlow estimator object to create a `Transformer` by invoking the `transformer()` method. This method takes arguments for configuring our options with the batch transform job; these do not need to be the same values as the one we used for the training job. The method also creates a SageMaker Model to be used for the batch transform jobs.

The `Transformer` class is responsible for running batch transform jobs, which will deploy the trained model to an endpoint and send requests for performing inference.

In [None]:
transformer = mnist_estimator.transformer(instance_count=1, instance_type='ml.m4.xlarge')

# Perform inference

Now that we've trained a model, we're going to use it to perform inference with a SageMaker batch transform job. The request handling behavior of the Endpoint deployed during the transform job is determined by the `mnist.py` script we looked at earlier.

## Run a batch transform job

For our batch transform job, we're going to use input data that contains 1000 MNIST images, located in the public SageMaker sample data S3 bucket. To create the batch transform job, we simply call `transform()` on our transformer with information about the input data.

In [None]:
input_bucket_name = 'sagemaker-sample-data-{}'.format(region)
input_file_path = 'batch-transform/mnist-1000-samples'

transformer.transform('s3://{}/{}'.format(input_bucket_name, input_file_path), content_type='text/csv')

Now we wait for the batch transform job to complete. We have a convenience method, `wait()`, that will block until the batch transform job has completed. We can call that here to see if the batch transform job is still running; the cell will finish running when the batch transform job has completed.

In [None]:
transformer.wait()

## Download the results

The batch transform job uploads its predictions to S3. Since we did not specify `output_path` when creating the Transformer, one was generated based on the batch transform job name:

In [None]:
print(transformer.output_path)

Now let's download the first ten results from S3:

In [None]:
import json
from six.moves.urllib import parse

import boto3

parsed_url = parse.urlparse(transformer.output_path)
bucket_name = parsed_url.netloc
prefix = parsed_url.path[1:]

s3 = boto3.resource('s3')

predictions = []
for i in range(10):
    file_key = '{}/data-{}.csv.out'.format(prefix, i)

    output_obj = s3.Object(bucket_name, file_key)
    output = output_obj.get()["Body"].read().decode('utf-8')

    predictions.extend(json.loads(output)['outputs']['classes']['int64Val'])

For demonstration purposes, we're also going to download the corresponding original input data so that we can see how the model did with its predictions.

In [None]:
import os

import matplotlib.pyplot as plt
from numpy import genfromtxt

plt.rcParams['figure.figsize'] = (2,10)

def show_digit(img, caption='', subplot=None):
    if subplot == None:
        _,(subplot) = plt.subplots(1,1)
    imgr = img.reshape((28,28))
    subplot.axis('off')
    subplot.imshow(imgr, cmap='gray')
    plt.title(caption)

tmp_dir = '/tmp/data'
if not os.path.exists(tmp_dir):
    os.makedirs(tmp_dir)

for i in range(10):
    input_file_name = 'data-{}.csv'.format(i)
    input_file_key = '{}/{}'.format(input_file_path, input_file_name)
    
    s3.Bucket(input_bucket_name).download_file(input_file_key, os.path.join(tmp_dir, input_file_name))
    input_data = genfromtxt(os.path.join(tmp_dir, input_file_name), delimiter=',')

    show_digit(input_data)

Here, we can see the original labels are:

```
7, 2, 1, 0, 4, 1, 4, 9, 5, 9
```

Now let's print out the predictions to compare:

In [None]:
print(', '.join(predictions))