diff --git a/lib/model/seq2seq.py b/lib/model/seq2seq.py index 7b36122..7b137d6 100644 --- a/lib/model/seq2seq.py +++ b/lib/model/seq2seq.py @@ -24,14 +24,18 @@ def __init__(self, config): else: devices = list('/gpu:' + x for x in config.devices) + if config.ensemble: + inputs = Input(shape=(None,)) + # TODO: set indices dynamically + reconstructed_inputs = Reshape((128,), input_shape=(config.dataset_size,))(inputs) + encoder_inputs = Lambda(lambda x: x[:, :50])(reconstructed_inputs) + decoder_inputs = Lambda(lambda x: x[:, 50:])(reconstructed_inputs) + with tf.device(devices[0]): initial_weights = RandomUniform(minval=-0.08, maxval=0.08, seed=config.seed) - encoder_inputs = Input(shape=(None,)) - reconstructed_inputs = Reshape((128,), input_shape=(config.dataset_size,)) encoder_embedding = Embedding(config.source_vocab_size, config.embedding_dim, weights=[config.source_embedding_map], trainable=False) - # TODO: set indices dynamically - encoder_inputs = Lambda(lambda x: x[:, :50])(reconstructed_inputs) + if not config.ensemble: encoder_inputs = Input(shape=(None, )) encoder_embedded = encoder_embedding(encoder_inputs) if recurrent_unit == 'lstm': @@ -50,7 +54,7 @@ def __init__(self, config): encoder_states = [state_h] with tf.device(devices[1]): - decoder_inputs = Input(shape=(None,)) + if not config.ensemble: decoder_inputs = Input(shape=(None, )) decoder_embedding = Embedding(config.target_vocab_size, config.embedding_dim, weights=[config.target_embedding_map], trainable=False) decoder_embedded = decoder_embedding(decoder_inputs) @@ -70,6 +74,12 @@ def __init__(self, config): decoder) # Use the final encoder state as context decoder_outputs, decoder_states = decoder[0], decoder[1] + # if config.ensemble: + # decoder_reshape = Reshape((128, self.config.target_vocab_size)) #? + # decoder_slice = Lambda(lambda x: x[:, 50:, :]) + # decoder_outputs = decoder_reshape(decoder_outputs) + # decoder_outputs = decoder_slice(decoder_outputs) + decoder_dense = Dense(config.target_vocab_size, activation='softmax') decoder_outputs = decoder_dense(decoder_outputs)