In [4]:
import tensorflow as tf
import numpy as np 
import os 

In [5]:
strategy = tf.distribute.MirroredStrategy()
print(strategy.num_replicas_in_sync)

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)
1


In [6]:
fashion_mnist = tf.keras.datasets.fashion_mnist
(train_images, train_label), (test_images, test_label) = fashion_mnist.load_data()

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz


In [7]:
train_images.shape

(60000, 28, 28)

In [8]:

train_images = train_images[...,None]
test_images = test_images[..., None]

In [9]:
train_images = train_images/ np.float32(255)
test_images = test_images/np.float32(255)


In [10]:
BUFFER_SIZE = len(train_images)
batch_size =  64
global_batch_size = batch_size * strategy.num_replicas_in_sync

In [11]:
train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_label)).\
shuffle(BUFFER_SIZE).batch(global_batch_size)
test_dataset= tf.data.Dataset.from_tensor_slices((train_images, train_label)).\
batch(global_batch_size)

In [12]:
train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)
test_dist_dataset = strategy.experimental_distribute_dataset(test_dataset)


In [13]:
from tensorflow.keras import Model, layers, Sequential
def createmodel():
  model = Sequential([
                      layers.Conv2D(32, 3, activation='relu'),
                      layers.MaxPooling2D(),
                      layers.Conv2D(64, 3, activation ='relu'),
                      layers.MaxPooling2D(),
                      layers.Flatten(),
                      layers.Dense(64, activation='relu'),
                      layers.Dense(10)
  ])
  return model   

In [14]:
!pip install ipython-autotime

%load_ext autotime

Collecting ipython-autotime
  Downloading https://files.pythonhosted.org/packages/b4/c9/b413a24f759641bc27ef98c144b590023c8038dfb8a3f09e713e9dff12c1/ipython_autotime-0.3.1-py2.py3-none-any.whl
Installing collected packages: ipython-autotime
Successfully installed ipython-autotime-0.3.1
time: 146 µs (started: 2021-04-29 11:36:37 +00:00)


In [15]:
with strategy.scope():
  loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction = tf.keras.losses.Reduction.NONE)

  #Main
  
  def compute_loss(labels, predictions):
    per_example_loss = loss_object(labels, predictions)
    return tf.nn.compute_average_loss(per_example_loss, global_batch_size=global_batch_size)
  
  #Rest is just declaration 

  test_loss = tf.keras.metrics.Mean(name ='test_loss')
  train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name = 'train_acc')
  test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name = 'test_acc')

  optimizer = tf.keras.optimizers.Adam()
  model = createmodel()


time: 53.4 ms (started: 2021-04-29 11:36:37 +00:00)


In [16]:

def distributed_training(dataset_inputs):
  per_replica_losses = strategy.run(train_step, args =(dataset_inputs,))
  #print(per_replica_losses)
  foo = strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
  return foo 

def train_step(inputs):
  images, label = inputs 
  with tf.GradientTape() as tape:
    predictions = model(images, training=True)
    loss = compute_loss(label, predictions)

  gradients = tape.gradient(loss, model.trainable_weights)
  optimizer.apply_gradients(zip(gradients, model.trainable_weights))
  train_accuracy.update_state(label, predictions)
  return loss



time: 9.7 ms (started: 2021-04-29 11:36:37 +00:00)


In [17]:
epochs = 10 
for epoch in range(epochs): 
  print('start')
  total_loss = 0 
  num_batches = 0 
  for batch in train_dist_dataset:
    total_loss += distributed_training(batch)
    num_batches += 1 
  train_loss = total_loss/num_batches
  train_acc = train_accuracy.result()

  print(epoch, train_loss, train_acc)
  train_accuracy.reset_states()


start
0 tf.Tensor(0.505495, shape=(), dtype=float32) tf.Tensor(0.81436664, shape=(), dtype=float32)
start
1 tf.Tensor(0.33438087, shape=(), dtype=float32) tf.Tensor(0.87911665, shape=(), dtype=float32)
start
2 tf.Tensor(0.2881608, shape=(), dtype=float32) tf.Tensor(0.8965833, shape=(), dtype=float32)
start
3 tf.Tensor(0.25674978, shape=(), dtype=float32) tf.Tensor(0.9072, shape=(), dtype=float32)
start
4 tf.Tensor(0.23394898, shape=(), dtype=float32) tf.Tensor(0.9144167, shape=(), dtype=float32)
start
5 tf.Tensor(0.21388002, shape=(), dtype=float32) tf.Tensor(0.9217, shape=(), dtype=float32)
start
6 tf.Tensor(0.19426897, shape=(), dtype=float32) tf.Tensor(0.9277667, shape=(), dtype=float32)
start
7 tf.Tensor(0.17845084, shape=(), dtype=float32) tf.Tensor(0.9336, shape=(), dtype=float32)
start
8 tf.Tensor(0.16298343, shape=(), dtype=float32) tf.Tensor(0.93948334, shape=(), dtype=float32)
start
9 tf.Tensor(0.15063448, shape=(), dtype=float32) tf.Tensor(0.94451666, shape=(), dtype=float32