In [None]:
import os
import sys
module_path = os.path.abspath(os.path.join('utils'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
import tensorflow as tf
from tensorflow import keras
from keras import callbacks
import utils.data_utils as data_utils
from utils.models import dcnn_model, mobilenetv3_model, compile_model
import numpy as np
from sklearn.model_selection import KFold
import time

tf.get_logger().setLevel('ERROR')

In [None]:
BASE_LEARNING_RATE = .0001
EPOCHS = 150
INSTRU_VISION_CNN = f"instru-vision-cnn-{int(time.time())}"
MOBILENETV3_FINE_TUNED = f"mobilenetv3-fine-tuned-{int(time.time())}"
EARLY_STOPPING = callbacks.EarlyStopping(monitor = "val_loss",
                            mode = "min",
                            min_delta = .01,
                            patience = 15,
                            restore_best_weights = True)

DCNN_CALLBACKS = [
    callbacks.CSVLogger(
        "../models/instru-vision-cnn/csv/training_history.csv",
        ",",
        append=False),
    callbacks.TensorBoard(log_dir=f"../models/logs/{INSTRU_VISION_CNN}"),
    callbacks.ModelCheckpoint(
    filepath="../models/instru-vision-cnn/saved_model/",
    monitor="val_loss",
    verbose=1,
    save_best_only=True,
    initial_value_threshold=0.9),
    EARLY_STOPPING
    ]

PRETRAINED_CALLBACKS = [
    callbacks.CSVLogger(
        "../models/mobilenetv3/csv/training_history.csv",
        ",",
        append=False),
    callbacks.TensorBoard(log_dir=f"../models/logs/{MOBILENETV3_FINE_TUNED}"),
    callbacks.ModelCheckpoint(
    filepath="../models/mobilenetv3/saved_model/",
    monitor="val_loss",
    verbose=1,
    save_best_only=True,
    initial_value_threshold=0.6)
]

# Dataset loading

In [None]:
training_dataset, validation_dataset = data_utils.get_datasets()

# Manual implementation based on GoogleNet architecture, not pretrained

In [None]:

dcnn_model = dcnn_model()
dcnn_model.summary()
compile_model(dcnn_model)
initial_dcnn_weights = dcnn_model.get_weights()


# Tensorflow MobileNetV3 Large


In [None]:
pretrained_model = mobilenetv3_model()
pretrained_model.summary()
compile_model(pretrained_model)
initial_pretrained_weights = pretrained_model.get_weights()

In [None]:

training_set = training_dataset.unbatch()
images = np.asarray(list(training_set.map(lambda x, y: x)))
labels = np.asarray(list(training_set.map(lambda x, y: y)))

kfold = KFold(n_splits=5, shuffle=True, random_state=42)

fold = 0
dcnn_cross_validation_history = []
pretrained_cross_validation_history = []

for train, validation in kfold.split(images, labels):

    fold += 1

    print(f"Starting {fold}. fold...")

    train_folds = (images[train], labels[train])
    validation_fold = (images[validation], labels[validation])
    
    dcnn_model.set_weights(initial_dcnn_weights)
    pretrained_model.set_weights(initial_pretrained_weights)

    dcnn_history = dcnn_model.fit(train_folds[0], train_folds[1], validation_data=validation_fold,
              verbose=0, epochs=EPOCHS, callbacks=DCNN_CALLBACKS)

    pretrained_history = pretrained_model.fit(train_folds[0], train_folds[1], validation_data=validation_fold,
              verbose=0, epochs=EPOCHS, callbacks=PRETRAINED_CALLBACKS)

    dcnn_val_loss = dcnn_history.history['val_loss']
    dcnn_val_accuracy = dcnn_history.history['val_accuracy']

    pretrained_val_loss = pretrained_history.history['val_loss']
    pretrained_val_accuracy = pretrained_history.history['val_accuracy']

    print(f"Fold {fold} Instru-vision-cnn score:\n\tval accuracy: {dcnn_val_accuracy[-1]}\n\tval loss: {dcnn_val_loss[-1]} ")
    print(f"Fold {fold} MobileNetV3 score:\n\tval accuracy: {pretrained_val_accuracy[-1]}\n\tval loss: {pretrained_val_loss[-1]} ")

    dcnn_model.save_weights(f"../models/instru-vision-cnn/k-fold/fold{fold}_weights.hdf5")
    pretrained_model.save_weights(f"../models/mobilenetv3/k-fold/fold{fold}_weights.hdf5")
    
    dcnn_cross_validation_history.append(dcnn_history.history)
    pretrained_cross_validation_history.append(pretrained_history.history)

In [None]:
data_utils.compare_and_display_model_results({
    "Instru-vision-cnn Fold 1": dcnn_cross_validation_history[0],
    "MobileNetV3 Fold 1": pretrained_cross_validation_history[0],
    "Instru-vision-cnn Fold 2": dcnn_cross_validation_history[1],
    "MobileNetV3 Fold 2": pretrained_cross_validation_history[1],
    "Instru-vision-cnn Fold 3": dcnn_cross_validation_history[2],
    "MobileNetV3 Fold 3": pretrained_cross_validation_history[2],
    "Instru-vision-cnn Fold 4": dcnn_cross_validation_history[3],
    "MobileNetV3 Fold 4": pretrained_cross_validation_history[3],
    "Instru-vision-cnn Fold 5": dcnn_cross_validation_history[4],
    "MobileNetV3 Fold 5": pretrained_cross_validation_history[4],
})