In [1]:
from keras.callbacks import ModelCheckpoint, EarlyStopping
from keras.models import Sequential, load_model
from keras.layers import ConvLSTM2D, BatchNormalization, Conv3D, Dense, Dropout, Flatten
from keras.utils.vis_utils import plot_model
import numpy as np
import tensorflow as tf

from constants import BATCH_SIZE, FACE_SHAPE, FRAMES_IN_SEQ, NUM_EPOCHS
from helpers import batch_gen, num_steps_per_epoch

Using TensorFlow backend.


In [2]:
train_gen = batch_gen("train", BATCH_SIZE)
train_steps = num_steps_per_epoch("train", BATCH_SIZE)
val_gen = batch_gen("val", BATCH_SIZE)
val_steps = num_steps_per_epoch("val", BATCH_SIZE)
test_gen = batch_gen("test", BATCH_SIZE)
test_steps = num_steps_per_epoch("test", BATCH_SIZE)

In [3]:
model = Sequential()
with tf.device("/gpu:0"):
    model.add(ConvLSTM2D(filters=32, kernel_size=(3, 3),
                       input_shape=(FRAMES_IN_SEQ, 64, 64, 3),
                       padding='valid', return_sequences=True, activation="elu",
                       dropout=0.5))
    model.add(BatchNormalization())

with tf.device("/gpu:1"):
    model.add(Conv3D(filters=1, kernel_size=(5, 5, 5),
                   activation='elu',
                   padding='valid', data_format='channels_last'))
    model.add(Flatten())
    model.add(Dense(512, activation="elu"))
    model.add(Dropout(0.5))
    model.add(Dense(128, activation="elu"))
    model.add(Dense(1, activation="sigmoid"))
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=["accuracy"])
model.summary()
plot_model(model, show_shapes=True, to_file="model.png")

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv_lst_m2d_1 (ConvLSTM2D)  (None, 7, 62, 62, 32)     40448     
_________________________________________________________________
batch_normalization_1 (Batch (None, 7, 62, 62, 32)     128       
_________________________________________________________________
conv3d_1 (Conv3D)            (None, 3, 58, 58, 1)      4001      
_________________________________________________________________
flatten_1 (Flatten)          (None, 10092)             0         
_________________________________________________________________
dense_1 (Dense)              (None, 512)               5167616   
_________________________________________________________________
dropout_1 (Dropout)          (None, 512)               0         
_________________________________________________________________
dense_2 (Dense)              (None, 128)               65664     
__________

In [4]:
# callbacks
checkpoint = ModelCheckpoint(filepath="checkpoints/weights-{epoch:02d}-{val_acc:.2f}-{val_loss:.2f}.hdf5", monitor='val_loss', verbose=1, save_best_only=False)
early_stopping = EarlyStopping(monitor="val_loss", patience=3, mode="min", verbose=1)
callbacks = [checkpoint, early_stopping]

model.fit_generator(train_gen, steps_per_epoch=train_steps, epochs=NUM_EPOCHS, verbose=1,
                    callbacks=callbacks, validation_data=val_gen, validation_steps=val_steps)

Epoch 1/50

Epoch 00001: saving model to checkpoints/weights-01-0.95-0.74.hdf5
Epoch 2/50

Epoch 00002: saving model to checkpoints/weights-02-0.95-0.82.hdf5
Epoch 3/50

Epoch 00003: saving model to checkpoints/weights-03-0.94-0.93.hdf5
Epoch 4/50

Epoch 00004: saving model to checkpoints/weights-04-0.98-0.35.hdf5
Epoch 5/50

Epoch 00005: saving model to checkpoints/weights-05-0.99-0.15.hdf5
Epoch 6/50

Epoch 00006: saving model to checkpoints/weights-06-0.99-0.12.hdf5
Epoch 7/50

Epoch 00007: saving model to checkpoints/weights-07-0.55-7.23.hdf5
Epoch 8/50

Epoch 00008: saving model to checkpoints/weights-08-0.54-7.48.hdf5
Epoch 9/50

Epoch 00009: saving model to checkpoints/weights-09-0.59-6.57.hdf5
Epoch 00009: early stopping


<keras.callbacks.History at 0x7fde24168048>

In [6]:
model = load_model("checkpoints/weights-06-0.99-0.12.hdf5")
eval_results = model.evaluate_generator(test_gen, steps=test_steps, verbose=1)



In [8]:
print("Loss {}, accuracy {}".format(eval_results[0], eval_results[1]))

Loss 0.08303335812002564, accuracy 0.9947916666666666
