# Predictions

This notebook will serve to test a custom ResNet50 model (using model.predict()).   
This notebook assumes that dataset directory is located in parent directory.  
(i.e. Each dataset should be saved under ../dataset/)

Import packages

In [None]:
import tensorflow as tf
import tensorflow_hub as hub
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.metrics import confusion_matrix, plot_confusion_matrix
import matplotlib.pyplot as plt
import numpy as np
import ipywidgets as widgets
import itertools
import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

Dataset selection

In [None]:
list_of_db = [dir for dir in os.listdir(os.path.join('../datasets/')) if not dir.startswith('.')]
dropdown_db = widgets.Dropdown(
    options=list_of_db,
    description='Dataset',
)
display(dropdown_db)

## Keras Model

Set up path variables

In [None]:
db_PATH = os.path.join('../datasets/', dropdown_db.value + '/')
test_dir = os.path.join(db_PATH, 'test')
test_real_dir = os.path.join(test_dir, 'real/')
test_fake_dir = os.path.join(test_dir, 'fake/')

Image size selection

In [None]:
size = widgets.IntText(
    description='Image Dim.:',
    disabled=False
)
display(size)

Set up parameters, and compute dataset sizes

In [None]:
batch_size=32
IMG_HEIGHT = size.value
IMG_WIDTH = size.value
num_real_test = len(os.listdir(test_real_dir))
num_fake_test = len(os.listdir(test_fake_dir))
total_test = num_real_test + num_fake_test

Format data, load images and apply rescaling

In [None]:
test_data_gen = ImageDataGenerator(rescale=1. / 255).flow_from_directory(batch_size=batch_size, 
                                                                         directory=test_dir,
                                                                         target_size=(IMG_HEIGHT, IMG_WIDTH),
                                                                         shuffle=False)

Load model

In [None]:
model_url = "https://tfhub.dev/google/bit/m-r50x1/1"
module = hub.KerasLayer(model_url)

class R50x1BiTModel(tf.keras.Model):
    def __init__(self, module):
        super().__init__()
        self.head = tf.keras.layers.Dense(2, activation='softmax', name='Classifcation')
        self.model = module
    
    def call(self, images):
        # No need to cut head off since we are using feature extractor model
        bit_embedding = self.model(images)
        return self.head(bit_embedding)

model = R50x1BiTModel(module)

optimizer = tf.keras.optimizers.SGD(learning_rate=3e-7, momentum=0.9)

model.compile(optimizer=optimizer,
              loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.load_weights('../checkpoints/ResNet50_base-v3/')

Predict

In [None]:
predictions = model.predict(test_data_gen, verbose=1)

## Confusion Matrix

In [None]:
predictions_classified = predictions.argmax(axis=-1)
test_data_labels = test_data_gen.classes
cm = confusion_matrix(test_data_labels, predictions_classified)

In [None]:
def plot_confusion_matrix(cm, classes,
                          normalize=False,
                          title='Confusion matrix',
                          cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    print(cm)

    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, cm[i, j],
                 horizontalalignment="center",
                 color="black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')

In [None]:
plot_confusion_matrix(cm, ['fake', 'real'])

Calculate metrics:

In [None]:
accuracy=(cm[0][0] + cm[1][1]) * 100 / total_test
precision=cm[0][0] * 100 / (cm[0][0] + cm[1][0])
recall=cm[0][0] * 100 / (cm[0][0] + cm[0][1])
print("Accuracy: %.2f" % accuracy)
print("Precision: %.2f" % precision)
print("Recall: %.2f" % recall)
