# Multi-GPU training with  `tf.keras` or Estimators and `tf.data`

We can train on multiple GPUs directly via `tf.keras`'s distributed strategy scope.

TensorFlow's [Estimators](https://www.tensorflow.org/programmers_guide/estimators) API is another useful way to training models in a distributed environment such as on nodes with multiple GPUs or on many nodes with GPUs. This is particularly useful when training on huge datasets especially when used with the `tf.keras` API. 

Here we will first present the `tf.keras` API for the tiny Fashion-MNIST dataset and then show a practical usecase in the end via Estimators.

**TL;DR**: Essentially what we want to remember is that a `tf.keras.Model` can be trained with `tf.estimator` API by converting it to an `tf.estimator.Estimator` object via the `tf.keras.estimator.model_to_estimator` method. Once converted we can apply the machinery that `Estimator` provides to train on different hardware configurations.

In [1]:
import os
import time

import tensorflow as tf
from tensorflow.python.ops import lookup_ops

import numpy as np

## Import the Fashion-MNIST dataset

We will use the [Fashion-MNIST](https://github.com/zalandoresearch/fashion-mnist) dataset, a drop-in replacement of MNIST, which contains thousands of grayscale images of [Zalando](https://www.zalando.de/) fashion articles. Getting the training and test data is as simple as:

In [2]:
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.fashion_mnist.load_data()

We  want to convert the pixel values of these images from a number between 0 and 255 to a number between 0 and 1 and convert the dataset to the `[B, H, W, C]` format where `B` is the number of images, `H` and `W` are the height and width and `C` the number of channels (1 for grayscale) of our images:

In [3]:
TRAINING_SIZE = len(train_images)
TEST_SIZE = len(test_images)

train_images = np.asarray(train_images, dtype=np.float32) / 255

# Convert the train images and add channels
train_images = train_images.reshape((TRAINING_SIZE, 28, 28, 1))

test_images = np.asarray(test_images, dtype=np.float32) / 255
# Convert the train images and add channels
test_images = test_images.reshape((TEST_SIZE, 28, 28, 1))

Next, we want to convert the labels from an integer format (e.g., `2` or `Pullover`), to a [one hot encoding](https://en.wikipedia.org/wiki/One-hot) (e.g., `0, 0, 1, 0, 0, 0, 0, 0, 0, 0`). To do so, we'll use the `tf.keras.utils.to_categorical` [function](https://www.tensorflow.org/api_docs/python/tf/keras/utils/to_categorical) function:

In [4]:
# How many categories we are predicting from (0-9)
LABEL_DIMENSIONS = 10

train_labels  = tf.keras.utils.to_categorical(train_labels, LABEL_DIMENSIONS)
test_labels = tf.keras.utils.to_categorical(test_labels, LABEL_DIMENSIONS)

# Cast the labels to floats, needed later
train_labels = train_labels.astype(np.float32)
test_labels = test_labels.astype(np.float32)

## Distribution strategy


So how do we go about training a  `tf.keras` model to use multi-GPUs? We can use the `tf.distribute.MirroredStrategy` paradigm which does in-graph replication with synchronous training. See this talk on [Distributed TensorFlow training](https://www.youtube.com/watch?v=bRMGoPqsn20) for more information about this strategy.

Essentially each worker GPU has a copy of the graph and gets a subset of the data on which it computes the local gradients and then waits for all the workers to finish in a synchronous manner. Then the workers communicate their local gradients to each other via a ring Allreduce operation which is typically optimized to reduce network bandwidth and increase through-put. Once all the gradients have arrived each worker averages them and updates its parameter and the next step begins. This is ideal in situations where you have multiple GPUs on a single node connected via some high-speed interconnect.

To create a `MirroredStrategy` just instantiate it via:

In [5]:
strategy = tf.distribute.MirroredStrategy()

## Build a `tf.keras` model

We will create our neural network using the [Keras Functional API](https://www.tensorflow.org/guide/keras#functional_api). Keras is a high-level API to build and train deep learning models and is user friendly, modular and easy to extend. `tf.keras` is TensorFlow's implementation of this API and it supports such things as [Eager Execution](https://www.tensorflow.org/guide/eager), `tf.data` pipelines and Estimators.

In terms of the architecture we will use ConvNets. On a very high level ConvNets are stacks of Convolutional layers (`Conv2D`) and Pooling layers (`MaxPooling2D`). But most importantly they will take for each training example a 3D tensors of shape (`height`, `width`, `channels`) where for the case of grayscale images `channels=1` and return a 3D tensor. 

Therefore after the ConvNet part we will need to `Flatten` the tensor and add  `Dense` layers, the last one returning the `LABEL_DIMENSIONS` outputs with the `softmax` activation. 

To allow this model to train on multiple GPUs via the strategy we defined above, we need to create and compile the `tf.keras` model in our `strategy.scope`:

In [6]:
with strategy.scope():
    inputs = tf.keras.Input(shape=(28,28,1))  # Returns a placeholder tensor
    x = tf.keras.layers.Conv2D(filters=32, kernel_size=(3, 3), activation=tf.nn.relu)(inputs)
    x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=2)(x)
    x = tf.keras.layers.Conv2D(filters=64, kernel_size=(3, 3), activation=tf.nn.relu)(x)
    x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=2)(x)
    x = tf.keras.layers.Conv2D(filters=64, kernel_size=(3, 3), activation=tf.nn.relu)(x)
    x = tf.keras.layers.Flatten()(x)
    x = tf.keras.layers.Dense(64, activation=tf.nn.relu)(x)
    predictions = tf.keras.layers.Dense(LABEL_DIMENSIONS, activation='sigmoid')(x)
    
    model = tf.keras.Model(inputs=inputs, outputs=predictions)
    
    model.compile(loss='categorical_crossentropy',
                  optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
                  metrics=['accuracy'])

## Create an `tf.data` input function

Next we define  a data importing function which returns a `tf.data` dataset of  `(images, labels)` batches of our data. The function below takes in `numpy` arrays and returns the dataset via an ETL process.

Note that in the end we are also calling the `prefetch` method which will buffer the data to the GPUs while they are training so that the next batch is ready and waiting for the GPUs rather than having the GPUs wait for the data at each iteration. The GPU might still not be fully utilized and to improve this we can use fused versions of the transformation operations like `shuffle_and_repeat` instead of two separate operations, but I have kept the simple case here.

In [7]:
def input_fn(images, labels, epochs, batch_size):
    # Convert the inputs to a Dataset. (E)
    dataset = tf.data.Dataset.from_tensor_slices((images, labels))

    # Shuffle, repeat, and batch the examples. (T)
    SHUFFLE_SIZE = 5000
    dataset = dataset.shuffle(SHUFFLE_SIZE).repeat(epochs).batch(batch_size)
    dataset = dataset.prefetch(None)

    # Return the dataset. (L)
    return dataset

## Training

In [8]:
BATCH_SIZE = 512
EPOCHS = 10
steps_per_epoch = int(np.ceil(60000 / float(BATCH_SIZE))) 

model.fit(input_fn(train_images, train_labels,
                   epochs=EPOCHS,
                   batch_size=BATCH_SIZE), epochs=EPOCHS, steps_per_epoch=steps_per_epoch)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10

ValueError: Input tensor shapes do not match for distributed tensor inputs PerReplica:{
  0 /job:localhost/replica:0/task:0/device:GPU:0: tf.Tensor(
[[[[0.        ]
   [0.        ]
   [0.        ]
   ...
   [0.        ]
   [0.        ]
   [0.        ]]

  [[0.        ]
   [0.        ]
   [0.        ]
   ...
   [0.        ]
   [0.        ]
   [0.        ]]

  [[0.        ]
   [0.        ]
   [0.        ]
   ...
   [0.        ]
   [0.        ]
   [0.        ]]

  ...

  [[0.        ]
   [0.        ]
   [0.        ]
   ...
   [0.        ]
   [0.        ]
   [0.        ]]

  [[0.        ]
   [0.        ]
   [0.        ]
   ...
   [0.        ]
   [0.        ]
   [0.        ]]

  [[0.        ]
   [0.        ]
   [0.        ]
   ...
   [0.        ]
   [0.        ]
   [0.        ]]]


 [[[0.        ]
   [0.        ]
   [0.        ]
   ...
   [0.        ]
   [0.        ]
   [0.        ]]

  [[0.        ]
   [0.        ]
   [0.        ]
   ...
   [0.        ]
   [0.        ]
   [0.        ]]

  [[0.        ]
   [0.        ]
   [0.        ]
   ...
   [0.        ]
   [0.        ]
   [0.        ]]

  ...

  [[0.        ]
   [0.        ]
   [0.        ]
   ...
   [0.        ]
   [0.        ]
   [0.        ]]

  [[0.        ]
   [0.        ]
   [0.        ]
   ...
   [0.        ]
   [0.        ]
   [0.        ]]

  [[0.        ]
   [0.        ]
   [0.        ]
   ...
   [0.        ]
   [0.        ]
   [0.        ]]]


 [[[0.        ]
   [0.        ]
   [0.        ]
   ...
   [0.        ]
   [0.        ]
   [0.        ]]

  [[0.        ]
   [0.        ]
   [0.        ]
   ...
   [0.        ]
   [0.        ]
   [0.        ]]

  [[0.        ]
   [0.        ]
   [0.        ]
   ...
   [0.        ]
   [0.        ]
   [0.        ]]

  ...

  [[0.        ]
   [0.07058824]
   [0.83137256]
   ...
   [0.8352941 ]
   [0.7019608 ]
   [0.03921569]]

  [[0.        ]
   [0.01568628]
   [0.8509804 ]
   ...
   [0.84705883]
   [0.9372549 ]
   [0.        ]]

  [[0.00784314]
   [0.        ]
   [0.        ]
   ...
   [0.05882353]
   [0.10588235]
   [0.        ]]]


 ...


 [[[0.        ]
   [0.        ]
   [0.        ]
   ...
   [0.        ]
   [0.        ]
   [0.        ]]

  [[0.        ]
   [0.        ]
   [0.        ]
   ...
   [0.        ]
   [0.        ]
   [0.        ]]

  [[0.        ]
   [0.        ]
   [0.        ]
   ...
   [0.        ]
   [0.        ]
   [0.        ]]

  ...

  [[0.        ]
   [0.        ]
   [0.        ]
   ...
   [0.        ]
   [0.        ]
   [0.        ]]

  [[0.        ]
   [0.        ]
   [0.        ]
   ...
   [0.        ]
   [0.        ]
   [0.        ]]

  [[0.        ]
   [0.        ]
   [0.        ]
   ...
   [0.        ]
   [0.        ]
   [0.        ]]]


 [[[0.        ]
   [0.        ]
   [0.        ]
   ...
   [0.        ]
   [0.        ]
   [0.        ]]

  [[0.        ]
   [0.        ]
   [0.        ]
   ...
   [0.        ]
   [0.        ]
   [0.        ]]

  [[0.        ]
   [0.        ]
   [0.        ]
   ...
   [0.00392157]
   [0.        ]
   [0.        ]]

  ...

  [[0.        ]
   [0.        ]
   [0.        ]
   ...
   [0.25490198]
   [0.01568628]
   [0.        ]]

  [[0.        ]
   [0.        ]
   [0.        ]
   ...
   [0.        ]
   [0.        ]
   [0.        ]]

  [[0.        ]
   [0.        ]
   [0.        ]
   ...
   [0.        ]
   [0.        ]
   [0.        ]]]


 [[[0.        ]
   [0.        ]
   [0.        ]
   ...
   [0.        ]
   [0.        ]
   [0.        ]]

  [[0.        ]
   [0.        ]
   [0.        ]
   ...
   [0.        ]
   [0.        ]
   [0.        ]]

  [[0.        ]
   [0.        ]
   [0.        ]
   ...
   [0.        ]
   [0.        ]
   [0.        ]]

  ...

  [[0.        ]
   [0.        ]
   [0.        ]
   ...
   [0.5764706 ]
   [0.4745098 ]
   [0.        ]]

  [[0.        ]
   [0.        ]
   [0.        ]
   ...
   [0.        ]
   [0.        ]
   [0.        ]]

  [[0.        ]
   [0.        ]
   [0.        ]
   ...
   [0.        ]
   [0.        ]
   [0.        ]]]], shape=(256, 28, 28, 1), dtype=float32),
  1 /job:localhost/replica:0/task:0/device:GPU:1: tf.Tensor(
[[[[0.        ]
   [0.        ]
   [0.        ]
   ...
   [0.        ]
   [0.        ]
   [0.        ]]

  [[0.        ]
   [0.        ]
   [0.        ]
   ...
   [0.        ]
   [0.        ]
   [0.        ]]

  [[0.        ]
   [0.        ]
   [0.        ]
   ...
   [0.        ]
   [0.        ]
   [0.        ]]

  ...

  [[0.        ]
   [0.        ]
   [0.        ]
   ...
   [0.        ]
   [0.        ]
   [0.        ]]

  [[0.        ]
   [0.        ]
   [0.        ]
   ...
   [0.        ]
   [0.        ]
   [0.        ]]

  [[0.        ]
   [0.        ]
   [0.        ]
   ...
   [0.        ]
   [0.        ]
   [0.        ]]]


 [[[0.        ]
   [0.        ]
   [0.        ]
   ...
   [0.        ]
   [0.        ]
   [0.        ]]

  [[0.        ]
   [0.        ]
   [0.        ]
   ...
   [0.        ]
   [0.        ]
   [0.        ]]

  [[0.        ]
   [0.        ]
   [0.        ]
   ...
   [0.        ]
   [0.        ]
   [0.        ]]

  ...

  [[0.        ]
   [0.        ]
   [0.        ]
   ...
   [0.        ]
   [0.        ]
   [0.        ]]

  [[0.        ]
   [0.        ]
   [0.        ]
   ...
   [0.        ]
   [0.        ]
   [0.        ]]

  [[0.        ]
   [0.        ]
   [0.        ]
   ...
   [0.        ]
   [0.        ]
   [0.        ]]]


 [[[0.        ]
   [0.        ]
   [0.        ]
   ...
   [0.        ]
   [0.        ]
   [0.        ]]

  [[0.        ]
   [0.        ]
   [0.        ]
   ...
   [0.        ]
   [0.        ]
   [0.        ]]

  [[0.        ]
   [0.        ]
   [0.        ]
   ...
   [0.6627451 ]
   [0.22352941]
   [0.        ]]

  ...

  [[0.        ]
   [0.        ]
   [0.        ]
   ...
   [0.62352943]
   [0.5803922 ]
   [0.        ]]

  [[0.        ]
   [0.        ]
   [0.        ]
   ...
   [0.        ]
   [0.        ]
   [0.        ]]

  [[0.        ]
   [0.        ]
   [0.        ]
   ...
   [0.        ]
   [0.        ]
   [0.        ]]]


 ...


 [[[0.        ]
   [0.        ]
   [0.        ]
   ...
   [0.        ]
   [0.        ]
   [0.        ]]

  [[0.        ]
   [0.        ]
   [0.        ]
   ...
   [0.        ]
   [0.        ]
   [0.        ]]

  [[0.        ]
   [0.        ]
   [0.        ]
   ...
   [0.        ]
   [0.        ]
   [0.        ]]

  ...

  [[0.        ]
   [0.        ]
   [0.        ]
   ...
   [0.        ]
   [0.        ]
   [0.        ]]

  [[0.        ]
   [0.        ]
   [0.        ]
   ...
   [0.        ]
   [0.        ]
   [0.        ]]

  [[0.        ]
   [0.        ]
   [0.        ]
   ...
   [0.        ]
   [0.        ]
   [0.        ]]]


 [[[0.        ]
   [0.        ]
   [0.        ]
   ...
   [0.        ]
   [0.        ]
   [0.        ]]

  [[0.        ]
   [0.00784314]
   [0.        ]
   ...
   [0.        ]
   [0.00392157]
   [0.        ]]

  [[0.01960784]
   [0.        ]
   [0.2784314 ]
   ...
   [0.        ]
   [0.01960784]
   [0.        ]]

  ...

  [[0.        ]
   [0.        ]
   [0.00784314]
   ...
   [0.        ]
   [0.        ]
   [0.        ]]

  [[0.        ]
   [0.        ]
   [0.        ]
   ...
   [0.00392157]
   [0.        ]
   [0.        ]]

  [[0.        ]
   [0.        ]
   [0.        ]
   ...
   [0.        ]
   [0.        ]
   [0.        ]]]


 [[[0.        ]
   [0.        ]
   [0.        ]
   ...
   [0.        ]
   [0.        ]
   [0.        ]]

  [[0.        ]
   [0.        ]
   [0.        ]
   ...
   [0.        ]
   [0.        ]
   [0.        ]]

  [[0.        ]
   [0.        ]
   [0.        ]
   ...
   [0.        ]
   [0.        ]
   [0.        ]]

  ...

  [[0.00392157]
   [0.13725491]
   [0.19215687]
   ...
   [0.        ]
   [0.02745098]
   [0.        ]]

  [[0.        ]
   [0.        ]
   [0.        ]
   ...
   [0.        ]
   [0.        ]
   [0.        ]]

  [[0.        ]
   [0.        ]
   [0.        ]
   ...
   [0.        ]
   [0.        ]
   [0.        ]]]], shape=(192, 28, 28, 1), dtype=float32)
}

## Create an Estimator

To create an Estimator from the compiled Keras model we call the `model_to_estimator` method. Note that the initial model state of the Keras model is preserved in the created Estimator.

So what's so good about Estimators? Well to start off with:

* you can run Estimator-based models on a local host or an a distributed multi-GPU environment without changing your model;
* Estimators simplify sharing implementations between model developers;
* Estimators build the graph for you, so a bit like Eager Execution, there is no explicit session.

In [9]:
inputs = tf.keras.Input(shape=(28,28,1))  # Returns a placeholder tensor
x = tf.keras.layers.Conv2D(filters=32, kernel_size=(3, 3), activation=tf.nn.relu)(inputs)
x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=2)(x)
x = tf.keras.layers.Conv2D(filters=64, kernel_size=(3, 3), activation=tf.nn.relu)(x)
x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=2)(x)
x = tf.keras.layers.Conv2D(filters=64, kernel_size=(3, 3), activation=tf.nn.relu)(x)
x = tf.keras.layers.Flatten()(x)
x = tf.keras.layers.Dense(64, activation=tf.nn.relu)(x)
predictions = tf.keras.layers.Dense(LABEL_DIMENSIONS, activation='sigmoid')(x)

model = tf.keras.Model(inputs=inputs, outputs=predictions)

model.compile(loss='categorical_crossentropy',
              optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=0.001),
              metrics=['accuracy'])

In [10]:
config = tf.estimator.RunConfig(train_distribute=strategy)
estimator = tf.keras.estimator.model_to_estimator(model, config=config)

W0227 23:00:46.923566 140663967582016 estimator.py:1752] Using temporary folder as model directory: /tmp/tmpm6jfjs84
W0227 23:00:46.935023 140663967582016 deprecation.py:506] From /home/kashif/.env/tf-2/lib/python3.6/site-packages/tensorflow/python/ops/init_ops.py:97: calling GlorotUniform.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
W0227 23:00:46.935882 140663967582016 deprecation.py:506] From /home/kashif/.env/tf-2/lib/python3.6/site-packages/tensorflow/python/ops/init_ops.py:1257: calling VarianceScaling.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
W0227 23:00:46.936684 140663967582016 deprecation.py:506] From /home/

## Train the Estimator

Lets first define a `SessionRunHook` class for recording the times of each iteration of stochastic gradient descent:

In [12]:
class TimeHistory(tf.estimator.SessionRunHook):
    def begin(self):
        self.times = []

    def before_run(self, run_context):
        self.iter_time_start = time.time()

    def after_run(self, run_context, run_values):
        self.times.append(time.time() - self.iter_time_start)

Now the good part! We can call the `train` function on our Estimator giving it the `input_fn` we defined (with the batch size and the number of epochs we wish to train for) and a `TimeHistory` instance via it's `hooks` argument:

In [12]:
BATCH_SIZE = 512
EPOCHS = 5

time_hist = TimeHistory()

estimator.train(input_fn=lambda:input_fn(train_images,
                                         train_labels,
                                         epochs=EPOCHS,
                                         batch_size=BATCH_SIZE), 
                hooks=[time_hist])

<tensorflow_estimator.python.estimator.estimator.Estimator at 0x7fed184f4390>

## Performance

Since we have our timing hook we can now use it to calculate the total time of training as well as the number of images we train on per second:

In [13]:
NUM_GPUS = 2
total_time =  sum(time_hist.times)
print(f"total time with {NUM_GPUS} GPUs: {total_time} seconds")

total time with 2 GPUs: 2.8365120887756348 seconds


In [14]:
avg_time_per_batch = np.mean(time_hist.times)
print(f"{BATCH_SIZE*NUM_GPUS/avg_time_per_batch} images/second with {NUM_GPUS} GPUs")

105774.97666491779 images/second with 2 GPUs


## Evaluate the Estimator

In order to check the performance of our model we then call the `evaluate` method on our Estimator:

In [15]:
estimator.evaluate(lambda:input_fn(test_images, 
                                   test_labels,
                                   epochs=1,
                                   batch_size=BATCH_SIZE))

W0227 23:01:34.007677 140663967582016 deprecation.py:323] From /home/kashif/.env/tf-2/lib/python3.6/site-packages/tensorflow/python/ops/metrics_impl.py:363: to_float (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.cast instead.
W0227 23:01:34.090094 140663967582016 deprecation.py:323] From /home/kashif/.env/tf-2/lib/python3.6/site-packages/tensorflow/python/training/saver.py:1276: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file APIs to check for files with this prefix.


{'accuracy': 0.8425, 'loss': 0.44238263, 'global_step': 293}

## Retinal OCT (optical coherence tomography) images example

To test the scaling performance on some bigger dataset we can use the [Retinal OCT images](https://www.kaggle.com/paultimothymooney/kermany2018) dataset, on of the many great datasets from [Kaggle](https://www.kaggle.com/datasets). This dataset consists of  cross sections of the retinas of living patients grouped into four categories: NORMAL, CNV, DME and DRUSEN.

![](https://i.imgur.com/fSTeZMd.png)

The dataset has a total of 84,495 X-Ray JPEG images, typically  `512x496`, and can be downloaded via the `kaggle` CLI:

In [13]:
#!pip install kaggle
#!kaggle datasets download -d paultimothymooney/kermany2018

Once downloaded the training and test set image classes are in their own respective folder so we can define a pattern as:

In [6]:
train_folder = os.path.join('OCT2017', 'train', '**', '*.jpeg')
test_folder = os.path.join('OCT2017', 'test', '**', '*.jpeg')

In [7]:
labels = ['CNV', 'DME', 'DRUSEN', 'NORMAL']

Next we have our Estimator's input function which takes any file pattern and returns resized images and one hot encoded labels as a `tf.data.Dataset`. Here we follow the best practices from the [Input Pipeline Performance Guide](https://www.tensorflow.org/performance/datasets_performance). Note in particular that if the `prefetch_buffer_size` is `None` then TensorFlow will use an optimal prefetch buffer size automatically:

In [20]:
def input_fn(file_pattern, labels,
             image_size=(224,224),
             shuffle=False,
             batch_size=64, 
             num_epochs=None, 
             buffer_size=4096,
             prefetch_buffer_size=None):

    table = lookup_ops.index_table_from_tensor(tf.constant(labels))
    num_classes = len(labels)

    def _map_func(filename):
        label = tf.string_split([filename], delimiter=os.sep).values[-2]
        image = tf.image.decode_jpeg(tf.io.read_file(filename), channels=3)
        image = tf.image.convert_image_dtype(image, dtype=tf.float32)
        image = tf.image.resize(image, size=image_size)
        return (image, tf.one_hot(table.lookup(label), num_classes))
    
    dataset = tf.data.Dataset.list_files(file_pattern, shuffle=shuffle)

    if num_epochs is not None and shuffle:
        dataset = dataset.apply(tf.data.experimental.shuffle_and_repeat(buffer_size, num_epochs))
    elif shuffle:
        dataset = dataset.shuffle(buffer_size)
    elif num_epochs is not None:
        dataset = dataset.repeat(num_epochs)

    dataset = dataset.apply(
        tf.data.experimental.map_and_batch(map_func=_map_func,
                                      batch_size=batch_size,
                                      num_parallel_calls=os.cpu_count()))
    dataset = dataset.prefetch(buffer_size=prefetch_buffer_size)
    
    return dataset

In order to train this we will use a pretrained VGG16 and train just it's last 5 layers:

In [25]:
with strategy.scope():
    keras_vgg16 = tf.keras.applications.VGG16(input_shape=(224,224,3),
                                              include_top=False)
    output = keras_vgg16.output
    output = tf.keras.layers.Flatten()(output)
    predictions = tf.keras.layers.Dense(len(labels), activation=tf.nn.softmax)(output)

    model = tf.keras.Model(inputs=keras_vgg16.input, outputs=predictions)
    for layer in keras_vgg16.layers[:-4]:
        layer.trainable = False
    
    model.compile(loss='categorical_crossentropy', 
                  optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=0.001),
                  metrics=['accuracy'])

Now we have all we need and can proceed as above and train our model in a few minutes using `NUM_GPUS` GPUs:

In [26]:
config = tf.estimator.RunConfig(train_distribute=strategy)
estimator = tf.keras.estimator.model_to_estimator(model, config=config)

W0227 23:31:40.365042 140404414289728 estimator.py:1752] Using temporary folder as model directory: /tmp/tmp6zu3qj78


In [27]:
BATCH_SIZE = 32
EPOCHS = 2

time_hist = TimeHistory()

estimator.train(input_fn=lambda:input_fn(train_folder,
                                         labels,
                                         shuffle=True,
                                         batch_size=BATCH_SIZE,
                                         buffer_size=2048,
                                         num_epochs=EPOCHS,
                                         prefetch_buffer_size=4),
                hooks=[time_hist])

<tensorflow_estimator.python.estimator.estimator.Estimator at 0x7fb12c408ac8>

In [29]:
NUM_GPUS = 2
total_time =  sum(time_hist.times)

print(f"total time with {NUM_GPUS} GPUs: {total_time} seconds")

total time with 2 GPUs: 351.3928852081299 seconds


In [30]:
avg_time_per_batch = np.mean(time_hist.times)

print(f"{BATCH_SIZE*NUM_GPUS/avg_time_per_batch} images/second with {NUM_GPUS} GPUs")

475.1832123780769 images/second with 2 GPUs


Once trained we can evaluate the accuracy on the test set, which should be around 95% (not bad for an initial baseline!):

In [31]:
estimator.evaluate(input_fn=lambda:input_fn(test_folder,
                                            labels, 
                                            shuffle=True,
                                            batch_size=BATCH_SIZE,
                                            buffer_size=2048,
                                            num_epochs=1))

W0227 23:38:14.842677 140404414289728 deprecation.py:323] From /home/kashif/.env/tf-2/lib/python3.6/site-packages/tensorflow/python/ops/metrics_impl.py:363: to_float (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.cast instead.
W0227 23:38:14.947003 140404414289728 deprecation.py:323] From /home/kashif/.env/tf-2/lib/python3.6/site-packages/tensorflow/python/training/saver.py:1276: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file APIs to check for files with this prefix.


{'accuracy': 0.80475205, 'loss': 0.4852837, 'global_step': 2609}