In [37]:
import numpy as np
import tensorflow as tf
import helper

tf.reset_default_graph()
sess = tf.InteractiveSession()

In [38]:
en_sentences = helper.read_sentences('data/small_vocab_en')
fr_sentences = helper.read_sentences('data/small_vocab_fr')

X, Y, en_word2idx, en_idx2word, en_vocab, fr_word2idx, fr_idx2word, fr_vocab = helper.create_dataset(en_sentences, fr_sentences)

In [40]:
vocab_size = len(fr_vocab)
input_embedding_size = vocab_size * 2

encoder_hidden_units = 20
decoder_hidden_units = encoder_hidden_units * 2

In [41]:
encoder_inputs = tf.placeholder(shape = (None, None), dtype = tf.int32, name = 'encoder_inputs')
encoder_inputs_length = tf.placeholder(shape = (None,), dtype = tf.int32, name = 'encoder_inputs_length')
decoder_targets = tf.placeholder(shape = (None, None), dtype = tf.int32, name = 'decoder_targets')
decoder_targets_length = encoder_inputs_length = tf.placeholder(shape = (None,), dtype = tf.int32, name = 'encoder_inputs_length')

In [42]:
embeddings = tf.Variable(tf.random_uniform([vocab_size, input_embedding_size], -1.0, 1.0), dtype = tf.float32)

encoder_inputs_embedded = tf.nn.embedding_lookup(embeddings, encoder_inputs)

In [43]:
from tensorflow.python.ops.rnn_cell import LSTMCell, LSTMStateTuple

In [44]:
encoder_cell = LSTMCell(encoder_hidden_units)

In [45]:
((encoder_fw_outputs,
  encoder_bw_outputs),
 (encoder_fw_final_state,
  encoder_bw_final_state)) = (
    tf.nn.bidirectional_dynamic_rnn(cell_fw=encoder_cell,
                                    cell_bw=encoder_cell,
                                    inputs=encoder_inputs_embedded,
                                    sequence_length=encoder_inputs_length,
                                    dtype=tf.float32, time_major=True)
    )

In [46]:
encoder_outputs = tf.concat((encoder_fw_outputs, encoder_bw_outputs), 2)
encoder_final_state_c = tf.concat((encoder_fw_final_state.c, encoder_bw_final_state.c), 1)
encoder_final_state_h = tf.concat((encoder_fw_final_state.h, encoder_bw_final_state.h), 1)
encoder_final_state = LSTMStateTuple(c = encoder_final_state_c, h = encoder_final_state_h)

In [47]:
decoder_cell = LSTMCell(decoder_hidden_units)

In [48]:
ecoder_max_time, batch_size = tf.unstack(tf.shape(encoder_inputs))

In [49]:
W = tf.Variable(tf.random_uniform([decoder_hidden_units, vocab_size], -1, 1), dtype = tf.float32)

In [50]:
b = tf.Variable(tf.zeros([vocab_size]), dtype=tf.float32)

In [51]:
eos_time_slice = tf.ones([batch_size], dtype=tf.int32, name='EOS')
pad_time_slice = tf.zeros([batch_size], dtype=tf.int32, name='PAD')

eos_step_embedded = tf.nn.embedding_lookup(embeddings, eos_time_slice)
pad_step_embedded = tf.nn.embedding_lookup(embeddings, pad_time_slice)

In [52]:
def loop_fn_initial():
    initial_elements_finished = (0 >= decoder_targets_length)
    initial_input = eos_step_embedded
    initial_cell_state = encoder_final_state
    initial_cell_output = None
    initial_loop_state = None
    return (initial_elements_finished,
            initial_input,
            initial_cell_state,
            initial_cell_output,
            initial_loop_state)

In [53]:
def loop_fn_transition(time, previous_output, previous_state, previous_loop_state):
    
    def get_next_input():
        output_logits = tf.add(tf.matmul(previous_output, W), b)
        prediction = tf.argmax(output_logits, axis=1)
        next_input = tf.nn.embedding_lookup(embeddings, prediction)
        return next_input
    
    elements_finished = (time >= decoder_targets_length)
    finished = tf.reduce_all(elements_finished)
    input = tf.cond(finished, lambda: pad_step_embedded, get_next_input)
    
    state = previous_state
    output = previous_output
    loop_state = None
    
    return (elements_finished,
           input,
           state,
           output,
           loop_state)

In [54]:
def loop_fn(time, previous_output, previous_state, previous_loop_state):
    if previous_state is None:
        assert previous_output is None and previous_state is None
        return loop_fn_initial()
    else:
        return loop_fn_transition(time, previous_output, previous_state, previous_loop_state)

In [55]:
decoder_outputs_ta, decoder_final_state,_ = tf.nn.raw_rnn(decoder_cell, loop_fn)
decoder_outputs = decoder_outputs_ta.stack()

In [56]:
decoder_max_steps, decoder_batch_size, decoder_dim = tf.unstack(tf.shape(decoder_outputs))
decoder_outputs_flat = tf.reshape(decoder_outputs, (-1,decoder_dim))
decoder_logits_flat = tf.add(tf.matmul(decoder_outputs_flat, W), b)
decoder_logits = tf.reshape(decoder_logits_flat, (decoder_max_steps, decoder_batch_size, vocab_size))

In [57]:
decoder_prediction = tf.argmax(decoder_logits, 2)

In [58]:
stepwise_cross_entropy = tf.nn.softmax_cross_entropy_with_logits(
    labels=tf.one_hot(decoder_targets, depth=vocab_size, dtype=tf.float32),
    logits=decoder_logits,
)
learning_rate = 5e-3
#loss function
loss = tf.reduce_mean(stepwise_cross_entropy)
#train it 
train_op = tf.train.RMSPropOptimizer(learning_rate).minimize(loss)

In [89]:
sess.run(tf.global_variables_initializer())

In [83]:
import numpy as np
def c_batch(inputs):
    sequence_lengths = [len(seq) for seq in inputs]
    batch_size = len(inputs)
    
    max_sequence_length = max(sequence_lengths)
    inputs_batch_major = np.zeros(shape=[batch_size, max_sequence_length], dtype=np.int32) # == PAD
    
    for i, seq in enumerate(inputs):
        for j, element in enumerate(seq):
            inputs_batch_major[i, j] = element
    inputs_time_major = inputs_batch_major.swapaxes(0, 1)
    return inputs_time_major, sequence_lengths

In [86]:
def next_feed():
    x = X[:100]
    y = Y[:100]
    encoder_inputs_, encoder_input_lengths_ = c_batch(x)
    decoder_targets_,decoder_targets_length_ = c_batch(y)
    return {
        encoder_inputs: encoder_inputs_,
        encoder_inputs_length: encoder_input_lengths_,
        decoder_targets: decoder_targets_,
        decoder_targets_length: decoder_targets_length_
    }


In [92]:
def c_decode(sequence, idx2word):
    sentence = []
    for s in sequence:
        sentence.append(idx2word[s])
    return ' '.join(sentence)

In [94]:
max_batches = 3001
batches_in_epoch = 1000
loss_track = []
try:
    print('start training')
    for batch in range(max_batches):
        fd = next_feed()
        _, l = sess.run([train_op, loss], fd)
        loss_track.append(l)

        if batch == 0 or batch % batches_in_epoch == 0:
            print('batch {}'.format(batch))
            print('  minibatch loss: {}'.format(sess.run(loss, fd)))
            predict_ = sess.run(decoder_prediction, fd)
            for i, (inp, out, pred) in enumerate(zip(fd[encoder_inputs].T, fd[decoder_targets].T, predict_.T)):
                print('  sample {}:'.format(i + 1))
                print('    input              > {}'.format(inp))
                print('    input              > {}'.format(c_decode(inp, en_idx2word)))
                print('    output             > {}'.format(out))
                print('    output sentence    > {}'.format(c_decode(out, fr_idx2word)))
                print('    predicted          > {}'.format(pred))
                print('    predicted sentence > {}'.format(c_decode(pred, fr_idx2word)))
                if i >= 2:
                    break
            print()
    print('end training')
except KeyboardInterrupt:
    print('training interrupted')

start training
batch 0
  minibatch loss: 6.114712238311768
  sample 1:
    input              > [24  3 11 69  6 40  2  9  5  3 11 70  4 36  2]
    input              > california is usually quiet during march  and it is usually hot in june 
    output             > [104   5  16  70   6  48   4  10   7   5  16  25   6  44   4]
    output sentence    > california est généralement calme en mars  et il est généralement chaud en juin 
    predicted          > [ 86 219 178 325  80 243  92 166 170 139 343 262 245  30  52]
    predicted sentence > pomme cheval préférés petits raisins voulaient fraise nouveau se visiter redoutée chevaux détendre à juillet
  sample 2:
    input              > [33 13 15  3  7 86  2  8 32 13  3  7 84  2  0]
    input              > his favorite fruit is the orange  but my favorite is the grape  <ukn>
    output             > [24 20 21  5 87  4  9 42 21  5 14 85  4  0  0]
    output sentence    > son fruit préféré est l'orange  mais mon préféré est le raisin  <ukn>