#1. Distributed training with Keras

API provides an abstraction for distributing your training across multiple processing units. The goal is to allow users to enable distributed training using existing models and training code, with minimal changes. 

In [1]:
!pip install -q tf-nightly

%tensorflow_version 2.x
import tensorflow_datasets as tfds
import tensorflow as tf
tfds.disable_progress_bar()

import os

[K     |████████████████████████████████| 519.0MB 31kB/s 
[K     |████████████████████████████████| 460kB 39.3MB/s 
[K     |████████████████████████████████| 3.0MB 39.7MB/s 
[?25h

In [2]:
print(tf.__version__)

2.2.0-dev20200419


##1.1. Download the dataset

Setting with_info to True includes the medadata for the entiredataset, which is being saved here to info. 

In [3]:
datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = datasets['train'], datasets['test']

local data directory. If you'd instead prefer to read directly from our public
GCS bucket (recommended if you're running on GCP), you can instead set
data_dir=gs://tfds-data/datasets.



[1mDownloading and preparing dataset mnist/3.0.0 (download: 11.06 MiB, generated: Unknown size, total: 11.06 MiB) to /root/tensorflow_datasets/mnist/3.0.0...[0m
[1mDataset mnist downloaded and prepared to /root/tensorflow_datasets/mnist/3.0.0. Subsequent calls will reuse this data.[0m


##1.2. Define distribution strategy

Create a MirroredStrategy object. This will handle distribution, and provides a context manager to build your model inside.

In [4]:
strategy = tf.distribute.MirroredStrategy()





INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)


In [5]:
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))

Number of devices: 1


##1.3. Setup input pipeline

When training a model with multiple GPUs, you can use the extra computing power effectively by increasing the batch size. In general, use the largest batch size that fits the GPU memory, and tune the learning rate accordingly. 

In [0]:
# You can also do info.splits.total_num_examples to get the total
# number of examples in the dataset.

num_train_examples = info.splits['train'].num_examples
num_test_examples = info.splits['test'].num_examples

BUFFER_SIZE = 10000

BATCH_SIZE_PER_REPLICA = 64
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

In [0]:
# Pixel values, which are 0-255, have to be normalized to the 0-1 range
def scale(image, label):
  image = tf.cast(image, tf.float32)
  image /= 255
  
  return image, label

In [0]:
# Apply this function to the training and test data. 
# shuffle the training data, and batch it for training. 
# Notice we are alos keeping an in-memory cache of the training data to improve performance.

train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

##1.4. Create the model

Create and compile the Keras model in the context of strategy.scope

In [0]:
with strategy.scope():
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(64, activation='relu'),
      tf.keras.layers.Dense(10)
  ])

  model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                optimizer=tf.keras.optimizers.Adam(),
                metrics=['accuracy'])

##1.5. Define the callbacks

In [0]:
# Define the checkpoint directory to store the checkpoints

checkpoint_dir = './training_checkpoints'
# Name of the checkpoint files
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")

In [0]:
# Function for decaying the learning rate.
# You can define any decay function you need.
def decay(epoch):
  if epoch < 3:
    return 1e-3
  elif epoch >= 3 and epoch < 7:
    return 1e-4
  else:
    return 1e-5

In [0]:
# Callback for printing the LR at the end of each epoch.
class PrintLR(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=None):
    print('\nLearning rate for epoch {} is {}'.format(epoch + 1,
                                                      model.optimizer.lr.numpy()))

In [0]:
callbacks = [
    tf.keras.callbacks.TensorBoard(log_dir='./logs'),
    tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix,
                                       save_weights_only=True),
    tf.keras.callbacks.LearningRateScheduler(decay),
    PrintLR()
]

##1.6. Train and evaluate

In [14]:
model.fit(train_dataset, epochs=12, callbacks=callbacks)

Epoch 1/12
Instructions for updating:
use `tf.profiler.experimental.stop` instead.


Instructions for updating:
use `tf.profiler.experimental.stop` instead.


Learning rate for epoch 1 is 0.0010000000474974513
Epoch 2/12
Learning rate for epoch 2 is 0.0010000000474974513
Epoch 3/12
Learning rate for epoch 3 is 0.0010000000474974513
Epoch 4/12
Learning rate for epoch 4 is 9.999999747378752e-05
Epoch 5/12
Learning rate for epoch 5 is 9.999999747378752e-05
Epoch 6/12
Learning rate for epoch 6 is 9.999999747378752e-05
Epoch 7/12
Learning rate for epoch 7 is 9.999999747378752e-05
Epoch 8/12
Learning rate for epoch 8 is 9.999999747378752e-06
Epoch 9/12
Learning rate for epoch 9 is 9.999999747378752e-06
Epoch 10/12
Learning rate for epoch 10 is 9.999999747378752e-06
Epoch 11/12
Learning rate for epoch 11 is 9.999999747378752e-06
Epoch 12/12
Learning rate for epoch 12 is 9.999999747378752e-06


<tensorflow.python.keras.callbacks.History at 0x7fcfaa33b9b0>

In [15]:
# check the checkpoint directory
!ls {checkpoint_dir}

checkpoint		     ckpt_4.data-00000-of-00001
ckpt_10.data-00000-of-00001  ckpt_4.index
ckpt_10.index		     ckpt_5.data-00000-of-00001
ckpt_11.data-00000-of-00001  ckpt_5.index
ckpt_11.index		     ckpt_6.data-00000-of-00001
ckpt_12.data-00000-of-00001  ckpt_6.index
ckpt_12.index		     ckpt_7.data-00000-of-00001
ckpt_1.data-00000-of-00001   ckpt_7.index
ckpt_1.index		     ckpt_8.data-00000-of-00001
ckpt_2.data-00000-of-00001   ckpt_8.index
ckpt_2.index		     ckpt_9.data-00000-of-00001
ckpt_3.data-00000-of-00001   ckpt_9.index
ckpt_3.index


In [16]:
model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))

eval_loss, eval_acc = model.evaluate(eval_dataset)

print('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))

Eval loss: 0.03579382598400116, Eval Accuracy: 0.9882000088691711


In [22]:
!tensorboard --logdir=/content/logs

Traceback (most recent call last):
  File "/usr/local/bin/tensorboard", line 8, in <module>
    sys.exit(run_main())
  File "/usr/local/lib/python3.6/dist-packages/tensorboard/main.py", line 75, in run_main
    app.run(tensorboard.main, flags_parser=tensorboard.configure)
  File "/usr/local/lib/python3.6/dist-packages/absl/app.py", line 299, in run
    _run_main(main, args)
  File "/usr/local/lib/python3.6/dist-packages/absl/app.py", line 250, in _run_main
    sys.exit(main(argv))
  File "/usr/local/lib/python3.6/dist-packages/tensorboard/program.py", line 289, in main
    return runner(self.flags) or 0
  File "/usr/local/lib/python3.6/dist-packages/tensorboard/program.py", line 305, in _run_serve_subcommand
    server = self._make_server()
  File "/usr/local/lib/python3.6/dist-packages/tensorboard/program.py", line 409, in _make_server
    self.flags, self.plugin_loaders, self.assets_zip_provider
  File "/usr/local/lib/python3.6/dist-packages/tensorboard/backend/application.py", line 

In [23]:
ls -sh ./logs

total 4.0K
4.0K [0m[01;34mtrain[0m/


##1.7. Export to SavedModel

In [24]:
path = 'saved_model/'
model.save(path, save_format='tf')

INFO:tensorflow:Assets written to: saved_model/assets


INFO:tensorflow:Assets written to: saved_model/assets


In [25]:
# Load model without strategy.scope
unreplicated_model = tf.keras.models.load_model(path)

unreplicated_model.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=tf.keras.optimizers.Adam(),
    metrics=['accuracy'])

eval_loss, eval_acc = unreplicated_model.evaluate(eval_dataset)

print('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))

Eval loss: 0.03579382598400116, Eval Accuracy: 0.9882000088691711


In [26]:
# Load with strategy.scope
with strategy.scope():
  replicated_model = tf.keras.models.load_model(path)
  replicated_model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                           optimizer=tf.keras.optimizers.Adam(),
                           metrics=['accuracy'])

  eval_loss, eval_acc = replicated_model.evaluate(eval_dataset)
  print ('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))

Eval loss: 0.03579382598400116, Eval Accuracy: 0.9882000088691711


#2. Custom training with tf.distribute.Strategy

In [27]:
%tensorflow_version 2.x
# Import TensorFlow
!pip install -q tf-nightly
import tensorflow as tf

# Helper libraries
import numpy as np
import os

print(tf.__version__)

2.2.0-dev20200419


##2.1. Download the fashion MNIST dataset

In [28]:
fashion_mnist = tf.keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

# Adding a dimension to the array -> new shape == (28, 28, 1)
# We are doing this because the first layer in our model is a convolutional
# layer and it requires a 4D input (batch_size, height, width, channels).
# batch_size dimension will be added later on.
train_images = train_images[..., None]
test_images = test_images[..., None]

# Getting the images in [0, 1] range.
train_images = train_images / np.float32(255)
test_images = test_images / np.float32(255)

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz


##2.2. Create a strategy to distribute the variables and the graph

In [29]:
# If the list of devices is not specified in the
# `tf.distribute.MirroredStrategy` constructor, it will be auto-detected.
strategy = tf.distribute.MirroredStrategy()





INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)


In [30]:
print ('Number of devices: {}'.format(strategy.num_replicas_in_sync))

Number of devices: 1


##2.3. Setup input pipeline

In [0]:
BUFFER_SIZE = len(train_images)

BATCH_SIZE_PER_REPLICA = 64
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

EPOCHS = 10

In [0]:
train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels)).shuffle(BUFFER_SIZE).batch(GLOBAL_BATCH_SIZE) 
test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(GLOBAL_BATCH_SIZE) 

train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)
test_dist_dataset = strategy.experimental_distribute_dataset(test_dataset)

##2.4. Create the model

In [0]:
def create_model():
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(32, 3, activation='relu'),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Conv2D(64, 3, activation='relu'),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(64, activation='relu'),
      tf.keras.layers.Dense(10)
    ])

  return model

In [0]:
# Create a checkpoint directory to store the checkpoints.
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")

##2.5. Define the loss function

Normally, on a single machine with 1 GPU/CPU, loss is divided by the number of examples in the batch of input. 

In [0]:
with strategy.scope():
  # Set reduction to `none` so we can do the reduction afterwards and divide by
  # global batch size.
  loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
      from_logits=True,
      reduction=tf.keras.losses.Reduction.NONE)
  def compute_loss(labels, predictions):
    per_example_loss = loss_object(labels, predictions)
    return tf.nn.compute_average_loss(per_example_loss, global_batch_size=GLOBAL_BATCH_SIZE)

In [0]:
# Define the metrics to track loss and accuracy
with strategy.scope():
  test_loss = tf.keras.metrics.Mean(name='test_loss')

  train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      name='train_accuracy')
  test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      name='test_accuracy')

##2.6. Training loop

In [0]:
# model, optimizer, and checkpoint must be created under `strategy.scope`.
with strategy.scope():
  model = create_model()

  optimizer = tf.keras.optimizers.Adam()

  checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)

In [0]:
def train_step(inputs):
  images, labels = inputs

  with tf.GradientTape() as tape:
    predictions = model(images, training=True)
    loss = compute_loss(labels, predictions)

  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

  train_accuracy.update_state(labels, predictions)
  return loss 

def test_step(inputs):
  images, labels = inputs

  predictions = model(images, training=False)
  t_loss = loss_object(labels, predictions)

  test_loss.update_state(t_loss)
  test_accuracy.update_state(labels, predictions)

In [39]:
# `run` replicates the provided computation and runs it
# with the distributed input.
@tf.function
def distributed_train_step(dataset_inputs):
  per_replica_losses = strategy.run(train_step, args=(dataset_inputs,))
  return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,
                         axis=None)

@tf.function
def distributed_test_step(dataset_inputs):
  return strategy.run(test_step, args=(dataset_inputs,))

for epoch in range(EPOCHS):
  # TRAIN LOOP
  total_loss = 0.0
  num_batches = 0
  for x in train_dist_dataset:
    total_loss += distributed_train_step(x)
    num_batches += 1
  train_loss = total_loss / num_batches

  # TEST LOOP
  for x in test_dist_dataset:
    distributed_test_step(x)

  if epoch % 2 == 0:
    checkpoint.save(checkpoint_prefix)

  template = ("Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, "
              "Test Accuracy: {}")
  print (template.format(epoch+1, train_loss,
                         train_accuracy.result()*100, test_loss.result(),
                         test_accuracy.result()*100))

  test_loss.reset_states()
  train_accuracy.reset_states()
  test_accuracy.reset_states()

Epoch 1, Loss: 0.5153678059577942, Accuracy: 81.36333465576172, Test Loss: 0.38318201899528503, Test Accuracy: 86.45999908447266
Epoch 2, Loss: 0.3339560031890869, Accuracy: 87.94999694824219, Test Loss: 0.32745468616485596, Test Accuracy: 88.47000122070312
Epoch 3, Loss: 0.289735347032547, Accuracy: 89.40666961669922, Test Loss: 0.3177132308483124, Test Accuracy: 88.6300048828125
Epoch 4, Loss: 0.2589462101459503, Accuracy: 90.55332946777344, Test Loss: 0.3060314953327179, Test Accuracy: 88.80000305175781
Epoch 5, Loss: 0.23737187683582306, Accuracy: 91.28500366210938, Test Loss: 0.2779774069786072, Test Accuracy: 89.69000244140625
Epoch 6, Loss: 0.21620415151119232, Accuracy: 92.05166625976562, Test Loss: 0.3102119565010071, Test Accuracy: 89.11000061035156
Epoch 7, Loss: 0.20048733055591583, Accuracy: 92.55333709716797, Test Loss: 0.25279226899147034, Test Accuracy: 91.18000030517578
Epoch 8, Loss: 0.18413300812244415, Accuracy: 93.17666625976562, Test Loss: 0.256083607673645, Test 

##2.7. Restore the latest checkpoint and test

In [0]:
eval_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      name='eval_accuracy')

new_model = create_model()
new_optimizer = tf.keras.optimizers.Adam()

test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(GLOBAL_BATCH_SIZE)

In [0]:
@tf.function
def eval_step(images, labels):
  predictions = new_model(images, training=False)
  eval_accuracy(labels, predictions)

In [42]:
checkpoint = tf.train.Checkpoint(optimizer=new_optimizer, model=new_model)
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

for images, labels in test_dataset:
  eval_step(images, labels)

print ('Accuracy after restoring the saved model without strategy: {}'.format(
    eval_accuracy.result()*100))

Accuracy after restoring the saved model without strategy: 91.37999725341797


##2.8. Alternate ways of iterating over a dataset

If you want to iterate over a given number of steps and not through the entire dataset you can create an iterator using the iter call and explicity call next on the iterator. 

In [43]:
# Iterating outside a tf.function
for _ in range(EPOCHS):
  total_loss = 0.0
  num_batches = 0
  train_iter = iter(train_dist_dataset)

  for _ in range(10):
    total_loss += distributed_train_step(next(train_iter))
    num_batches += 1
  average_train_loss = total_loss / num_batches

  template = ("Epoch {}, Loss: {}, Accuracy: {}")
  print (template.format(epoch+1, average_train_loss, train_accuracy.result()*100))
  train_accuracy.reset_states()

Epoch 10, Loss: 0.16225148737430573, Accuracy: 94.21875
Epoch 10, Loss: 0.165340855717659, Accuracy: 93.59375
Epoch 10, Loss: 0.12995003163814545, Accuracy: 94.6875
Epoch 10, Loss: 0.16367259621620178, Accuracy: 94.375
Epoch 10, Loss: 0.13290463387966156, Accuracy: 95.78125
Epoch 10, Loss: 0.12039406597614288, Accuracy: 95.78125
Epoch 10, Loss: 0.1760627180337906, Accuracy: 93.75
Epoch 10, Loss: 0.12315411865711212, Accuracy: 94.6875
Epoch 10, Loss: 0.14033515751361847, Accuracy: 94.0625
Epoch 10, Loss: 0.16444949805736542, Accuracy: 95.15625


In [44]:
# Iterating inside a tf.function
@tf.function
def distributed_train_epoch(dataset):
  total_loss = 0.0
  num_batches = 0
  for x in dataset:
    per_replica_losses = strategy.run(train_step, args=(x,))
    total_loss += strategy.reduce(
      tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
    num_batches += 1
  return total_loss / tf.cast(num_batches, dtype=tf.float32)

for epoch in range(EPOCHS):
  train_loss = distributed_train_epoch(train_dist_dataset)

  template = ("Epoch {}, Loss: {}, Accuracy: {}")
  print (template.format(epoch+1, train_loss, train_accuracy.result()*100))

  train_accuracy.reset_states()

Epoch 1, Loss: 0.1442921906709671, Accuracy: 94.59833526611328
Epoch 2, Loss: 0.1308281272649765, Accuracy: 95.17499542236328
Epoch 3, Loss: 0.12104317545890808, Accuracy: 95.57833099365234
Epoch 4, Loss: 0.1106492206454277, Accuracy: 95.9800033569336
Epoch 5, Loss: 0.10212011635303497, Accuracy: 96.17833709716797
Epoch 6, Loss: 0.09471506625413895, Accuracy: 96.47666931152344
Epoch 7, Loss: 0.0839567556977272, Accuracy: 96.86166381835938
Epoch 8, Loss: 0.07851479202508926, Accuracy: 97.16999816894531
Epoch 9, Loss: 0.07310765981674194, Accuracy: 97.29166412353516
Epoch 10, Loss: 0.06589626520872116, Accuracy: 97.52166748046875


##2.9. Tracking training loss across replicas

We do not recommend using tf.metrics.Mean to track the training loss across different replicas, because of the loss scaling computation that is carried out. 

#3. Multi-worker training with Keras

In [2]:
%tensorflow_version 2.x
!pip install -q tf-nightly
import tensorflow_datasets as tfds
import tensorflow as tf
tfds.disable_progress_bar()

[K     |████████████████████████████████| 519.0MB 33kB/s 
[K     |████████████████████████████████| 3.0MB 25.9MB/s 
[K     |████████████████████████████████| 460kB 43.1MB/s 
[?25h

##3.1. Preparing dataset

In [3]:
BUFFER_SIZE = 10000
BATCH_SIZE = 64

def make_datasets_unbatched():
  # Scaling MNIST data from (0, 255] to (0., 1.]
  def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255
    return image, label

  datasets, info = tfds.load(name='mnist',
                            with_info=True,
                            as_supervised=True)

  return datasets['train'].map(scale).cache().shuffle(BUFFER_SIZE)

train_datasets = make_datasets_unbatched().batch(BATCH_SIZE)

local data directory. If you'd instead prefer to read directly from our public
GCS bucket (recommended if you're running on GCP), you can instead set
data_dir=gs://tfds-data/datasets.



[1mDownloading and preparing dataset mnist/3.0.0 (download: 11.06 MiB, generated: Unknown size, total: 11.06 MiB) to /root/tensorflow_datasets/mnist/3.0.0...[0m
[1mDataset mnist downloaded and prepared to /root/tensorflow_datasets/mnist/3.0.0. Subsequent calls will reuse this data.[0m


##3.2. Keras Model

In [0]:
def build_and_compile_cnn_model():
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(64, activation='relu'),
      tf.keras.layers.Dense(10)
  ])
  model.compile(
      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
      optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),
      metrics=['accuracy'])
  return model

In [4]:
single_worker_model = build_and_compile_cnn_model()
single_worker_model.fit(x=train_datasets, epochs=3, steps_per_epoch=5)

Epoch 1/3
Epoch 2/3
Epoch 3/3


<tensorflow.python.keras.callbacks.History at 0x7f97801d7f28>

###3.2.1. Multi-worker Configuration

In [0]:
os.environ['TF_CONFIG'] = json.dumps({
    'cluster': {
        'worker': ["localhost:12345", "localhost:23456"]
    },
    'task': {'type': 'worker', 'index': 0}
})

##3.3. Choose the right strategy

In [5]:
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()





INFO:tensorflow:Using MirroredStrategy with devices ('/device:GPU:0',)


INFO:tensorflow:Using MirroredStrategy with devices ('/device:GPU:0',)


INFO:tensorflow:Single-worker MultiWorkerMirroredStrategy with local_devices = ('/device:GPU:0',), communication = CollectiveCommunication.AUTO


INFO:tensorflow:Single-worker MultiWorkerMirroredStrategy with local_devices = ('/device:GPU:0',), communication = CollectiveCommunication.AUTO


##3.4. Train the model with MultiWorkerMirroredStrategy

In [6]:
NUM_WORKERS = 2
# Here the batch size scales up by number of workers since 
# `tf.data.Dataset.batch` expects the global batch size. Previously we used 64, 
# and now this becomes 128.
GLOBAL_BATCH_SIZE = 64 * NUM_WORKERS

# Creation of dataset needs to be after MultiWorkerMirroredStrategy object
# is instantiated.
train_datasets = make_datasets_unbatched().batch(GLOBAL_BATCH_SIZE)
with strategy.scope():
  # Model building/compiling need to be within `strategy.scope()`.
  multi_worker_model = build_and_compile_cnn_model()

# Keras' `model.fit()` trains the model with specified number of epochs and
# number of steps per epoch. Note that the numbers here are for demonstration
# purposes only and may not sufficiently produce a model with good quality.
multi_worker_model.fit(x=train_datasets, epochs=3, steps_per_epoch=5)

Epoch 1/3
Epoch 2/3
Epoch 3/3


<tensorflow.python.keras.callbacks.History at 0x7f976bfab358>

In [0]:
# Dataset sharding and batch size
# In multi-worker training, sharding data into multiple parts is needed to 
# ensure convergence and performance. 

options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF
train_datasets_no_auto_shard = train_datasets.with_options(options)

##3.5. Fault tolerance

ModelCheckpoint callback:

In synchronous training, the cluster would fail if one of the workers fails and no failure-recovery mechanism exists. 

In [8]:
# Replace the `filepath` argument with a path in the file system
# accessible by all workers.
callbacks = [tf.keras.callbacks.ModelCheckpoint(filepath='/tmp/keras-ckpt')]
with strategy.scope():
  multi_worker_model = build_and_compile_cnn_model()
multi_worker_model.fit(x=train_datasets,
                       epochs=3,
                       steps_per_epoch=5,
                       callbacks=callbacks)

Epoch 1/3


INFO:tensorflow:Assets written to: /tmp/keras-ckpt/assets


Epoch 2/3


INFO:tensorflow:Assets written to: /tmp/keras-ckpt/assets


Epoch 3/3








INFO:tensorflow:Assets written to: /tmp/keras-ckpt/assets


INFO:tensorflow:Assets written to: /tmp/keras-ckpt/assets




<tensorflow.python.keras.callbacks.History at 0x7f978028abe0>