## Imports

In [6]:
from __future__ import absolute_import, division, print_function, unicode_literals

In [7]:
import tensorflow as tf
import tensorflow_hub as hub

import numpy as np
import os
from tqdm import tqdm

In [8]:
import tensorflow_datasets as tfds
tfds.disable_progress_bar()

## Download the dataset

In [11]:
splits = ['train[:80%]', 'train[80%:90%]', 'train[90%:]']
(train_examples, validation_examples, test_examples), info = tfds.load('oxford_flowers102',
                                                                     with_info=True, split=splits,
                                                                      as_supervised=True, data_dir='data/')
num_examples = info.splits['train'].num_examples
num_classes = info.features['label'].num_classes


In [12]:
num_examples, num_classes

(1020, 102)

## Create a strategy to distribute the variables and the graph

In [13]:
# If the list of devices is not specified in the
# `tf.distribute.MirroredStrategy` constructor, it will be auto-detected.
strategy = tf.distribute.MirroredStrategy()

In [14]:
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))

Number of devices: 1


## Setup input pipeline

In [15]:
BUFFER_SIZE = num_examples
EPOCHS = 10
pixels = 224
# MODULE_HANDLE = 'data/resnet_50_feature_vector'
MODULE_HANDLE = 'https://tfhub.dev/google/imagenet/resnet_v2_50/feature_vector/4'
IMAGE_SIZE = (pixels, pixels)
print("Using {} with input size {}".format(MODULE_HANDLE, IMAGE_SIZE))

Using https://tfhub.dev/google/imagenet/resnet_v2_50/feature_vector/4 with input size (224, 224)


In [16]:
def format_image(image, label):
    image = tf.image.resize(image, IMAGE_SIZE) / 255.0
    return image, label

## Set the global batch size

In [17]:
def set_global_batch_size(batch_size_per_replica, strategy):
    global_batch_size = batch_size_per_replica * strategy.num_replicas_in_sync
    return global_batch_size

In [18]:
BATCH_SIZE_PER_REPLICA = 64
GLOBAL_BATCH_SIZE = set_global_batch_size(BATCH_SIZE_PER_REPLICA, strategy)
GLOBAL_BATCH_SIZE

64

In [19]:
train_batches = train_examples.shuffle(num_examples // 4).map(format_image).batch(BATCH_SIZE_PER_REPLICA).prefetch(1)
validation_batches = validation_examples.map(format_image).batch(BATCH_SIZE_PER_REPLICA).prefetch(1)
test_batches = test_examples.map(format_image).batch(1)

In [23]:
def distribute_datasets(strategy, train_batches, validation_batches, test_batches):

    train_dist_dataset = strategy.experimental_distribute_dataset(train_batches)
    val_dist_dataset = strategy.experimental_distribute_dataset(validation_batches)
    test_dist_dataset = strategy.experimental_distribute_dataset(test_batches)

    return train_dist_dataset, val_dist_dataset, test_dist_dataset

In [24]:
train_dist_dataset, val_dist_dataset, test_dist_dataset = distribute_datasets(strategy, train_batches, validation_batches, test_batches)

In [25]:
print(type(train_dist_dataset))
print(type(val_dist_dataset))
print(type(test_dist_dataset))

<class 'tensorflow.python.distribute.input_lib.DistributedDataset'>
<class 'tensorflow.python.distribute.input_lib.DistributedDataset'>
<class 'tensorflow.python.distribute.input_lib.DistributedDataset'>


In [26]:
for x in train_dist_dataset:
    break

print(f"x is a tuple that contains {len(x)} values ")
print(f"x[0] contains the features, and has shape {x[0].shape}")
print(f"  so it has {x[0].shape[0]} examples in the batch, each is an image that is {x[0].shape[1:]}")
print(f"x[1] contains the labels, and has shape {x[1].shape}")

x is a tuple that contains 2 values 
x[0] contains the features, and has shape (64, 224, 224, 3)
  so it has 64 examples in the batch, each is an image that is (224, 224, 3)
x[1] contains the labels, and has shape (64,)


## Create the model

In [27]:
class ResNetModel(tf.keras.Model):
    def __init__(self, classes):
        super(ResNetModel, self).__init__()
        self._feature_extractor = hub.KerasLayer(MODULE_HANDLE,
                                                 trainable=False)
        self._classifier = tf.keras.layers.Dense(classes, activation='softmax')

    def call(self, inputs):
        x = self._feature_extractor(inputs)
        x = self._classifier(x)
        return x

In [28]:
# Create a checkpoint directory to store the checkpoints.
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")

## Define the loss function

In [29]:
with strategy.scope():
    loss_object = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)
    def compute_loss(y_true, y_pred):
        per_example_loss = loss_object(y_true, y_pred)
        return tf.nn.compute_average_loss(per_example_loss, global_batch_size=GLOBAL_BATCH_SIZE)

    test_loss = tf.keras.metrics.Mean(name='test_loss')


## Define the metrics to track loss and accuracy

In [30]:
with strategy.scope():
    train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
        name='train_accuracy')
    test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
        name='test_accuracy')


## Instantiate the model, optimizer, and checkpoints

In [31]:
with strategy.scope():
    model = ResNetModel(num_classes)
    optimizer = tf.keras.optimizers.Adam()
    checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)


## Training Loop

In [32]:
def train_test_step_fns(strategy, model, compute_loss, optimizer, train_accuracy, loss_object, test_loss, test_accuracy):
    with strategy.scope():
        def train_step(inputs):
            images, labels = inputs
            with tf.GradientTape():
                predictions = model(images, training=True)
                loss = compute_loss(labels, predictions)
            gradients = tf.gradients(loss, model.trainable_variables)
            optimizer.apply_gradients(zip(gradients,  model.trainable_variables))

            train_accuracy.update_state(labels, predictions)
            return loss

        def test_step(inputs):
            images, labels = inputs
            predictions = model(images, training=False)
            t_loss = compute_loss(labels, predictions)
            test_loss.update_state(t_loss)
            test_accuracy.update_state(labels, predictions)

    return train_step, test_step

In [33]:
train_step, test_step = train_test_step_fns(strategy, model, compute_loss, optimizer, train_accuracy, loss_object, test_loss, test_accuracy)

In [34]:
def distributed_train_test_step_fns(strategy, train_step, test_step, model, compute_loss, optimizer, train_accuracy, loss_object, test_loss, test_accuracy):
    with strategy.scope():
        @tf.function
        def distributed_train_step(dataset_inputs):
            per_replica_losses = strategy.run(train_step, args=(dataset_inputs,))
            return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,
                                   axis=None)

        @tf.function
        def distributed_test_step(dataset_inputs):
            return strategy.run(test_step, args=(dataset_inputs,))

        return distributed_train_step, distributed_test_step

In [35]:
distributed_train_step, distributed_test_step = distributed_train_test_step_fns(strategy, train_step, test_step, model, compute_loss, optimizer, train_accuracy, loss_object, test_loss, test_accuracy)

In [36]:
with strategy.scope():
    for epoch in range(EPOCHS):
        # TRAIN LOOP
        total_loss = 0.0
        num_batches = 0
        for x in tqdm(train_dist_dataset):
            total_loss += distributed_train_step(x)
            num_batches += 1
        train_loss = total_loss / num_batches

        # TEST LOOP
        for x in test_dist_dataset:
            distributed_test_step(x)

        template = ("Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, "
                    "Test Accuracy: {}")
        print (template.format(epoch+1, train_loss,
                               train_accuracy.result()*100, test_loss.result(),
                               test_accuracy.result()*100))

        test_loss.reset_state()
        train_accuracy.reset_state()
        test_accuracy.reset_state()

13it [02:27, 11.36s/it]


Epoch 1, Loss: 4.515346050262451, Accuracy: 4.779411792755127, Test Loss: 0.06089402735233307, Test Accuracy: 11.764705657958984


13it [02:19, 10.71s/it]


Epoch 2, Loss: 2.634058952331543, Accuracy: 48.897056579589844, Test Loss: 0.04581480473279953, Test Accuracy: 41.17647171020508


13it [02:18, 10.68s/it]


Epoch 3, Loss: 1.58052659034729, Accuracy: 82.96568298339844, Test Loss: 0.03760666400194168, Test Accuracy: 56.86274719238281


13it [03:22, 15.54s/it]


Epoch 4, Loss: 1.0048941373825073, Accuracy: 90.68627166748047, Test Loss: 0.03249422833323479, Test Accuracy: 56.86274719238281


13it [02:16, 10.53s/it]


Epoch 5, Loss: 0.686842679977417, Accuracy: 95.22058868408203, Test Loss: 0.029490171000361443, Test Accuracy: 62.74510192871094


13it [02:17, 10.55s/it]


Epoch 6, Loss: 0.4996119737625122, Accuracy: 97.30392456054688, Test Loss: 0.027804484590888023, Test Accuracy: 66.66667175292969


13it [02:27, 11.33s/it]


Epoch 7, Loss: 0.3796531856060028, Accuracy: 98.52941131591797, Test Loss: 0.02612422965466976, Test Accuracy: 68.62745666503906


13it [02:20, 10.81s/it]


Epoch 8, Loss: 0.2985183000564575, Accuracy: 99.50980377197266, Test Loss: 0.025496482849121094, Test Accuracy: 68.62745666503906


13it [02:25, 11.23s/it]


Epoch 9, Loss: 0.24128498136997223, Accuracy: 99.75489807128906, Test Loss: 0.024948669597506523, Test Accuracy: 67.64705657958984


13it [02:20, 10.81s/it]


Epoch 10, Loss: 0.19785989820957184, Accuracy: 100.0, Test Loss: 0.024284949526190758, Test Accuracy: 70.5882339477539


In [38]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [39]:
save_path = '/content/drive/My Drive/my_model'

# Save the model
tf.saved_model.save(model, save_path)