Skip to content

Commit

Permalink
Remove decode unoll #14
Browse files Browse the repository at this point in the history
  • Loading branch information
zeynepakkalyoncu committed Dec 17, 2018
1 parent c90066e commit c4f63fc
Showing 1 changed file with 0 additions and 11 deletions.
11 changes: 0 additions & 11 deletions lib/model/seq2seq.py
Expand Up @@ -25,7 +25,6 @@ def __init__(self, config):
# Decoder
with tf.device(devices[1]):
decoder_inputs = Input(shape=(None, ))
# decoder_outputs = Lambda(self.decode, arguments={'prev_states': encoder_states})(decoder_inputs)
decoder_outputs = self.decode(decoder_inputs, encoder_states)

# Input: Source and target sentence, Output: Predicted translation
Expand All @@ -45,16 +44,6 @@ def encode(self, initial_weights, encoder_inputs):
_, state_h, state_c = encoder
return [state_h, state_c]

def decode_unroll(self, decoder_inputs, prev_states):
self.config.max_target_len = 10
decoder_outputs = np.zeros((self.config.batch_size, self.config.max_target_len))
for i in range(self.config.max_target_len):
decoder_output, prev_states = self.decode_step(decoder_inputs, prev_states)
decoder_output = np.argmax(decoder_output, axis=-1) # Greedy
decoder_outputs[:, i] = decoder_output

return K.variable(decoder_outputs, dtype=tf.int64)

def decode(self, decoder_inputs, encoder_states):
decoder_embedding = Embedding(self.config.target_vocab_size, self.config.embedding_dim,
weights=[self.config.target_embedding_map], trainable=False)
Expand Down

0 comments on commit c4f63fc

Please sign in to comment.