Skip to content

Commit

Permalink
Add time history callback
Browse files Browse the repository at this point in the history
  • Loading branch information
achyudh committed Dec 18, 2018
1 parent 9a9ab09 commit c8510a0
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 12 deletions.
25 changes: 14 additions & 11 deletions lib/model/seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from keras.callbacks import ModelCheckpoint

from lib.model.metrics import bleu_score
from lib.model.util import lr_scheduler
from lib.model.util import lr_scheduler, TimeHistory


class Seq2Seq:
Expand Down Expand Up @@ -53,13 +53,13 @@ def encode(self, encoder_inputs, recurrent_unit='lstm'):
encoder_embedding = Embedding(self.config.source_vocab_size, self.config.embedding_dim,
weights=[self.config.source_embedding_map], trainable=False)
encoder_embedded = encoder_embedding(encoder_inputs)
if recurrent_unit == 'lstm':
if recurrent_unit.lower() == 'lstm':
encoder = LSTM(self.config.hidden_dim, return_state=True, return_sequences=True, recurrent_initializer=initial_weights)(encoder_embedded)
for i in range(1, self.config.num_encoder_layers):
encoder = LSTM(self.config.hidden_dim, return_state=True, return_sequences=True)(encoder)
_, state_h, state_c = encoder
return [state_h, state_c]
else: # GRU
else:
encoder = GRU(self.config.hidden_dim, return_state=True, return_sequences=True, recurrent_initializer=initial_weights)(encoder_embedded)
for i in range(1, self.config.num_encoder_layers):
encoder = GRU(self.config.hidden_dim, return_state=True, return_sequences=True)(encoder)
Expand All @@ -70,24 +70,25 @@ def decode(self, decoder_inputs, encoder_states, recurrent_unit='lstm'):
decoder_embedding = Embedding(self.config.target_vocab_size, self.config.embedding_dim,
weights=[self.config.target_embedding_map], trainable=False)
decoder_embedded = decoder_embedding(decoder_inputs)
if recurrent_unit == 'lstm':
if recurrent_unit.lower() == 'lstm':
decoder = LSTM(self.config.hidden_dim, return_state=True, return_sequences=True)(decoder_embedded, initial_state=encoder_states) # Accepts concatenated encoder states as input
for i in range(1, self.config.num_decoder_layers):
decoder = LSTM(self.config.hidden_dim, return_state=True, return_sequences=True)(decoder) # Use the final encoder state as context
decoder = LSTM(self.config.hidden_dim, return_state=True, return_sequences=True)(decoder) # Use the final encoder state as context
decoder_outputs, decoder_states = decoder[0], decoder[1:]
else: # GRU
else:
decoder = GRU(self.config.hidden_dim, return_state=True, return_sequences=True)(decoder_embedded, initial_state=encoder_states) # Accepts concatenated encoder states as input
for i in range(1, self.config.num_decoder_layers):
decoder = GRU(self.config.hidden_dim, return_state=True, return_sequences=True)(decoder) # Use the final encoder state as context
decoder = GRU(self.config.hidden_dim, return_state=True, return_sequences=True)(decoder) # Use the final encoder state as context
decoder_outputs, decoder_states = decoder[0], decoder[1]
decoder_dense = Dense(self.config.target_vocab_size, activation='softmax')
return decoder_dense(decoder_outputs)

def train(self, encoder_train_input, decoder_train_input, decoder_train_target):
checkpoint_filename = \
'ep{epoch:02d}_el%d_dl%d_ds%d_sv%d_tv%d.hdf5' % (self.config.num_encoder_layers, self.config.num_decoder_layers, self.config.dataset_size,
self.config.source_vocab_size, self.config.target_vocab_size)
callbacks = [lr_scheduler(initial_lr=self.config.lr, decay_factor=self.config.decay),
self.config.source_vocab_size, self.config.target_vocab_size)
time_callback = TimeHistory()
callbacks = [lr_scheduler(initial_lr=self.config.lr, decay_factor=self.config.decay), time_callback,
ModelCheckpoint(os.path.join(os.getcwd(), 'data', 'checkpoints', self.config.dataset, checkpoint_filename),
monitor='val_loss', verbose=1, save_best_only=False,
save_weights_only=True, mode='auto', period=1)]
Expand All @@ -100,13 +101,15 @@ def train(self, encoder_train_input, decoder_train_input, decoder_train_target):
def train_generator(self, training_generator, validation_generator):
checkpoint_filename = \
'ep{epoch:02d}_el%d_dl%d_ds%d_sv%d_tv%d.hdf5' % (self.config.num_encoder_layers, self.config.num_decoder_layers, self.config.dataset_size,
self.config.source_vocab_size, self.config.target_vocab_size)
callbacks = [lr_scheduler(initial_lr=self.config.lr, decay_factor=self.config.decay),
self.config.source_vocab_size, self.config.target_vocab_size)
time_callback = TimeHistory()
callbacks = [lr_scheduler(initial_lr=self.config.lr, decay_factor=self.config.decay), time_callback,
ModelCheckpoint(os.path.join(os.getcwd(), 'data', 'checkpoints', self.config.dataset, checkpoint_filename),
monitor='val_loss', verbose=1, save_best_only=False,
save_weights_only=True, mode='auto', period=1)]
self.model.fit_generator(training_generator, epochs=self.config.epochs, callbacks=callbacks,
validation_data=validation_generator)
print("Training time (in seconds):", time_callback.times)

def predict(self, encoder_predict_input, decoder_predict_input):
return self.model.predict([encoder_predict_input, decoder_predict_input])
Expand Down
15 changes: 14 additions & 1 deletion lib/model/util.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,25 @@
import codecs
import os
import time

import dill
import keras
import numpy as np
from keras.callbacks import LearningRateScheduler
import keras.backend as K
from tqdm import tqdm


class TimeHistory(keras.callbacks.Callback):
def on_train_begin(self, logs=dict()):
self.times = []

def on_epoch_begin(self, epoch, logs=dict()):
self.epoch_time_start = time.time()

def on_epoch_end(self, epoch, logs=dict()):
self.times.append(time.time() - self.epoch_time_start)


def lr_scheduler(initial_lr, decay_factor):
def schedule(epoch):
if epoch and epoch < 5:
Expand Down

0 comments on commit c8510a0

Please sign in to comment.