In [None]:
import tensorflow as tf
import numpy      as np

import directoryFunctions
import pathlib
import config
import data
import time
import os

from tensorflow.keras.applications import InceptionV3
from tensorflow.keras.optimizers   import Adam
from tensorflow.keras.callbacks    import CSVLogger, EarlyStopping, ModelCheckpoint, TensorBoard
from tensorflow.keras.layers       import Dense, GlobalAveragePooling2D
from tensorflow.keras.losses       import CategoricalCrossentropy
from tensorflow.keras.models       import load_model, Sequential

"""
Documentation:
- numpy
    1. array()
        - https://numpy.org/doc/stable/reference/generated/numpy.array.html?highlight=array#numpy.array
    2. ceil()
        - https://numpy.org/doc/stable/reference/generated/numpy.ceil.html?highlight=ceil#numpy.ceil
- os
    1. path.sep
- pathlib
    1. Path(), /, glob()
        - https://docs.python.org/3/library/pathlib.html
- tensorflow
    - data.Dataset
        1. batch(), cache(), list_files(), map(), prefetch(), repeat(), shuffle()
            - https://www.tensorflow.org/versions/r2.1/api_docs/python/tf/data/Dataset
    - keras
        - applications
            1. InceptionV3()
                - https://keras.io/api/applications/inceptionv3/
        - callbacks
            1. CSVLogger(), EarlyStopping(), ModelCheckpoint(), TensorBoard()
                - https://www.tensorflow.org/versions/r2.1/api_docs/python/tf/keras/callbacks
        - layers
            1. Dense()
                - https://www.tensorflow.org/versions/r2.1/api_docs/python/tf/keras/layers/Dense
            2. GlobalAveragePooling2D()
                - https://www.tensorflow.org/versions/r2.1/api_docs/python/tf/keras/layers/GlobalAveragePooling2D
        - losses
            1. CategoricalCrossentropy()
                - https://www.tensorflow.org/versions/r2.1/api_docs/python/tf/keras/losses/CategoricalCrossentropy
        - models
            1. load_model()
                - https://www.tensorflow.org/versions/r2.1/api_docs/python/tf/keras/models/load_model
            2. Sequential()
                1. compile(), fit()
                    - https://www.tensorflow.org/versions/r2.1/api_docs/python/tf/keras/Sequential
        - optimizers
            1. Adam()
                - https://www.tensorflow.org/versions/r2.1/api_docs/python/tf/keras/optimizers/Adam
    - image
        1. convert_image_dtype(), decode_jpeg(), resize()
            - https://www.tensorflow.org/versions/r2.1/api_docs/python/tf/image
    -io
        1. read_file()
            - https://www.tensorflow.org/versions/r2.1/api_docs/python/tf/io/read_file
    - strings
        1. split()
            - https://www.tensorflow.org/versions/r2.1/api_docs/python/tf/strings/split
- time
    1. time()
        - https://docs.python.org/3/library/time.html#time.time

Sources:
    1. Input pipeline using tf.data
        - https://www.tensorflow.org/tutorials/load_data/images
    2. Transfer learning
        - https://www.tensorflow.org/tutorials/images/transfer_learning
    3. InceptionV3
        - https://keras.io/api/applications/
"""

In [None]:
"""
Function Name: getFrameCount
Number of parameters: 1
List of parameters:
    1. dataDirectory | pathlib.Path | Path to a directory.
Pre-condition:
    1. 'dataDirectory' exists.
Post-condition:
    1. Returns the number of frames (.jpg files) in the folders within 'dataDirectory'. 
"""
def getFrameCount(dataDirectory):
    return len(list(dataDirectory.glob('*/*.jpg')))

In [None]:
"""
Function Name: getDataSet
Number of parameters: 1
List of parameters:
    1. dataDirectory | pathlib.Path | Path to a directory.
Pre-condition:
    1. 'dataDirectory' exists.
Post-condition:
    1. Returns a tf.data.Dataset object based on the files in the folders within 'dataDirectory'.
"""
def getDataSet(dataDirectory):
    imagePathsDataset = tf.data.Dataset.list_files(str(dataDirectory/'*/*'))
    return imagePathsDataset

In [None]:
"""
Function Name: getLabeledFrames
Number of parameters: 1
List of parameters:
    1. imagePath | tf.Tensor (string) | Tensor of paths to images.
Pre-condition:
    1. Paths in 'imagePath' exists.
Post-condition:
    1. Returns a processed (decoded, converted, resized) image along with its label.
"""
def getLabeledFrames(imagePath):
    def getLabel(imagePath):
        # convert the path to a list of path components
        parts = tf.strings.split(imagePath, os.path.sep)
        # The second to last is the class-directory
        return parts[-2] == classNames
    
    def decodeImg(img):
        # convert the compressed string to a 3D uint8 tensor
        img = tf.image.decode_jpeg(img, channels=3)
        # Use `convert_image_dtype` to convert to floats in the [0,1] range.
        img = tf.image.convert_image_dtype(img, tf.float32)
        # resize the image to the desired size.
        return tf.image.resize(img, [299, 299])
    
    label = getLabel(imagePath)
    # load the raw data from the file as a string
    img = tf.io.read_file(imagePath)
    img = decodeImg(img)
    return img, label

In [None]:
"""
Function Name: prepareTrainDataset
Number of parameters: 3
List of parameters:
    1. dataset           | tf.data.Dataset | Dataset object that contains the processed images and their labels.
    2. cache             | bool/str        | If False or an empty string, cache is not used. 
                                             If True then dataset is cached in memory.
                                             Otherwise, dataset is cached in a cache file.
    3. shuffleBufferSize | int             | Size of the shuffle buffer.
Pre-condition:
    1. If 'cache' is a non-empty string, then it must be a directory that exists.
Post-condition:
    1. If specified, the dataset will be cached (either on memory or on disk).
    2. Prefetches the next batched dataset.
    3. Returns a shuffled and batched dataset.
"""
def prepareTrainDataset(dataset, cache, shuffleBufferSize):
    if cache:
        if isinstance(cache, str):
            dataset = dataset.cache(cache)
        else:
            dataset = dataset.cache()
    
    dataset = dataset.shuffle(buffer_size = shuffleBufferSize)
    # Repeat forever
    dataset = dataset.repeat()
    dataset = dataset.batch(BATCH_SIZE)
    # `prefetch` lets the dataset fetch batches in the background while the model is training.
    dataset = dataset.prefetch(buffer_size=AUTOTUNE)
    
    return dataset

In [None]:
"""
Function Name: getModel
Number of parameters: 0
List of parameters: n/a
Pre-condition: n/a
Post-condition:
    1. Returns a compiled model (tf.keras.models.Sequential object).
"""
def getModel():
    base_model = InceptionV3(input_shape = (299, 299, 3),
                             include_top = False,
                             weights     = 'imagenet')
    base_model.trainable = False
    global_average_layer = GlobalAveragePooling2D()
    prediction_layer     = Dense(numClasses, activation = 'softmax')
    
    model = Sequential([base_model, global_average_layer, prediction_layer])
    
    model.compile(optimizer = Adam(lr = 0.0001),
                  loss      = CategoricalCrossentropy(from_logits = True),
                  metrics   = ['accuracy'])
    return model

In [None]:
"""
Function Name: fineTuneModel
Number of parameters: 2
List of parameters:
    1. model      | tf.keras.models | Model to be fine-tuned.
    2. fineTuneAt | int             | Layers indexed at and after this value are set to be trainable.
Pre-condition: n/a
Post-condition:
    1. Sets some of the layers of the first layer in 'model' to be trainable.
    2. Returns a compiled model (tf.keras.models.Sequential object).
"""
def fineTuneModel(model, fineTuneAt):
    base_model = model.layers[0]
    base_model.trainable = True
    
    for layer in base_model.layers[:fineTuneAt]:
        layer.trainable = False
        
    model.compile(optimizer = Adam(lr = 0.00001),
                  loss      = CategoricalCrossentropy(from_logits = True),
                  metrics   = ['accuracy'])
    return model

In [None]:
"""
Function Name: trainModel
Number of parameters: 8
List of parameters:
    1. model             | tf.keras.models | Model to be trained.
    2. initial_epoch     | int             | Epoch to start training.
    3. epochs            | int             | Epoch to stop training.
    4. trainDataset      | tf.data.Dataset | Data used for training.
    5. validationDataset | tf.data.Dataset | Data used for validation.
    6. steps_per_epoch   | int             | Total number of steps in a 'training' epoch.
    7. validation_steps  | int             | Total number of steps in a 'validation' epoch.
    8. callbacks         | list            | List of callbacks. (tf.keras.callbacks.Callback)
Pre-condition: n/a
Post-condition:
    1. Returns the model and History object.
"""
def trainModel(model, initial_epoch, epochs, trainDataset, validationDataset, steps_per_epoch, validation_steps, callbacks):
    history = model.fit(trainDataset,
                        initial_epoch    = initial_epoch, 
                        epochs           = epochs, 
                        validation_data  = validationDataset,
                        steps_per_epoch  = steps_per_epoch,
                        validation_steps = validation_steps,
                        callbacks        = callbacks)
    return model, history

In [None]:
dataObj    = data.Data()
numClasses = dataObj.numClasses
classNames = np.array(dataObj.classes)

AUTOTUNE   = tf.data.experimental.AUTOTUNE
BATCH_SIZE = 32

In [None]:
"""
Function Name: main
Number of parameters: 0
List of parameters: n/a
Pre-condition: n/a
Post-condition:
    1. Trains a model and saves callbacks to disk. 
    2. Nothing is returned.
"""
def main():
    cf         = config.Config()
    rootPath   = pathlib.Path(cf.rootPath)
    framesPath = pathlib.Path(cf.framesPath)
    
    trainDataDirectory      = framesPath/'Train'
    validationdataDirectory = framesPath/'Validation'
    
    cnnCallbacksDirectory    = rootPath/'Callbacks'/'CNN'/f'{numClasses}'
    modelCheckpointDirectory = cnnCallbacksDirectory/'ModelCheckpoint'
    tensorboardDirectory     = cnnCallbacksDirectory/'Tensorboard'
    csvLoggerDirectory       = cnnCallbacksDirectory/'CSVLogger'
    
    directoryFunctions.createDirectory(modelCheckpointDirectory)
    directoryFunctions.createDirectory(csvLoggerDirectory)
    
    trainDataset      = getDataSet(trainDataDirectory)
    validationDataset = getDataSet(validationdataDirectory)
    
    trainDataset      = trainDataset.map(getLabeledFrames,      num_parallel_calls=AUTOTUNE)
    validationDataset = validationDataset.map(getLabeledFrames, num_parallel_calls=AUTOTUNE)
    
    trainFrameCount      = getFrameCount(trainDataDirectory)
    validationFrameCount = getFrameCount(validationdataDirectory)
    
    cachePath     = pathlib.Path(r"./Cache")
    cacheFilePath = cachePath/'trainCNNDatasetCache'
    
    directoryFunctions.removeDirectory(cachePath)
    directoryFunctions.createDirectory(cachePath)
    
    trainDataset      = prepareTrainDataset(trainDataset, str(cacheFilePath), trainFrameCount)
    validationDataset = validationDataset.batch(BATCH_SIZE)
    
    steps_per_epoch  = np.ceil(trainFrameCount/BATCH_SIZE)
    validation_steps = np.ceil(validationFrameCount/BATCH_SIZE)
    
    currTime = int(time.time())
    modelCheckpointName = f'{currTime}' + '_CNN_{epoch:03d}_{val_loss:.2f}.h5'
    modelCheckpoint = ModelCheckpoint(filepath       = str(modelCheckpointDirectory/modelCheckpointName),
                                      save_best_only = True)
    tensorboard     = TensorBoard(log_dir = str(tensorboardDirectory/f'{currTime}'))
    csvLogger       = CSVLogger(str(csvLoggerDirectory/f'{currTime}.log'))
    earlyStopping   = EarlyStopping(monitor = 'val_loss', patience = 5)
    callbacks       = [modelCheckpoint, tensorboard, csvLogger, earlyStopping]
    
    initial_epoch    = 0
    epochs           = 1
    fine_tune_epochs = 1
    fineTuneAt       = 249
    
    savedModelPath = "" # insert path to saved model (.h5 file) here
    if savedModelPath == "":
        model                  = getModel()
        trained_model, history = trainModel(model, 
                                            initial_epoch, epochs, 
                                            trainDataset, validationDataset, 
                                            steps_per_epoch, validation_steps, 
                                            [])
        fine_tuned_model       = fineTuneModel(trained_model, fineTuneAt)
        initial_epoch          = history.epoch[-1]
    else:
        fine_tuned_model = load_model(savedModelPath)
        initial_epoch    = 0
        fine_tune_epochs = 1
    
    trained_model, history_fine = trainModel(fine_tuned_model, 
                                             initial_epoch, fine_tune_epochs, 
                                             trainDataset, validationDataset,
                                             steps_per_epoch, validation_steps,
                                             callbacks)
    
    directoryFunctions.removeDirectory(cachePath)

In [None]:
main()