In [None]:
import os 
os.environ["CUDA_VISIBLE_DEVICES"]="0";

In [None]:
import tensorflow as tf

tf.test.gpu_device_name()

In [None]:
#######################################################
# Model Training
# 11/09/2019 
#
#######################################################
import numpy as np
import sklearn.metrics as metrics
import os
from math import floor


from tensorflow.keras.callbacks import ModelCheckpoint,CSVLogger
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from ISY5002_CA2_02_ModelDefinitions import createModel, my_preprocess

from tensorflow.keras.applications.vgg16 import preprocess_input as VGG16_preprocess_input
from tensorflow.keras.applications.resnet50 import preprocess_input as ResNet50_preprocess_input
from tensorflow.keras.applications.inception_v3 import preprocess_input as InceptionV3_preprocess_input

# import sys

# defining global variables
DEBUG_MODE = False
image_path = "./Resized"
seed        = 29 # fix random seed for reproducibility
np.random.seed(seed)
optmz       = 'adam'    # optimizers.RMSprop(lr=0.0001)
modelnameBase   = 'CA2'
num_classes = 3

channel = 3
num_epochs = 100

# hyperparameters
bsize = 32
rng_rot = 0
rng_zoom = 0.1
rng_bright = [0.01, 0]
rng_channel = 0.1
opt_hflip = True
opt_vflip = False

def create_summarise_plot(size, channel, index = 0):
    imgrows = size
    imgclms = size
    model = createModel(imgrows, imgclms, channel, index) # for training
    # modelGo = createModel() # for final testing
    model.summary()

    # Plot structure of network
    #from tensorflow.keras.utils import plot_model
    #plot_model(model, to_file='ISY5002_CA2_NN_' + str(index) + '.pdf', show_shapes=True, show_layer_names=False, rankdir='TB') 

    return model

def createIterators(size, index=0):
    imgrows = size
    imgclms = size

    preprocessing_fn = my_preprocess
    if (index > 90):
        # Using pretrained datasets - have to use their preprocess functions
        if (index == 90):
            preprocessing_fn = VGG16_preprocess_input
        elif (index == 91):
            preprocessing_fn = ResNet50_preprocess_input
        elif (index == 92):
            preprocessing_fn = InceptionV3_preprocess_input
    datagen = ImageDataGenerator(preprocessing_function = preprocessing_fn)
    datagenTrain = ImageDataGenerator(preprocessing_function = preprocessing_fn, 
        rotation_range=rng_rot, zoom_range=rng_zoom, 
        horizontal_flip=opt_hflip, vertical_flip=opt_vflip) 
        #brightness_range = rng_bright, channel_shift_range = rng_channel)
    
    train_it = datagenTrain.flow_from_directory('./Resized/train/', class_mode='categorical', target_size=(imgrows, imgclms), batch_size=bsize, color_mode='rgb')
    val_it = datagen.flow_from_directory('./Resized/validation/', class_mode='categorical', target_size=(imgrows, imgclms), batch_size=bsize, color_mode='rgb')
    # test_it = datagen.flow_from_directory('./Resized/test/', class_mode='categorical', target_size=(imgrows, imgclms), batch_size=1, color_mode='rgb')

    n_train = sum([len(files) for r, d, files in os.walk('./Resized/train/')])
    n_val = sum([len(files) for r, d, files in os.walk('./Resized/validation/')])
    # n_test = sum([len(files) for r, d, files in os.walk('./Resized/test/')])

    batchX, batchy = train_it.next()
    print('Batch shape=%s, min=%.3f, max=%.3f' % (batchX.shape, batchX.min(), batchX.max()))

    return train_it, val_it, n_train, n_val  



def main():

    # ------ CHANGE THESE ------
    index = 90
    size = 224
    # --------------------------


    modelname = modelnameBase + "_" + str(index) + "_" + str(size)

    # Create model and summary
    train_it, val_it, n_train, n_val= createIterators(size, index)
    model = create_summarise_plot(size, channel, index)

    # Create checkpoint for the training
    # This checkpoint performs model saving when
    # an epoch gives highest testing accuracy
    filepath        = modelname + ".hdf5"
    checkpoint      = ModelCheckpoint(filepath, monitor='val_acc', verbose=0, save_best_only=True, mode='max')

    # Log the epoch detail into csv
    csv_logger      = CSVLogger(modelname +'.csv')
    callbacks_list  = [checkpoint,csv_logger]
    
    # steps_per_epoch = total training data across all classes / batch size
    # validation_steps = number of batches in validation dataset defining 1 epoch
    model.fit_generator(
        train_it, steps_per_epoch=floor(n_train/bsize), 
        validation_data=val_it, 
        validation_steps=floor(n_val/bsize),
        epochs=num_epochs, callbacks = callbacks_list )

main()