In [20]:
import tensorflow as tf
import pathlib
import data
import time

from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications        import InceptionV3
from tensorflow.keras.optimizers          import Adam
from tensorflow.keras.callbacks           import ModelCheckpoint, TensorBoard
from tensorflow.keras.layers              import Dense, GlobalAveragePooling2D
from tensorflow.keras.losses              import CategoricalCrossentropy
from tensorflow.keras                     import Sequential

In [21]:
def getImageDataGenerators(train_data_directory, validation_data_directory):
    train_image_generator      = ImageDataGenerator(rescale = 1./255)
    validation_image_generator = ImageDataGenerator(rescale = 1./255)
    
    train_data_gen = train_image_generator.flow_from_directory(directory     = str(train_data_directory),
                                                               target_size   = (299, 299),
                                                               interpolation = "lanczos")
    
    validation_data_gen = validation_image_generator.flow_from_directory(directory    = str(validation_data_directory),
                                                                        target_size   = (299, 299),
                                                                        interpolation = "lanczos")
    return train_data_gen, validation_data_gen

In [22]:
def getModel(numClasses):
    base_model = InceptionV3(input_shape = (299, 299, 3),
                             include_top = False,
                             weights     = 'imagenet')
    base_model.trainable = False
    global_average_layer = GlobalAveragePooling2D()
    prediction_layer     = Dense(numClasses, activation = 'softmax')
    
    model = Sequential([base_model, global_average_layer, prediction_layer])
    
    model.compile(optimizer = Adam(lr = 0.0001),
                  loss      = CategoricalCrossentropy(from_logits = True),
                  metrics   = ['accuracy'])
    return model

In [23]:
def fineTuneModel(model, fineTuneAt):
    base_model = model.layers[0]
    base_model.trainable = True
    
    for layer in base_model.layers[:fineTuneAt]:
        layer.trainable = False
        
    model.compile(optimizer = Adam(lr = 0.00001),
                  loss      = CategoricalCrossentropy(from_logits = True),
                  metrics   = ['accuracy'])
    return model

In [24]:
def trainModel(model, initial_epoch, epochs, train_data_gen, validation_data_gen, callbacks):
    history = model.fit(train_data_gen,
                        initial_epoch   = initial_epoch, 
                        epochs          = epochs, 
                        validation_data = validation_data_gen,
                        callbacks       = callbacks)
    return model, history

In [25]:
def main():
    dataObj    = data.Data()
    numClasses = dataObj.numClasses
    
    initial_epoch    = 0
    epochs           = 5
    fine_tune_epochs = 9
    fineTuneAt       = 300
    
    train_data_directory      = pathlib.Path(r"D:\ActionRecognition\Frames\Train")
    validation_data_directory = pathlib.Path(r"D:\ActionRecognition\Frames\Test")
    train_data_gen, validation_data_gen = getImageDataGenerators(train_data_directory, validation_data_directory)
    
    modelCheckpointDirectory = pathlib.Path(r"D:\ActionRecognition\Callbacks\CNN\ModelCheckpoint")
    tensorboardDirectory     = pathlib.Path(r"D:\ActionRecognition\Callbacks\CNN\Tensorboard")
    
    modelCheckpoint = ModelCheckpoint(filepath       = str(modelCheckpointDirectory/'CNN_{epoch:03d}_{val_loss:.2f}'),
                                      save_best_only = True)
    tensorboard     = TensorBoard(log_dir = str(tensorboardDirectory/f'{int(time.time())}'))
    
    model                       = getModel(numClasses)
    trained_model, history      = trainModel(model, initial_epoch, epochs, train_data_gen, validation_data_gen, [])
    fine_tuned_model            = fineTuneModel(trained_model, fineTuneAt)
    trained_model, history_fine = trainModel(fine_tuned_model, 
                                             history.epoch[-1], fine_tune_epochs, 
                                             train_data_gen, validation_data_gen, 
                                             [modelCheckpoint, tensorboard])

In [26]:
main()

Found 5740 images belonging to 3 classes.
Found 2331 images belonging to 3 classes.
  ...
    to  
  ['...']
  ...
    to  
  ['...']
Train for 180 steps, validate for 73 steps
Epoch 1/5


KeyboardInterrupt: 