Skip to content

Commit

Permalink
Add extra layers for Elephas model
Browse files Browse the repository at this point in the history
  • Loading branch information
zeynepakkalyoncu committed Dec 19, 2018
1 parent 2f6edb1 commit ce3849b
Showing 1 changed file with 15 additions and 5 deletions.
20 changes: 15 additions & 5 deletions lib/model/seq2seq.py
Expand Up @@ -24,14 +24,18 @@ def __init__(self, config):
else:
devices = list('/gpu:' + x for x in config.devices)

if config.ensemble:
inputs = Input(shape=(None,))
# TODO: set indices dynamically
reconstructed_inputs = Reshape((128,), input_shape=(config.dataset_size,))(inputs)
encoder_inputs = Lambda(lambda x: x[:, :50])(reconstructed_inputs)
decoder_inputs = Lambda(lambda x: x[:, 50:])(reconstructed_inputs)

with tf.device(devices[0]):
initial_weights = RandomUniform(minval=-0.08, maxval=0.08, seed=config.seed)
encoder_inputs = Input(shape=(None,))
reconstructed_inputs = Reshape((128,), input_shape=(config.dataset_size,))
encoder_embedding = Embedding(config.source_vocab_size, config.embedding_dim,
weights=[config.source_embedding_map], trainable=False)
# TODO: set indices dynamically
encoder_inputs = Lambda(lambda x: x[:, :50])(reconstructed_inputs)
if not config.ensemble: encoder_inputs = Input(shape=(None, ))
encoder_embedded = encoder_embedding(encoder_inputs)

if recurrent_unit == 'lstm':
Expand All @@ -50,7 +54,7 @@ def __init__(self, config):
encoder_states = [state_h]

with tf.device(devices[1]):
decoder_inputs = Input(shape=(None,))
if not config.ensemble: decoder_inputs = Input(shape=(None, ))
decoder_embedding = Embedding(config.target_vocab_size, config.embedding_dim,
weights=[config.target_embedding_map], trainable=False)
decoder_embedded = decoder_embedding(decoder_inputs)
Expand All @@ -70,6 +74,12 @@ def __init__(self, config):
decoder) # Use the final encoder state as context
decoder_outputs, decoder_states = decoder[0], decoder[1]

# if config.ensemble:
# decoder_reshape = Reshape((128, self.config.target_vocab_size)) #?
# decoder_slice = Lambda(lambda x: x[:, 50:, :])
# decoder_outputs = decoder_reshape(decoder_outputs)
# decoder_outputs = decoder_slice(decoder_outputs)

decoder_dense = Dense(config.target_vocab_size, activation='softmax')
decoder_outputs = decoder_dense(decoder_outputs)

Expand Down

0 comments on commit ce3849b

Please sign in to comment.