## Libraries

In [None]:
### Uncomment the next two lines to,
### install tensorflow_hub and tensorflow datasets

#!pip install tensorflow_hub
#!pip install tensorflow_datasets

In [None]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

import tensorflow_hub as hub
import tensorflow_datasets as tfds

from tensorflow.keras import layers

### Download and Split data into Train and Validation

In [None]:
def get_data():
    (train_set, validation_set), info = tfds.load(
        'tf_flowers', 
        with_info=True, 
        as_supervised=True, 
        split=['train[:70%]', 'train[70%:]'],
    )
    
    return train_set, validation_set, info

train_set, validation_set, info = get_data()

In [None]:
num_examples = info.splits['train'].num_examples
num_classes = info.features['label'].num_classes

print('Total Number of Classes: {}'.format(num_classes))
print('Total Number of Training Images: {}'.format(len(train_set)))
print('Total Number of Validation Images: {} \n'.format(len(validation_set)))

In [None]:
img_shape = 299
batch_size = 32

def format_image(image, label):
    image = tf.image.resize(image, (img_shape, img_shape))/255.0
    return image, label

train_batches = train_set.shuffle(num_examples//4).map(format_image).batch(batch_size).prefetch(1)
validation_batches = validation_set.map(format_image).batch(batch_size).prefetch(1)

### Getting Inception model learned features

In [None]:
def get_mobilenet_features():
    URL = "https://tfhub.dev/google/tf2-preview/inception_v3/feature_vector/4"
    global img_shape
    feature_extractor = hub.KerasLayer(URL, input_shape=(img_shape, img_shape,3))
    
    return feature_extractor

In [None]:
### Freezing the layers of transferred model (InceptionV3 Model)
feature_extractor = get_mobilenet_features()
feature_extractor.trainable = False

## Deep Learning Model - Transfer Learning using InceptionV3

In [None]:
def create_transfer_learned_model(feature_extractor):
    
    global num_classes
    model = tf.keras.Sequential([
        feature_extractor, 
        tf.keras.layers.Dense(512, activation='relu'), 
        tf.keras.layers.Dropout(0.4),
        layers.Dense(num_classes, activation='softmax')
    ])

    model.compile(
        optimizer='adam', 
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), 
        metrics=['accuracy'])

    model.summary()
    
    return model

### Training the last classification layer of the model

Achieved Validation Accuracy: 92.10% (significant improvement over simple architecture)

In [None]:
epochs = 10
model = create_transfer_learned_model(feature_extractor)
history = model.fit(train_batches,
                    epochs=epochs,
                    validation_data=validation_batches)

### Plotting Accuracy and Loss Curves

In [None]:
def create_plots(history):
    acc = history.history['accuracy']
    val_acc = history.history['val_accuracy']

    loss = history.history['loss']
    val_loss = history.history['val_loss']

    global epochs
    epochs_range = range(epochs)

    plt.figure(figsize=(8, 8))
    plt.subplot(1, 2, 1)
    plt.plot(epochs_range, acc, label='Training Accuracy')
    plt.plot(epochs_range, val_acc, label='Validation Accuracy')
    plt.legend(loc='lower right')
    plt.title('Training and Validation Accuracy')

    plt.subplot(1, 2, 2)
    plt.plot(epochs_range, loss, label='Training Loss')
    plt.plot(epochs_range, val_loss, label='Validation Loss')
    plt.legend(loc='upper right')
    plt.title('Training and Validation Loss')
    plt.show()
    
create_plots(history)

### Prediction

In [None]:
def predict():
    
    global train_batches, info
    
    image_batch, label_batch = next(iter(train_batches.take(1)))
    image_batch = image_batch.numpy()
    label_batch = label_batch.numpy()

    predicted_batch = model.predict(image_batch)
    predicted_batch = tf.squeeze(predicted_batch).numpy()

    class_names = np.array(info.features['label'].names)
    predicted_ids = np.argmax(predicted_batch, axis=-1)
    predicted_class_names = class_names[predicted_ids]
    
    return image_batch, label_batch, predicted_ids, predicted_class_names

In [None]:
image_batch, label_batch, predicted_ids, predicted_class_names = predict()
print("Labels: ", label_batch)
print("Predicted labels: ", predicted_ids)

In [None]:
def plot_figures():
    
    global image_batch, predicted_ids, label_batch
    plt.figure(figsize=(10,9))
    for n in range(30):
        plt.subplot(6,5,n+1)
        plt.subplots_adjust(hspace = 0.3)
        plt.imshow(image_batch[n])
        color = "blue" if predicted_ids[n] == label_batch[n] else "red"
        plt.title(predicted_class_names[n].title(), color=color)
        plt.axis('off')
    _ = plt.suptitle("Model predictions (blue: correct, red: incorrect)")
    
plot_figures()