# Imports

In [1]:
import sklearn.utils.class_weight as wgt
import numpy as np
import tensorflow as tf
import os
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications.vgg16 import preprocess_input
import sys
sys.path.append('../../')
import utils
import custom_metrics
tfk = tf.keras 
tfkl = tfk.layers

Init Plugin
Init Graph Optimizer
Init Kernel


# Setting seed for reproducibility

In [2]:
# Setting random seed
seed = 17560
np.random.seed(seed)
tf.random.set_seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
tf.compat.v1.set_random_seed(seed)

# Model Building

In [3]:
training_dir = os.path.join('.', 'training')
validation_dir = os.path.join('.', 'validation')

batch_size = 64
epochs = 1000
patience = 10
input_shape = (256, 256, 3)
neurons = [128]
hiddens = 1

train_data_gen = ImageDataGenerator(preprocessing_function=preprocess_input,
                                   rotation_range=45,
                                   zoom_range=0.2,
                                   horizontal_flip=True,
                                   vertical_flip=True,
                                   height_shift_range=0.2,
                                   width_shift_range=0.2)

valid_data_gen = ImageDataGenerator(preprocessing_function=preprocess_input,
                                   rotation_range=45,
                                   zoom_range=0.2,
                                   horizontal_flip=True,
                                   vertical_flip=True,
                                   height_shift_range=0.2,
                                   width_shift_range=0.2)

train_gen = train_data_gen.flow_from_directory(directory=training_dir,
                                               target_size=input_shape[:2],
                                               interpolation='bilinear',
                                               color_mode='rgb',
                                               batch_size=batch_size,
                                               class_mode='categorical',
                                               classes=None,
                                               shuffle=True,
                                               seed=seed)  

valid_gen = valid_data_gen.flow_from_directory(directory=validation_dir,
                                               target_size=input_shape[:2],
                                               interpolation='bilinear',
                                               color_mode='rgb',
                                               class_mode='categorical',
                                               batch_size=batch_size,
                                               classes=None,
                                               shuffle=False,
                                               seed=seed)  
                                        

Found 14176 images belonging to 14 classes.
Found 3552 images belonging to 14 classes.


In [6]:
tl_model = utils.build_tl_vgg_model(hiddens=hiddens, neurons=neurons, input_shape=input_shape, seed=seed)

# Compile the model
tl_model.compile(loss=custom_metrics.categorical_focal_loss(), optimizer=tfk.optimizers.Adam(), metrics=utils.metrics())
tl_model.summary()

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_6 (InputLayer)         [(None, 256, 256, 3)]     0         
_________________________________________________________________
vgg16 (Functional)           (None, 8, 8, 512)         14714688  
_________________________________________________________________
Flattening (Flatten)         (None, 32768)             0         
_________________________________________________________________
dropout_2 (Dropout)          (None, 32768)             0         
_________________________________________________________________
dense_4 (Dense)              (None, 128)               4194432   
_________________________________________________________________
dense_5 (Dense)              (None, 14)                1806      
Total params: 18,910,926
Trainable params: 4,196,238
Non-trainable params: 14,714,688
_________________________________________

# Training and Results Visualization

In [None]:
tl_history = tl_model.fit(
    x = train_gen,
    batch_size = batch_size,
    epochs = epochs,
    validation_data = valid_gen,
    callbacks = callbacks("val_f1_m", "max", patience, "custom_weights_cat_focal_loss_model")
).history

In [None]:
name = 'focal_' + utils.get_name_model(hiddens=hiddens, neurons=neurons)
tl_model.save(name)
del tl_model

# Fine Tuning

In [22]:
# Set all VGG layers to True
ft_model = tfk.models.load_model(name, custom_objects={'f1_m': metrics.f1_m, 'precision_m': metrics.precision_m, 'recall_m': metrics.recall_m, 'categorical_focal_loss': metrics.categorical_focal_loss, 'focal_loss': metrics.focal_loss})
ft_model.get_layer('vgg16').trainable = True # we get the layer by name, in the summary there are the names of all the names of the layers
for i, layer in enumerate(ft_model.get_layer('vgg16').layers):
    print(i, layer.name, layer.trainable) # all trainable layers

# Freeze first N layers, e.g., until 14th
for i, layer in enumerate(ft_model.get_layer('vgg16').layers[:14]):
    layer.trainable=False # Tipically the forst layers detect edges so we don't want to retrain these layers
for i, layer in enumerate(ft_model.get_layer('vgg16').layers):
    print(i, layer.name, layer.trainable)
ft_model.summary()

ft_model.compile(loss=categorical_focal_loss(), optimizer=tfk.optimizers.Adam(1e-4), metrics=metrics())

ft_history = ft_model.fit(
    x = train_gen,
    batch_size = batch_size,
    epochs = epochs,
    validation_data = valid_gen,
    callbacks = callbacks("val_f1_m", "max", patience)
).history

0 input_1 True
1 block1_conv1 True
2 block1_conv2 True
3 block1_pool True
4 block2_conv1 True
5 block2_conv2 True
6 block2_pool True
7 block3_conv1 True
8 block3_conv2 True
9 block3_conv3 True
10 block3_pool True
11 block4_conv1 True
12 block4_conv2 True
13 block4_conv3 True
14 block4_pool True
15 block5_conv1 True
16 block5_conv2 True
17 block5_conv3 True
18 block5_pool True
0 input_1 False
1 block1_conv1 False
2 block1_conv2 False
3 block1_pool False
4 block2_conv1 False
5 block2_conv2 False
6 block2_pool False
7 block3_conv1 False
8 block3_conv2 False
9 block3_conv3 False
10 block3_pool False
11 block4_conv1 False
12 block4_conv2 False
13 block4_conv3 False
14 block4_pool True
15 block5_conv1 True
16 block5_conv2 True
17 block5_conv3 True
18 block5_pool True
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_2 (InputLayer)         [(None, 256, 256, 3)]     0         
_______________

2021-11-27 03:24:18.060889: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.




2021-11-27 03:27:55.691665: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.


INFO:tensorflow:Assets written to: ./ckpts-focal/cp_last.ckpt/assets
INFO:tensorflow:Assets written to: ./ckpts-focal/cp_best.ckpt/assets
Epoch 2/1000
INFO:tensorflow:Assets written to: ./ckpts-focal/cp_last.ckpt/assets
INFO:tensorflow:Assets written to: ./ckpts-focal/cp_best.ckpt/assets
Epoch 3/1000
INFO:tensorflow:Assets written to: ./ckpts-focal/cp_last.ckpt/assets
Epoch 4/1000
INFO:tensorflow:Assets written to: ./ckpts-focal/cp_last.ckpt/assets
Epoch 5/1000
INFO:tensorflow:Assets written to: ./ckpts-focal/cp_last.ckpt/assets
INFO:tensorflow:Assets written to: ./ckpts-focal/cp_best.ckpt/assets
Epoch 6/1000
INFO:tensorflow:Assets written to: ./ckpts-focal/cp_last.ckpt/assets
INFO:tensorflow:Assets written to: ./ckpts-focal/cp_best.ckpt/assets
Epoch 7/1000
INFO:tensorflow:Assets written to: ./ckpts-focal/cp_last.ckpt/assets
INFO:tensorflow:Assets written to: ./ckpts-focal/cp_best.ckpt/assets
Epoch 8/1000
INFO:tensorflow:Assets written to: ./ckpts-focal/cp_last.ckpt/assets
INFO:tensorf

INFO:tensorflow:Assets written to: ./ckpts-focal/cp_last.ckpt/assets
Epoch 13/1000
INFO:tensorflow:Assets written to: ./ckpts-focal/cp_last.ckpt/assets
INFO:tensorflow:Assets written to: ./ckpts-focal/cp_best.ckpt/assets
Epoch 14/1000
INFO:tensorflow:Assets written to: ./ckpts-focal/cp_last.ckpt/assets
Epoch 15/1000
INFO:tensorflow:Assets written to: ./ckpts-focal/cp_last.ckpt/assets
Epoch 16/1000
INFO:tensorflow:Assets written to: ./ckpts-focal/cp_last.ckpt/assets
Epoch 17/1000
INFO:tensorflow:Assets written to: ./ckpts-focal/cp_last.ckpt/assets
Epoch 18/1000
INFO:tensorflow:Assets written to: ./ckpts-focal/cp_last.ckpt/assets
Epoch 19/1000
INFO:tensorflow:Assets written to: ./ckpts-focal/cp_last.ckpt/assets
INFO:tensorflow:Assets written to: ./ckpts-focal/cp_best.ckpt/assets
Epoch 20/1000
INFO:tensorflow:Assets written to: ./ckpts-focal/cp_last.ckpt/assets
Epoch 21/1000
INFO:tensorflow:Assets written to: ./ckpts-focal/cp_last.ckpt/assets
Epoch 22/1000
INFO:tensorflow:Assets written t

Epoch 25/1000
INFO:tensorflow:Assets written to: ./ckpts-focal/cp_last.ckpt/assets
Epoch 26/1000
INFO:tensorflow:Assets written to: ./ckpts-focal/cp_last.ckpt/assets
Epoch 27/1000
INFO:tensorflow:Assets written to: ./ckpts-focal/cp_last.ckpt/assets
INFO:tensorflow:Assets written to: ./ckpts-focal/cp_best.ckpt/assets
Epoch 28/1000
INFO:tensorflow:Assets written to: ./ckpts-focal/cp_last.ckpt/assets
Epoch 29/1000
INFO:tensorflow:Assets written to: ./ckpts-focal/cp_last.ckpt/assets
Epoch 30/1000
INFO:tensorflow:Assets written to: ./ckpts-focal/cp_last.ckpt/assets
Epoch 31/1000
INFO:tensorflow:Assets written to: ./ckpts-focal/cp_last.ckpt/assets
Epoch 32/1000
INFO:tensorflow:Assets written to: ./ckpts-focal/cp_last.ckpt/assets
Epoch 33/1000
INFO:tensorflow:Assets written to: ./ckpts-focal/cp_last.ckpt/assets
Epoch 34/1000
INFO:tensorflow:Assets written to: ./ckpts-focal/cp_last.ckpt/assets
Epoch 35/1000
INFO:tensorflow:Assets written to: ./ckpts-focal/cp_last.ckpt/assets
Epoch 36/1000
INFO

In [None]:
ft_model.save(name + '_ft')