In [1]:
import os
import src.data.data as data

IMG_SIZE = 224
BATCH_SIZE = 64
NO_OF_OUTPUT_CLASSES = 27

TRAIN_DIR = data.get_output_dir(IMG_SIZE, IMG_SIZE, True, False, "train")
VALIDATION_DIR = data.get_output_dir(IMG_SIZE, IMG_SIZE, True, False, "test")

CHECKPOINT_DIR = os.path.join("data", "models", "cnn_mobilenetv2")
CHECKPOINT_PATH = os.path.join(CHECKPOINT_DIR, "cp_{epoch:04d}.ckpt")

In [2]:
import tensorflow_hub as hub
from tensorflow.keras import layers, Sequential, regularizers

# Download mobilenet model from tensorflow hub
URL = "https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/2"
feature_extractor = hub.KerasLayer(URL,
                                   input_shape=(IMG_SIZE, IMG_SIZE,3),
                                   trainable=True)

model = Sequential([
    layers.InputLayer(input_shape=(IMG_SIZE, IMG_SIZE, 3)),
    layers.Rescaling(1./255),
    feature_extractor,
    layers.Dropout(rate=0.2),
    layers.Dense(NO_OF_OUTPUT_CLASSES,
                 kernel_regularizer=regularizers.l2(0.0001))
])
model.build((None, IMG_SIZE, IMG_SIZE, 3))

In [7]:
import tensorflow as tf

train_ds = tf.keras.utils.image_dataset_from_directory(TRAIN_DIR,
                                                             image_size=(IMG_SIZE, IMG_SIZE),
                                                             batch_size=BATCH_SIZE,
                                                             label_mode="categorical")

val_ds = tf.keras.utils.image_dataset_from_directory(VALIDATION_DIR,
                                                                  image_size=(IMG_SIZE, IMG_SIZE),
                                                                  batch_size=BATCH_SIZE,
                                                                  label_mode="categorical")

AUTOTUNE = tf.data.AUTOTUNE

train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

Found 67932 files belonging to 27 classes.
Found 16984 files belonging to 27 classes.


In [15]:
print(int(steps_per_epoch))
print(steps_per_epoch * BATCH_SIZE)

1061
67932.0


In [17]:
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.train import latest_checkpoint

model.compile(
  optimizer=SGD(learning_rate=0.005, momentum=0.9), 
  loss=CategoricalCrossentropy(from_logits=True, label_smoothing=0.1),
  metrics=['accuracy'])


latest = latest_checkpoint(CHECKPOINT_DIR)
if(latest is not None):
    print("Loading checkpoint", latest)
    model.load_weights(latest)
else:
    print("No checkpoint to load")

# This callback will stop the training when there is no improvement in
# the accuracy for three consecutive epochs.
cp_callbacks = [
    EarlyStopping(monitor='val_accuracy', patience=3),
    ModelCheckpoint(CHECKPOINT_PATH,
                    save_weights_only=True,
                    verbose=1)
    ]

steps_per_epoch = int(67932 / BATCH_SIZE)
validation_steps = int(16984 / BATCH_SIZE)
hist = model.fit(
    train_ds.repeat(),
    epochs=100, steps_per_epoch=steps_per_epoch,
    validation_data=val_ds.repeat(),
    validation_steps=validation_steps,
    callbacks=cp_callbacks).history

Loading checkpoint data\models\cnn_mobilenetv2\cp_0002.ckpt
Epoch 1/100
Epoch 1: saving model to data\models\cnn_mobilenetv2\cp_0001.ckpt
Epoch 2/100
  51/1061 [>.............................] - ETA: 5:24 - loss: 1.3175 - accuracy: 0.8002

KeyboardInterrupt: 