Overview
The tf.distribute.Strategy API provides an abstraction for distributing your training across multiple processing units. It allows you to carry out distributed training using existing models and training code with minimal changes.

This tutorial demonstrates how to use the tf.distribute.MirroredStrategy to perform in-graph replication with synchronous training on many GPUs on one machine. The strategy essentially copies all of the model's variables to each processor. Then, it uses all-reduce to combine the gradients from all processors, and applies the combined value to all copies of the model.

MirroredStrategy trains your model on multiple GPUs on a single machine. For synchronous training on many GPUs on multiple workers, use the tf.distribute.MultiWorkerMirroredStrategy with the Keras Model.fit or a custom training loop. For other options, refer to the Distributed training guide.

In [1]:
import tensorflow_datasets as tfds
import tensorflow as tf

import os

# Load the TensorBoard notebook extension.

In [1]:

%load_ext tensorboard

In [3]:
print(tf.__version__)

2.18.0


In [4]:
import os

In [5]:
datasets, info = tfds.load(name='mnist', with_info=True, data_dir=os.getcwd(), as_supervised=True)

[1mDownloading and preparing dataset Unknown size (download: Unknown size, generated: Unknown size, total: Unknown size) to d:\a27_YEARS_OLD\deep_learning\distributed_training\mnist\3.0.1...[0m


  from .autonotebook import tqdm as notebook_tqdm
Dl Completed...: 0 url [00:00, ? url/s]
Dl Completed...:   0%|          | 0/1 [00:00<?, ? url/s]
Dl Completed...:   0%|          | 0/2 [00:00<?, ? url/s]
Dl Completed...:   0%|          | 0/3 [00:00<?, ? url/s]
Dl Completed...:   0%|          | 0/4 [00:00<?, ? url/s]
Dl Completed...:   0%|          | 0/4 [00:00<?, ? url/s]
Dl Completed...:   0%|          | 0/4 [00:00<?, ? url/s]
Dl Completed...:   0%|          | 0/4 [00:00<?, ? url/s]
Dl Completed...:  25%|██▌       | 1/4 [00:00<00:02,  1.23 url/s]
Dl Completed...:  25%|██▌       | 1/4 [00:00<00:02,  1.13 url/s]
Dl Completed...:  25%|██▌       | 1/4 [00:00<00:02,  1.08 url/s]
Dl Completed...:  25%|██▌       | 1/4 [00:00<00:02,  1.07 url/s]
Dl Completed...:  50%|█████     | 2/4 [00:01<00:01,  1.98 url/s]
Dl Completed...:  50%|█████     | 2/4 [00:01<00:01,  1.98 url/s]
[A
Dl Completed...:  50%|█████     | 2/4 [00:01<00:01,  1.98 url/s]
Dl Completed...:  50%|█████     | 2/4 [00:01<00:01, 

[1mDataset mnist downloaded and prepared to d:\a27_YEARS_OLD\deep_learning\distributed_training\mnist\3.0.1. Subsequent calls will reuse this data.[0m


In [6]:
mnist_train, mnist_test = datasets['train'], datasets['test']

In [8]:
mnist_train

<_PrefetchDataset element_spec=(TensorSpec(shape=(28, 28, 1), dtype=tf.uint8, name=None), TensorSpec(shape=(), dtype=tf.int64, name=None))>

In [9]:
next(iter(mnist_train))

(<tf.Tensor: shape=(28, 28, 1), dtype=uint8, numpy=
 array([[[  0],
         [  0],
         [  0],
         [  0],
         [  0],
         [  0],
         [  0],
         [  0],
         [  0],
         [  0],
         [  0],
         [  0],
         [  0],
         [  0],
         [  0],
         [  0],
         [  0],
         [  0],
         [  0],
         [  0],
         [  0],
         [  0],
         [  0],
         [  0],
         [  0],
         [  0],
         [  0],
         [  0]],
 
        [[  0],
         [  0],
         [  0],
         [  0],
         [  0],
         [  0],
         [  0],
         [  0],
         [  0],
         [  0],
         [  0],
         [  0],
         [  0],
         [  0],
         [  0],
         [  0],
         [  0],
         [  0],
         [  0],
         [  0],
         [  0],
         [  0],
         [  0],
         [  0],
         [  0],
         [  0],
         [  0],
         [  0]],
 
        [[  0],
         [  0],
         [  0]


Define the distribution strategy
Create a MirroredStrategy object. This will handle distribution and provide a context manager (MirroredStrategy.scope) to build your model inside.

In [10]:
strategy = tf.distribute.MirroredStrategy()

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)


In [11]:
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))

Number of devices: 1


Set up the input pipeline
When training a model with multiple GPUs, you can use the extra computing power effectively by increasing the batch size. In general, use the largest batch size that fits the GPU memory and tune the learning rate accordingly.

In [32]:
# You can also do info.splits.total_num_examples to get the total
# number of examples in the dataset.

num_train_examples = info.splits['train'].num_examples
num_test_examples = info.splits['test'].num_examples

BUFFER_SIZE = 1000

BATCH_SIZE_PER_REPLICA = 16
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

Define a function that normalizes the image pixel values from the [0, 255] range to the [0, 1] range (feature scaling):

In [33]:
def scale(image, label):
  image = tf.cast(image, tf.float32)
  image /= 255

  return image, label

Apply this scale function to the training and test data, and then use the tf.data.Dataset APIs to shuffle the training data (Dataset.shuffle), and batch it (Dataset.batch). Notice that you are also keeping an in-memory cache of the training data to improve performance (Dataset.cache)

In [34]:
train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

In [35]:
with strategy.scope():
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(64, activation='relu'),
      tf.keras.layers.Dense(10)
  ])

  model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
                metrics=['accuracy'])

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


In [37]:
model.summary()

In [38]:
input_shape=(28, 28, 1)

In [21]:
input_shape[-1]

1

In [22]:
input_shape[-2]

28

In [26]:
input_shape=(None, 28, 1)

In [27]:
input_shape[-2]

28

For this toy example with the MNIST dataset, you will be using the Adam optimizer's default learning rate of 0.001.

For larger datasets, the key benefit of distributed training is to learn more in each training step, because each step processes more training data in parallel, which allows for a larger learning rate (within the limits of the model and dataset).

Define the callbacks
Define the following Keras Callbacks:

tf.keras.callbacks.TensorBoard: writes a log for TensorBoard, which allows you to visualize the graphs.
tf.keras.callbacks.ModelCheckpoint: saves the model at a certain frequency, such as after every epoch.
tf.keras.callbacks.BackupAndRestore: provides the fault tolerance functionality by backing up the model and current epoch number. Learn more in the Fault tolerance section of the Multi-worker training with Keras tutorial.
tf.keras.callbacks.LearningRateScheduler: schedules the learning rate to change after, for example, every epoch/batch.

In [43]:
# Define the checkpoint directory to store the checkpoints.
checkpoint_dir = './training_checkpoints'
# Define the name of the checkpoint files.
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}.weights.h5")

In [40]:
# Define a function for decaying the learning rate.
# You can define any decay function you need.
def decay(epoch):
  if epoch < 3:
    return 1e-3
  elif epoch >= 3 and epoch < 7:
    return 1e-4
  else:
    return 1e-5

In [48]:
# Define a callback for printing the learning rate at the end of each epoch.
class PrintLR(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=None):
    print('\nLearning rate for epoch {} is {}'.format(        epoch + 1, model.optimizer.learning_rate.numpy()))

In [49]:
# Put all the callbacks together.
callbacks = [
    tf.keras.callbacks.TensorBoard(log_dir='./logs'),
    tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix,
                                       save_weights_only=True),
    tf.keras.callbacks.LearningRateScheduler(decay),
    PrintLR()
]

Train and evaluate
Now, train the model in the usual way by calling Keras Model.fit on the model and passing in the dataset created at the beginning of the tutorial. This step is the same whether you are distributing the training or not.

In [50]:
EPOCHS = 5

model.fit(train_dataset, epochs=EPOCHS, callbacks=callbacks)

Epoch 1/5
[1m3750/3750[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 6ms/step - accuracy: 0.9882 - loss: 0.0391
Learning rate for epoch 1 is 0.0010000000474974513
[1m3750/3750[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m25s[0m 7ms/step - accuracy: 0.9882 - loss: 0.0391 - learning_rate: 0.0010
Epoch 2/5
[1m3749/3750[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 7ms/step - accuracy: 0.9917 - loss: 0.0263
Learning rate for epoch 2 is 0.0010000000474974513
[1m3750/3750[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m26s[0m 7ms/step - accuracy: 0.9917 - loss: 0.0263 - learning_rate: 0.0010
Epoch 3/5
[1m3750/3750[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 7ms/step - accuracy: 0.9941 - loss: 0.0181
Learning rate for epoch 3 is 0.0010000000474974513
[1m3750/3750[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m26s[0m 7ms/step - accuracy: 0.9941 - loss: 0.0181 - learning_rate: 0.0010
Epoch 4/5
[1m3749/3750[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0

<keras.src.callbacks.history.History at 0x171e6c23710>

In [53]:
os.listdir(checkpoint_dir)

['ckpt_1.weights.h5',
 'ckpt_2.weights.h5',
 'ckpt_3.weights.h5',
 'ckpt_4.weights.h5',
 'ckpt_5.weights.h5']

In [56]:
# model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))


In [59]:
model.load_weights(os.path.join(checkpoint_dir,'ckpt_5.weights.h5'))

In [60]:

eval_loss, eval_acc = model.evaluate(eval_dataset)

print('Eval loss: {}, Eval accuracy: {}'.format(eval_loss, eval_acc))

[1m625/625[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 3ms/step - accuracy: 0.9891 - loss: 0.0399
Eval loss: 0.04309540241956711, Eval accuracy: 0.9883999824523926


To visualize the output, launch TensorBoard and view the logs:

Save the model

In [63]:
path = 'my_model.keras'

In [64]:
model.save(path)

In [65]:
unreplicated_model = tf.keras.models.load_model(path)

unreplicated_model.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=tf.keras.optimizers.Adam(),
    metrics=['accuracy'])

eval_loss, eval_acc = unreplicated_model.evaluate(eval_dataset)

print('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))

[1m625/625[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - accuracy: 0.9891 - loss: 0.0399
Eval loss: 0.04309540241956711, Eval Accuracy: 0.9883999824523926


In [66]:
with strategy.scope():
  replicated_model = tf.keras.models.load_model(path)
  replicated_model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                           optimizer=tf.keras.optimizers.Adam(),
                           metrics=['accuracy'])

  eval_loss, eval_acc = replicated_model.evaluate(eval_dataset)
  print ('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))







[1m625/625[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - accuracy: 0.9891 - loss: 0.0399
Eval loss: 0.04309540241956711, Eval Accuracy: 0.9883999824523926


In [62]:
%tensorboard --logdir=logs/train