In [None]:
import tensorflow as tf
import automatic_speech_recognition as asr

# Load Data

In [None]:
data_dir = '/blue/vbindschadler/hadi10102/train_data/speech_commands/'

file_paths = asr.util.preprocess.read_simple_word(data_dir)

# Split data into train, val, and test
train_files_paths = file_paths[:6400]
val_files_paths = file_paths[6400: 6400 + 800]
test_files_paths = file_paths[-800:]

# Reads the audio data for the file_paths
train_ds = asr.util.preprocess.preprocess_simple_word(train_files_paths)
val_ds = asr.util.preprocess.preprocess_simple_word(val_files_paths)
test_ds = asr.util.preprocess.preprocess_simple_word(test_files_paths)

# Padd the data sets
train_ds_padded = train_ds.padded_batch(8, padded_shapes={'audio_input': (None, 1), 
                                                          'y_true' : (None,),
                                                          'y_true_length' :()})

val_ds_padded = val_ds.padded_batch(8, padded_shapes={'audio_input': (None, 1),
                                                      'y_true' : (None,),
                                                      'y_true_length' :()})

# Load Model Architecture

In [None]:
# Load the model input and output layers
logits, input_audio, y_true, y_true_length = asr.models.cnn_raw_speech.cnn_raw_speech.get_model()

# Setup the CTC loss layer
ctc_loss_layer = asr.util.ctc_loss.get_ctc_layer(logits, y_true, y_true_length)

# Setup Model and training params
model = tf.keras.Model(inputs = [input_audio, y_true, y_true_length], outputs = [ctc_loss_layer, logits])
opt = tf.keras.optimizers.Adam(learning_rate=0.0001)

model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer = opt)

# Train!

In [None]:
checkpoint = ModelCheckpoint("best_model.hdf5", monitor='loss', verbose=1, save_best_only=True, mode='auto', period=1)
early_stop = EarlyStopping(monitor='val_loss', min_delta=0.0001, patience=10, verbose=1, mode='min')

model.fit(x = train_ds_padded, 
          validation_data = val_ds_padded, 
          shuffle = True,
          callbacks = [checkpoint, early_stop],
          epochs=100)