<a href="https://colab.research.google.com/github/amitshakarchy/HW_3_CNN/blob/master/HW_3_CNN_Transfer_Learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from keras.layers import BatchNormalization
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Flatten
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.python.keras import Input
from tensorflow.python.keras.applications.inception_v3 import InceptionV3
from tensorflow.keras import layers
import tensorflow_hub as hub
import matplotlib.pyplot as plt
from tensorflow.python.keras.callbacks import Callback
from tensorboard.plugins.hparams import api as hp
from tensorflow.python.keras.layers import Dropout, GlobalAveragePooling2D


def data(split_another_way=False):
    dataset, dataset_info = tfds.load('oxford_flowers102', with_info=True, as_supervised=True,
                                      split='train+test+validation')
    # base on https://stackoverflow.com/questions/60646972/trouble-with-splitting-data-from-tensorflow-datasets?fbclid=IwAR2_gQUx3Fw07pVoFwRImVsG5skpl1A0B032qMmou1-lXWtkO-JjprLGozY
    df_all_length = [i for i, _ in enumerate(dataset)][-1] + 1

    train_size = int(0.5 * df_all_length)
    val_test_size = int(0.25 * df_all_length)

    if split_another_way:
        test_val_set = dataset.take(val_test_size * 2)  # [ 1, 2, 3, 4, 5]
        training_set = dataset.skip(val_test_size * 2)  # [ 6, 7, 8, 9, 10]
        validation_set = test_val_set.skip(val_test_size)
        test_set = test_val_set.take(val_test_size)

    else:
        # split whole dataset
        # [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
        training_set = dataset.take(train_size)  # [ 1, 2, 3, 4, 5]
        df_test_val = dataset.skip(train_size)  # [ 6, 7, 8, 9, 10]
        validation_set = df_test_val.skip(val_test_size)  # [9, 10]
        test_set = df_test_val.take(val_test_size)  # [6, 7, 8]

    # Create a training set, a validation set and a test set.
    # test_set, training_set, validation_set = dataset['test'], dataset['train'], dataset['validation']
    num_training_examples = 0
    num_validation_examples = 0
    num_test_examples = 0

    for example in training_set:
        num_training_examples += 1

    for example in validation_set:
        num_validation_examples += 1

    for example in test_set:
        num_test_examples += 1

    print('Total Number of Training Images: {}'.format(num_training_examples))
    print('Total Number of Validation Images: {}'.format(num_validation_examples))
    print('Total Number of Test Images: {} \n'.format(num_test_examples))
    # Get the number of classes in the dataset from the dataset info.
    num_classes = dataset_info.features['label'].num_classes
    print('Total Number of Classes: {}'.format(num_classes))

    def format_image(image, label):
        if normalization:
            image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE)) / 255.0
        else:
            image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
        return image, label

    train_batches = training_set.cache().shuffle(num_training_examples // 4).map(format_image).batch(
        BATCH_SIZE).prefetch(1)

    validation_batches = validation_set.cache().map(format_image).batch(BATCH_SIZE).prefetch(1)

    test_batches = test_set.cache().map(format_image).batch(BATCH_SIZE).prefetch(1)

    return train_batches, validation_batches, test_batches


def get_mobilenet_v2_adapted(hparams):
    URL = "https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4"
    feature_extractor = hub.KerasLayer(URL, input_shape=(IMG_SIZE, IMG_SIZE, 3))
    # Freeze the Pre-Trained Model
    feature_extractor.trainable = False

    # Attach a classification head
    model = tf.keras.Sequential([
        feature_extractor,
        Dense(hparams[HP_NUM_UNITS], activation='relu'),
        Dropout(hparams[HP_DROPOUT]),
        layers.Dense(N_CLASSES, activation='softmax')
    ])
    return model


def get_inceptionv3_adapted(hparams):
    input_tensor = Input(shape=(IMG_SIZE, IMG_SIZE, 3))
    base_model = InceptionV3(include_top=False,
                             weights='imagenet',
                             input_shape=(IMG_SIZE, IMG_SIZE, 3))
    base_model.trainable = False
    bn = BatchNormalization()(input_tensor)
    x = base_model(bn)
    x = GlobalAveragePooling2D()(x)
    x = Dense(hparams[HP_NUM_UNITS], activation='relu')(x)
    x = Dropout(hparams[HP_DROPOUT])(x)
    output = Dense(N_CLASSES, activation='softmax')(x)
    model = Model(input_tensor, output)
    return model

# Implementation base on https://github.com/keras-team/keras/issues/2548
class TestCallback(Callback):
    def __init__(self, test_data):
        super().__init__()
        self.test_data = test_data
        self.history_test = {'test_accuracy': [], 'test_loss': []}

    def on_epoch_end(self, epoch, logs={}):
        x = self.test_data.as_numpy_iterator().next()[0]
        y = self.test_data.as_numpy_iterator().next()[1]
        loss, acc = self.model.evaluate(x, y, verbose=0)
        self.history_test['test_accuracy'].append(acc)
        self.history_test['test_loss'].append(loss)
        print('\nTesting loss: {}, acc: {}\n'.format(loss, acc))


def plot(history, history_test, session_num):
    epochs_range = range(len(history.history['accuracy']))
    plt.figure(figsize=(8, 8))
    plt.subplot(1, 2, 1)
    plt.plot(epochs_range, history.history['accuracy'], label='Training Accuracy')
    plt.plot(epochs_range, history.history['val_accuracy'], label='Validation Accuracy')
    plt.plot(epochs_range, history_test['test_accuracy'], label='Test Accuracy')
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.legend(loc='lower right')
    plt.title('Accuracy')

    plt.subplot(1, 2, 2)
    plt.plot(epochs_range, history.history['loss'], label='Training Loss')
    plt.plot(epochs_range, history.history['val_loss'], label='Validation Loss')
    plt.plot(epochs_range, history_test['test_loss'], label='Test Loss')
    plt.legend(loc='upper right')
    plt.title('Loss')
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.savefig(str(session_num) + '.png')
    plt.show()


# Press the green button in the gutter to run the script.
print("Let's go!")
N_CLASSES = 102
IMG_SIZE = 224
BATCH_SIZE = 32
EPOCHS = 10
HP_NUM_UNITS = hp.HParam('num_units', hp.Discrete([128, 256]))
HP_DROPOUT = hp.HParam('dropout', hp.Discrete([0.0, 0.3]))
HP_OPTIMIZER = hp.HParam('optimizer', hp.Discrete(['adam']))
optimizer = 'adam'
INPUT_SHAPE = Input(shape=(IMG_SIZE, IMG_SIZE, 3))

parameters = dict(model='feature_vector', normalization=True, num_units=128, dropout_rate=0.3, split_another_way=False)
# model - 'inceptionv3' , 'feature_vector'
# normalization - True, False
# num_units - 128, 256
# dropout - 0.0, 0.3
# split_another_way - False, True

run_model = parameters['model']
normalization = parameters['normalization']
num_units = parameters['num_units']
dropout_rate = parameters['dropout_rate']
train_batches, validation_batches, test_batches = data(parameters['split_another_way'])

hparams = {
    HP_NUM_UNITS: num_units,
    HP_DROPOUT: dropout_rate,
    HP_OPTIMIZER: optimizer,
}
if run_model == 'feature_vector':
    model = get_mobilenet_v2_adapted(hparams)
else:
    model = get_inceptionv3_adapted(hparams)
print(f"-----------------------------------------------------------{run_model} "
      f"--------------------------------------------------------------")
model.compile(
    optimizer=tf.keras.optimizers.RMSprop(),
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy'])

print({h.name: hparams[h] for h in hparams})

# Stop training when there is no improvement in the validation loss for 5 consecutive epochs
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=2)
callable_test = TestCallback(test_batches)

history = model.fit(train_batches,
                    epochs=EPOCHS,
                    validation_data=validation_batches,
                    callbacks=[early_stopping, callable_test])


loss_test, acc_test = model.evaluate(test_batches)
str_loss_acc = "loss_{:.3f}_acc_{:.3f}".format(loss_test, acc_test)
plot(history, callable_test.history_test, str_loss_acc)
