# Parallelizing TensorFlow with Mirrored Strategy

We will show how to use TensorFlow Distributed API

In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds

We will create two virtual GPUs

In [None]:
# Create two virtual GPUs
gpu_devices = tf.config.list_physical_devices('GPU')
if gpu_devices:
    try:
        tf.config.experimental.set_virtual_device_configuration(gpu_devices[0],
                                                   [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=1024),
                                                    tf.config.experimental.VirtualDeviceConfiguration(memory_limit=1024) ])
    except RuntimeError as e:
        # Memory growth cannot be modified after GPU has been initialized
        print(e)

Next, we will load the MNIST dataset via the `tensorflow_datasets` API.

In [None]:
datasets, info = tfds.load('mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = datasets['train'], datasets['test']

Then, we will prepare the data. 

In [None]:
def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  return tf.cast(image, tf.float32) / 255., label

mnist_train = mnist_train.map(
    normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
mnist_train = mnist_train.cache()
mnist_train = mnist_train.shuffle(info.splits['train'].num_examples)
mnist_train = mnist_train.prefetch(tf.data.experimental.AUTOTUNE)


mnist_test = mnist_test.map(
    normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
mnist_test = mnist_test.cache()
mnist_test = mnist_test.prefetch(tf.data.experimental.AUTOTUNE)

We are now ready to apply a mirrored strategy. The goal of this strategy is to replicate the model across all GPUs on the same machine.
Each model is trained on different batches of data and a synchronous training strategy is applied. 

In [None]:
mirrored_strategy = tf.distribute.MirroredStrategy()

We check that we have two devices corresponding to the two virtual GPUs created at the beginning of this recipe.

In [None]:
print('Number of devices: {}'.format(mirrored_strategy.num_replicas_in_sync))

We'll define the batch size

In [None]:
BATCH_SIZE_PER_REPLICA = 128
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync

mnist_train = mnist_train.batch(BATCH_SIZE)
mnist_test = mnist_test.batch(BATCH_SIZE)

We'll define and compile our model using the mirrored strategy.


In [None]:
with mirrored_strategy.scope():
    model = tf.keras.Sequential()
    model.add(tf.keras.layers.Flatten(name="FLATTEN"))
    model.add(tf.keras.layers.Dense(units=128 , activation="relu", name="D1"))
    model.add(tf.keras.layers.Dense(units=64 , activation="relu", name="D2"))
    model.add(tf.keras.layers.Dense(units=10, activation="softmax", name="OUTPUT"))
    
    model.compile(
        optimizer="sgd", 
        loss="sparse_categorical_crossentropy",
        metrics=["accuracy"]
    )


In [None]:
model.fit(mnist_train, 
          epochs=10,
          validation_data= mnist_test
          )