<br>

<div align=center><font color=maroon size=6><b>Custom training with tf.distribute.Strategy</b></font></div>

<br>

<font size=4><b>References:</b></font>
1. TF2 official tutorials: <a href="https://www.tensorflow.org/tutorials" style="text-decoration:none;">TensorFlow Tutorials</a> 
    * `TensorFlow > Learn > TensorFlow Core > `Tutorials > <a href="https://www.tensorflow.org/tutorials/distribute/custom_training" style="text-decoration:none;">Custom training with tf.distribute.Strategy</a>
        * Run in <a href="https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/distribute/custom_training.ipynb" style="text-decoration:none;">Google Colab</a>

<br>
<br>
<br>

This tutorial demonstrates how to use <font size=3 color=maroon>`tf.distribute.Strategy` — a TensorFlow API that provides an abstraction for [distributing your training](https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/guide/distributed_training.ipynb) across multiple processing units (GPUs, multiple machines, or TPUs) — with custom training loops.</font> In this example, you will train a simple convolutional neural network on the [Fashion MNIST dataset](https://github.com/zalandoresearch/fashion-mnist) containing 70,000 images of size 28 x 28.

<font size=3 color=maroon>[Custom training loops](https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/customization/custom_training_walkthrough.ipynb) provide flexibility and a greater control on training. They also make it is easier to debug the model and the training loop.</font>

<br>

In [1]:
# Import TensorFlow
import tensorflow as tf

# Helper libraries
import numpy as np
import os

print(tf.__version__)

2.8.0


<br>

## Download the Fashion MNIST dataset

In [9]:
# help(tf.keras.datasets.fashion_mnist.load_data)

In [4]:
fashion_mnist = tf.keras.datasets.fashion_mnist

(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()


# 默认下载路径：C:\Users\18617\.keras\datasets

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz


In [10]:
# Add a dimension to the array -> new shape == (28, 28, 1)
# This is done because the first layer in our model is a convolutional
# layer and it requires a 4D input (batch_size, height, width, channels).
# batch_size dimension will be added later on.
train_images = train_images[..., None]
test_images = test_images[..., None]

# Scale the images to the [0, 1] range.
train_images = train_images / np.float32(255)
test_images = test_images / np.float32(255)

<br>
<br>
<br>

## Create a strategy to distribute the variables and the graph

How does `tf.distribute.MirroredStrategy` strategy work?

*   All the variables and the model graph are replicated across the replicas.
*   Input is evenly distributed across the replicas.
*   Each replica calculates the loss and gradients for the input it received.
*   The gradients are synced across all the replicas by summing them.
*   After the sync, the same update is made to the copies of the variables on each replica.

<font size=3 color=maroon>**Note**:</font> You can put all the code below inside a single scope. This example divides it into several code cells for illustration purposes.


In [11]:
# If the list of devices is not specified in
# `tf.distribute.MirroredStrategy` constructor, they will be auto-detected.

strategy = tf.distribute.MirroredStrategy()

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)


In [12]:
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))

Number of devices: 1


<br>
<br>
<br>

## Setup input pipeline

In [13]:
BUFFER_SIZE = len(train_images)

BATCH_SIZE_PER_REPLICA = 64
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

EPOCHS = 10

<br>

Create the datasets and distribute them:

In [14]:
train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels)) \
                               .shuffle(BUFFER_SIZE).batch(GLOBAL_BATCH_SIZE) 
test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)) \
                              .batch(GLOBAL_BATCH_SIZE) 

train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)
test_dist_dataset = strategy.experimental_distribute_dataset(test_dataset)

<br>
<br>
<br>

## Create the model

Create a model using `tf.keras.Sequential`. You can also use the [Model Subclassing API](https://www.tensorflow.org/guide/keras/custom_layers_and_models) or the [functional API](https://www.tensorflow.org/guide/keras/functional) to do this.

In [15]:
def create_model():
    model = tf.keras.Sequential([tf.keras.layers.Conv2D(32, 3, activation='relu'),
                                 tf.keras.layers.MaxPooling2D(),
                                 tf.keras.layers.Conv2D(64, 3, activation='relu'),
                                 tf.keras.layers.MaxPooling2D(),
                                 tf.keras.layers.Flatten(),
                                 tf.keras.layers.Dense(64, activation='relu'),
                                 tf.keras.layers.Dense(10)
                                ])

    return model

In [16]:
# Create a checkpoint directory to store the checkpoints.
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")

<br>
<br>
<br>

## Define the loss function

<font size=3 color=maroon>Normally, on a single machine with single GPU/CPU, loss is divided by the number of examples in the batch of input.

***So, how should the loss be calculated when using a `tf.distribute.Strategy`?***</font>

* For an example, let's say you have 4 GPU's and a batch size of 64. One batch of input is distributed
across the replicas (4 GPUs), each replica getting an input of size 16.

* The model on each replica does a forward pass with its respective input and calculates the loss. Now, instead of dividing the loss by the number of examples in its respective input (`BATCH_SIZE_PER_REPLICA` = 16), the loss should be divided by the `GLOBAL_BATCH_SIZE` (64).

<font size=3 color=maroon>***Why do this?***</font>

* This needs to be done because after the gradients are calculated on each replica, they are synced across the replicas by **summing** them.

<font size=3 color=maroon>***How to do this in TensorFlow?***</font>

* If you're writing a custom training loop, as in this tutorial, you should sum the per example losses and divide the sum by the `GLOBAL_BATCH_SIZE`: 
`scale_loss = tf.reduce_sum(loss) * (1. / GLOBAL_BATCH_SIZE)`
or you can use `tf.nn.compute_average_loss` which takes the per example loss,
optional sample weights, and `GLOBAL_BATCH_SIZE` as arguments and returns the scaled loss.

* If you are using regularization losses in your model then you need to scale
the loss value by number of replicas. You can do this by using the `tf.nn.scale_regularization_loss` function.

* Using `tf.reduce_mean` <font color=maroon>is not recommended.</font> Doing so divides the loss by actual per replica batch size which may vary step to step.

* <font color=maroon>This reduction and scaling is done automatically in keras `model.compile` and `model.fit`

* If using `tf.keras.losses` classes (as in the example below), the loss reduction needs to be explicitly specified to be one of `NONE` or `SUM`.</font> `AUTO` and `SUM_OVER_BATCH_SIZE`  are disallowed when used with `tf.distribute.Strategy`. 
    * `AUTO` is disallowed because the user should explicitly think about what reduction they want to make sure it is correct in the distributed case. 
    * `SUM_OVER_BATCH_SIZE` is disallowed because currently it would only divide by per replica batch size, and leave the dividing by number of replicas to the user, which might be easy to miss. 
    * 
    * So instead we ask the user do the reduction themselves explicitly.


* If `labels` is multi-dimensional, then average the `per_example_loss` across the number of elements in each sample. For example, if the shape of `predictions` is `(batch_size, H, W, n_classes)` and `labels` is `(batch_size, H, W)`, you will need to update `per_example_loss` like: `per_example_loss /= tf.cast(tf.reduce_prod(tf.shape(labels)[1:]), tf.float32)`

  <font color=maroon>**Caution**:</font> **Verify the shape of your loss**. 
  Loss functions in `tf.losses`/`tf.keras.losses` typically
  return the average over the last dimension of the input. The loss
  classes wrap these functions. Passing `reduction=Reduction.NONE` when
  creating an instance of a loss class means "no **additional** reduction".
  
  * For categorical losses with an example input shape of `[batch, W, H, n_classes]` the `n_classes`
  dimension is reduced. 
  
  * For pointwise losses like
  `losses.mean_squared_error` or `losses.binary_crossentropy` include a
  dummy axis so that `[batch, W, H, 1]` is reduced to `[batch, W, H]`. Without
  the dummy axis  `[batch, W, H]` will be incorrectly reduced to `[batch, W]`.


In [17]:
with strategy.scope():
    # Set reduction to `NONE` so you can do the reduction afterwards and divide by
    # global batch size
    loss_object = tf.keras.losses \
                          .SparseCategoricalCrossentropy(from_logits=True,
                                                         reduction=tf.keras.losses.Reduction.NONE)
    
    def compute_loss(labels, predictions):
        per_example_loss = loss_object(labels, predictions)
        
        return tf.nn.compute_average_loss(per_example_loss, global_batch_size=GLOBAL_BATCH_SIZE)

<br>
<br>
<br>

## Define the metrics to track loss and accuracy

These metrics track the test loss and training and test accuracy. You can use `.result()` to get the accumulated statistics at any time.

In [18]:
with strategy.scope():
    test_loss = tf.keras.metrics.Mean(name='test_loss')
    
    train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
    
    test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Redu

<br>
<br>
<br>

## Training loop

In [19]:
# model, optimizer, and checkpoint must be created under `strategy.scope`.
with strategy.scope():
    model = create_model()
    
    optimizer = tf.keras.optimizers.Adam()
    
    checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)

In [20]:
def train_step(inputs):
    images, labels = inputs
    
    with tf.GradientTape() as tape:
        predictions = model(images,  training=True)
        loss = compute_loss(labels, predictions)
        
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    
    train_accuracy.update_state(labels, predictions)
    
    return loss

In [21]:
def test_step(inputs):
    images, labels = inputs
    
    predictions = model(images, training=False)
    t_loss = loss_object(labels, predictions)
    
    test_loss.update_state(t_loss)
    test_accuracy.update_state(labels, predictions)

<br>

In [22]:
# `run` replicates the provided computation and runs it
# with the distributed input.

@tf.function
def distributed_train_step(dataset_inputs):
    per_replica_losses = strategy.run(train_step, args=(dataset_inputs, ))
    
    return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)



@tf.function
def distributed_test_step(dataset_inputs):
    
    return strategy.run(test_step, args=(dataset_inputs, ))

In [23]:
for epoch in range(EPOCHS):
    # TRAIN LOOP
    total_loss = 0.0
    num_batches = 0
    
    for x in train_dist_dataset:
        total_loss += distributed_train_step(x)
        num_batches += 1
    train_loss = total_loss / num_batches
    
    
    
    # TEST LOOP
    for x in test_dist_dataset:
        distributed_test_step(x)

    if epoch % 2 == 0:
        checkpoint.save(checkpoint_prefix)

    template = ("Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, "
                "Test Accuracy: {}")
    
    print(template.format(epoch + 1, train_loss,
                          train_accuracy.result() * 100, test_loss.result(),
                          test_accuracy.result() * 100))

    test_loss.reset_states()
    train_accuracy.reset_states()
    test_accuracy.reset_states()

Epoch 1, Loss: 0.5101115703582764, Accuracy: 81.6883316040039, Test Loss: 0.37020131945610046, Test Accuracy: 86.94000244140625
Epoch 2, Loss: 0.33030104637145996, Accuracy: 88.06999969482422, Test Loss: 0.33064186573028564, Test Accuracy: 88.12999725341797
Epoch 3, Loss: 0.2831439971923828, Accuracy: 89.68499755859375, Test Loss: 0.2856801152229309, Test Accuracy: 89.61000061035156
Epoch 4, Loss: 0.25157228112220764, Accuracy: 90.74833679199219, Test Loss: 0.27305877208709717, Test Accuracy: 90.22000122070312
Epoch 5, Loss: 0.22915494441986084, Accuracy: 91.66999816894531, Test Loss: 0.27353930473327637, Test Accuracy: 89.9800033569336
Epoch 6, Loss: 0.2078235149383545, Accuracy: 92.36166381835938, Test Loss: 0.25566616654396057, Test Accuracy: 90.72999572753906
Epoch 7, Loss: 0.18994446098804474, Accuracy: 92.98500061035156, Test Loss: 0.2537948489189148, Test Accuracy: 90.69999694824219
Epoch 8, Loss: 0.17425008118152618, Accuracy: 93.4433364868164, Test Loss: 0.2527698576450348, Te

<br>

Things to note in the example above:

* Iterate over the `train_dist_dataset` and `test_dist_dataset` using  a `for x in ...` construct.
* The scaled loss is the return value of the `distributed_train_step`. This value is aggregated across replicas using the `tf.distribute.Strategy.reduce` call and then across batches by summing the return value of the `tf.distribute.Strategy.reduce` calls.
* `tf.keras.Metrics` should be updated inside `train_step` and `test_step` that gets executed by `tf.distribute.Strategy.run`.
*`tf.distribute.Strategy.run` returns results from each local replica in the strategy, and there are multiple ways to consume this result. You can do `tf.distribute.Strategy.reduce` to get an aggregated value. You can also do `tf.distribute.Strategy.experimental_local_results` to get the list of values contained in the result, one per local replica.


<br>
<br>
<br>

## Restore the latest checkpoint and test

A model checkpointed with a `tf.distribute.Strategy` can be restored with or without a strategy.

In [24]:
eval_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='eval_accuracy')

new_model = create_model()
new_optimizer = tf.keras.optimizers.Adam()

test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)) \
                              .batch(GLOBAL_BATCH_SIZE)

In [25]:
@tf.function
def eval_step(images, labels):
    predictions = new_model(images, training=False)
    eval_accuracy(labels, predictions)

In [26]:
checkpoint = tf.train.Checkpoint(optimizer=new_optimizer, model=new_model)
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

for images, labels in test_dataset:
    eval_step(images, labels)

print('Accuracy after restoring the saved model without strategy: {}'.format(
    eval_accuracy.result() * 100))

Accuracy after restoring the saved model without strategy: 91.18000030517578


<br>
<br>
<br>

## Alternate ways of iterating over a dataset

### Using iterators

<font size=3 color=maroon>If you want to iterate over a given number of steps and not through the entire dataset you can create an iterator using the `iter` call and explicity call `next` on the iterator. You can choose to iterate over the dataset both inside and outside the tf.function.</font>

Here is a small snippet demonstrating iteration of the dataset outside the tf.function using an iterator.


In [28]:
# for _ in range(EPOCHS):
for epoch in range(EPOCHS):
    total_loss = 0.0
    num_batches = 0
    train_iter = iter(train_dist_dataset)

    for _ in range(10):
        total_loss += distributed_train_step(next(train_iter))
        num_batches += 1
    average_train_loss = total_loss / num_batches

    template = ("Epoch {}, Loss: {}, Accuracy: {}")
    print(template.format(epoch + 1, average_train_loss, train_accuracy.result() * 100))
    train_accuracy.reset_states()

Epoch 1, Loss: 0.12687364220619202, Accuracy: 95.625
Epoch 2, Loss: 0.14564602077007294, Accuracy: 94.6875
Epoch 3, Loss: 0.14935092628002167, Accuracy: 94.6875
Epoch 4, Loss: 0.1348617970943451, Accuracy: 95.0
Epoch 5, Loss: 0.10454156249761581, Accuracy: 95.625
Epoch 6, Loss: 0.12649133801460266, Accuracy: 94.6875
Epoch 7, Loss: 0.1216755285859108, Accuracy: 96.40625
Epoch 8, Loss: 0.1177939623594284, Accuracy: 95.78125
Epoch 9, Loss: 0.09366155415773392, Accuracy: 96.875
Epoch 10, Loss: 0.10629288852214813, Accuracy: 95.9375


<br>
<br>

### Iterating inside a tf.function
You can also iterate over the entire input `train_dist_dataset` inside a `tf.function` using the `for x in ...` construct or by creating iterators like you did above. The example below demonstrates wrapping one epoch of training with a `@tf.function` decorator and iterating over `train_dist_dataset` inside the function.

In [30]:
@tf.function
def distributed_train_epoch(dataset):
    total_loss = 0.0
    num_batches = 0
    for x in dataset:
        per_replica_losses = strategy.run(train_step, args=(x,))
        total_loss += strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
        num_batches += 1
    return total_loss / tf.cast(num_batches, dtype=tf.float32)



for epoch in range(EPOCHS):
    train_loss = distributed_train_epoch(train_dist_dataset)

    template = ("Epoch {}, Loss: {}, Accuracy: {}")
    print(template.format(epoch + 1, train_loss, train_accuracy.result() * 100))

    train_accuracy.reset_states()



Epoch 1, Loss: 0.13530224561691284, Accuracy: 94.9949951171875
Epoch 2, Loss: 0.1233699694275856, Accuracy: 95.40332794189453
Epoch 3, Loss: 0.11012730747461319, Accuracy: 95.90166473388672
Epoch 4, Loss: 0.10032857954502106, Accuracy: 96.20833587646484
Epoch 5, Loss: 0.09551814943552017, Accuracy: 96.4383316040039
Epoch 6, Loss: 0.08562958985567093, Accuracy: 96.84333038330078
Epoch 7, Loss: 0.0772203728556633, Accuracy: 97.086669921875
Epoch 8, Loss: 0.07069350779056549, Accuracy: 97.30833435058594
Epoch 9, Loss: 0.06665600836277008, Accuracy: 97.47999572753906
Epoch 10, Loss: 0.059388745576143265, Accuracy: 97.8550033569336


<br>
<br>

### Tracking training loss across replicas

Note: As a general rule, you should use `tf.keras.Metrics` to track per-sample values and avoid values that have been aggregated within a replica.

Because of the loss scaling computation that is carried out, it's not recommended to use `tf.metrics.Mean` to track the training loss across different replicas.

For example, if you run a training job with the following characteristics:
* Two replicas
* Two samples are processed on each replica
* Resulting loss values: [2,  3] and [4,  5] on each replica
* Global batch size = 4

With loss scaling, you calculate the per-sample value of loss on each replica by adding the loss values, and then dividing by the global batch size. In this case: `(2 + 3) / 4 = 1.25` and `(4 + 5) / 4 = 2.25`. 

If you use `tf.metrics.Mean` to track loss across the two replicas, the result is different. In this example, you end up with a `total` of 3.50 and `count` of 2, which results in `total`/`count` = 1.75  when `result()` is called on the metric. <font size=3 color=maroon>Loss calculated with `tf.keras.Metrics` is scaled by an additional factor that is equal to the number of replicas in sync.</font>

<br>
<br>
<br>

## Guide and examples
Here are some examples for using distribution strategy with custom training loops:

1. [Distributed training guide](https://github.com/tensorflow/docs/blob/master/site/en/guide/distributed_training)
2. [DenseNet](https://github.com/tensorflow/examples/blob/master/tensorflow_examples/models/densenet/distributed_train.py) example using `MirroredStrategy`.
1. [BERT](https://github.com/tensorflow/models/blob/master/official/nlp/bert/run_classifier.py) example trained using `MirroredStrategy` and `TPUStrategy`.
This example is particularly helpful for understanding how to load from a checkpoint and generate periodic checkpoints during distributed training etc.
2. [NCF](https://github.com/tensorflow/models/blob/master/official/recommendation/ncf_keras_main.py) example trained using `MirroredStrategy` that can be enabled using the `keras_use_ctl` flag.
3. [NMT](https://github.com/tensorflow/examples/blob/master/tensorflow_examples/models/nmt_with_attention/distributed_train.py) example trained using `MirroredStrategy`.

More examples listed in the [Distribution strategy guide](https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/guide/distributed_training.ipynb#examples_and_tutorials).

<br>
<br>
<br>

## Next steps

*   Try out the new `tf.distribute.Strategy` API on your models.
*   Visit the [Better performance with tf.function](https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/guide/function.ipynb) and [TensorFlow Profiler](https://github.com/tensorflow/docs/blob/master/site/en/guide/profiler.md) guide to learn more about tools to optimize the performance of your TensorFlow models.
*   The [Distributed training in TensorFlow](https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/guide/distributed_training.ipynb) guide provides an overview of the available distribution strategies.

<br>
<br>
<br>

```python
# MIT License
#
# Copyright (c) 2017 François Chollet
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
```

<br>
<br>
<br>