In [1]:
from __future__ import print_function

import numpy as np
np.random.seed(1337)

from keras.engine import Model
from keras_vggface.vggface import VGGFace
from keras.preprocessing import image
from keras.models import Sequential
from keras.layers.recurrent import LSTM, GRU
from keras.layers.core import Dense
from keras.layers.wrappers import Bidirectional

from pre_processing import load_data, split_train_test

Using TensorFlow backend.


In [58]:
# Load the data.
x, y, kmasks, vocab = load_data(k=5, speakers=[2], shuffle=True, use_delta_frames=True)

Skipped 8 empty data files.


In [25]:
vocab.occurrences()

{'a': 40,
 'again': 250,
 'at': 252,
 'b': 40,
 'bin': 240,
 'blue': 250,
 'by': 248,
 'c': 40,
 'd': 40,
 'e': 40,
 'eight': 100,
 'f': 40,
 'five': 100,
 'four': 100,
 'g': 40,
 'green': 250,
 'h': 40,
 'i': 40,
 'in': 252,
 'j': 40,
 'k': 40,
 'l': 40,
 'lay': 248,
 'm': 40,
 'n': 40,
 'nine': 100,
 'now': 250,
 'o': 40,
 'one': 100,
 'p': 40,
 'place': 256,
 'please': 250,
 'q': 40,
 'r': 40,
 'red': 250,
 's': 40,
 'set': 256,
 'seven': 100,
 'six': 100,
 'soon': 250,
 't': 40,
 'three': 100,
 'two': 100,
 'u': 40,
 'v': 40,
 'white': 250,
 'with': 248,
 'x': 40,
 'y': 40,
 'z': 40,
 'zero': 100}

In [59]:
# Select which fold of the k-folds to use.
FOLD = 0

train, test = split_train_test(x, y, kmasks[FOLD])
x_train, y_train = train
x_test, y_test = test

print('Number of words for training: ', x_train.shape[0])
print('Number of words for testing: ', x_test.shape[0])
print('Frames per word: ', x_train.shape[1])
print('Features per frame: ', x_train.shape[2])

Number of words for training:  4800
Number of words for testing:  1200
Frames per word:  6
Features per frame:  512


In [57]:
# Build the LSTM model.
model = Sequential()
model.add(LSTM(128, input_shape=x_train[0].shape, return_sequences=True))
model.add(LSTM(128))
model.add(Dense(len(vocab), activation='tanh'))
model.compile(loss='mean_squared_error', optimizer='adam', metrics=['accuracy'])

In [54]:
# Train the model.
model.fit(x_train, y_train, batch_size=32, epochs=30, verbose=1, validation_data=(x_test, y_test))

Train on 4800 samples, validate on 1200 samples
Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30
Epoch 10/30
Epoch 11/30
Epoch 12/30
Epoch 13/30
Epoch 14/30
Epoch 15/30
Epoch 16/30
Epoch 17/30
Epoch 18/30
Epoch 19/30
Epoch 20/30
Epoch 21/30
Epoch 22/30
Epoch 23/30
Epoch 24/30
Epoch 25/30
Epoch 26/30
Epoch 27/30
Epoch 28/30
Epoch 29/30
Epoch 30/30


<keras.callbacks.History at 0x7f66455eda50>

In [52]:
# Test the model.
score, acc = model.evaluate(x_test, y_test, batch_size=32)
print('Test score: ', score)
print('Test accuracy: ', acc)

Test accuracy:  0.763333333333


In [24]:
# Try matching some results.
frames_per_word = x_test.shape[1]
features_per_frame = x_test.shape[2]

RESULTS_TO_CHECK = 50
MATCHES_PER_WORD = 4

for i in xrange(RESULTS_TO_CHECK):
    pred = model.predict(x_test[i].reshape(1, frames_per_word, features_per_frame))
    predicted_indexes = np.argsort(pred.reshape(len(vocab)))[::-1][:MATCHES_PER_WORD]
    correct_index = np.argsort(y_test[i])[::-1][0]

    correct_word = vocab[correct_index]
    predicted_words = [vocab[i] for i in predicted_indexes]

    print('{} : {}'.format(correct_word, predicted_words))

soon : ['soon', 'seven', 'at', 'place']
set : ['set', 'g', 'again', 'by']
six : ['six', 't', 'j', 'g']
please : ['please', 'x', 'red', 'in']
please : ['please', 'place', 'seven', 'c']
red : ['red', 'eight', 'g', 'i']
eight : ['eight', 'at', 'one', 'd']
seven : ['seven', 'now', 'k', 'four']
five : ['five', 'at', 'j', 'six']
n : ['f', 'n', 'l', 'a']
set : ['set', 'c', 'again', 'five']
x : ['x', 'n', 'i', 'e']
again : ['again', 'one', 'p', 'red']
u : ['h', 'u', 'q', 'o']
one : ['one', 'r', 'three', 'nine']
r : ['now', 'r', 'soon', 'l']
again : ['again', 'v', 'x', 'q']
at : ['at', 'k', 'j', 'bin']
lay : ['lay', 'i', 'q', 'n']
now : ['now', 'l', 'g', 'two']
lay : ['lay', 'place', 'nine', 'soon']
place : ['place', 'please', 'e', 'f']
u : ['o', 'r', 'u', 'l']
at : ['k', 'at', 'a', 'f']
zero : ['zero', 'seven', 'by', 'bin']
three : ['three', 'c', 'nine', 'zero']
at : ['at', 'k', 'eight', 'f']
lay : ['lay', 'nine', 'q', 'l']
place : ['place', 'a', 'white', 'eight']
two : ['two', 'nine', 'm', 'u