In [13]:
import keras
import matplotlib.pyplot as plt
import sys
import argparse
import os
import tensorflow as tf
from keras.models import Model
from keras.models import load_model
from keras.layers.normalization import BatchNormalization
from keras.layers import Activation, Dropout, Flatten, Dense
from keras.layers.core import Dense, Flatten
from keras.optimizers import Adam, SGD, rmsprop
from keras.preprocessing.image import ImageDataGenerator
from keras.utils import to_categorical
from sklearn.metrics import confusion_matrix

In [20]:
def train():
    global model_name_to_save
    model_name_to_save = "Model2_" + str(model_name) + "_" + str(learning_rate) + "_" + str(batch_size) + "_" + str(epochs)

#     print("Using {0} with lr = {1} and batch size = {2}".format(model_name,learning_rate, batch_size))

    # Provide path names for data and folder names for classes
    train_batches = ImageDataGenerator(rescale = 1.0/255.0).flow_from_directory('../misc/data/images/train', class_mode = 'categorical', classes = ['Gesture_0', 'Gesture_1', 'Gesture_2', 'Gesture_3', 'Gesture_4', 'Gesture_5', 'Gesture_6', 'Gesture_7', 'Gesture_8', 'Gesture_9'], batch_size = batch_size, target_size = (224, 224), shuffle=True)
    validation_batches = ImageDataGenerator(rescale = 1.0/255.0).flow_from_directory('../misc/data/images/validate', class_mode = 'categorical', classes = ['Gesture_0', 'Gesture_1', 'Gesture_2', 'Gesture_3', 'Gesture_4', 'Gesture_5', 'Gesture_6', 'Gesture_7', 'Gesture_8', 'Gesture_9'], batch_size = batch_size, target_size = (224, 224), shuffle=True)

    # CNN architectures
    if model_name == 'ResNet152':
        model = keras.applications.ResNet152(include_top=False, input_shape=(224, 224, 3))
    elif model_name == 'InceptionV3':
        model = keras.applications.InceptionV3(include_top=False, input_shape=(224, 224, 3))
    elif model_name == 'NASNetLarge':
        model = keras.applications.NASNetLarge(include_top=False, input_shape=(331, 331, 3))
    elif model_name == 'VGG16':
        model = keras.applications.vgg16.VGG16(include_top=False, input_shape=(224, 224, 3))
    else:
        model = keras.applications.resnet50.ResNet50(include_top=False, input_shape=(224, 224, 3))
    print(model)
    
    # Add additional layers
    flat1 = Flatten(input_shape = model.output_shape[1:])(model.layers[-1].output)
    dense1 = Dense(256, activation = 'relu')(flat1)
    output = Dense(10, activation = 'softmax')(dense1)

    # Define new model
    model = Model(inputs = model.inputs, outputs = output)
    #model.summary()

    # lr comes from cmd line args
    model.compile(SGD(lr = learning_rate, momentum = 0.9), loss = 'categorical_crossentropy', metrics = ['accuracy'])

    # fit model
    history = model.fit_generator(train_batches, steps_per_epoch = len(train_batches), validation_data = validation_batches, validation_steps = len(validation_batches), epochs = epochs, verbose = 0)

    # validation
    acc = model.evaluate_generator(validation_batches, steps = len(validation_batches), verbose = 0)
    print('Loss, Accuracy: ' + str(acc))

    # save model
    model.save(model_name_to_save)

In [21]:
def test():
    # Provide test folder and classes
    test_batches = ImageDataGenerator(rescale = 1.0/255.0).flow_from_directory('../misc/data/images/test', class_mode = 'categorical', classes = ['Gesture_0', 'Gesture_1', 'Gesture_2', 'Gesture_3', 'Gesture_4', 'Gesture_5', 'Gesture_6', 'Gesture_7', 'Gesture_8', 'Gesture_9'], batch_size = batch_size, target_size = (224, 224), shuffle = True)

    # load model
    model = load_model(model_name_to_save)
    model.compile(SGD(lr = learning_rate, momentum = 0.9), loss = 'categorical_crossentropy', metrics = ['accuracy'])

    # testing
    acc = model.evaluate_generator(test_batches, steps = len(test_batches), verbose = 0)
    print('Loss, Accuracy: ' + str(acc))

In [22]:
model_name="ResNet52"
batch_size=4
learning_rate=0.0001
epochs=1
train()
test()

Found 10000 images belonging to 10 classes.
Found 5794 images belonging to 10 classes.




Downloading data from https://github.com/fchollet/deep-learning-models/releases/download/v0.2/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5
<keras.engine.training.Model object at 0x1857d1780>


KeyboardInterrupt: 