In [33]:
import tensorflow as tf
import numpy as np

In [80]:
# Referenced from https://github.com/ikostrikov/TensorFlow-Pointer-Networks/blob/master/dataset.py
class DataGenerator(object):

    def __init__(self):
        """Construct a DataGenerator."""
        pass

    def next_batch(self, batch_size, N, train_mode=True):
        """Return the next `batch_size` examples from this data set."""

        # A sequence of random numbers from [0, 1]
        reader_input_batch = []

        # Sorted sequence that we feed to encoder
        # In inference we feed an unordered sequence again
        decoder_input_batch = []

        # Ordered sequence where one hot vector encodes position in the input array
        writer_outputs_batch = []
        for _ in range(N):
            reader_input_batch.append(np.zeros([batch_size, 1]))
        for _ in range(N + 1):
            decoder_input_batch.append(np.zeros([batch_size, 1]))
            writer_outputs_batch.append(np.zeros([batch_size, N + 1]))

        for b in range(batch_size):
            shuffle = np.random.permutation(N)
            sequence = np.sort(np.random.randint(0, 100, N))
            shuffled_sequence = sequence[shuffle]

            for i in range(N):
                reader_input_batch[i][b] = shuffled_sequence[i]
                if train_mode:
                    decoder_input_batch[i + 1][b] = sequence[i]
                else:
                    decoder_input_batch[i + 1][b] = shuffled_sequence[i]
                writer_outputs_batch[shuffle[i]][b, i + 1] = 1.0

            # Points to the stop symbol
            writer_outputs_batch[N][b, 0] = 1.0

        return reader_input_batch, decoder_input_batch, writer_outputs_batch

In [75]:
lstm_cell = tf.contrib.rnn.LSTMCell
static_rnn = tf.nn.static_rnn

In [76]:
lstm_size = 256
batch_size = 100
input_max_len = 10
learning_rate = .01

In [60]:
tf.reset_default_graph()

uniform_initializer = tf.random_uniform_initializer(-.1, .1)
normal_initializer = tf.truncated_normal_initializer(.0, .1)

inputs = [tf.placeholder(tf.float32, [batch_size, 1], "enc_input_{}".format(i)) for i in range(input_max_len)]
targets = [tf.placeholder(tf.float32, [batch_size, input_max_len+1], "dec_target_{}".format(i)) for i in range(input_max_len+1)]
target_weights = [tf.placeholder(tf.float32, [batch_size, 1], "target_weights_{}".format(i)) for i in range(input_max_len+1)]

# Encoding
enc_cell = lstm_cell(lstm_size, initializer=uniform_initializer)
enc_outputs, enc_state = static_rnn(enc_cell, inputs, dtype=tf.float32)

end_of_sentence = tf.zeros([batch_size, lstm_size])
enc_outputs = [end_of_sentence] + enc_outputs

dec_state = enc_state

# enc_output.shape : batch_size * input_max_len * lstm_size
dec_inputs = [tf.placeholder(tf.float32, [batch_size, 1], "dec_inputs_{}".format(i)) for i in range(input_max_len+1)]
dec_cell = lstm_cell(lstm_size, initializer=uniform_initializer)

W1 = tf.get_variable("W1", [input_max_len, lstm_size], initializer=normal_initializer)
W2 = tf.get_variable("W2", [input_max_len, lstm_size], initializer=normal_initializer)
v = tf.get_variable("v", [input_max_len, 1], initializer=normal_initializer)

with tf.variable_scope("decode") as scope:
    predictions = []
    for i, dec_input in enumerate(dec_inputs):
        if i > 0:
            scope.reuse_variables()
        
        w_d = tf.get_variable("w_d", [1, lstm_size], initializer=normal_initializer)
        b_d = tf.get_variable("b_d", [lstm_size], initializer=tf.zeros_initializer())

        cell_input = tf.nn.elu(tf.matmul(dec_input, w_d) + b_d)
        dec_output, dec_state = dec_cell(cell_input, dec_state)
        
        u_i = []
        for j, enc_output in enumerate(enc_outputs):
            W1ej = tf.matmul(enc_output, tf.transpose(W1))
            W2di = tf.matmul(dec_output, tf.transpose(W2))
            u_i_j = tf.matmul(tf.tanh(W1ej +W2di), v)
            u_i.append(u_i_j)
        u_i = tf.concat(u_i, axis=1)
        p_c_i = tf.nn.softmax(u_i)
        predictions.append(p_c_i)

# For prediction
predictions_idx = tf.argmax(predictions, axis=2)
        
loss = tf.sqrt(tf.reduce_mean(tf.pow(tf.stack(predictions) - tf.stack(targets), 2.0)))
optimizer = tf.train.AdamOptimizer(learning_rate)
train_op = optimizer.minimize(loss)

In [88]:
init= tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    
    datagen = DataGenerator()
    
    def create_feed_dict(input_data, dec_input_data, dec_target_data):
        feed_dict = {}
        for placeholder, data in zip(inputs, input_data):
            feed_dict[placeholder] = data
        for placeholder, data in zip(dec_inputs, dec_input_data):
            feed_dict[placeholder] = data
        for placeholder, data in zip(targets, dec_target_data):
            feed_dict[placeholder] = data
        for placeholder in target_weights:
            feed_dict[placeholder] = np.ones([batch_size, 1])
        
        return feed_dict
    
    for i in range(10000):
        input_data, dec_input_data, dec_target_data = datagen.next_batch(batch_size, input_max_len)
        train_feed_dict = create_feed_dict(input_data, dec_input_data, dec_target_data)
        
        loss_val, _ = sess.run([loss, train_op], feed_dict=train_feed_dict)
        
        if i % 200 == 0:
            test_input_data, test_dec_input_data, test_dec_target_data = datagen.next_batch(batch_size, input_max_len)
            test_feed_dict = create_feed_dict(test_input_data, test_dec_input_data, test_dec_target_data)
            test_loss_val, idx = sess.run([loss, predictions_idx], feed_dict=test_feed_dict)
            
            print("Loss: {} / Validation: {}".format(loss_val, test_loss_val))
            
            sample_test_input_data = np.transpose(test_input_data, (1, 0, 2))[0].flatten()
            sample_idx = np.transpose(idx)[0][:-1] - 1
            predict_arr = [sample_test_input_data[i] for i in sample_idx]
            print("\t{:10}: {}".format("Input", np.array(sample_test_input_data).astype(int)))
            print("\t{:10}: {}".format("Prediction", np.array(predict_arr).astype(int)))

Loss: 0.28756311535835266 / Validation: 0.28734520077705383
	Input     : [10 37 56 22 71 23 89 85 63  9]
	Prediction: [9 9 9 9 9 9 9 9 9 9]
Loss: 0.16160115599632263 / Validation: 0.1540466994047165
	Input     : [47  5 90  3 61 83  7 87 55 10]
	Prediction: [ 3  5  7 10 47 61 61 83 90 90]
Loss: 0.14069871604442596 / Validation: 0.12411071360111237
	Input     : [79 43 78 60 46 74 10 70  6 10]
	Prediction: [ 6 10 10 43 46 60 70 74 78 78]
Loss: 0.12289181351661682 / Validation: 0.11063995212316513
	Input     : [47 26 28 40 40 61 23 39 78 72]
	Prediction: [23 26 28 39 39 40 47 61 72 78]
Loss: 0.11292976140975952 / Validation: 0.12545043230056763
	Input     : [44 47 77 18 57 61 79 29 37 87]
	Prediction: [18 29 37 44 44 57 61 77 79 87]
Loss: 0.10825784504413605 / Validation: 0.10536710172891617
	Input     : [60 62 62 24 13 10 82 18 36  6]
	Prediction: [ 6 10 13 18 24 36 60 62 62 82]
Loss: 0.10708148777484894 / Validation: 0.11803729832172394
	Input     : [15  6 72 13  9 37 76 71 67 31]
	Predi