# MNIST with `TPUEstimator`
The purpose of this notebook is to convert the `Estimator` implementation of MNIST MLP in the notebook mnist-estimator.ipynb into an `TPUEstimator` implementation. With this done, we will have completed our ultimate quest to compare the runtime of MNIST on CPUs, GPUs, and TPUs, with the same code and datasets.

Most of this code will be ported from the last notebook, mnist-estimator.ipynb. In order to port it, we will make use of [this](https://cloud.google.com/tpu/docs/tutorials/migrating-to-tpuestimator-api) Google tutorial on how to port `Estimator` to `TPUEstimator`.

## Setup
First, we import the required libraries.

In [None]:
import os
import sys
import time
import tensorflow as tf
import numpy as np
import time

# Disable depreciation warnings and limit verbosity during training
try:
    from tensorflow.python.util import deprecation
    deprecation._PRINT_DEPRECATION_WARNINGS = False
except AttributeError:
    print("Import of warning suppression module failed")
    
tf.logging.set_verbosity(0)

Now we create some global variables which will define the learning process. They are set to be identical to the corresponding global variables in the `keras` implementation of MNIST in this repository.

In [None]:
NUM_CLASSES = 10
NUM_CHANNELS = 1
NUM_EPOCHS = 1
IMG_EDGE = 28
MODEL_DIR = 'gs://harrisgroup-ctpu/jtdinsmo/mnist/output/'
DATA_DIR = 'gs://harrisgroup-ctpu/jtdinsmo/mnist/data/'
TPU_NAME='jtdinsmo-tpu-2'
ZONE_NAME='us-central1-b'
PROJECT_NAME = 'harrisgroup-223921'
NUM_ITERATIONS = 50 # Number of iterations per TPU training loop
TRAIN_STEPS = 1000
EVALUATE_STEPS = 1000
INFERENCE_TIME_THRESHOLD = 10 # Seconds
NUM_SHARDS = 8 # Number of shards (TPU chips).
LEARNING_RATE = 0.05

We must download the MNIST dataset. We will download it in its orginal form, fitted neither to MLP format nor CNN format. We will reformat it for each of these implementations later. 

In [None]:
mnist = tf.contrib.learn.datasets.load_dataset("mnist")
train_data = mnist.train.images  # Returns an np.array
train_labels = np.asarray(mnist.train.labels, dtype=np.int32)
eval_data = mnist.test.images  # Returns an np.array
eval_labels = np.asarray(mnist.test.labels, dtype=np.int32)

We also need a function to describe the behavior of the `TPUEstimator` during evaluation.

In [None]:
def metric_fn(labels, logits):
    accuracy = tf.metrics.accuracy(
        labels=labels, predictions=tf.argmax(logits, axis=1))
    return {"accuracy": accuracy}

def eval_input_fn(params):
    return (eval_data, eval_labels)
    #return ds
    
def train_input_fn(params):
    print("Training")
    return (train_data, train_labels)
        
def predict_input_fn(params):
    batch_size = params["batch_size"]
    dataset_images = tf.data.Dataset.from_tensor_slices(tf.random_uniform([batch_size, IMG_EDGE**2]))
    dataset_images = dataset_images.batch(batch_size)
    return dataset_images

## Run
Now we will create a superclass &mdash; much like we did in the `keras` implementation of MNIST in this repository &mdash; which contains all of the functions common to both MLP and CNN implementations of MNIST. The methods of this superclass are described as follows:
- `__init__` simply sets the device to use, either CPU or GPU so that it is consistent throughout the whole process of training and predicting on the `TPUEstimator`.
- `_load` perfoms all the implementation-specific functions, such as determining the topology of the neural net, and reshaping the data.
- `model_fn` is a required argument of `tf.contrib.tpu.TPUEstimator.__init__` which describes the behavior of the `TPUEstimator`. 
- `_create` creates the `TPUEstimator`. Sometimes, it throws an `UNHEALTHY_TENSORFLOW` error if several models have already been created on the TPU cluster. In this case, stopping and restarting the cluster will solve the issue. We therefore stop and restart the TPU cluster every time `_create` is called, even when it doesn't throw the `UNHEALTHY_TENSORFLOW` for the sake of parity.
- `_train` will train the `TPUEstimator` and return the time per iteration it takes to do so, just as it did in the `keras` implementation of MNIST in this repository.
- `_predict` does the same but for inference. It will return the time it takes to do inference on the `TPUEstimator` per inference, where the batch size refers to the number of inferences to do in parallel.
- `_main` is the function that the `tensorflow` app runs, which calls `_train` and `_predict` and gathers all the data into `self.train_times` and `self.inference_times`.
- `get_data` is the only "public" function &mdash; the only one which is called outside of the class. It calls `_main`.

In [None]:
batch_train_data = []
batch_inference_data = []
MAX_STEPS = train_data.shape[0] // 8 + 1

class MNIST:
    def __init__(self, device):
        self.device = device
        self.learning_rate = None
        self.use_tpu = True
    
    # To be overridden
    def _load(self, features):
        '''
        Returns the model, the new features formatted in the way
        MLP or CNN needs them.
        '''
        return None, None
    
    # To be overridden
    def _load_batch_size_limits(self):
        self.start_power = None
        self.end_power = None
    
    def _get_batch_sizes(self):
        self._load_batch_size_limits()
        self.batch_sizes = []
        for i in range(self.start_power, self.end_power):
            self.batch_sizes += list(range(8**i, 8**(i+1), 8**i))
        self.batch_sizes += [8**self.end_power]
        
    def model_fn(self, features, labels, mode, params):
        del params# Unused
        image = features
        if isinstance(image, dict):
            image = features["image"]

        model, features = self._load(features)

        if mode == tf.estimator.ModeKeys.PREDICT:
            logits = model(image)
            predictions = {
                'class_ids': tf.argmax(logits, axis=1),
                'probabilities': tf.nn.softmax(logits),
            }
            return tf.contrib.tpu.TPUEstimatorSpec(mode, predictions=predictions)

        logits = model(image)
        loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)

        if mode == tf.estimator.ModeKeys.TRAIN:
            learning_rate = tf.train.exponential_decay(
                LEARNING_RATE,
                tf.train.get_global_step(),
                decay_steps=100000,
                decay_rate=0.96)
            optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
            if self.use_tpu:
                optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
            return tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=loss,
                train_op=optimizer.minimize(loss, tf.train.get_global_step()))

        if mode == tf.estimator.ModeKeys.EVAL:
            return tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode, loss=loss, eval_metrics=(metric_fn, [labels, logits]))

    def _train(self, batch_size):
        start_time = time.time()
        for i in range(NUM_EPOCHS):
            epoch_start_time = time.time()
            print("Epoch " + str(i+1) + '/' + str(NUM_EPOCHS)+': ', end='')
            self.estimator.train(input_fn=train_input_fn, max_steps=TRAIN_STEPS)
            metrics = self.estimator.evaluate(input_fn=eval_input_fn, steps=EVALUATE_STEPS)
            epoch_end_time = time.time()
            print('Accuracy:', metrics['accuracy'], '\tLoss:', metrics['loss'],
                  '\tTime:', epoch_end_time - epoch_start_time, 's')
        end_time = time.time()
        iterations = NUM_EPOCHS * (train_data.shape[0] / batch_size)
        train_time = (end_time - start_time) / iterations
        return train_time
    
    def _predict(self, batch_size):
        inference_time = 0
        inference_num = 0
        start_inference = time.time()
        while True:
            start_time = time.time()
            predictions = self.estimator.predict(predict_input_fn)
            for pred_dict in predictions:
                class_id = pred_dict['class_ids']
                probability = pred_dict['probabilities'][class_id]
                #print("Prediction is", class_id, "("+str(100 * probability)+").")
                
            end_time = time.time()
            inference_time += end_time - start_time
            inference_num += 1
            if end_time - start_inference > INFERENCE_TIME_THRESHOLD:
                # Do as many inferences as possible in INFERENCE_TIME_THRESHOLD seconds
                break
        inference_time = inference_time / inference_num / batch_size
        return inference_time, inference_num
    
    def _create(self, batch_size):
        # Restart the TPU cluster to avoid UNHEALTHY_TENSORFLOW
        os.system("gcloud compute tpus stop jtdinsmo-tpu-2 --zone us-central1-b")
        os.system("gcloud compute tpus start jtdinsmo-tpu-2 --zone us-central1-b")
        time.sleep(5)
            # Pause for a bit to allow the TPU clusters to finish running 
            # startup scripts if there are any.
        print("TPU cluster restarted")

        tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
            TPU_NAME,
            zone=ZONE_NAME,
            project=PROJECT_NAME)

        run_config = tf.contrib.tpu.RunConfig(
            cluster=tpu_cluster_resolver,
            model_dir=MODEL_DIR,
            session_config=tf.ConfigProto(
            allow_soft_placement=True, log_device_placement=True),
            tpu_config=tf.contrib.tpu.TPUConfig(NUM_ITERATIONS, NUM_SHARDS),)
        
        self.estimator = tf.contrib.tpu.TPUEstimator(
            model_fn=self.model_fn,
            use_tpu=self.use_tpu,
            train_batch_size=batch_size,
            eval_batch_size=batch_size,
            predict_batch_size=batch_size,
            params={"data_dir": DATA_DIR},
            config=run_config)
    
    def _main(self, _):
        self.train_times = []
        self.inference_times = []
        self._get_batch_sizes()
        for batch_size in self.batch_sizes:
            self._create(batch_size)
                            
            train_time = self._train(batch_size)
                
            inference_time, inference_num = self._predict(batch_size)
            print('\n','Batch size:', batch_size, '\tTrain time:', train_time,
                  '\tInference time', inference_time, '(%s)'%inference_num)
            print('+'*100)
            self.train_times.append(train_time)
            self.inference_times.append(inference_time)
            
            tf.reset_default_graph()# For memory conservation
            
    def get_data(self):
        try:
            tf.app.run(self._main)
        except SystemExit:
            # Prevent the program from exiting when done
            pass

With the superclass defined, we must now implement it for MLP and CNN versions of MNIST. For each, we define a class which contains the topology of the MNIST neural net (e.g., `ModelMLP`), and then we implement a subclass of the class `MNIST` we just defined which returns it (e.g., `MNIST_MLP`).

In [None]:
class ModelMLP(object):
    def __call__(self, inputs):
        net = tf.layers.dense(inputs, 512, activation=tf.nn.relu, name='dense1')
        net = tf.layers.dropout(net, rate=0.2, name='drop1')
        net = tf.layers.dense(net, 512, activation=tf.nn.relu, name='dense2')
        net = tf.layers.dropout(net, rate=0.2, name='drop2')
        net = tf.layers.dense(net, NUM_CLASSES, activation=tf.nn.softmax, name='dense3')
        return net
    
class MNIST_MLP(MNIST):
    def _load(self, features):
        images = tf.reshape(features, [-1, IMG_EDGE**2])
        self.learning_rate = 0.001# Learning rate from keras.optimizers.RMSprop
        return ModelMLP(), images
    def _load_batch_size_limits(self):
        self.start_power = 1
        self.end_power = 4

In [None]:
class ModelCNN(object):
    def __call__(self, inputs):
        net = tf.layers.conv2d(inputs, 32, [5, 5], activation=tf.nn.relu, name='conv1')
        net = tf.layers.max_pooling2d(net, [2, 2], 2, name='pool1')
        net = tf.layers.conv2d(net, 64, [5, 5], activation=tf.nn.relu, name='conv2')
        net = tf.layers.max_pooling2d(net, [2, 2], 2, name='pool2')
        net = tf.layers.flatten(net)
        net = tf.layers.dense(net, NUM_CLASSES, activation=None, name='fc1')
        return net
          
class MNIST_CNN(MNIST):
    def _load(self, features):
        images = tf.reshape(features, [-1, IMG_EDGE, IMG_EDGE, NUM_CHANNELS])
        self.learning_rate =1.0# Learning rate from keras.optimizers.Adadelta
        return ModelCNN(), images
    def _load_batch_size_limits(self):
        self.start_power = 1
        self.end_power = 2 # CNN nets are bigger, and for high batch size they run out of memory.

With the subclasses implemented, all we need to do is instantiate them and call `get_data` for both. This will train all the models and run all the inferences for all the batch sizes we need, with all possible combinations of machine type (CPU or GPU) and MNIST implementation (MLP or CNN).

In [None]:
print("MNIST MLP")
print()
mlp_cpu = MNIST_MLP('/cpu:0')
mlp_gpu = MNIST_MLP('/gpu:0')

print()
print("TRAIN ON CPUS")
print()
print('+'*100)
mlp_cpu.get_data()
 
print()
print("TRAIN ON GPUS")
print()
print('+'*100)
mlp_gpu.get_data()

print()
print('+'*47, "DONE", '+'*47)

As with the `keras` implementaton, we will again back up the data on the MLP network speeds and save it to the disk so that if the CNN training runs out of memory we can recover the MLP data.

In [None]:
backup = open("backup.txt", 'w')

assert len(mlp_cpu.batch_sizes) == len(mlp_gpu.batch_sizes) == len(mlp_cpu.train_times) == len(mlp_cpu.inference_times) \
                 == len(mlp_gpu.train_times) == len(mlp_gpu.inference_times)
for i in range(len(mlp_cpu.batch_sizes)):
    assert mlp_cpu.batch_sizes[i] == mlp_gpu.batch_sizes[i]
    backup.write(str(mlp_cpu.batch_sizes[i]) + '|' +
                 str(mlp_cpu.train_times[i]) + '|' +
                 str(mlp_cpu.inference_times[i]) + '|' +
                 str(mlp_gpu.train_times[i]) + '|' +
                 str(mlp_gpu.inference_times[i]) + '|' + '\n')

backup.close()

Now we can actually run the CNN implementation of MNIST.

In [None]:
print("MNIST CNN")
print()
cnn_cpu = MNIST_CNN('/cpu:0')
cnn_gpu = MNIST_CNN('/gpu:0')

print()
print("TRAIN ON CPUS")
print()
print('+'*100)
cnn_cpu.get_data()
 
print()
print("TRAIN ON GPUS")
print()
print('+'*100)
cnn_gpu.get_data()

print()
print('+'*47, "DONE", '+'*47)

## Graphing Data
Now we will graph the data we have collected above. First we import the `matplotlib` libraries.

In [None]:
import numpy as np
import matplotlib.pyplot as plt

x_mlp = np.array(mlp_cpu.batch_sizes)
x_cnn = np.array(cnn_cpu.batch_sizes)

Then we plot the train time and inference times on CPUs and GPUs as a function of batch size. This is train time.

In [None]:
plt.scatter(x_mlp, mlp_cpu.train_times, c='b', alpha = 0.5)
plt.scatter(x_mlp, mlp_gpu.train_times, c='r', alpha = 0.5, marker='s')
plt.scatter(x_cnn, cnn_cpu.train_times, c='y', alpha = 0.5, marker='^')
plt.scatter(x_cnn, cnn_gpu.train_times, c='m', alpha = 0.5, marker='v')
plt.xlabel('Batch size')
plt.ylabel('Train time (s)')
plt.xscale('log')
plt.yscale('log')
plt.axis([1, 10000, 0.001, 2])
plt.legend(['MLP CPU', 'MLP GPU', 'CNN CPU', 'CNN GPU'])
plt.show()

This is inference time.

In [None]:
plt.scatter(x_mlp, mlp_cpu.inference_times, c='b', alpha = 0.5)
plt.scatter(x_mlp, mlp_gpu.inference_times, c='r', alpha = 0.5, marker='s')
plt.scatter(x_cnn, cnn_cpu.inference_times, c='y', alpha = 0.5, marker='^')
plt.scatter(x_cnn, cnn_gpu.inference_times, c='m', alpha = 0.5, marker='v')
plt.xlabel('Batch size')
plt.ylabel('Inference time (s)')
plt.xscale('log')
plt.yscale('log')
plt.axis([1, 10000, 0.0001, 0.1])
plt.legend(['MLP CPU', 'MLP GPU', 'CNN CPU', 'CNN GPU'])
plt.show()

We can also plot the performance gain in using GPUs over CPUs. This is performance gain in train time.

In [None]:
def get_improvement(cpu_times, gpu_times):
    gain = []
    for i in range(len(cpu_times)):
        gain.append(cpu_times[i] / gpu_times[i] * 100)
    return np.array(gain)

gain_train_mlp = get_improvement(mlp_cpu.train_times, mlp_gpu.train_times)
gain_train_cnn = get_improvement(cnn_cpu.train_times, cnn_gpu.train_times)

plt.scatter(x_mlp, gain_train_mlp, c='k', alpha = 0.5, marker = 'd')
plt.scatter(x_cnn, gain_train_cnn, c='c', alpha = 0.5, marker = '>')
plt.xlabel('Batch size')
plt.ylabel('Train speed gain by using GPUs (%)')
plt.xscale('log')
plt.yscale('linear')
plt.legend(['MLP', 'CNN'])
plt.axhline(100, linestyle='--', linewidth=1, color='k')
plt.show()

This is the performance gain in inference time.

In [None]:
gain_inference_mlp = get_improvement(mlp_cpu.inference_times, mlp_gpu.inference_times)
gain_inference_cnn = get_improvement(cnn_cpu.inference_times, cnn_gpu.inference_times)

plt.scatter(x_mlp, gain_inference_mlp, c='k', alpha = 0.5, marker='d')
plt.scatter(x_cnn, gain_inference_cnn, c='c', alpha = 0.5, marker='>')
plt.xlabel('Batch size')
plt.ylabel('Inference speed gain by using GPUs (%)')
plt.xscale('log')
plt.yscale('linear')
plt.legend(['MLP', 'CNN'])
plt.axhline(100, linestyle='--', linewidth=1, color='k')
plt.show()

This concludes the experiment.