# Mirrored Strategy: Basic

In this notebook, we'll cover some of the fundamentals of implementing [Mirrored Strategy](https://www.tensorflow.org/api_docs/python/tf/distribute/MirroredStrategy)..


**Note:** To use MirroredStrategy(), we need to use GPU. It is possible that when running on GPU, there will only be 1 GPU device listed. This is not a problem as the purpose of this notebook is to explore and understand distributed strategies

## Imports

In [None]:
import tensorflow_datasets as tfds
import tensorflow as tf
tfds.disable_progress_bar()

import os

We will use MNIST dataset for this notebook, so let's load it and split it into training and test sets.

In [None]:
# Load the dataset we'll use for this lab
datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True, data_dir='./data')

mnist_train, mnist_test = datasets['train'], datasets['test']

The data is loaded and splits are created. lets define our strategy using the MirroredStrategy() class. We'll then print to see the number of devices available.

In [None]:
# Define the strategy to use and print the number of devices found
strategy = tf.distribute.MirroredStrategy()
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))

Number of devices: 1


Following that, we generate our training and test samples, specify our batch size, and establish `BATCH_SIZE_PER_REPLICA`, representing the allocation we're making for each device at our disposal.

In [None]:
# Get the number of examples in the train and test sets
num_train_examples = info.splits['train'].num_examples
num_test_examples = info.splits['test'].num_examples

BUFFER_SIZE = 10000

BATCH_SIZE_PER_REPLICA = 64

# Use for Mirrored Strategy
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

# Use for No Strategy
# BATCH_SIZE = BATCH_SIZE_PER_REPLICA * 1

We will also create a mapping function that will normalizes the images

In [None]:
# Function for normalizing the image
def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255

    return image, label

We need to create our training and evaluation datasets with the desired batch size by shuffling through the buffer size.

In [None]:
# Set up the train and val data set
train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
val_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

To ensure the model adheres to the strategy, define it within the scope of the strategy.

- Execute all the cells below and observe the outcomes.
- Subsequently, comment out `with strategy.scope():` and rerun everything, excluding the strategy.
Then, a comparison of the results can be made.

The crucial aspect to observe and compare is the time taken for each epoch to complete. As mentioned in the lecture, employing mirrored strategy on a single device (which our lab environment has) might prolong training due to the overhead in implementing the strategy. Consequently, the benefits of employing this strategy become more apparent when utilized across multiple devices.

In [None]:
# Use for Mirrored Strategy -- comment out `with strategy.scope():` and deindent for no strategy
with strategy.scope():
    model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(64, activation='relu'),
      tf.keras.layers.Dense(10)
    ])

# If no strategy is desired, we can use the commented code below

# model = tf.keras.Sequential([
#       tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
#       tf.keras.layers.MaxPooling2D(),
#       tf.keras.layers.Flatten(),
#       tf.keras.layers.Dense(64, activation='relu'),
#       tf.keras.layers.Dense(10)
#     ])

In [None]:
model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                optimizer=tf.keras.optimizers.Adam(),
                metrics=['accuracy'])

In [None]:
model.fit(train_dataset, epochs=12)

Epoch 1/12
Epoch 2/12
Epoch 3/12
Epoch 4/12
Epoch 5/12
Epoch 6/12
Epoch 7/12
Epoch 8/12
Epoch 9/12
Epoch 10/12
Epoch 11/12
Epoch 12/12


<keras.src.callbacks.History at 0x7e80ba7757b0>