# Week 4 Assignment: Custom training with tf.distribute.Strategy

Welcome to the final assignment of this course! For this week, you will implement a distribution strategy to train on the Oxford Flowers 102 dataset. As the name suggests, distribution strategies allow you to setup training across multiple devices. We are just using a single device in this lab but the syntax you'll apply should also work when you have a multi-device setup. Let's begin!

In [6]:
from __future__ import absolute_import, division, print_function, unicode_literals

import tensorflow as tf
import tensorflow_hub as hub

import numpy as np
import os
from tqdm import tqdm

In [7]:
import tensorflow_datasets as tfds
tfds.disable_progress_bar()

In [14]:
splits = ['train[:80%]', 'train[80%:90%]', 'train[90%:]']

(train_examples, validation_examples, test_examples), info = tfds.load('oxford_flowers102', with_info=True, as_supervised=True, split = splits, data_dir='data/')

num_examples = info.splits['train'].num_examples
num_classes = info.features['label'].num_classes

How does tf.distribute.MirroredStrategy strategy work?

All the variables and the model graph are replicated on the replicas.
Input is evenly distributed across the replicas.
Each replica calculates the loss and gradients for the input it received.
The gradients are synced across all the replicas by summing them.
After the sync, the same update is made to the copies of the variables on each replica.

In [10]:
strategy=tf.distribute.MirroredStrategy()
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)


Number of devices: 1


In [15]:
BUFFER_SIZE=num_examples
EPOCHS=10
pixels=224
MODULE_HANDLE='data/resnet_50_feature_vector'
IMAGE_SIZE=(pixels,pixels)


In [16]:
def format_image(image,label):
    image=tf.image.resize(image,IMAGE_SIZE)/255.0
    return image,label

In [17]:
def set_global_batch_size(batch_size_per_replica,strategy):
    global_batch_size=batch_size_per_replica*strategy.num_replicas_in_sync
    return global_batch_size

In [18]:
BATCH_SIZE_PER_REPLICA=64
GLOBAL_BATCH_SIZE=set_global_batch_size(BATCH_SIZE_PER_REPLICA,strategy)


In [19]:
train_batches = train_examples.shuffle(num_examples // 4).map(format_image).batch(BATCH_SIZE_PER_REPLICA).prefetch(1)
validation_batches = validation_examples.map(format_image).batch(BATCH_SIZE_PER_REPLICA).prefetch(1)
test_batches = test_examples.map(format_image).batch(1)

In [22]:
def distribute_datasets(strategy,train_batches,validation_batches,test_batches):
    train_dist_dataset=strategy.experimental_distribute_dataset(train_batches)
    val_dist_dataset=strategy.experimental_distribute_dataset(validation_batches)
    test_dist_dataset=strategy.experimental_distribute_dataset(test_batches)
    
    return train_dist_dataset,val_dist_dataset,test_dist_dataset
    


train_dist_dataset, val_dist_dataset, test_dist_dataset = distribute_datasets(strategy, train_batches, validation_batches, test_batches)
print(type(train_dist_dataset))
print(type(val_dist_dataset))
print(type(test_dist_dataset))

<class 'tensorflow.python.distribute.input_lib.DistributedDataset'>
<class 'tensorflow.python.distribute.input_lib.DistributedDataset'>
<class 'tensorflow.python.distribute.input_lib.DistributedDataset'>


In [25]:
class ResNetModel(tf.keras.Model):
    def __init__(self,classes):
        super(ResNetModel, self).__init__()
        self._feature_extractor=hub.KerasLayer(MODULE_HANDLE,
                                                 trainable=False)
        self._classifier = tf.keras.layers.Dense(classes, activation='softmax')
        
    
    def call(self, inputs):
        x = self._feature_extractor(inputs)
        x = self._classifier(x)
        return x

In [28]:
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")

with strategy.scope():
    loss_object=tf.keras.losses.SparseCategoricalCrossentropy(
    reduction=tf.keras.losses.Reduction.NONE)
    
    def compute_loss(labels,predictions):
        per_example_loss= loss_object(labels,predictions)
        return tf.nn.compute_average_loss(per_example_loss,global_batch_size=GLOBAL_BATCH_SIZE)
    test_loss=tf.keras.metrics.Mean(name='test_loss')
    
    

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


In [29]:
with strategy.scope():
    train_accuracy=tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
    test_accuracy=tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')
    

INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


In [32]:
with strategy.scope():
    model=ResNetModel(classes=num_classes)
    optimizer=tf.keras.optimizers.Adam()
    checkpoint=tf.train.Checkpoint(optimizer=optimizer, model=model)
    

In [33]:
def train_test_step_fns(strategy, model, compute_loss, optimizer, train_accuracy, loss_object, test_loss, test_accuracy):

    with strategy.scope():
        def train_step(inputs):
            images,labels=inputs
            with tf.GradientTape() as tape:
                predictions=model(images,training=True)
                loss=compute_loss(labels,predictions)
                
            gradients=tape.gradient(loss,model.trainable_variables)
            optimizer.apply_gradients(zip(gradients,model.trainable_variables))
            
            train_accuracy.update_state(labels,predictions)
        
        def test_step(inputs):
            images,labels=inputs
            predictions=model(images,training=False)
            t_loss=loss_object(labels,predictions)
            test_loss.update_state(labels,predictions)
            test_accuracy.update_state(labels,predictions)
            
        return train_step,test_step
            
train_step, test_step = train_test_step_fns(strategy, model, compute_loss, optimizer, train_accuracy, loss_object, test_loss, test_accuracy)     

In [34]:
def distributed_train_test_step_fns(strategy, train_step, test_step, model, compute_loss, optimizer, train_accuracy, loss_object, test_loss, test_accuracy):
    
    with strategy.scope():
        @tf.function
        def distributed_train_step(dataset_inputs):
            per_replica_losses=strategy.run( train_step,dataset_inputs)
            return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,
                                   axis=None)
        
        @tf.function
        def distributed_test_step(dataset_inputs):
            return strategy.run(test_step,dataset_inputs)
    
    
    return distributed_train_step,distributed_test_step

In [35]:
distributed_train_step, distributed_test_step = distributed_train_test_step_fns(strategy, train_step, test_step, model, compute_loss, optimizer, train_accuracy, loss_object, test_loss, test_accuracy)

In [None]:
with strategy.scope():
    for epoch in range(EPOCHS):
        # TRAIN LOOP
        total_loss = 0.0
        num_batches = 0
        for x in tqdm(train_dist_dataset):
            total_loss += distributed_train_step(x)
            num_batches += 1
        train_loss = total_loss / num_batches

        # TEST LOOP
        for x in test_dist_dataset:
            distributed_test_step(x)

        template = ("Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, "
                    "Test Accuracy: {}")
        print (template.format(epoch+1, train_loss,
                               train_accuracy.result()*100, test_loss.result(),
                               test_accuracy.result()*100))

        test_loss.reset_states()
        train_accuracy.reset_states()
        test_accuracy.reset_states()