# Model training
Here, models are ResNet50, InceptionV3, DenseNet121, MobileNetV2 and SqueezeNet


## Importing modules

In [None]:
from utils.miscellaneous import read_config
from copy import deepcopy

import utils.train_val_test_dataset_import as tvt
import utils.class_imbalances as ci
import utils.plots as plot
import models.Models as models
import tensorflow as tf
import matplotlib.pyplot as plt


## Parse configuration file + initializations

In [None]:
# read config files
cfg = read_config('./config.yaml')

# constants
image_height = cfg['image_height']
image_width = cfg['image_width']
batch_size = cfg['batch_size']['tra']
num_epochs = cfg['trainParams']['num_epochs']
lr_rate = cfg['adamParams']['lr']
num_classes = cfg['num_classes']

# paths
path_train = cfg['Path']['path_train']
path_val = cfg['Path']['path_val']

# load datasets
ds_train, ds_val = tvt.import_dataset_train_val(
    path_train, path_val, image_height, image_width, batch_size)

# autotune
AUTOTUNE = tf.data.AUTOTUNE
ds_train = ds_train.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
ds_val = ds_val.cache().prefetch(buffer_size=AUTOTUNE)

# class weights
class_weights_train = ci.class_weights_4(path_train)

# paths to model and checkpoint file save
save_model_path_fromscratch = cfg['Path']['save_model_path_fromscratch']
save_ckp_path_fromscratch = cfg['Path']['save_ckp_path_fromscratch']
save_model_path_TL_classifier = cfg['Path']['save_model_path_TL_classifier']
save_ckp_path_TL_classifier = cfg['Path']['save_ckp_path_TL_classifier']
save_model_path_TL_all = cfg['Path']['save_model_path_TL_all']
save_ckp_path_TL_all = cfg['Path']['save_ckp_path_TL_all']

## Training models

### (1) Training models from scratch

In [None]:
# SqueezeNet
r_squeeze = models.SqueezeN(num_classes, 'scratch', class_weights=class_weights_train, save_model_path=save_model_path_fromscratch[1],
    save_ckp_path=save_ckp_path_fromscratch[1],
    image_height=image_height, image_width=image_width,
    ds_train=ds_train, ds_val=ds_val, lr_rate=lr_rate, num_epochs=num_epochs)
val_acc = r_squeeze.history['val_accuracy']
print("Best Validation Accuracy is", max(val_acc))

In [None]:
# Resnet50
r_resnet_s = models.ResN50(num_classes, 'scratch', class_weights=class_weights_train, save_model_path=save_model_path_fromscratch[2],
     save_ckp_path=save_ckp_path_fromscratch[2],
     image_height=image_height, image_width=image_width,
     ds_train=ds_train, ds_val=ds_val, lr_rate=lr_rate, num_epochs=num_epochs)
val_acc = r_resnet_s.history['val_accuracy']
print("Best Validation Accuracy is", max(val_acc))

In [None]:
# InceptionV3
r_inception_s = models.IncV3(num_classes, 'scratch', class_weights=class_weights_train, save_model_path=save_model_path_fromscratch[3],
    save_ckp_path=save_ckp_path_fromscratch[3],
    image_height=image_height, image_width=image_width,
    ds_train=ds_train, ds_val=ds_val, lr_rate=lr_rate, num_epochs=num_epochs)
val_acc = r_inception_s.history['val_accuracy']
print("Best Validation Accuracy is", max(val_acc))


In [None]:
# DenseNet121
r_dense_s = models.DenseN121(num_classes, 'scratch', class_weights=class_weights_train, save_model_path=save_model_path_fromscratch[4],
                       save_ckp_path=save_ckp_path_fromscratch[4],
                       image_height=image_height, image_width=image_width,
                       ds_train=ds_train, ds_val=ds_val, lr_rate=lr_rate, num_epochs=num_epochs)
val_acc = r_dense_s.history['val_accuracy']
print("Best Validation Accuracy is", max(val_acc))

In [None]:
# MobileNetV2
r_mobileV2 = models.MNetV2(num_classes, 'scratch', class_weights=class_weights_train, save_model_path=save_model_path_fromscratch[0],
   save_ckp_path=save_ckp_path_fromscratch[0],
   image_height=image_height, image_width=image_width,
   ds_train=ds_train, ds_val=ds_val, lr_rate=lr_rate, num_epochs=num_epochs)
val_acc = r_mobileV2.history['val_accuracy']
print("Best Validation Accuracy is", max(val_acc))

### (2) Training models using the FTC strategy
pre-train models on ImageNet and then only fine-tune the classifier on the train sets (freeze the Convolutional base)

In [None]:
# SqueezeNet
r_squeeze = models.SqueezeN(num_classes, 'TL_classifier', class_weights=class_weights_train, save_model_path=save_model_path_TL_classifier[1], 
    save_ckp_path=save_ckp_path_TL_classifier[1],
    image_height=image_height, image_width=image_width, 
    ds_train=ds_train, ds_val=ds_val, lr_rate=lr_rate, num_epochs=num_epochs)
val_acc = r_squeeze.history['val_accuracy']
print("Best Validation Accuracy is", max(val_acc))

In [None]:
# ResNet50
r_resnet_c = models.ResN50(num_classes, 'TL_classifier', class_weights=class_weights_train, save_model_path=save_model_path_TL_classifier[2], 
     save_ckp_path=save_ckp_path_TL_classifier[2],
     image_height=image_height, image_width=image_width, 
     ds_train=ds_train, ds_val=ds_val, lr_rate=lr_rate, num_epochs=num_epochs)
val_acc = r_resnet_c.history['val_accuracy']
print("Best Validation Accuracy is", max(val_acc))


In [None]:
# InceptionV3
r_inceptionr_c = models.IncV3(num_classes, 'TL_classifier', class_weights=class_weights_train, save_model_path=save_model_path_TL_classifier[3], 
    save_ckp_path=save_ckp_path_TL_classifier[3],
    image_height=image_height, image_width=image_width, 
    ds_train=ds_train, ds_val=ds_val, lr_rate=lr_rate, num_epochs=num_epochs)
val_acc = r_inceptionr_c.history['val_accuracy']
print("Best Validation Accuracy is", max(val_acc))


In [None]:
# DenseNet121
r_dense_c = models.DenseN121(num_classes, 'TL_classifier', class_weights=class_weights_train, save_model_path=save_model_path_TL_classifier[4], 
    save_ckp_path=save_ckp_path_TL_classifier[4],
    image_height=image_height, image_width=image_width, 
    ds_train=ds_train, ds_val=ds_val, lr_rate=lr_rate, num_epochs=num_epochs)
val_acc = r_dense_c.history['val_accuracy']
print("Best Validation Accuracy is", max(val_acc))

In [None]:
# MobileNetV2
r_mobileV2 = models.MNetV2(num_classes, 'TL_classifier', class_weights=class_weights_train, save_model_path=save_model_path_TL_classifier[0], 
   save_ckp_path=save_ckp_path_TL_classifier[0],
   image_height=image_height, image_width=image_width, 
   ds_train=ds_train, ds_val=ds_val, lr_rate=lr_rate, num_epochs=num_epochs)
val_acc = r_mobileV2.history['val_accuracy']
print("Best Validation Accuracy is", max(val_acc))

### (3) Training models using the FTAL strategy
pre-train models on ImageNet and fine-tune all layers on train sets

In [None]:
# SqueezeNet
r_squeeze = models.SqueezeN(num_classes, 'TL_all', class_weights=class_weights_train, save_model_path=save_model_path_TL_all[1], 
    save_ckp_path=save_ckp_path_TL_all[1], 
    image_height=image_height, image_width=image_width, 
    ds_train=ds_train, ds_val=ds_val, lr_rate=lr_rate, num_epochs=num_epochs)
val_acc = r_squeeze.history['val_accuracy']
print("Best Validation Accuracy is", max(val_acc))

In [None]:
## ResNet50
r_resnet_a = models.ResN50(num_classes, 'TL_all', class_weights=class_weights_train, save_model_path=save_model_path_TL_all[2], 
     save_ckp_path=save_ckp_path_TL_all[2], 
     image_height=image_height, image_width=image_width, 
     ds_train=ds_train, ds_val=ds_val, lr_rate=lr_rate, num_epochs=num_epochs)
val_acc = r_resnet_a.history['val_accuracy']
print("Best Validation Accuracy is", max(val_acc))


In [None]:
# InceptionV3
r_inception_a = models.IncV3(num_classes, 'TL_all', class_weights=class_weights_train, save_model_path=save_model_path_TL_all[3], 
    save_ckp_path=save_ckp_path_TL_all[3], 
    image_height=image_height, image_width=image_width, 
    ds_train=ds_train, ds_val=ds_val, lr_rate=lr_rate, num_epochs=num_epochs)
val_acc = r_inception_a.history['val_accuracy']
print("Best Validation Accuracy is", max(val_acc))


In [None]:
# DenseNet121
r_dense_a = models.DenseN121(num_classes, 'TL_all', class_weights=class_weights_train, save_model_path=save_model_path_TL_all[4], 
    save_ckp_path=save_ckp_path_TL_all[4], 
    image_height=image_height, image_width=image_width, 
    ds_train=ds_train, ds_val=ds_val, lr_rate=lr_rate, num_epochs=num_epochs)
val_acc = r_dense_a.history['val_accuracy']
print("Best Validation Accuracy is", max(val_acc))

In [None]:
# MobileNetV2
r_mobileV2 = models.MNetV2(num_classes, 'TL_all', class_weights=class_weights_train, save_model_path=save_model_path_TL_all[0], 
   save_ckp_path=save_ckp_path_TL_all[0], 
   image_height=image_height, image_width=image_width, 
   ds_train=ds_train, ds_val=ds_val, lr_rate=lr_rate, num_epochs=num_epochs)
val_acc = r_mobileV2.history['val_accuracy']
print("Best Validation Accuracy is", max(val_acc))


### Plotting accuracy and loss

In [None]:
# Plot model accuracy and loss

# plot.plot_hist(hist=r_mobileV2, model_name="MobileNetV2")
# plot.plot_hist(hist=r_squeeze, model_name="SqueezeNet")
# plot.plot_hist(hist=r_resnet_s, model_name="ResNet50")
# plot.plot_hist(hist=r_resnet_c, model_name="ResNet50")
# plot.plot_hist(hist=r_resnet_a, model_name="ResNet50")
# plot.plot_hist(hist=r_inception_s, model_name='InceptionV3')
# plot.plot_hist(hist=r_inceptionr_c, model_name='InceptionV3')
# plot.plot_hist(hist=r_inception_a, model_name='InceptionV3')
# plot.plot_hist(hist=r_dense_s, model_name="DenseNet121")
# plot.plot_hist(hist=r_dense_c, model_name="DenseNet121")
# plot.plot_hist(hist=r_dense_a, model_name="DenseNet121")

### Tensorboard

In [None]:
# load the tensorboard

%load_ext tensorboard

# if the tensorboard page on VS Code is not so clear, 
# you can type this (localhost:6006) on web browser after executing this code 