# Multi-GPU training with Estimators, `tf.keras` and `tf.data`

At [Zalando Research](https://research.zalando.com/), as in most AI research departments, we realise the importance of experimenting and quickly prototyping ideas. With datasets getting bigger it thus becomes useful to know how to train deep learning models quickly and efficiently on the shared resources we have.

TensorFlow's [Estimators](https://www.tensorflow.org/programmers_guide/estimators) API is useful for training models in a distributed environment such as on nodes with multiple GPUs or on many nodes with GPUs. This is particularly useful when training on huge datasets. But we will first present the API for the tiny Fashion-MNIST dataset and then show a practical usecase in the end.

**TL;DR**: Essentially what we want to remember is that a `tf.keras.Model` can be trained with `tf.estimator` API by converting it to an `tf.estimator.Estimator` object via the `tf.keras.estimator.model_to_estimator` method. Once converted we can apply the machinery that `Estimator` provides to train on different hardware configurations.

In [1]:
import os

#!pip install -q -U tensorflow-gpu
import tensorflow as tf

import numpy as np

  return f(*args, **kwds)
  return f(*args, **kwds)


## Import the Fashion-MNIST dataset

We will use the [Fashion-MNIST](https://github.com/zalandoresearch/fashion-mnist) dataset, a drop-in replacement of MNIST, which contains thousands of grayscale images of [Zalando](https://www.zalando.de/) fashion articles. Getting the training and test data is as simple as:

In [2]:
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.fashion_mnist.load_data()

We  want to convert the pixel values of these images from a number between 0 and 255 to a number between 0 and 1 and convert the dataset to the `[B, H, W, C]` format where `B` is the number of images, `H` and `W` are the height and width and `C` the number of channels (1 for grayscale) of our images:

In [3]:
TRAINING_SIZE = len(train_images)
TEST_SIZE = len(test_images)

train_images = np.asarray(train_images, dtype=np.float32) / 255

# Convert the train images and add channels
train_images = train_images.reshape((TRAINING_SIZE, 28, 28, 1))

test_images = np.asarray(test_images, dtype=np.float32) / 255
# Convert the train images and add channels
test_images = test_images.reshape((TEST_SIZE, 28, 28, 1))

Next, we want to convert the labels from an integer format (e.g., `2` or `Pullover`), to a [one hot encoding](https://en.wikipedia.org/wiki/One-hot) (e.g., `0, 0, 1, 0, 0, 0, 0, 0, 0, 0`). To do so, we'll use the `tf.keras.utils.to_categorical` [function](https://www.tensorflow.org/api_docs/python/tf/keras/utils/to_categorical) function:

In [4]:
# How many categories we are predicting from (0-9)
LABEL_DIMENSIONS = 10

train_labels  = tf.keras.utils.to_categorical(train_labels, LABEL_DIMENSIONS)
test_labels = tf.keras.utils.to_categorical(test_labels, LABEL_DIMENSIONS)

# Cast the labels to floats, needed later
train_labels = train_labels.astype(np.float32)
test_labels = test_labels.astype(np.float32)

## Build a `tf.keras` model

We will create our neural network using the [Keras Sequential API](https://www.tensorflow.org/api_docs/python/tf/keras/Sequential). Keras is a high-level API to build and train deep learning models and is user friendly, modular and easy to extend. `tf.keras` is TensorFlow's implementation of this API and it supports such things as eager execution, `tf.data` pipelines and Estimators.

In terms of the architecture we will use ConvNets. On a very high level ConvNets are stacks of Convolutional layers (`Conv2D`) and Pooling layers (`MaxPooling2D`). But most importantly they will take for each training example a 3D tensors of shape (`height`, `width`, `channels`) where for the case of grayscale images `channels=1` and return a 3D tensor. 

Therefore after the ConvNet part we will need to `Flatten` the tensor and add  `Dense` layers, the last one returning the `LABEL_DIMENSIONS` outputs with the `softmax` activation:

In [5]:
inputs = tf.keras.Input(shape=(28,28,1))  # Returns a placeholder tensor
x = tf.keras.layers.Conv2D(filters=32, kernel_size=(3, 3), activation=tf.nn.relu)(inputs)
x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=2)(x)
x = tf.keras.layers.Conv2D(filters=64, kernel_size=(3, 3), activation=tf.nn.relu)(x)
x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=2)(x)
x = tf.keras.layers.Conv2D(filters=64, kernel_size=(3, 3), activation=tf.nn.relu)(x)
x = tf.keras.layers.Flatten()(x)
x = tf.keras.layers.Dense(64, activation=tf.nn.relu)(x)
predictions = tf.keras.layers.Dense(LABEL_DIMENSIONS, activation=tf.nn.softmax)(x)

We can now define our model, select the optimizer (we choose one from TensorFlow rather than using one from `tf.keras.optimizers`) and compile it:

In [6]:
model = tf.keras.Model(inputs=inputs, outputs=predictions)

In [7]:
optimizer = tf.train.AdamOptimizer(learning_rate=0.001)

In [8]:
model.compile(loss='categorical_crossentropy',
              optimizer=optimizer,
              metrics=['accuracy'])

## Create an Estimator

To create an Estimator from the compiled Keras model we call the `model_to_estimator` method. Note that the initial model state of the Keras model is preserved in the created Estimator.

So what's so good about Estimators? Well to start off with:

* you can run Estimator-based models on a local host or an a distributed multi-GPU environment without changing your model;
* Estimators simplify sharing implementations between model developers;
* Estimators build the graph for you, so a bit like Eager Execution, there is no explicit session.


So how do we go about training our simple `tf.keras` model to use multi-GPUs? We can use the `tf.contrib.MirroredStrategy` paradigm which does in-graph replication with synchronous training. See this talk on [Distributed TensorFlow training](https://www.youtube.com/watch?v=bRMGoPqsn20) for more information about this strategy.

Essentially each worker GPU has a copy of the graph and gets a subset of the data on which it computes the local gradients and then waits for all the workers to finish in a synchronous manner. Then the workers communicate their local gradients to each other via a ring Allreduce operation which is typically optimized to reduce network bandwidth and increase through-put. Once all the gradients have arrived each worker averages them and updates its parameter and the next step begins. This is ideal in situations where you have multiple GPUs on a single node connected via some high-speed interconnect.


To use this strategy we first create an Estimator from the compiled `tf.keras` model and give it the `MirroredStrategy` configuration via the `RunConfig` config. This configuration by default will use all the GPUs but you can also give it a `num_gpus` option to use a specific number of GPUs:

In [9]:
NUM_GPUS = 2

strategy = tf.contrib.distribute.MirroredStrategy(num_gpus=NUM_GPUS)
config = tf.estimator.RunConfig(train_distribute=strategy)

estimator = tf.keras.estimator.model_to_estimator(model, config=config)

INFO:tensorflow:Using the Keras model provided.
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpp9jik95r', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': None, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': <tensorflow.contrib.distribute.python.mirrored_strategy.MirroredStrategy object at 0x7fd276307400>, '_device_fn': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7fd276307438>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}


## Create an Estimator input function

To pipe data into Estimators we need to define a data importing function which returns a `tf.data` dataset of  `(images, labels)` batches of our data. The function below takes in `numpy` arrays and returns the dataset via an ETL process.

Note that in the end we are also calling the `prefetch` method which will buffer the data to the GPUs while they are training so that the next batch is ready and waiting for the GPUs rather than having the GPUs wait for the data at each iteration. The GPU might still not be fully utilized and to improve this we can use fused versions of the transformation operations like `shuffle_and_repeat` instead of two separate operations, but I have kept the simple case here.

In [10]:
def input_fn(images, labels, epochs, batch_size):
    # Convert the inputs to a Dataset. (E)
    dataset = tf.data.Dataset.from_tensor_slices((images, labels))

    # Shuffle, repeat, and batch the examples. (T)
    SHUFFLE_SIZE = 5000
    dataset = dataset.shuffle(SHUFFLE_SIZE).repeat(epochs).batch(batch_size)
    dataset = dataset.prefetch(2)

    # Return the dataset. (L)
    return dataset

## Train the Estimator

Now the good part! We can call the `train` function on our Estimator giving it the `input_fn` we defined as well as the number of steps of the optimizer we wish to run:

In [13]:
BATCH_SIZE = 256
EPOCHS = 10
STEPS = 1000

estimator.train(input_fn=lambda:input_fn(train_images,
                                         train_labels,
                                         epochs=EPOCHS,
                                         batch_size=BATCH_SIZE//NUM_GPUS), 
                steps=STEPS)

INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:batch_all_reduce invoked for batches size = 10 with algorithm = nccl, num_packs = 1, agg_small_grads_max_bytes = 0 and agg_small_grads_max_group = 10
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmpp9jik95r/model.ckpt-1000
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 1000 into /tmp/tmpp9jik95r/model.ckpt.
INFO:tensorflow:loss = 0.31297606, step = 1000
INFO:tensorflow:global_step/sec: 144.403
INFO:tensorflow:loss = 0.28465617, step = 1100 (0.694 sec)
INFO:tensorflow:global_step/sec: 233.725
INFO:tensorflow:loss = 0.33305198, step = 1200 (0.428 sec)
INFO:tensorflow:global_step/sec: 230.657
INFO:tensorflow:loss = 0.3072132, step = 1300 (0.434 sec)
INFO:tensorflow:global_st

<tensorflow.python.estimator.estimator.Estimator at 0x7fd277699438>

## Evaluate the Estimator

In order to check the performance of our model we then call the `evaluate` method on our Estimator:

In [15]:
estimator.evaluate(lambda:input_fn(test_images, 
                                            test_labels,
                                            epochs=1,
                                            batch_size=BATCH_SIZE//NUM_GPUS))

INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2018-08-18-07:44:25
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmpp9jik95r/model.ckpt-2000
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Finished evaluation at 2018-08-18-07:44:26
INFO:tensorflow:Saving dict for global step 2000: accuracy = 0.89645964, global_step = 2000, loss = 0.2830377
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 2000: /tmp/tmpp9jik95r/model.ckpt-2000


{'accuracy': 0.89645964, 'loss': 0.2830377, 'global_step': 2000}

## Retinal OCT (optical coherence tomography) images example

To test the scaling performance on some bigger dataset we can use the [Retinal OCT images](https://www.kaggle.com/paultimothymooney/kermany2018) dataset, on of the many great datasets from [Kaggle](https://www.kaggle.com/datasets). This dataset consists of  cross sections of the retinas of living patients grouped into four categories: NORMAL, CNV, DME and DRUSEN.

![](https://i.imgur.com/fSTeZMd.png)

The dataset has a total of 84,495 X-Ray JPEG images, typically  `512x496`, and can be downloaded via the `kaggle` CLI:

In [None]:
#!pip install kaggle
#!kaggle datasets download -d paultimothymooney/kermany2018

Once downloaded the training and test set image classes are in their own respective folder so we can define a pattern as:

In [None]:
train_folder = os.path.join('OCT2017', 'train', '**', '*.jpeg')
test_folder = os.path.join('OCT2017', 'test', '**', '*.jpeg')

In [None]:
labels = ['CNV', 'DME', 'DRUSEN', 'NORMAL']

Next we have our Estimator's input function which takes any file pattern and returns resized images and one hot encoded labels as a `tf.data.Dataset`. Here we follow the best practices from the [Input Pipeline Performance Guide](https://www.tensorflow.org/performance/datasets_performance):

In [None]:
def input_fn(file_pattern, labels,
             image_size=(224,224),
             shuffle=False,
             batch_size=64, 
             num_epochs=None, 
             buffer_size=4096,
             prefetch_buffer_size=4):

    table = tf.contrib.lookup.index_table_from_tensor(mapping=tf.constant(labels))
    num_classes = len(labels)

    def _map_func(filename):
        label = tf.string_split([filename], delimiter=os.sep).values[-2]
        image = tf.image.decode_jpeg(tf.read_file(filename), channels=3)
        image = tf.image.convert_image_dtype(image, dtype=tf.float32)
        image = tf.image.resize_images(image, size=image_size)
        return ({"input_1": image}, tf.one_hot(table.lookup(label), num_classes))
    
    dataset = tf.data.Dataset.list_files(file_pattern, shuffle=shuffle)

    if num_epochs is not None and shuffle:
        dataset = dataset.apply(tf.contrib.data.shuffle_and_repeat(buffer_size, num_epochs))
    elif shuffle:
        dataset = dataset.shuffle(buffer_size)
    elif num_epochs is not None:
        dataset = dataset.repeat(num_epochs)

    dataset = dataset.apply(tf.contrib.data.map_and_batch(map_func=_map_func,
                                                          batch_size=batch_size,
                                                          num_parallel_calls=os.cpu_count()))
    dataset = dataset.prefetch(buffer_size=prefetch_buffer_size)
    
    return dataset

In order to train this we will use a pretrained VGG16 and train just it's last 5 layers:

In [None]:
keras_vgg16 = tf.keras.applications.VGG16(input_shape=(224,224,3),
                                          include_top=False)

In [None]:
output = keras_vgg16.output
output = tf.keras.layers.Flatten()(output)
predictions = tf.keras.layers.Dense(len(labels), activation=tf.nn.softmax)(output)

In [None]:
model = tf.keras.Model(inputs=keras_vgg16.input, outputs=predictions)

for layer in keras_vgg16.layers[:-4]:
    layer.trainable = False

Now we have all we need and can proceed as above and train our model in a few minutes using `NUM_GPUS` GPUs:

In [None]:
optimizer = tf.train.AdamOptimizer()

In [None]:
model.compile(loss='categorical_crossentropy', 
              optimizer=optimizer,
              metrics=['accuracy'])

In [None]:
strategy = tf.contrib.distribute.MirroredStrategy()
config = tf.estimator.RunConfig(train_distribute=strategy)
estimator = tf.keras.estimator.model_to_estimator(model, config=config)

In [None]:
estimator.train(input_fn=lambda:input_fn(train_folder,
                                         labels,
                                         shuffle=True,
                                         batch_size=32,
                                         buffer_size=2048,
                                         num_epochs=20),
                steps=1000)

Once trained we can evaluate the accuracy on the test set, which should be around 95% (not bad for an initial baseline!):

In [None]:
estimator.evaluate(input_fn=lambda:input_fn(test_folder,
                                            labels, 
                                            shuffle=True,
                                            batch_size=64,
                                            buffer_size=1024,
                                            num_epochs=1,
                                            prefetch_buffer_size=2))