In [73]:
import os
import math
import random

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from keras.applications.inception_v3 import preprocess_input
from keras.regularizers import l2
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Flatten
from tensorflow.keras import regularizers

In [74]:
tfk = tf.keras
tfkl = tf.keras.layers
print(tf.__version__)

2.10.0


In [75]:
seed = 42

random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
tf.random.set_seed(seed)
tf.compat.v1.set_random_seed(seed)

In [76]:
dataset_dir = 'dataset96'


labels = ['Species1',       # 0
          'Species2',       # 1
          'Species3',       # 2
          'Species4',       # 3
          'Species5',       # 4
          'Species6',       # 5
          'Species7',       # 6
          'Species8',       # 7
          ]


In [77]:
img_w = 96
img_h = 96
input_shape = (96, 96, 3)
classes = 8

class_weights = {0: 2.389358108108108, 
                 1: 0.8320588235294117, 
                 2: 0.8583131067961165, 
                 3: 0.8667279411764706, 
                 4: 0.8340212264150944, 
                 5: 1.9978813559322033, 
                 6: 0.8243006993006993, 
                 7: 0.8709975369458128}
epochs = 30
patience_epochs = 8
batch_size = 24

last_nonTrainable_layer = 207

In [72]:
# Dataset folders 

training_dir = os.path.join(dataset_dir, 'train')
validation_dir = os.path.join(dataset_dir, 'val')
#test_dir = os.path.join(dataset_dir, 'test')

In [47]:
train_data_gen = ImageDataGenerator(rotation_range=20,
                                        height_shift_range=10,
                                        width_shift_range=10,
                                        zoom_range=0.1,
                                        shear_range = 0.2,
                                        horizontal_flip=True,
                                        vertical_flip=True, 
                                        brightness_range=[0.3,1.4],
                                        fill_mode='reflect',
                                        rescale=1/255.)

train_gen = train_data_gen.flow_from_directory(directory=training_dir,
                                               target_size=(96,96),
                                               color_mode='rgb',
                                               classes=labels,
                                               class_mode='categorical',
                                               batch_size=batch_size,
                                               shuffle=True,
                                               seed=seed)

Found 2829 images belonging to 8 classes.


In [48]:
valid_data_gen = ImageDataGenerator(preprocessing_function=preprocess_input)

valid_gen = train_data_gen.flow_from_directory(directory=validation_dir,
                                               target_size=(96,96),
                                               color_mode='rgb',
                                               classes=labels,
                                               class_mode='categorical',
                                               batch_size=batch_size,
                                               shuffle=False,
                                               seed=seed)

Found 713 images belonging to 8 classes.


In [54]:

# Download and plot the InceptionV3 model
supernet = tfk.applications.InceptionV3(
    include_top=False,
    weights="imagenet",
    input_shape=(96,96,3)
)

supernet.trainable = True

for i, layer in enumerate(supernet.layers[:last_nonTrainable_layer]):
  layer.trainable=False

for i, layer in enumerate(supernet.layers):
   print(i, layer.name, layer.trainable)

0 input_9 False
1 conv2d_376 False
2 batch_normalization_376 False
3 activation_376 False
4 conv2d_377 False
5 batch_normalization_377 False
6 activation_377 False
7 conv2d_378 False
8 batch_normalization_378 False
9 activation_378 False
10 max_pooling2d_16 False
11 conv2d_379 False
12 batch_normalization_379 False
13 activation_379 False
14 conv2d_380 False
15 batch_normalization_380 False
16 activation_380 False
17 max_pooling2d_17 False
18 conv2d_384 False
19 batch_normalization_384 False
20 activation_384 False
21 conv2d_382 False
22 conv2d_385 False
23 batch_normalization_382 False
24 batch_normalization_385 False
25 activation_382 False
26 activation_385 False
27 average_pooling2d_36 False
28 conv2d_381 False
29 conv2d_383 False
30 conv2d_386 False
31 conv2d_387 False
32 batch_normalization_381 False
33 batch_normalization_383 False
34 batch_normalization_386 False
35 batch_normalization_387 False
36 activation_381 False
37 activation_383 False
38 activation_386 False
39 activati

In [50]:
def step_decay(epoch):

   initial_lrate = 0.005
   drop = 0.1
   epochs_drop = 10.0

   lrate = initial_lrate * math.pow(drop, math.floor((1+epoch)/epochs_drop))

   return lrate

In [51]:
from datetime import datetime

def create_folders_and_callbacks(model_name):

  exps_dir = os.path.join('callbackSaves')
  if not os.path.exists(exps_dir):
      os.makedirs(exps_dir)

  now = datetime.now().strftime('%b%d_%H-%M-%S')

  exp_dir = os.path.join(exps_dir, model_name + '_' + str(now))
  if not os.path.exists(exp_dir):
      os.makedirs(exp_dir)
      
  callbacks = []

  # Model checkpoint
  # ----------------
  ckpt_dir = os.path.join(exp_dir, 'ckpts')
  if not os.path.exists(ckpt_dir):
      os.makedirs(ckpt_dir)

  ckpt_callback = tf.keras.callbacks.ModelCheckpoint(filepath=os.path.join(ckpt_dir, 'cp.ckpt'), 
                                                     save_weights_only=False, # True to save only weights
                                                     save_best_only=True) # True to save only the best epoch 
  callbacks.append(ckpt_callback)

  # Visualize Learning on Tensorboard
  # ---------------------------------
  tb_dir = os.path.join(exp_dir, 'tb_logs')
  if not os.path.exists(tb_dir):
      os.makedirs(tb_dir)
      
  # By default shows losses and metrics for both training and validation
  tb_callback = tf.keras.callbacks.TensorBoard(log_dir=tb_dir, 
                                               profile_batch=0,
                                               histogram_freq=1)  # if > 0 (epochs) shows weights histograms
  callbacks.append(tb_callback)


  # Early Stopping
  # --------------
  #es_callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10, mode='max', restore_best_weights=True)
  #callbacks.append(es_callback)


  # Learning Rate Scheduler
  # --------------
  LRS_callback = tf.keras.callbacks.LearningRateScheduler(step_decay)
  callbacks.append(LRS_callback)
  

  return callbacks

In [62]:
# Use the supernet as feature extractor

inputs = tfk.Input(shape=input_shape)

x = supernet(inputs)

glob_pooling = tfkl.GlobalAveragePooling2D(name='GlobalPooling')(x)

x = tfkl.Dense(
    128,
    kernel_initializer = tfk.initializers.GlorotUniform(seed)
)(glob_pooling)

leaky_relu_layer = tfkl.LeakyReLU()(x)

x = tfkl.Dropout(0.3, seed=seed)(leaky_relu_layer)

outputs = tfkl.Dense(
    classes, 
    activation='softmax',
    kernel_initializer = tfk.initializers.GlorotUniform(seed),
)(x)


# Connect input and output through the Model class
ft_model = tfk.Model(inputs=inputs, outputs=outputs, name='model')

# Compile the model
ft_model.compile(loss=tfk.losses.CategoricalCrossentropy(), optimizer=tfk.optimizers.SGD(momentum=0.9, decay=0.0005, nesterov=False), metrics=['accuracy', tfk.metrics.Precision(), tfk.metrics.Recall()])
ft_model.summary()

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_12 (InputLayer)       [(None, 96, 96, 3)]       0         
                                                                 
 inception_v3 (Functional)   (None, 1, 1, 2048)        21802784  
                                                                 
 GlobalPooling (GlobalAverag  (None, 2048)             0         
 ePooling2D)                                                     
                                                                 
 dense_10 (Dense)            (None, 128)               262272    
                                                                 
 leaky_re_lu_3 (LeakyReLU)   (None, 128)               0         
                                                                 
 dropout_3 (Dropout)         (None, 128)               0         
                                                             

In [65]:

x = supernet.output
x = Flatten()
#x = GlobalAveragePooling2D()(x)
x = tfkl.Dropout(0.3, seed=seed)(x)
x = tfkl.Dense(
    512, 
    kernel_regularizer=regularizers.l2(0.01),
    activation='relu',
    kernel_initializer = tfk.initializers.GlorotUniform(seed))(x)
#x = tfkl.Dropout(0.3, seed=seed)(x)
predictions = Dense(8, activation='softmax')(x)

# Connect input and output through the Model class
model = tfk.Model(inputs=supernet.inputs, outputs=predictions, name='model')

In [66]:
callbacks = create_folders_and_callbacks(model_name='GoogleNetModel')
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=5e-4),
    loss=tf.keras.losses.CategoricalCrossentropy(),
    metrics=['accuracy'])
history = ft_model.fit(
    x = train_gen,
    batch_size = batch_size,
    epochs = epochs,
    validation_data = valid_gen,
    class_weight = class_weights,
    callbacks = callbacks
).history

ft_model.save("fineTuningModel")

Epoch 1/30



INFO:tensorflow:Assets written to: callbackSaves\GoogleNetModel_Nov16_18-37-01\ckpts\cp.ckpt\assets


INFO:tensorflow:Assets written to: callbackSaves\GoogleNetModel_Nov16_18-37-01\ckpts\cp.ckpt\assets


Epoch 2/30
Epoch 3/30



INFO:tensorflow:Assets written to: callbackSaves\GoogleNetModel_Nov16_18-37-01\ckpts\cp.ckpt\assets


INFO:tensorflow:Assets written to: callbackSaves\GoogleNetModel_Nov16_18-37-01\ckpts\cp.ckpt\assets


Epoch 4/30



INFO:tensorflow:Assets written to: callbackSaves\GoogleNetModel_Nov16_18-37-01\ckpts\cp.ckpt\assets


INFO:tensorflow:Assets written to: callbackSaves\GoogleNetModel_Nov16_18-37-01\ckpts\cp.ckpt\assets


Epoch 5/30
Epoch 6/30
Epoch 7/30

KeyboardInterrupt: 