# Custom training with tf.distribute.Strategy

In [None]:
from __future__ import absolute_import, division, print_function, unicode_literals

# Import TensorFlow
try:
  # %tensorflow_version only exists in Colab.
  %tensorflow_version 2.x
except Exception:
  pass
import tensorflow as tf
import tensorflow_hub as hub

# Helper libraries
import numpy as np
import os

print(tf.__version__)

## Download the dataset

In [None]:
!pip install tfds-nightly

In [None]:
import tensorflow_datasets as tfds
tfds.disable_progress_bar()

In [None]:
splits = ['train[:80%]', 'train[80%:90%]', 'train[90%:]']

(train_examples, validation_examples, test_examples), info = tfds.load('oxford_flowers102', with_info=True, as_supervised=True, split = splits)

num_examples = info.splits['train'].num_examples
num_classes = info.features['label'].num_classes

## Create a strategy to distribute the variables and the graph

How does `tf.distribute.MirroredStrategy` strategy work?

*   All the variables and the model graph is replicated on the replicas.
*   Input is evenly distributed across the replicas.
*   Each replica calculates the loss and gradients for the input it received.
*   The gradients are synced across all the replicas by summing them.
*   After the sync, the same update is made to the copies of the variables on each replica.

Note: You can put all the code below inside a single scope. We are dividing it into several code cells for illustration purposes.


In [None]:
# If the list of devices is not specified in the
# `tf.distribute.MirroredStrategy` constructor, it will be auto-detected.
strategy = tf.distribute.MirroredStrategy()

In [None]:
print ('Number of devices: {}'.format(strategy.num_replicas_in_sync))

## Setup input pipeline

Export the graph and the variables to the platform-agnostic SavedModel format. After your model is saved, you can load it with or without the scope.

In [None]:
BUFFER_SIZE = num_examples

BATCH_SIZE_PER_REPLICA = 64
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

EPOCHS = 10

In [None]:
pixels = 224
MODULE_HANDLE = 'https://tfhub.dev/tensorflow/resnet_50/feature_vector/1'
IMAGE_SIZE = (pixels, pixels)
print("Using {} with input size {}".format(MODULE_HANDLE, IMAGE_SIZE))

In [None]:
def format_image(image, label):
  image = tf.image.resize(image, IMAGE_SIZE) / 255.0
  return  image, label

Create the datasets and distribute them:

In [None]:
train_batches = train_examples.shuffle(num_examples // 4).map(format_image).batch(BATCH_SIZE_PER_REPLICA).prefetch(1)
validation_batches = validation_examples.map(format_image).batch(BATCH_SIZE_PER_REPLICA).prefetch(1)
test_batches = test_examples.map(format_image).batch(1)

train_dist_dataset = strategy.experimental_distribute_dataset(train_batches)
val_dist_dataset = strategy.experimental_distribute_dataset(validation_batches)
test_dist_dataset = strategy.experimental_distribute_dataset(test_batches)

## Create the model

We use the Model Subclassing API to do this.

In [None]:
class ResNetModel(tf.keras.Model):
  def __init__(self, classes):
    super(ResNetModel, self).__init__()
    self._feature_extractor = hub.KerasLayer(MODULE_HANDLE,
                                             trainable=False) 
    self._classifier = tf.keras.layers.Dense(classes, activation='softmax')

  def call(self, inputs):
    x = self._feature_extractor(inputs)
    x = self._classifier(x)
    return x

In [None]:
# Create a checkpoint directory to store the checkpoints.
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")

## Define the loss function


In [None]:
with strategy.scope():
  # Set reduction to `none` so we can do the reduction afterwards and divide by
  # global batch size.
  loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
      reduction=tf.keras.losses.Reduction.NONE)
  # or loss_fn = tf.keras.losses.sparse_categorical_crossentropy
  def compute_loss(labels, predictions):
    per_example_loss = loss_object(labels, predictions)
    return tf.nn.compute_average_loss(per_example_loss, global_batch_size=GLOBAL_BATCH_SIZE)

  test_loss = tf.keras.metrics.Mean(name='test_loss')

## Define the metrics to track loss and accuracy

These metrics track the test loss and training and test accuracy. You can use `.result()` to get the accumulated statistics at any time.

In [None]:
with strategy.scope():
  train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      name='train_accuracy')
  test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      name='test_accuracy')

## Training loop

In [None]:
# model and optimizer must be created under `strategy.scope`.
with strategy.scope():
  model = ResNetModel(classes=num_classes)
  optimizer = tf.keras.optimizers.Adam()
  checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)

In [None]:
with strategy.scope():
  def train_step(inputs):
    images, labels = inputs

    with tf.GradientTape() as tape:
      predictions = model(images, training=True)
      loss = compute_loss(labels, predictions)

    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    train_accuracy.update_state(labels, predictions)
    return loss 

  def test_step(inputs):
    images, labels = inputs

    predictions = model(images, training=False)
    t_loss = loss_object(labels, predictions)

    test_loss.update_state(t_loss)
    test_accuracy.update_state(labels, predictions)

In [None]:
from tqdm import tqdm

In [None]:
with strategy.scope():
  @tf.function
  def distributed_train_step(dataset_inputs):
    per_replica_losses = strategy.run(train_step,
                                                      args=(dataset_inputs,))
    return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,
                           axis=None)
 
  @tf.function
  def distributed_test_step(dataset_inputs):
    return strategy.run(test_step, args=(dataset_inputs,))

  for epoch in range(EPOCHS):
    # TRAIN LOOP
    total_loss = 0.0
    num_batches = 0
    for x in tqdm(train_dist_dataset):
      total_loss += distributed_train_step(x)
      num_batches += 1
    train_loss = total_loss / num_batches

    # TEST LOOP
    for x in test_dist_dataset:
      distributed_test_step(x)

    template = ("Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, "
                "Test Accuracy: {}")
    print (template.format(epoch+1, train_loss,
                           train_accuracy.result()*100, test_loss.result(),
                           test_accuracy.result()*100))

    test_loss.reset_states()
    train_accuracy.reset_states()
    test_accuracy.reset_states()

Things to note in the example above:

* We are iterating over the `train_dist_dataset` and `test_dist_dataset` using  a `for x in ...` construct.
* The scaled loss is the return value of the `distributed_train_step`. This value is aggregated across replicas using the `tf.distribute.Strategy.reduce` call and then across batches by summing the return value of the `tf.distribute.Strategy.reduce` calls.
* `tf.keras.Metrics` should be updated inside `train_step` and `test_step` that gets executed by `tf.distribute.Strategy.experimental_run_v2`.
*`tf.distribute.Strategy.experimental_run_v2` returns results from each local replica in the strategy, and there are multiple ways to consume this result. You can do `tf.distribute.Strategy.reduce` to get an aggregated value. You can also do `tf.distribute.Strategy.experimental_local_results` to get the list of values contained in the result, one per local replica.


## Restore the latest checkpoint and test

A model checkpointed with a `tf.distribute.Strategy` can be restored with or without a strategy.

In [None]:
with strategy.scope():
  eval_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      name='eval_accuracy')
  
  def eval_step(images, labels):
    predictions = loaded_model(images, training=False)
    eval_accuracy(labels, predictions)

  @tf.function
  def distributed_test_step(images, labels):
    return strategy.experimental_run_v2(eval_step, args=(images, labels))

In [None]:
with strategy.scope():
  loaded_model = ResNetModel(classes=num_classes)
  optimizer = tf.keras.optimizers.Adam()
  checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=loaded_model)
  checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

for images, labels in test_dist_dataset:
  distributed_test_step(images, labels)

print ('Accuracy after restoring the saved model with strategy: {}'.format(
    eval_accuracy.result()*100))

## Next steps

Try out the new `tf.distribute.Strategy` API on your models.