## ResNet50 IMAGE CLASSICATION NETWORK

Import packages

In [None]:
import tensorflow as tf
import tensorflow_hub as hub
from tensorflow.keras.layers import Dense, Input
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import os, datetime
# os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

Assign training and validation paths to variables

In [None]:
PATH = '/../../datasets/faces/'
train_dir = os.path.join(PATH, 'train')
validation_dir = os.path.join(PATH, 'validation')
train_real_dir = os.path.join(train_dir, 'real/')
train_fake_dir = os.path.join(train_dir, 'fake/')
validation_real_dir = os.path.join(validation_dir, 'real/')
validation_fake_dir = os.path.join(validation_dir, 'fake/')

Compute and verify size of validation and training sets

In [None]:
num_real_tr = len(os.listdir(train_real_dir))
num_fake_tr = len(os.listdir(train_fake_dir))
num_real_val = len(os.listdir(validation_real_dir))
num_fake_val = len(os.listdir(validation_fake_dir))
total_train = num_real_tr + num_fake_tr
total_val = num_fake_val + num_fake_val

Set up variables for pre-processing dataset and training network

In [None]:
batch_size = 32
epochs = 50
IMG_HEIGHT = 1024
IMG_WIDTH = 1024

Data formatter

In [None]:
train_image_generator = ImageDataGenerator(rescale=1. / 255,
                                           rotation_range=45,
                                           width_shift_range=.15,
                                           height_shift_range=.15,
                                           horizontal_flip=True,
                                           zoom_range=0.5)
validation_image_generator = ImageDataGenerator(rescale=1. / 255)

Load images and apply rescaling

In [None]:
train_data_gen = train_image_generator.flow_from_directory(batch_size=batch_size,
                                                           directory=train_dir,
                                                           shuffle=True,
                                                           target_size=(IMG_HEIGHT, IMG_WIDTH))
val_data_gen = validation_image_generator.flow_from_directory(batch_size=batch_size,
                                                              directory=validation_dir,
                                                              target_size=(IMG_HEIGHT, IMG_WIDTH))

Create and compile models

In [None]:
model_url = "https://tfhub.dev/google/bit/m-r50x1/1"
module = hub.KerasLayer(model_url)

class R50x1BiTModel(tf.keras.Model):
    def __init__(self, module):
        super().__init__()
        self.head = tf.keras.layers.Dense(2, activation='softmax', name='Classifcation')
        self.model = module
    
    def call(self, images):
        # No need to cut head off since we are using feature extractor model
        bit_embedding = self.model(images)
        return self.head(bit_embedding)

model = R50x1BiTModel(module)

lr = 0.003
total_steps = (total_train // batch_size) * epochs
SCHEDULE_BOUNDARIES = [int(total_steps*i) for i in [0.30, 0.60, 0.90]]

lr_schedule = tf.keras.optimizers.schedules.PiecewiseConstantDecay(boundaries=SCHEDULE_BOUNDARIES,
                                                                  values=[lr, lr*0.1, lr*0.001, lr*0.0001])
optimizer = tf.keras.optimizers.SGD(learning_rate=lr_schedule, momentum=0.9)

inputs = Input(shape=(IMG_HEIGHT, IMG_HEIGHT, 3))
model._set_inputs(inputs)

model.compile(optimizer=optimizer,
              loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=['accuracy'])

# model.load_weights('../../checkpoints/ResNet50_base/')

Define checkpoints

In [None]:
checkpoint_dir = '../../checkpoints/ResNet50_base/'
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_dir,
                                                save_weights_only=True,
                                                save_best_only=True,
                                                monitor='val_accuracy',
                                                verbose=1)

Create TensorBoard callback

In [None]:
log_dir = "../../log/ResNet50_base/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)

Create EarlyStopping callback

In [None]:
# es_callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5)

Display model summary

In [None]:
# model.summary()

Train model using fit

In [None]:
history = model.fit(
    train_data_gen,
    steps_per_epoch=total_train // batch_size,
    epochs=epochs,
    validation_data=val_data_gen,
    validation_steps=total_val // batch_size,
    callbacks=[cp_callback, tensorboard_callback],
    verbose=1
)

In [None]:
model.save_weights('../../checkpoints/ResNet50_base/')

In [None]:
model.save('../../saved_models/ResNet50_base/', save_format='tf')