In [None]:
from google.colab import drive
drive.mount('/gdrive')

In [None]:
# Set the path for training
train_dir = 'splitDataset/train'
# Set the path for validation
val_dir = 'splitDataset/validation'

In [None]:
from keras.layers import Input, Lambda, Dense, Flatten
from keras.models import Model
from keras.applications.xception import preprocess_input
from keras.preprocessing import image
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Sequential
from keras.optimizers import adam_v2

import numpy as np
from glob import glob
import matplotlib.pyplot as plt
import tensorflow.keras as tfk
import tensorflow.keras.layers as tfkl
import random
import os
import tensorflow as tf
from PIL import Image
import logging

In [None]:
# Random seed for reproducibility
seed = 42

random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
tf.random.set_seed(seed)
tf.compat.v1.set_random_seed(seed)

In [None]:
# We create instances of ImageDataGenerator for preprocessing
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# We are augmenting the data and preprocessed them with the Xception preprocessing function 
aug_train_data_gen = ImageDataGenerator( horizontal_flip=True,
                                  rotation_range=90,
                                  vertical_flip=True,
                                  height_shift_range=0.3,
                                  width_shift_range=0.3,
                                  brightness_range=[0.7,1.3],
                                  preprocessing_function=preprocess_input)

noaug_val_data_gen = ImageDataGenerator(preprocessing_function=preprocess_input)

# Generates the dataset for training 
aug_train_gen = aug_train_data_gen.flow_from_directory(directory=train_dir,
                                                           target_size=(256,256),
                                                           color_mode='rgb',
                                                           classes=None,
                                                           class_mode='categorical',
                                                           batch_size=16,
                                                           shuffle=True,
                                                           seed=seed)

# Generates the dataset for validation 
noaug_val_gen = noaug_val_data_gen.flow_from_directory(directory=val_dir,
                                                           target_size=(256,256),
                                                           color_mode='rgb',
                                                           classes=None,
                                                           class_mode='categorical',
                                                           batch_size=16,
                                                           shuffle=True,
                                                           seed=seed)

In [None]:
# Set input shape
input_shape = (256, 256, 3)

# Set the number of epochs for training
epochs = 200

In [None]:

#Instantiates the Xception architecture
base_model = tfk.applications.Xception(weights='imagenet',  # Load weights pre-trained on ImageNet.
    input_shape=input_shape , include_top=False)
base_model.trainable = False

inputs = tfk.Input(shape=input_shape)

# Instantiates the fully connected part
x = base_model(inputs, training=False)
x = tfk.layers.Flatten(name='Flattening')(x)
x = tfkl.Dropout(0.3, seed=seed)(x)
x = tfkl.Dense(256, activation='relu', kernel_initializer = tfk.initializers.GlorotUniform(seed))(x)
classifier_layer = tfkl.Dropout(0.3, seed=seed)(x)
output_layer = tfkl.Dense(units=14, activation='softmax', kernel_initializer=tfk.initializers.GlorotUniform(seed), name='Output')(x)

# Istantiate the Model class
model = tfk.Model(inputs, output_layer)

# Compile the model
model.compile(optimizer=tfk.optimizers.Adam(),
              loss=tfk.losses.CategoricalCrossentropy(),
              metrics=['accuracy',tfk.metrics.AUC(),tfk.metrics.Precision(),tfk.metrics.Recall()])

# Return the model
model.summary()

In [None]:
# Utility function to create folders and callbacks for training
from datetime import datetime

def create_folders_and_callbacks(model_name):

  exps_dir = os.path.join('data_callbacks_experiments')
  if not os.path.exists(exps_dir):
      os.makedirs(exps_dir)

  now = datetime.now().strftime('%b%d_%H-%M-%S')

  exp_dir = os.path.join(exps_dir, model_name + '_' + str(now))
  if not os.path.exists(exp_dir):
      os.makedirs(exp_dir)
      
  callbacks = []

  # Creation of model checkpoint
  ckpt_dir = os.path.join(exp_dir, 'ckpts')
  if not os.path.exists(ckpt_dir):
      os.makedirs(ckpt_dir)

  # Callback to save the model
  ckpt_callback = tf.keras.callbacks.ModelCheckpoint(filepath=os.path.join(ckpt_dir, 'cp.ckpt'), 
                                                     save_weights_only=False, # True to save only weights
                                                     save_best_only=False) # True to save only the best epoch 
  callbacks.append(ckpt_callback)

  # Early Stopping
  es_callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)
  callbacks.append(es_callback)

  return callbacks

In [None]:
# We will first train the Dense layers, so we are going to freeze the convolutional part
for i, layer in enumerate(model.get_layer('xception').layers):
  layer.trainable=False
for i, layer in enumerate(model.get_layer('xception').layers):
   print(i, layer.name, layer.trainable)
model.summary()

In [None]:
# Create folders and callbacks and fit
aug_callbacks = create_folders_and_callbacks(model_name='SteveJobs_transfL')

# Train the model
history = model.fit(
    x = aug_train_gen,
    epochs = epochs,
    validation_data = noaug_val_gen,
    callbacks = aug_callbacks,
    initial_epoch = 55
).history

In [None]:
# Save the transfer learning model
model.save("experiments/SteveJobs_transfL")

In [None]:
model.get_layer('xception').trainable = True

In [None]:
# we decide to freeze the first 126 layers and unfreeze the rest
for i, layer in enumerate(model.get_layer('xception').layers[:126]):
  layer.trainable=False
for i, layer in enumerate(model.get_layer('xception').layers):
   print(i, layer.name, layer.trainable)
# Show the structure of the model
model.summary()

In [None]:
# Create folders and callbacks
aug_callbacks = create_folders_and_callbacks(model_name='SteveJobs_fineTuned')

In [None]:
# Train the model
history = model.fit(
    x = aug_train_gen,
    epochs = epochs,
    validation_data = noaug_val_gen,
    callbacks = aug_callbacks).history

In [None]:
# Save the fine tuned model
model.save("experiments/SteveJobs_fineTuned")