In [1]:
import tensorflow as tf
from tensorflow.python.ops import lookup_ops
from tensorflow.python.layers import core as layers_core

tf.reset_default_graph()

# Create training data

In [2]:
with open('/tmp/toy_data.txt', 'w') as data_file:
    for _ in range(1000):
        data_file.write("abc\tdefdef\n")
        data_file.write("def\tabcabc\n")

# Vocabulary as a lookup table

In [3]:
vocab = ['PAD', 'EOS', 'SOS'] + list("aábcdef")
EOS = 1  # end of sentence
SOS = 2  # start of sentence (GO symbol)
table = lookup_ops.index_table_from_tensor(tf.constant(vocab))
vocab = {k: i for i, k in enumerate(vocab)}
vocab_size = len(vocab)

table_initializer = tf.tables_initializer()

# Reading dataset

Format:

~~~
input TAB output
input TAB output
~~~

In [4]:
dataset = tf.contrib.data.TextLineDataset('/tmp/toy_data.txt')
dataset = dataset.map(lambda string: tf.string_split([string], delimiter='\t').values)
source = dataset.map(lambda string: string[0])
target = dataset.map(lambda string: string[1])

source = source.map(lambda string: tf.string_split([string], delimiter='').values)
source = source.map(lambda words: table.lookup(words))
target = target.map(lambda string: tf.string_split([string], delimiter='').values)
target = target.map(lambda words: table.lookup(words))

src_tgt_dataset = tf.contrib.data.Dataset.zip((source, target))
src_tgt_dataset = src_tgt_dataset.map(
    lambda src, tgt: (src,
                      tf.concat(([SOS], tgt), 0),
                      tf.concat((tgt, [EOS]), 0),)
)
src_tgt_dataset = src_tgt_dataset.map(
    lambda src, tgt_in, tgt_out: (src, tgt_in, tgt_out, tf.size(src), tf.size(tgt_in))
)

# Padded batch

In [5]:
# if I set the third padding shape to tf.TensorShape([5]),
# it fails if there is no 4 character long sample in the batch
# WHY???
batched = src_tgt_dataset.padded_batch(4, padded_shapes=(
    tf.TensorShape([6]), tf.TensorShape([9]), tf.TensorShape([None]),
         tf.TensorShape([]), tf.TensorShape([])))
batched_iter = batched.make_initializable_iterator()
src_ids, tgt_in_ids, tgt_out_ids, src_size, tgt_size = batched_iter.get_next()

# Encoder

In [6]:
embedding = tf.get_variable("embedding", [vocab_size, 3], dtype=tf.float32)

encoder_emb_inp = tf.nn.embedding_lookup(embedding, src_ids)
    
encoder_cell = tf.contrib.rnn.BasicLSTMCell(16)

encoder_outputs, encoder_state = tf.nn.dynamic_rnn(encoder_cell, encoder_emb_inp, dtype=tf.float32,
                                                   sequence_length=src_size)

# Decoder

In [7]:
decoder_initial_state = encoder_state
decoder_cell = tf.contrib.rnn.BasicLSTMCell(16)
decoder_emb_inp = tf.nn.embedding_lookup(embedding, tgt_in_ids)
helper = tf.contrib.seq2seq.TrainingHelper(decoder_emb_inp, tgt_size)
decoder = tf.contrib.seq2seq.BasicDecoder(decoder_cell, helper, decoder_initial_state)
outputs, final_context_state, _ = tf.contrib.seq2seq.dynamic_decode(decoder)
sample_id = outputs.sample_id
output_proj = layers_core.Dense(vocab_size, name="output_projection")
logits = output_proj(outputs.rnn_output)

# Loss

In [8]:
crossent = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tgt_out_ids, logits=logits)
target_weights = tf.sequence_mask(tgt_size, tf.shape(tgt_out_ids)[1], tf.float32)
loss = tf.reduce_sum(crossent * target_weights) / tf.to_float(5)

# Optimizer and gradient update

In [9]:
optimizer =tf.train.AdamOptimizer(0.1)
params = tf.trainable_variables()
gradients = tf.gradients(loss, params)
update = optimizer.apply_gradients(zip(gradients, params))

# Starting session

In [10]:
sess = tf.InteractiveSession()
sess.run(table_initializer)
sess.run(batched_iter.initializer)
sess.run(tf.global_variables_initializer())

# Training

In [11]:
for i in range(100):
    sess.run(update)
    l = sess.run(loss)
    if i % 10 == 9:
        print("Iteration: {}, training loss: {}".format(i+1, l))

Iteration: 10, training loss: 2.7747371196746826
Iteration: 20, training loss: 0.1978912651538849
Iteration: 30, training loss: 0.030035044997930527
Iteration: 40, training loss: 0.01124502345919609
Iteration: 50, training loss: 0.006356580648571253
Iteration: 60, training loss: 0.004859971813857555
Iteration: 70, training loss: 0.004096614196896553
Iteration: 80, training loss: 0.003633463056758046
Iteration: 90, training loss: 0.003309317398816347
Iteration: 100, training loss: 0.003039924893528223


# Manual greedy decoding

NOTE: running logits iterates over the next batch in the dataset, so running this cell multiple times decodes a different batch in the dataset.

In [12]:
inv_vocab = {v: k for k, v in vocab.items()}
inv_vocab[-1] = 'UNK'
skip_symbols = ('PAD', 'SOS', 'EOS', 'UNK')

## Input and output labels

Greed: just take the highest probabilty along the last axis.

In [13]:
input_ids, out_probs = sess.run([src_ids, logits])
output_ids = out_probs.argmax(axis=-1)

output_ids.shape

(4, 7)

## Convert labels to characters

In [14]:
def decode_ids(input_ids, output_ids):
    decoded = []
    for sample_i in range(output_ids.shape[0]):
        input_sample = input_ids[sample_i]
        output_sample = output_ids[sample_i]
        input_decoded = [inv_vocab[s] for s in input_sample]
        input_decoded = ''.join(c for c in input_decoded if c not in skip_symbols)
        output_decoded = [inv_vocab[s] for s in output_sample]
        output_decoded = ''.join(c for c in output_decoded if c not in skip_symbols)
        decoded.append((input_decoded, output_decoded))
    return decoded
 
decoded = decode_ids(input_ids, output_ids)
print('\n'.join(
    '{} ---> {}'.format(dec[0], dec[1]) for dec in decoded
))

abc ---> defdef
def ---> abcabc
abc ---> defdef
def ---> abcabc


# Greedy decoding with `GreedyEmbeddingHelper`

The encoder stays the same but we need to redefine the decoder.

In [15]:
helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(embedding, tf.fill([4], SOS), EOS)
decoder = tf.contrib.seq2seq.BasicDecoder(decoder_cell, helper, encoder_state,
                                         output_layer=output_proj)

outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder, maximum_iterations=4)

In [16]:
input_ids, output_ids = sess.run([src_ids, outputs.sample_id])
decoded = decode_ids(input_ids, output_ids)
print('\n'.join(
    '{} ---> {}'.format(dec[0], dec[1]) for dec in decoded
))

abc ---> defd
def ---> abca
abc ---> defd
def ---> abca


# Beam search decoding

In [17]:
beam_width = 2
start_tokens = tf.fill([4], SOS)
bm_decoder_initial_state = tf.contrib.seq2seq.tile_batch(encoder_state, multiplier=beam_width)
bm_decoder = tf.contrib.seq2seq.BeamSearchDecoder(
    cell=decoder_cell,
    embedding=embedding,
    start_tokens=start_tokens,
    initial_state=bm_decoder_initial_state,
    beam_width=beam_width,
    output_layer=output_proj,
    end_token=EOS,
)
bm_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(bm_decoder, maximum_iterations=100)

In [18]:
input_ids, output_ids = sess.run([src_ids, bm_outputs.predicted_ids])
output_ids

all_decoded = []
for beam_i in range(beam_width):
    inputs = []
    all_decoded.append([])
    decoded = decode_ids(input_ids, output_ids[:,:,beam_i])
    for dec in decoded:
        all_decoded[-1].append(dec[1])
        inputs.append(dec[0])

print('\n'.join(
    '{} ---> {}'.format(inputs[i], ' / '.join(d[i] for d in all_decoded))
                        for i in range(len(inputs))
))

abc ---> defdef / deffdef
def ---> abcabc / abcabca
abc ---> defdef / deffdef
def ---> abcabc / abcabca
