# Custom training with tf.distribute.Strategy

In this notebook, we will implement a distribution strategy to train on the [Oxford Flowers 102](https://www.tensorflow.org/datasets/catalog/oxford_flowers102) dataset. Distribution strategies enable training across multiple devices, although we'll start with a single device setup. The syntax used here will also be applicable for multi-device environments

## Imports

In [None]:
from __future__ import absolute_import, division, print_function, unicode_literals

import tensorflow as tf
import tensorflow_hub as hub

# Helper libraries
import numpy as np
import os
from tqdm import tqdm

## Download the dataset

In [None]:
import tensorflow_datasets as tfds
tfds.disable_progress_bar()

In [None]:
splits = ['train[:80%]', 'train[80%:90%]', 'train[90%:]']

(train_examples, validation_examples, test_examples), info = tfds.load('oxford_flowers102', with_info=True, as_supervised=True, split = splits, data_dir='data/')

num_examples = info.splits['train'].num_examples
num_classes = info.features['label'].num_classes

## Create a strategy to distribute the variables and the graph

The `tf.distribute.MirroredStrategy` is a TensorFlow strategy for synchronous training across multiple GPUs on the same machine. This strategy is designed to provide efficient distributed training with the following key steps:

1. **Replication of Variables and Graph:** It begins by replicating all the variables and the model graph across multiple devices or replicas. This setup ensures that each device has a copy of the model to work with.

2. **Distribution of Input:** The input data is evenly distributed or split among the replicas. Each replica receives a portion of the input data, ensuring parallel processing that maximizes the utilization of all available GPUs.

3. **Local Computation:** Each replica independently computes the outputs, losses, and gradients based on the subset of data it has received. This parallel computation allows for efficient handling of large datasets by dividing the workload.

4. **Synchronization of Gradients:** Once all replicas have computed their gradients, these gradients are synchronized across all replicas. Typically, this involves summing the gradients from all replicas to ensure that all replicas contribute equally to the learning process.

5. **Update of Variables:** After synchronization, the updates calculated from the combined gradients are applied uniformly across all replicas. This ensures that all copies of the variables are updated consistently, keeping the model's state synchronized across all devices.

The `tf.distribute.MirroredStrategy` effectively leverages multiple GPUs to speed up training by parallelizing the computation and gradient synchronization, making it a popular choice for optimizing training performance on multi-GPU setups.

In [None]:
# If the list of devices is not specified in the
# `tf.distribute.MirroredStrategy` constructor, it will be auto-detected.
strategy = tf.distribute.MirroredStrategy()

In [None]:
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))

Number of devices: 1


## Setup input pipeline

Let's set up some essential constants for our training process, including the buffer size, number of epochs, and image size. These parameters are crucial for managing how our data is processed and ensuring that our training sessions run smoothly and efficiently.

In [None]:
BUFFER_SIZE = num_examples
EPOCHS = 10
pixels = 224
MODULE_HANDLE = 'data/resnet_50_feature_vector'
IMAGE_SIZE = (pixels, pixels)
print("Using {} with input size {}".format(MODULE_HANDLE, IMAGE_SIZE))

Using data/resnet_50_feature_vector with input size (224, 224)


We will define a function to properly format our images. This function will resize each image to a uniform dimension and scale the pixel values to a [0,1] range. Ensuring that all images are uniformly processed is vital for maintaining consistent model training.

In [None]:
def format_image(image, label):
    image = tf.image.resize(image, IMAGE_SIZE) / 255.0
    return  image, label

## Set the global batch size

Now, we'll set the `GLOBAL_BATCH_SIZE` using the function we've just defined. This parameter is essential for dictating how much data our model processes in one go, balancing speed and memory usage optimally

In [None]:
def set_global_batch_size(batch_size_per_replica, strategy):
    '''
    Args:
        batch_size_per_replica (int) - batch size per replica
        strategy (tf.distribute.Strategy) - distribution strategy
    '''

    # Set the global batch size
    global_batch_size = batch_size_per_replica * strategy.num_replicas_in_sync

    return global_batch_size

After setting the global batch size, we should expect an output that confirms the size has been set correctly. For instance, an output of `64` would indicate that our batch size parameter is properly configured.

In [None]:
BATCH_SIZE_PER_REPLICA = 64
GLOBAL_BATCH_SIZE = set_global_batch_size(BATCH_SIZE_PER_REPLICA, strategy)

print(GLOBAL_BATCH_SIZE)

64


We will also use the global batch size to create batches for training, validation and test sets

In [None]:
train_batches = train_examples.shuffle(num_examples // 4).map(format_image).batch(BATCH_SIZE_PER_REPLICA).prefetch(1)
validation_batches = validation_examples.map(format_image).batch(BATCH_SIZE_PER_REPLICA).prefetch(1)
test_batches = test_examples.map(format_image).batch(1)

## Define the distributed datasets

We'll create the distributed datasets for both training and validation by utilizing the `experimental_distribute_dataset()` method from the [Strategy](https://www.tensorflow.org/api_docs/python/tf/distribute/Strategy) class. This method efficiently distributes the batches across the available devices under the specified strategy, ensuring optimal data handling and computation distribution. By applying this method to our training and validation batches, we enhance our model's ability to process data efficiently across different compute resources, which is crucial for achieving faster and more scalable training. This step is fundamental in leveraging the full potential of distributed computing in our model's training process.

In [None]:
def distribute_datasets(strategy, train_batches, validation_batches, test_batches):

    train_dist_dataset = strategy.experimental_distribute_dataset(train_batches)
    val_dist_dataset = strategy.experimental_distribute_dataset(validation_batches)
    test_dist_dataset = strategy.experimental_distribute_dataset(test_batches)

    return train_dist_dataset, val_dist_dataset, test_dist_dataset

Let's use the function we just defined to create distributed datasets.

In [None]:
train_dist_dataset, val_dist_dataset, test_dist_dataset = distribute_datasets(strategy, train_batches, validation_batches, test_batches)

It is a good idea to explore the `type` of our created datasets

In [None]:
print(type(train_dist_dataset))
print(type(val_dist_dataset))
print(type(test_dist_dataset))

<class 'tensorflow.python.distribute.input_lib.DistributedDataset'>
<class 'tensorflow.python.distribute.input_lib.DistributedDataset'>
<class 'tensorflow.python.distribute.input_lib.DistributedDataset'>


Let's take a quick look at a single batch from `train_dist_dataset` to see what our distributed dataset looks like

In [None]:
# Take a look at a single batch from the train_dist_dataset
x = iter(train_dist_dataset).get_next()

print(f"x is a tuple that contains {len(x)} values ")
print(f"x[0] contains the features, and has shape {x[0].shape}")
print(f"  so it has {x[0].shape[0]} examples in the batch, each is an image that is {x[0].shape[1:]}")
print(f"x[1] contains the labels, and has shape {x[1].shape}")

x is a tuple that contains 2 values 
x[0] contains the features, and has shape (64, 224, 224, 3)
  so it has 64 examples in the batch, each is an image that is (224, 224, 3)
x[1] contains the labels, and has shape (64,)


## Create the model

We'll define ResNetModel as a subclass of tf.keras.Model using the Model Subclassing API. This approach provides flexibility in defining custom behaviors and layers for our model.

In [None]:
MODULE_HANDLE = "https://tfhub.dev/google/imagenet/resnet_v2_50/feature_vector/5"

class ResNetModel(tf.keras.Model):
    def __init__(self, classes):
        super(ResNetModel, self).__init__()
        self._feature_extractor = hub.KerasLayer(MODULE_HANDLE,
                                                 trainable=False)
        self._classifier = tf.keras.layers.Dense(classes, activation='softmax')

    def call(self, inputs):
        x = self._feature_extractor(inputs)
        x = self._classifier(x)
        return x

We will create a directory to store checkpoints during training. Checkpoints capture the model's weights at various stages, allowing for recovery or starting subsequent training from a known state.

In [None]:
# Create a checkpoint directory to store the checkpoints.
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")

## Define the loss function

Within the `strategy.scope()`, we'll define the `loss_object` for later use in test set evaluations and compute_loss for calculating average loss during training. This scoping ensures that our loss computations are compatible with the distributed strategy.

In [None]:
  with strategy.scope():
    # Set reduction to `NONE` so we can do the reduction afterwards and divide by global batch size.
    loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
        reduction=tf.keras.losses.Reduction.NONE)
    # or loss_fn = tf.keras.losses.sparse_categorical_crossentropy
    def compute_loss(labels, predictions):
        per_example_loss = loss_object(labels, predictions)
        return tf.nn.compute_average_loss(per_example_loss, global_batch_size=GLOBAL_BATCH_SIZE)

    test_loss = tf.keras.metrics.Mean(name='test_loss')

## Define the metrics to track loss and accuracy

To define the metrics for trackking loss and accuracy for both training and testing, we'll utilize `.result()` to retrieve accumulated statistics, providing insights into the model's performance throughout the training and evaluation phases.

In [None]:
with strategy.scope():
    train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
        name='train_accuracy')
    test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
        name='test_accuracy')

## Instantiate the model, optimizer, and checkpoints

Also within the `strategy.scope()`, we will instantiate the ResNetModel specifying the number of classes, and create an instance of the Adam optimizer. Set up a checkpoint mechanism for the model and its optimizer to ensure training progress is saved.

In [None]:
# Model and optimizer must be created under `strategy.scope`.
with strategy.scope():
    model = ResNetModel(classes=num_classes)
    optimizer = tf.keras.optimizers.Adam()
    checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)

## Training Loop

We will define a regular training step and test step, which can function without a distributed strategy. We will then use `strategy.run` to apply these functions in a distributed manner.
- Notice that we will define `train_step` and `test_step` inside another function `train_test_step_fns`, which will then return these two functions.

### Define train_step
Within the strategy's scope, we will define `train_step(inputs)`
- `inputs` will be a tuple containing `(images, labels)`.
- Create a gradient tape block.
- Within the gradient tape block:
  - Call the model, passing in the images and setting training to `True` (complete this part).
  - Call the `compute_loss` function (defined earlier) to compute the training loss (complete this part).
  - Use the gradient tape to calculate the gradients.
  - Use the optimizer to update the weights using the gradients.

### Define test_step
Also within the strategy's scope, we will define `test_step(inputs)`
- `inputs` is a tuple containing `(images, labels)`.
  - Call the model, passing in the images and setting training to `False`, because the model is not going to train on the test data. (complete this part).
  - Use the `loss_object`, which will compute the test loss. Check `compute_loss`, defined earlier, to see what parameters to pass into `loss_object`. (complete this part).
  - Next, update `test_loss` (the running test loss) with the `t_loss` (the loss for the current batch).
  - Also update the `test_accuracy`.

In [None]:
def train_test_step_fns(strategy, model, compute_loss, optimizer, train_accuracy, loss_object, test_loss, test_accuracy):
    with strategy.scope():
        def train_step(inputs):
            images, labels = inputs

            with tf.GradientTape() as tape:
                predictions = model(images, training=True)
                loss = compute_loss(labels, predictions)

            gradients = tape.gradient(loss, model.trainable_variables)
            optimizer.apply_gradients(zip(gradients, model.trainable_variables))

            train_accuracy.update_state(labels, predictions)
            return loss

        def test_step(inputs):
            images, labels = inputs

            predictions = model(images, training=False)
            t_loss = compute_loss(labels, predictions)

            test_loss.update_state(t_loss)
            test_accuracy.update_state(labels, predictions)

        return train_step, test_step

We will use the train_test_step_fns function to generate the train_step and test_step functions. These functions are initially designed for a non-distributed setup but are flexible enough to be adapted for distributed training and testing using TensorFlow's distribution strategies.

In [None]:
train_step, test_step = train_test_step_fns(strategy, model, compute_loss, optimizer, train_accuracy, loss_object, test_loss, test_accuracy)

## Distributed training and testing (please complete this section)

`Distributed Train Step`
To apply the train_step in a distributed environment, we will utilize the strategy.run method. Here's how we will set it up:

*   **Utilizing strategy.run:** We will call the run function of the strategy, passing in our previously defined train_step as the function argument, along with the dataset inputs.
*   **Function Call Format:** The run function is formatted as run(fn, args=()), where fn is the function to execute (in this case, train_step) and args takes a tuple of the dataset inputs necessary for that function.


`Distributed Test Step`
Similarly, for the testing phase:

* **Setting up distributed_test_step:** We will use the strategy.run method again, this time passing in the test_step along with its required dataset inputs.
* **Execution:** This method will distribute the test step across the available devices, allowing for efficient evaluation of the model on the test dataset.

In [None]:
# See various ways of passing in the inputs

def fun1(args=()):
    print(f"number of arguments passed is {len(args)}")


list_of_inputs = [1,2]
print("When passing in args=list_of_inputs:")
fun1(args=list_of_inputs)
print()
print("When passing in args=(list_of_inputs)")
fun1(args=(list_of_inputs))
print()
print("When passing in args=(list_of_inputs,)")
fun1(args=(list_of_inputs,))

When passing in args=list_of_inputs:
number of arguments passed is 2

When passing in args=(list_of_inputs)
number of arguments passed is 2

When passing in args=(list_of_inputs,)
number of arguments passed is 1


In [None]:
def distributed_train_test_step_fns(strategy, train_step, test_step, model, compute_loss, optimizer, train_accuracy, loss_object, test_loss, test_accuracy):
    with strategy.scope():
        @tf.function
        def distributed_train_step(dataset_inputs):
            per_replica_losses = strategy.run(train_step, args=(dataset_inputs,))
            return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,
                                   axis=None)

        @tf.function
        def distributed_test_step(dataset_inputs):
            return strategy.run(test_step, args=(dataset_inputs,))

        return distributed_train_step, distributed_test_step

Once these functions are defined within the scope of our distribution strategy, we will call the train_test_step_fns to retrieve the train_step and test_step. This call effectively links our training and testing functions with the distribution logic provided by TensorFlow, ensuring they are ready to handle data across multiple compute resources efficiently

In [None]:
distributed_train_step, distributed_test_step = distributed_train_test_step_fns(strategy, train_step, test_step, model, compute_loss, optimizer, train_accuracy, loss_object, test_loss, test_accuracy)

## Run the distributed training in a loop

To effectively train and evaluate our model using a distributed strategy across multiple epochs, we'll implement a structured approach within a for-loop. This loop will handle both training and testing phases, ensuring that our model learns from the training data and is accurately assessed using the test data at each epoch. Here's how we'll set it up:

### Training and Testing in a Distributed Manner

1. **Epoch Loop**:
   - Begin by looping through the desired number of epochs to train the model for multiple cycles over the dataset.

2. **Training Phase**:
   - **Loop through Training Batches**: For each epoch, iterate through each batch of the distributed training dataset.
   - **Execute Training Step**: For each training batch, call the `distributed_train_step`. This function will apply the training step across all replicas and return the loss for that batch.
   - **Calculate Average Training Loss**: After completing all training batches for the epoch, calculate the average training loss to get a sense of how well the model is learning over time.

3. **Testing Phase**:
   - **Loop through Test Batches**: Similar to training, loop through each batch of the distributed test set.
   - **Execute Test Step**: For each test batch, run the `distributed_test_step`. This step will update the test loss and test accuracy based on the outcomes of the model's predictions compared to the actual labels.
   - **Store Metrics**: Accumulate the test losses and accuracies to later calculate averages for the epoch.

4. **Reporting**:
   - **Print Results**: At the end of each epoch, print out the epoch number, the average training loss, the training accuracy, the test loss, and the test accuracy. This step provides a checkpoint to monitor the model's performance and improvements epoch by epoch.
   - **Reset Metrics**: Before moving to the next epoch, reset the metrics for training and test losses and accuracies. This reset is crucial to ensure that the metrics for each epoch are calculated independently of previous epochs.

5. **Repeat**:
   - This process repeats for the specified number of epochs, allowing the model to progressively learn and improve its predictive accuracy.

In [None]:
with strategy.scope():
    for epoch in range(EPOCHS):
        # TRAIN LOOP
        total_loss = 0.0
        num_batches = 0
        for x in tqdm(train_dist_dataset):
            total_loss += distributed_train_step(x)
            num_batches += 1
        train_loss = total_loss / num_batches

        # TEST LOOP
        for x in test_dist_dataset:
            distributed_test_step(x)

        template = ("Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, "
                    "Test Accuracy: {}")
        print (template.format(epoch+1, train_loss,
                               train_accuracy.result()*100, test_loss.result(),
                               test_accuracy.result()*100))

        test_loss.reset_states()
        train_accuracy.reset_states()
        test_accuracy.reset_states()

13it [00:32,  2.46s/it]


Epoch 1, Loss: 4.741341590881348, Accuracy: 4.779411792755127, Test Loss: 0.0638691708445549, Test Accuracy: 8.823529243469238


13it [00:02,  4.92it/s]


Epoch 2, Loss: 2.8097712993621826, Accuracy: 42.52450942993164, Test Loss: 0.048645444214344025, Test Accuracy: 34.31372833251953


13it [00:02,  4.94it/s]


Epoch 3, Loss: 1.6825740337371826, Accuracy: 78.06372833251953, Test Loss: 0.039348047226667404, Test Accuracy: 50.0


13it [00:02,  5.22it/s]


Epoch 4, Loss: 1.0582751035690308, Accuracy: 91.17646789550781, Test Loss: 0.034339699894189835, Test Accuracy: 55.88235092163086


13it [00:02,  4.88it/s]


Epoch 5, Loss: 0.7129174470901489, Accuracy: 95.3431396484375, Test Loss: 0.030914997681975365, Test Accuracy: 60.78431701660156


13it [00:02,  5.18it/s]


Epoch 6, Loss: 0.5220160484313965, Accuracy: 97.67156982421875, Test Loss: 0.028901422396302223, Test Accuracy: 64.70588684082031


13it [00:02,  5.24it/s]


Epoch 7, Loss: 0.3930404782295227, Accuracy: 99.14215850830078, Test Loss: 0.027198707684874535, Test Accuracy: 65.68627166748047


13it [00:02,  4.93it/s]


Epoch 8, Loss: 0.3086308240890503, Accuracy: 99.50980377197266, Test Loss: 0.026264455169439316, Test Accuracy: 66.66667175292969


13it [00:02,  5.37it/s]


Epoch 9, Loss: 0.24907761812210083, Accuracy: 99.75489807128906, Test Loss: 0.025534609332680702, Test Accuracy: 64.70588684082031


13it [00:02,  5.49it/s]


Epoch 10, Loss: 0.20324990153312683, Accuracy: 99.87745666503906, Test Loss: 0.025066642090678215, Test Accuracy: 66.66667175292969
