In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds

from google.colab import drive
drive.mount('/content/drive')

## Parameters Configuration

In [None]:
IMG_SIZE = 224
IMG_SHAPE = (IMG_SIZE, IMG_SIZE, 3)

BATCH_SIZE= 32
SHUFFLE_BUFFER_SIZE= 1000

EPOCHS = 8

#Adjust the paths as needed
BASE_PATH = '/content/drive/MyDrive/Computer-Vision'
DATASET_PATH = BASE_PATH + '/Dataset/'
MODEL_SAVE_PATH = BASE_PATH + '/Models/FineTuning/V1-Batch32.h5'

## Pre-Processing Function

In [None]:
def format_image(image , label):
    image = tf.cast(image, tf.float32)
    image = (image/127.5) - 1
    image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
    return image, label

## Data Loading

In [None]:
# The script will assume to find three subfolders in the DATASET_PATH :
#  - train 
#  - val (for validation data)
#  - test
builder = tfds.folder_dataset.ImageFolder(DATASET_PATH)
raw_train = builder.as_dataset(split='train', as_supervised=True, shuffle_files=True)
raw_validation = builder.as_dataset(split='val', as_supervised=True)

info = builder.info
label_names = info.features['label'].names

print("Total training images: {}  ".format(len(raw_train)) )
print("Total validation images: {} ".format(len(raw_validation)))
print("Label names: {}".format(info.features['label'].names))

#Pre-process all the images in order to match the expected pre-requisites for the input of MobileNetV2
train = raw_train.map(format_image)
validation = raw_validation.map(format_image)

train_batches = train.shuffle(SHUFFLE_BUFFER_SIZE).batch(BATCH_SIZE)
validation_batches = validation.batch(BATCH_SIZE)

## CNN model definition

In [None]:
def build_model(model_version, layer_to_freeze):
    #base model
    base_model = tf.keras.applications.mobilenet_v2.MobileNetV2(include_top=False, weights='imagenet', input_shape=IMG_SHAPE)
    num_layers_to_unfreeze = layer_to_freeze

    for layer in base_model.layers:
        layer.trainable = False

    for layer in base_model.layers[-num_layers_to_unfreeze:]:
        layer.trainable = True

    if (model_version == 'V1'):
        average_layer = tf.keras.layers.GlobalAveragePooling2D()
        first_dense_layer = tf.keras.layers.Dense(units=256, activation='relu')
        middle_dense_layer = tf.keras.layers.Dense(units=128, activation='relu')
        prediction_layer = tf.keras.layers.Dense(units=5, activation='softmax')

        new_model = tf.keras.Sequential([
            base_model,
            average_layer,
            first_dense_layer,
            middle_dense_layer,
            prediction_layer
        ])

        return(new_model)
    elif (model_version == 'V2'):
        average_layer = tf.keras.layers.GlobalAveragePooling2D()
        first_dense_layer = tf.keras.layers.Dense(units=128, activation='relu')
        prediction_layer = tf.keras.layers.Dense(units=5, activation='softmax')

        new_model = tf.keras.Sequential([
            base_model,
            average_layer,
            first_dense_layer,
            prediction_layer
        ])

        return(new_model)
    elif (model_version == 'V3'):
        average_layer = tf.keras.layers.GlobalAveragePooling2D()
        new_model = tf.keras.Sequential([
            base_model,
            average_layer,
            prediction_layer
        ])

        return(new_model)

In [None]:
VERSION = 'V1'

new_model = build_model(VERSION, 5)
new_model.compile(optimizer=tf.keras.optimizers.Adam(), loss=tf.keras.losses.SparseCategoricalCrossEntropy(), metrics=['accuracy'])

new_model.summary()

## Model Training

In [None]:
#Training for fine-tuning
early_stop_callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)

check_point_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=MODEL_SAVE_PATH,
    monitor='val_loss',
    save_best_only=False,
    save_weights_only=False,
    mode='min',
    verbose=1
)

history = new_model.fit(train_batches, epochs=EPOCHS, validation_data=validation_batches, verbose=1, callbacks=[early_stop_callback, check_point_callback])
acc = history.history['accuracy']
print(acc)