In [1]:
import numpy as np
import tensorflow as tf

import tempfile
logdir = tempfile.mkdtemp()
print(logdir)

/tmp/tmpo11htsq7


In [2]:
tf.reset_default_graph()
sess = tf.InteractiveSession()

In [3]:
seq_length = 5
batch_size = 64

vocab_size = 7
embedding_dim = 50

memory_dim = 100

In [4]:
enc_inp = [tf.placeholder(tf.int32, shape=(None,),
                          name="inp%i" % t)
           for t in range(seq_length)]

labels = [tf.placeholder(tf.int32, shape=(None,),
                        name="labels%i" % t)
          for t in range(seq_length)]

weights = [tf.ones_like(labels_t, dtype=tf.float32)
           for labels_t in labels]

# Decoder input: prepend some "GO" token and drop the final
# token of the encoder input
dec_inp = ([tf.zeros_like(enc_inp[0], dtype=np.int32, name="GO")]
           + enc_inp[:-1])

# Initial memory value for recurrence.
prev_mem = tf.zeros((batch_size, memory_dim))

In [5]:
cell = tf.nn.rnn_cell.GRUCell(memory_dim)

dec_outputs, dec_memory = tf.nn.seq2seq.embedding_rnn_seq2seq(
    enc_inp, dec_inp, cell, vocab_size, vocab_size, embedding_dim)

In [6]:
loss = tf.nn.seq2seq.sequence_loss(dec_outputs, labels, weights, vocab_size)

In [7]:
tf.summary.scalar("loss", loss)

<tf.Tensor 'loss:0' shape=() dtype=string>

In [8]:
magnitude = tf.sqrt(tf.reduce_sum(tf.square(dec_memory[1])))
tf.summary.scalar("magnitude at t=1", magnitude)

INFO:tensorflow:Summary name magnitude at t=1 is illegal; using magnitude_at_t_1 instead.


<tf.Tensor 'magnitude_at_t_1:0' shape=() dtype=string>

In [9]:
summary_op = tf.summary.merge_all()

In [10]:
learning_rate = 0.05
momentum = 0.9
optimizer = tf.train.MomentumOptimizer(learning_rate, momentum)
train_op = optimizer.minimize(loss)

In [11]:
summary_writer = tf.summary.FileWriter(logdir, sess.graph)

In [12]:
init = tf.global_variables_initializer()
sess.run(init)

In [13]:
def train_batch(batch_size):
    X = [np.random.choice(vocab_size, size=(seq_length,), replace=False)
         for _ in range(batch_size)]
    Y = X[:]
    
    # Dimshuffle to seq_len * batch_size
    X = np.array(X).T
    Y = np.flipud(np.array(Y).T)

    feed_dict = {enc_inp[t]: X[t] for t in range(seq_length)}
    feed_dict.update({labels[t]: Y[t] for t in range(seq_length)})

    _, loss_t, summary = sess.run([train_op, loss, summary_op], feed_dict)
    return loss_t, summary

In [14]:
for t in range(500):
    loss_t, summary = train_batch(batch_size)
    summary_writer.add_summary(summary, t)
summary_writer.flush()

In [15]:
X_batch = [np.random.choice(vocab_size, size=(seq_length,), replace=False)
           for _ in range(10)]
X_batch = np.array(X_batch).T

feed_dict = {enc_inp[t]: X_batch[t] for t in range(seq_length)}
dec_outputs_batch = sess.run(dec_outputs, feed_dict)

In [16]:
X_batch

array([[0, 5, 2, 0, 3, 0, 6, 6, 2, 4],
       [5, 2, 4, 6, 1, 2, 5, 0, 4, 1],
       [4, 3, 5, 4, 5, 1, 4, 4, 6, 2],
       [6, 0, 1, 3, 4, 4, 3, 2, 0, 0],
       [1, 6, 0, 2, 0, 3, 0, 3, 3, 5]])

In [17]:
[logits_t.argmax(axis=1) for logits_t in dec_outputs_batch]

[array([1, 6, 0, 2, 0, 3, 0, 3, 3, 5]),
 array([6, 0, 1, 3, 4, 4, 3, 2, 0, 0]),
 array([4, 3, 5, 4, 5, 1, 4, 4, 6, 2]),
 array([5, 2, 4, 6, 1, 2, 5, 0, 4, 1]),
 array([0, 5, 2, 0, 3, 0, 6, 6, 2, 4])]