# MNIST distributed training and batch transform

The **SageMaker Python SDK** helps you deploy your models for training and hosting in optimized, productions ready containers in SageMaker. The SageMaker Python SDK is easy to use, modular, extensible and compatible with TensorFlow and MXNet. This tutorial focuses on how to create a convolutional neural network model to train the [MNIST dataset](http://yann.lecun.com/exdb/mnist/) using **TensorFlow distributed training**.



### Set up the environment

In [None]:
import os
import sagemaker
import boto3
from sagemaker import get_execution_role
from sagemaker.session import Session

sagemaker_session = sagemaker.Session()

role = get_execution_role()

### Download the MNIST dataset

In [None]:
import utils
from tensorflow.contrib.learn.python.learn.datasets import mnist
import tensorflow as tf

data_sets = mnist.read_data_sets('data', dtype=tf.uint8, reshape=False, validation_size=5000)

utils.convert_to(data_sets.train, 'train', 'data')
utils.convert_to(data_sets.validation, 'validation', 'data')
utils.convert_to(data_sets.test, 'test', 'data')

### Upload the data
We use the ```sagemaker.Session.upload_data``` function to upload our datasets to an S3 location. The return value inputs identifies the location -- we will use this later when we start the training job.

In [None]:
inputs = sagemaker_session.upload_data(path='data', key_prefix='data/DEMO-mnist')

# Construct a script for distributed training 
Here is the full code for the network model:

In [None]:
!cat 'mnist.py'

## Create a training job using the sagemaker.TensorFlow estimator

In [None]:
from sagemaker.tensorflow import TensorFlow

mnist_estimator = TensorFlow(entry_point='mnist.py',
                             role=role,
                             training_steps=1000, 
                             evaluation_steps=100,
                             train_instance_count=2,
                             train_instance_type='ml.c4.xlarge')

mnist_estimator.fit(inputs)

The **```fit```** method will create a training job in two **ml.c4.xlarge** instances. The logs above will show the instances doing training, evaluation, and incrementing the number of **training steps**. 

In the end of the training, the training job will generate a saved model for TF serving.

### SageMaker's transformer class

After training, we use our TensorFlow estimator object to create a `Transformer` by invoking the `transformer()` method. This method takes arguments for configuring our options with the batch transform job; these do not need to be the same values as the one we used for the training job.

The `Transformer` class is responsible for running batch transform jobs, which will deploy the trained model to an endpoint and send requests for performing inference.

In [None]:
transformer = mnist_estimator.transformer(instance_count=1, instance_type='ml.m4.xlarge')

# Inspect input data

Before we perform the inference let's inspect the input data.

In [None]:
import boto3
import matplotlib.pyplot as plt
import os
from numpy import genfromtxt

plt.rcParams["figure.figsize"] = (2,10)

def show_digit(img, caption='', subplot=None):
    if subplot==None:
        _,(subplot)=plt.subplots(1,1)
    imgr=img.reshape((28,28))
    subplot.axis('off')
    subplot.imshow(imgr, cmap='gray')
    plt.title(caption)

tmp_dir = '/tmp/data'
input_file_name = 'data.csv'
input_file_path = 'batch-transform/mnist/' + input_file_name
input_bucket_name = 'sagemaker-sample-data-{}'.format(boto3.Session().region_name)

if not os.path.exists(tmp_dir):
    os.makedirs(tmp_dir)

s3 = boto3.resource('s3')

s3.Bucket(input_bucket_name).download_file(input_file_path, os.path.join(tmp_dir, input_file_name))
input_data = genfromtxt(os.path.join(tmp_dir, input_file_name), delimiter=',')

show_digit(input_data)

### Running a batch transform job

Now we can perform some inference with the model we've trained by running a batch transform job. The request handling behavior of the Endpoint deployed during the transform job is determined by the `mnist.py` script.

In [None]:
transformer.transform('s3://{}/{}'.format(input_bucket_name, input_file_path), content_type='text/csv')

Now we wait for the batch transform job to complete. We have a convenience method, `wait()`, that will block until the batch transform job has completed. We can call that here to see if the batch transform job is still running; the cell will finish running when the batch transform job has completed.

In [None]:
transformer.wait()

### Downloading the results

The batch transform job uploads its predictions to S3. Since we did not specify `output_path` when creating the Transformer, one was generated based on the batch transform job name:

In [None]:
print(transformer.output_path)

We use that to download the results from S3:

In [None]:
import json
from urllib.parse import urlparse

parsed_url = urlparse(transformer.output_path)
bucket_name = parsed_url.netloc
file_key = '{}/{}.out'.format(parsed_url.path[1:], input_file_name)

output_obj = s3.Object(bucket_name, file_key)
output = output_obj.get()["Body"].read()

print('Prediction is {}'.format(json.loads(output)['outputs']['classes']['int64Val']))