# Import dependencies


In [None]:
import pandas as pd
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras.models import Sequential

# Get access to google drive
import os
from google.colab import drive
drive.mount('/content/drive')

# To organize directories
import shutil

# For metrics
from sklearn.metrics import classification_report, confusion_matrix, f1_score, accuracy_score


In [None]:
#os.chdir('drive/MyDrive/')
!ls

# Datasets

In [None]:
TRAIN_DIR = 'predict_numbers/train_numbers'
VAL_DIR = 'predict_numbers/val_numbers'

BATCH_SIZE = 64
IMG_SIZE = (96, 96)
VAL_BATCHES = 10

train_dataset = tf.keras.utils.image_dataset_from_directory(TRAIN_DIR, image_size=IMG_SIZE)
val_dataset = tf.keras.utils.image_dataset_from_directory(VAL_DIR, batch_size=VAL_BATCHES, image_size=IMG_SIZE)

class_names = train_dataset.class_names
print(class_names)

In [None]:
def encode_labels(image, vector):
    return image, tf.one_hot(vector, 14)

encoded_ds_train = train_dataset.map(encode_labels)
encoded_ds_val = val_dataset.map(encode_labels)

# Model

## Data aug

In [None]:
class RandomBrightness(tf.keras.layers.Layer):
    def __init__(self, top_brightness, **kwargs):
        super().__init__(**kwargs)
        self.top_brightness = top_brightness
    def call(self, x):
        br = tf.random.uniform([1], minval=0, maxval=self.top_brightness, dtype=tf.dtypes.float32)[0]
        return tf.image.adjust_brightness(x, br)

class DataAug(tf.keras.Model):
    def __init__(self, brightness=.28, **kwargs):
        super().__init__()
        self.transformations = [tf.keras.layers.RandomTranslation(0.1, 0.1, fill_mode='nearest'),
                                tf.keras.layers.RandomZoom((-0.2, 0.2), fill_mode='nearest'), RandomBrightness(brightness),
                                tf.keras.layers.RandomRotation(0.5, fill_mode='nearest')]  
    def call(self, x):
        for t in self.transformations:
            x = t(x)
        return x

data_aug = DataAug()

# Prediction model

In [None]:
pred_model = tf.keras.applications.MobileNetV2(input_shape=list(IMG_SIZE)+[3], include_top=False, weights='imagenet')
pred_model.trainable = False

## Final model

In [None]:
class MyModel(tf.keras.Model):
    def __init__(self, base_model=pred_model):
        super(MyModel, self).__init__()
        self.data_aug = data_aug
        self.rescaling = tf.keras.layers.Rescaling(1./127.5, offset=-1)
        self.base_model = pred_model
        self.predition_head = [tf.keras.layers.GlobalAveragePooling2D(),
                               tf.keras.layers.Dense(32, activation='relu'),
                              tf.keras.layers.Dense(14, activation='softmax')]

    def call(self, x, data_aug=True):
        if data_aug:
            x = self.data_aug(x) 
        x = self.rescaling(x)
        x = self.base_model(x, training=False)
        for layer in self.predition_head:
            x = layer(x)
        return x
model = MyModel()

In [None]:
class MyModel_save(tf.keras.Model):
    def __init__(self, base_model=pred_model):
        super(MyModel_save, self).__init__()
        self.base_model = pred_model
        self.predition_head = [tf.keras.layers.GlobalAveragePooling2D(),
                               tf.keras.layers.Dense(32, activation='relu'),
                              tf.keras.layers.Dense(14, activation='softmax')]

    def call(self, x, data_aug=True):
        x = self.base_model(x, training=False)
        for layer in self.predition_head:
            x = layer(x)
        return x

model_save = MyModel_save()

In [None]:
model(np.zeros((3,96, 96, 3)), data_aug=False)
model_save(np.zeros((3,96, 96, 3)), data_aug=False)
model.summary()

# Training

In [None]:
def my_plot_cm(cm, my_labels=class_names, title=''):
    """
    Plots the confusion matrix. Title is the title of the plot (string), and my_labels is
    a list of labels for x and y axis
    """
    l = len(my_labels)
    fig = plt.figure(figsize = (7,7))
    ax = fig.add_subplot(111)

    color = plt.cm.summer
    cax = ax.matshow(cm, cmap = color)
    for i in range(len(my_labels)):
        for j in range(len(my_labels)):
            c = cm[j,i].round(2)
            ax.text(i, j, str(c), va='center', ha='center', color='red')
    plt.grid(False)
    ax.title.set_text(title)
    ax.set_yticks(range(l))
    ax.set_xticks(range(l))
    ax.set_xticklabels(labels=my_labels)
    ax.set_yticklabels(labels=my_labels)
    plt.show()

In [None]:
optimizer = tf.keras.optimizers.Adam()
loss_fn = tf.keras.losses.CategoricalCrossentropy()

def gradient_step(X, y, my_model, return_loss=True, my_optimizer=optimizer):
    """
    Perform a step of gradient descent updating the loss if past_loss is passed (past_loss != None).
    X,y have to be encoded
    """
    with tf.GradientTape() as tape:
        predictions = my_model(X)
        my_loss = loss_fn(y, predictions)
            
    grads = tape.gradient(my_loss, my_model.trainable_variables)
    my_optimizer.apply_gradients(zip(grads, my_model.trainable_variables))
    if return_loss:
        return my_loss

In [None]:
AUTOTUNE = tf.data.AUTOTUNE

EPOCHS = 15
STEPS_PER_EPOCH = 50
VAL_BATCH_SIZE = 20

past_loss = []
past_loss_val = []
accuracy_scores = []
max_acc_score = .9
c_report = 0

batched_ds = encoded_ds_train.prefetch(buffer_size=AUTOTUNE)
batched_ds_val = encoded_ds_val.prefetch(buffer_size=AUTOTUNE)

def training_loop(my_model=model, my_optimizer=optimizer, batched_ds=batched_ds, batched_ds_val=batched_ds_val,
                  epochs=EPOCHS, steps_per_epoch=STEPS_PER_EPOCH, val_batches=VAL_BATCH_SIZE,
                  save_model_at_checkpoint=False, model_for_saving=None,
                  past_loss=past_loss, past_loss_val=past_loss_val, decay_lr=1.,
                  accuracy_scores=accuracy_scores, max_acc_score=max_acc_score, c_report=c_report, classes=class_names):
    '''
    Training loop for classification.
    '''

    for epoch in range(1, epochs + 1):
        print('###########################')
        print('Epoch', epoch)
        print('---------------------------')

        ## Gradient step
        for X_batch, y_batch in batched_ds.take(steps_per_epoch):
            new_loss = gradient_step(X_batch, y_batch, my_model=my_model, my_optimizer=my_optimizer)
            past_loss.append(new_loss)

        my_optimizer.learning_rate = my_optimizer.learning_rate*decay_lr

        ## Validation
        predictions = []
        true_results = []
        for X_batch, y_batch in batched_ds_val.take(val_batches):
            y_pred = model(X_batch, data_aug=False)
            if len(predictions) != 0:
                predictions = np.concatenate((predictions, y_pred.numpy()))
            else:
                predictions = y_pred.numpy()

            if len(true_results) == 0:
                true_results = y_batch.numpy()
            else:
                true_results = np.concatenate((true_results, y_batch.numpy()))

            new_loss_val = loss_fn(y_batch, y_pred)
            past_loss_val.append(new_loss_val)

        ## Plot loss
        loss1 = pd.DataFrame(past_loss, columns = ['train loss'])
        loss2 = pd.DataFrame(past_loss_val, columns = ['validation loss'])

        newdf = pd.DataFrame(np.repeat(loss2.values, STEPS_PER_EPOCH//VAL_BATCHES, axis=0))
        newdf.columns = loss2.columns
        loss_df = loss1.join(newdf)
        loss_df.plot(figsize = (18,12))
        plt.show()
    
        rolling_loss = loss_df.rolling(window=50).mean().dropna()
        rolling_loss.columns = ['rolling loss train', 'rolling loss validation']
        rolling_loss.plot(figsize = (18,12))
        plt.show()

        ## Metrics ##
        # Plot confusion matrix
        cm = confusion_matrix(true_results.argmax(axis=1), predictions.argmax(axis=1), normalize='true', labels=range(len(classes)))
        my_plot_cm(cm, my_labels=class_names, title='Confusion matrix validation set')

        # Print classification report
        print('---------------------')
        print('Classification report validation:')
        print('Previous epoch:')
        print(c_report)
        print('Current epoch')
        c_report = classification_report(true_results.argmax(axis=1), predictions.argmax(axis=1),labels=range(len(classes)), target_names=classes)
        print(c_report)
        print('---------------------')

        # Plot accuracy score
        new_accuracy_score = accuracy_score(true_results.argmax(axis=1), predictions.argmax(axis=1))
        print('Latest accuracy score:', new_accuracy_score)
        print('---------------------')
        accuracy_scores.append(new_accuracy_score)
        acc = pd.DataFrame(accuracy_scores, columns = ['accuracy score'])
        acc.plot(figsize=(18,12))
        plt.show()

        ## Checkpoints
        if save_model_at_checkpoint:
            if new_accuracy_score > max_acc_score:
                max_acc_score = new_accuracy_score
                print('**********************')
                print('New best accuracy score:', new_accuracy_score)
                print('**********************')
                model_for_saving.set_weights(my_model.get_weights())
                model_for_saving.save('accuracy'+str(max_acc_score))
                               
if False: 
    training_loop(epochs=20)

# Fine tune mobilenet

In [None]:
pred_model.trainable = True
for layer in pred_model.layers[:75]:
    layer.trainable = False

current_lr = optimizer.lr
model.summary()

In [None]:
optimizer = tf.keras.optimizers.RMSprop(momentum=.9, learning_rate=current_lr/100)

if False: 
    training_loop(epochs=60, decay_lr=0.98,
                  my_optimizer=optimizer, model_for_saving=model_save,
                  save_model_at_checkpoint=True)