In [None]:
from keras import applications
from keras.preprocessing.image import ImageDataGenerator
from keras import optimizers
from keras.models import Sequential
from keras.layers import Dropout, Flatten, Dense
from keras.models import Model

In [None]:
top_model_weights_path = 'fc_model.h5'
cat_dog_vgg16 = 'cat_dog_vgg16_weights'
img_width, img_height = 150, 150

In [None]:
train_data_dir = 'data/train'
validation_data_dir = 'data/validation'
test_data_dir = 'data/test'
nb_train_samples = 8000
nb_validation_samples = 200
epochs = 20
batch_size = 20

In [None]:
base_model = applications.VGG16(weights='imagenet', include_top=False, input_shape=(img_width, img_height, 3))
base_model.summary()

In [None]:
top_model = Sequential()
top_model.add(Flatten(input_shape=base_model.output_shape[1:]))
top_model.add(Dense(256, activation='relu'))
top_model.add(Dropout(0.5))
top_model.add(Dense(1, activation='sigmoid'))
top_model.summary()

top_model.load_weights(top_model_weights_path)

In [None]:
model = Model(inputs=base_model.input, outputs=top_model(base_model.output))
model.summary()

In [None]:
for layer in model.layers[:15]:
    layer.trainable = False

In [None]:
model.compile(loss='binary_crossentropy',optimizer=optimizers.SGD(lr=1e-4, momentum=0.5),metrics=['accuracy'])

In [None]:
train_datagen = ImageDataGenerator(
                rotation_range=40,
                width_shift_range=0.2,
                height_shift_range=0.2,
                shear_range=0.2,
                zoom_range=0.2,
                horizontal_flip=True,
                fill_mode='nearest')

validation_datagen = ImageDataGenerator(rescale=1. / 255)

In [None]:
train_generator = train_datagen.flow_from_directory(
                                train_data_dir,
                                target_size=(img_height, img_width),
                                batch_size=batch_size,
                                class_mode='binary')

validation_generator = validation_datagen.flow_from_directory(
                                validation_data_dir,
                                target_size=(img_height, img_width),
                                batch_size=batch_size,
                                class_mode='binary')

In [None]:
model.fit_generator(
                train_generator,
                steps_per_epoch=nb_train_samples // batch_size,
                epochs=epochs,
                validation_data=validation_generator,
                validation_steps=nb_validation_samples // batch_size,
                workers=5,
                verbose=1)

In [None]:
model.save_weights(cat_dog_vgg16)

In [None]:
def plot_history(history):
    loss_list = [s for s in history.history.keys() if 'loss' in s and 'val' not in s]
    val_loss_list = [s for s in history.history.keys() if 'loss' in s and 'val' in s]
    acc_list = [s for s in history.history.keys() if 'acc' in s and 'val' not in s]
    val_acc_list = [s for s in history.history.keys() if 'acc' in s and 'val' in s]
    
    if len(loss_list) == 0:
        print('Loss is missing in history')
        return 
    
    epochs = range(1,len(history.history[loss_list[0]]) + 1)
    
    plt.figure(1)
    for l in loss_list:
        plt.plot(epochs, history.history[l], 'b', label='Training loss (' + str(str(format(history.history[l][-1],'.5f'))+')'))
    for l in val_loss_list:
        plt.plot(epochs, history.history[l], 'g', label='Validation loss (' + str(str(format(history.history[l][-1],'.5f'))+')'))
    
    plt.title('Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    
    ## Accuracy
    plt.figure(2)
    for l in acc_list:
        plt.plot(epochs, history.history[l], 'b', label='Training accuracy (' + str(format(history.history[l][-1],'.5f'))+')')
    for l in val_acc_list:    
        plt.plot(epochs, history.history[l], 'g', label='Validation accuracy (' + str(format(history.history[l][-1],'.5f'))+')')

    plt.title('Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.show()



In [None]:
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
plot_history(model.history)