# Custom Traing with tf.distribution

This tutorial demonstrates how to use `tf.distribute.Strategy` with custom training loops. We will train a simple CNN model on the fashion MNIST dataset. The fashion MNIST dataset contains 60000 train images of size 28 x 28 and 10000 test images of size 28 x 28.

We are using custom training loops to train our model because they give us flexibility and a greater control on training. Moreover, it is easier to debug the model and the training loop.

In [1]:
from __future__ import absolute_import, division, print_function, unicode_literals

# Import TensorFlow
import tensorflow as tf

# Helper libraries
import numpy as np
import os

print(tf.__version__)

2.0.0


In [2]:
import os
root_dir = './distribute_s2' # section 2

In [3]:
gpus = tf.config.experimental.list_physical_devices('GPU')
print(gpus)
# GPU 메모리 제한하기
MEMORY_LIMIT_CONFIG = [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=5120)]
print(MEMORY_LIMIT_CONFIG)
tf.config.experimental.set_virtual_device_configuration(gpus[0], MEMORY_LIMIT_CONFIG)

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
[VirtualDeviceConfiguration(memory_limit=5120)]


In [4]:
fashion_mnist = tf.keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

# Adding a dimension to the array -> new shape == (28, 28, 1)
# We are doing this because the first layer in our model is a convolutional
# layer and it requires a 4D input (batch_size, height, width, channels).
# batch_size dimension will be added later on.
train_images = train_images[..., None]
test_images = test_images[..., None]

# Getting the images in [0, 1] range.
train_images = train_images / np.float32(255)
test_images = test_images / np.float32(255)

In [5]:
train_images.shape, test_images.shape

((60000, 28, 28, 1), (10000, 28, 28, 1))

## Create a strategy to distribute the variables and the graph  
How does tf.distribute.MirroredStrategy strategy work?

* All the variables and the model graph is replicated on the replicas.
* Input is evenly distributed across the replicas.
* Each replica calculates the loss and gradients for the input it received.
* The gradients are synced across all the replicas by summing them.
* After the sync, the same update is made to the copies of the variables on each replica.

In [6]:
# If the list of devices is not specified in the
# `tf.distribute.MirroredStrategy` constructor, it will be auto-detected.
strategy = tf.distribute.MirroredStrategy()

print ('Number of devices: {}'.format(strategy.num_replicas_in_sync))

Number of devices: 1


## Setup input pipeline  
Export the graph and the variables to the platform-agnostic SavedModel format. After your model is saved, you can load it with or without the scope.

In [7]:
BUFFER_SIZE = len(train_images)
BATCH_SIZE_PER_REPLICA = 64
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
EPOCHS = 10

Create the datasets and distribute them:

In [8]:
# create the datasets
train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels)).shuffle(BUFFER_SIZE).batch(GLOBAL_BATCH_SIZE) 
test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(GLOBAL_BATCH_SIZE) 
# and distribute them
train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)
test_dist_dataset = strategy.experimental_distribute_dataset(test_dataset)

In [9]:
img_batch, label_batch = next(iter(train_dataset))

In [10]:
img_batch.shape, label_batch.shape

(TensorShape([64, 28, 28, 1]), TensorShape([64]))

In [11]:
img_dist, label_dist = next(iter(train_dist_dataset))

In [12]:
img_dist.shape, label_dist.shape

(TensorShape([64, 28, 28, 1]), TensorShape([64]))

## Create the model

In [13]:
def create_model():
    model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(32, 3, activation='relu'),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Conv2D(64, 3, activation='relu'),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(64, activation='relu'),
      tf.keras.layers.Dense(10, activation='softmax')])
    return model

In [14]:
model_example = create_model()

In [15]:
model_example(img_batch).shape

TensorShape([64, 10])

In [16]:
# Create a checkpoint directory to store the checkpoints.
checkpoint_dir = os.path.join(root_dir, 'checkpoints')
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")

## Define the loss function
Normally, on a single machine with 1 GPU/CPU, loss is divided by the number of examples in the batch of input.

So, how should the loss be calculated when using a tf.distribute.Strategy?

* For an example, let's say you have 4 GPU's and a batch size of 64. One batch of input is distributed across the replicas (4 GPUs), each replica getting an input of size 16.

* The model on each replica does a forward pass with its respective input and calculates the loss. Now, instead of dividing the loss by the number of examples in its respective input (BATCH_SIZE_PER_REPLICA = 16), the loss should be divided by the GLOBAL_BATCH_SIZE (64).

Why do this?

* This needs to be done because after the gradients are calculated on each replica, they are synced across the replicas by summing them.
How to do this in TensorFlow?

* If you're writing a custom training loop, as in this tutorial, you should sum the per example losses and divide the sum by the GLOBAL_BATCH_SIZE: scale_loss = tf.reduce_sum(loss) * (1. / GLOBAL_BATCH_SIZE) or you can use tf.nn.compute_average_loss which takes the per example loss, optional sample weights, and GLOBAL_BATCH_SIZE as arguments and returns the scaled loss.

* If you are using regularization losses in your model then you need to scale the loss value by number of replicas. You can do this by using the tf.nn.scale_regularization_loss function.

* Using tf.reduce_mean is not recommended. Doing so divides the loss by actual per replica batch size which may vary step to step.

* This reduction and scaling is done automatically in keras model.compile and model.fit

* If using tf.keras.losses classes (as in the example below), the loss reduction needs to be explicitly specified to be one of NONE or SUM. AUTO and SUM_OVER_BATCH_SIZE are disallowed when used with tf.distribute.Strategy. AUTO is disallowed because the user should explicitly think about what reduction they want to make sure it is correct in the distributed case. SUM_OVER_BATCH_SIZE is disallowed because currently it would only divide by per replica batch size, and leave the dividing by number of replicas to the user, which might be easy to miss. So instead we ask the user do the reduction themselves explicitly

In [17]:
with strategy.scope():
    # Set reduction to `none` so we can do the reduction afterwards and divide by
    # global batch size.
    loss_object = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)
    # or loss_fn = tf.keras.losses.sparse_categorical_crossentropy
    def compute_loss(labels, predictions):
        per_example_loss = loss_object(labels, predictions)
        return tf.nn.compute_average_loss(per_example_loss, global_batch_size=GLOBAL_BATCH_SIZE)

In [18]:
label_batch.shape, model_example(img_batch).shape

(TensorShape([64]), TensorShape([64, 10]))

In [19]:
# __call__(y_true, y_pred) was predifined in tf.keras.losses.SparseCategoricalCrossentropy 
loss_object(y_true=label_batch, y_pred=model_example(img_batch)).shape

TensorShape([64])

In [20]:
# example
cce = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
# cce = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='auto')
# cce = tf.keras.losses.SparseCategoricalCrossentropy(reduction='none')
print(cce(tf.convert_to_tensor([0, 1, 2]), tf.convert_to_tensor([[.9, .05, .05], [.5, .89, .6], [.05, .01, .94]])))
# from_logits=True, redution='auto' will lead to 0.6981444

tf.Tensor([0.61779296 0.8859636  0.5906766 ], shape=(3,), dtype=float32)


In [21]:
# GLOBAL_BATCH_SIZE = 3 example
tf.nn.compute_average_loss(per_example_loss=cce(tf.convert_to_tensor([0, 1, 2]), 
                                                tf.convert_to_tensor([[.9, .05, .05], [.5, .89, .6], [.05, .01, .94]]))
                           ,global_batch_size=3)

<tf.Tensor: id=320, shape=(), dtype=float32, numpy=0.6981444>

## Define the metrics to track loss and accuracy

These metrics track the test loss and training and test accuracy. You can use .result() to get the accumulated statistics at any time

In [22]:
with strategy.scope():
    test_loss = tf.keras.metrics.Mean(name='test_loss')
    train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
    test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')

## Training loop

In [23]:
# model and optimizer must be created under `strategy.scope`.
with strategy.scope():
    model = create_model()
    optimizer = tf.keras.optimizers.Adam()
    checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)

In [24]:
# example of distribute
data_dist = (img_dist, label_dist)
def f_dist(inputs):
    images, labels = inputs
    print(images.shape, labels.shape)
    predictions = model(images)
    loss = compute_loss(labels, predictions)
    return loss 

with strategy.scope():
    per_replica_losses = strategy.experimental_run_v2(f_dist, args=(data_dist,))
    print(per_replica_losses)

(64, 28, 28, 1) (64,)
tf.Tensor(2.3160014, shape=(), dtype=float32)


Note that different function is used when computing loss

In [25]:
with strategy.scope():
    def train_step(inputs):
        images, labels = inputs

        with tf.GradientTape() as tape:
            predictions = model(images, training=True)
            loss = compute_loss(labels, predictions)

        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))
        
        train_accuracy.update_state(labels, predictions)
        return loss 

    def test_step(inputs):
        images, labels = inputs

        predictions = model(images, training=False)
        t_loss = loss_object(labels, predictions)

        test_loss.update_state(t_loss)
        test_accuracy.update_state(labels, predictions)

In [26]:
with strategy.scope():
    # `experimental_run_v2` replicates the provided computation and runs it
    # with the distributed input.
    @tf.function
    def distributed_train_step(dataset_inputs):
        per_replica_losses = strategy.experimental_run_v2(train_step, args=(dataset_inputs,))
        return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)

    @tf.function
    def distributed_test_step(dataset_inputs):
        return strategy.experimental_run_v2(test_step, args=(dataset_inputs,))

    for epoch in range(EPOCHS):
        
        # TRAIN LOOP
        total_loss = 0.0
        num_batches = 0
        for x in train_dist_dataset:
            total_loss += distributed_train_step(x)
            num_batches += 1
        train_loss = total_loss / num_batches

        # TEST LOOP
        for x in test_dist_dataset:
            distributed_test_step(x)

        if epoch % 2 == 0:
            checkpoint.save(checkpoint_prefix)

        template = ("Epoch {0:}, Loss: {1:.3f}, Accuracy: {2:.3f}, Test Loss: {2:.3f}, Test Accuracy: {2:.3f}")
        print (template.format(epoch+1, train_loss,
                               train_accuracy.result()*100, test_loss.result(),
                               test_accuracy.result()*100))
        
        # initialize tf.keras.metrics.* object
        test_loss.reset_states()
        train_accuracy.reset_states()
        test_accuracy.reset_states()

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Redu

Things to note in the example above:

* We are iterating over the `train_dist_dataset` and `test_dist_dataset` using a for x in ... construct.
* The scaled loss is the return value of the `distributed_train_step`. This value is aggregated across replicas using the `tf.distribute.Strategy.reduce` call and then across batches by summing the return value of the `tf.distribute.Strategy.reduce` calls.
* `tf.keras.Metrics` should be updated inside train_step and test_step that gets executed by `tf.distribute.Strategy.experimental_run_v2`.
* `tf.distribute.Strategy.experimental_run_v2` returns results from each local replica in the strategy, and there are multiple ways to consume this result. You can do `tf.distribute.Strategy.reduce` to get an aggregated value. You can also do `tf.distribute.Strategy.experimental_local_results` to get the list of values contained in the result, one per local replica.

## Restore the latest checkpoint and test

A model checkpointed with a tf.distribute.Strategy can be restored with or without a strategy.

In [27]:
eval_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='eval_accuracy')
new_model = create_model()
new_optimizer = tf.keras.optimizers.Adam()

In [28]:
@tf.function
def eval_step(images, labels):
    predictions = new_model(images, training=False)
    eval_accuracy(labels, predictions)

In [29]:
checkpoint = tf.train.Checkpoint(optimizer=new_optimizer, model=new_model)
print(checkpoint_dir)
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

./distribute_s2/checkpoints


<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7fa9488e2d10>

In [30]:
for images, labels in test_dataset:
    eval_step(images, labels)

print ('Accuracy after restoring the saved model without strategy: {0:.3f}'.format(
    eval_accuracy.result()*100))

Accuracy after restoring the saved model without strategy: 90.440
