# Classification

This notebook assumes that models directory is located in the same directory as this notebook.  
This notebook assumes that image directory is located under ../classification/samples/images/  

Import packages

In [None]:
import tensorflow as tf
import tensorflow_hub as hub
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from IPython.display import Image
import matplotlib.pyplot as plt
import ipywidgets as widgets
import itertools
import os
import shutil
# os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

## Keras Model

Image size selection

In [None]:
size = widgets.IntText(
    description='Image Dim.:',
    disabled=False
)
display(size)

Set up parameters

In [None]:
batch_size=1
sample_path = '../classification/samples/'
IMG_HEIGHT = size.value
IMG_WIDTH = size.value

Format data, load images and apply rescaling

In [None]:
sample_image_gen = ImageDataGenerator(rescale=1. / 255).flow_from_directory(batch_size=batch_size,
                                                                            directory=sample_path,
                                                                            target_size=(IMG_HEIGHT, IMG_WIDTH))
                                                                        

Load model

In [None]:
model_url = "https://tfhub.dev/google/bit/m-r50x1/1"
module = hub.KerasLayer(model_url)

class R50x1BiTModel(tf.keras.Model):
    def __init__(self, module):
        super().__init__()
        self.head = tf.keras.layers.Dense(2, activation='softmax', name='Classifcation')
        self.model = module
    
    def call(self, images):
        # No need to cut head off since we are using feature extractor model
        bit_embedding = self.model(images)
        return self.head(bit_embedding)

model = R50x1BiTModel(module)

optimizer = tf.keras.optimizers.SGD(learning_rate=3e-7, momentum=0.9)

model.compile(optimizer=optimizer,
              loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.load_weights('../checkpoints/ResNet50_base/')

Predict

In [None]:
predictions_probabilities = model.predict(sample_image_gen, batch_size=None, verbose=1)

## Classify

Seperate results amongst classes

In [None]:
files = []
for file_dir in sample_image_gen.filepaths:
    files.append(os.path.split(file_dir)[1])

results = zip(predictions_probabilities, predictions_probabilities.argmax(axis=-1), files, list(sample_image_gen.filepaths))
fakes, reals = [], []

for element in results:
    if element[1] == 0:
        fakes.append(element)
    else:
        reals.append(element)

Move files to respective directories

In [None]:
fake_dir = '../classification/results/fake/'
real_dir = '../classification/results/real/'


for element in fakes:
    shutil.move(element[3], fake_dir)

for element in reals:
    shutil.move(element[3], real_dir)

Display most probable fake images (Top 10)

In [None]:
# Sort based on probabilities
fakes.sort(key=lambda y: y[0][0], reverse=True)
    
for i in range(len(fakes)):
    if i == 10:
        break
    display(Image(filename=os.path.join(fake_dir, fakes[i][2])))