Text Encoder Decoder
Inspired by https://arxiv.org/abs/1802.01817

In [5]:
import tensorflow as tf
import numpy as np
from IPython.display import display, HTML, clear_output
import functools
import random
from collections import namedtuple

In [55]:
def convert_example(text_example, temperature):
    """Takes input string, splits it into suffix and """
    split_point = (len(text_example)*3)//4
    range_max = min(split_point+2, len(text_example)-1)
    range_min = max(split_point-2, 3)
    split_point = random.randint(range_min, range_max) if range_min < range_max else range_min
    return _text_to_codes(text_example[:split_point]), _text_to_codes(text_example[split_point:])
    

In [16]:
def InputDatastream(params):
    dataset = tf.data.TextLineDataset("big_set.en")
    dataset = dataset.repeat(params.num_epochs)
    dataset = dataset.shuffle(params.shuffle_buffer)
    dataset = dataset.map(lambda text_string: 
                          tuple(tf.py_func(convert_examples,[text_string, temperature], [tf.uint8, tf.uint8])),8)    
    dataset = dataset.padded_batch(params.batch_size,
                                   padded_shapes=([params.max_string_length], [params.max_string_length]))
    dataset = dataset.map(byte_hot_encoding)
    training_iterator = dataset.make_initializable_iterator()
 

In [46]:
# model parameters
params = tf.contrib.training.HParams(
    max_string_length = 64,
    model_n = 8,
    
    batch_size = 256,
    shuffle_buffer = 100000,
    num_epochs = 10,
    
    conv_n_filters = 256,
    conv_filter_size = 3,
    
    gradient_noise_scale=None,
)

In [42]:
# Converts a string of text into a vector of char codes.
def _text_to_codes(text_string):
    text_string = str(text_string.strip()[:params.max_string_length])
    encoded = [ord(c) for c in text_string][:params.max_string_length]
    in_array = np.array(encoded, dtype=np.uint8)
    return in_array

def byte_hot_encoding(text_codes, corrupted_codes):
    return tf.one_hot(text_codes, 256), tf.one_hot(corrupted_codes, 256)

In [10]:
# Converts a prediction back to a string.
def _codes_to_string(codes):
    return "".join((chr(min(127, c)) for c in codes))

def _argmax_to_string(argmaxes):
    return "".join((chr(min(127, np.argmax(c))) for c in argmaxes))

def decode_text(decoded_tensor):
    from_max = tf.argmax(decoded_tensor, axis=-1)
    return from_max

In [11]:
def encoder_transform_to_internal(input_data, layer_id, training_flag):
    with tf.variable_scope("input_transform_%d" % layer_id):
      layer_outputs = [input_data]
      for _ in range(params.model_n):
        layer_output = tf.layers.conv1d(layer_outputs[-1],filters=params.conv_n_filters, kernel_size=params.conv_filter_size, padding="same", activation=None)
        layer_output = tf.layers.batch_normalization(layer_output, training=training_flag)
        layer_output = tf.nn.relu(layer_output)
        if len(layer_outputs) > 2:
          # add residual connection
          layer_output += layer_outputs[-2]
        layer_outputs.append(layer_output)
    return layer_outputs[-1]
    
def encoder_recursion_layer(input_data, layer_id, training_flag):
    processed = encoder_transform_to_internal(input_data, layer_id, training_flag)
    return tf.layers.max_pooling1d(processed, pool_size=2, strides = 2)

def encoder_fully_connected(input_data):
    with tf.variable_scope("encoder_fully_connected"):
        flattened = tf.contrib.layers.flatten(input_data)
        layer_size = flattened.get_shape()[-1].value
        layers_output = [flattened]
        for _ in range(params.model_n):
            layer_output = tf.contrib.layers.fully_connected(layers_output[-1], layer_size)
            if len(layers_output) > 2:
                layer_output += layers_output[-2]
            layers_output.append(layer_output)
    return layers_output[-1]
    

def build_encoder(input_data, training_flag):
    data = [encoder_transform_to_internal(input_data, 0, training_flag)]
    for i in range(1, 5):
      data.append(encoder_recursion_layer(data[-1], i, training_flag))
    # return encoder_fully_connected(data[-1])    
    return data[-1]

In [20]:
def decoder_fully_connected(input_data):
    with tf.variable_scope("decoder_fully_connected"):
        layer_size = input_data.get_shape()[-1].value
        layer_outputs = [input_data]
        for _ in range(params.model_n):
            layer_output = tf.contrib.layers.fully_connected(layer_outputs[-1], layer_size)
            if len(layer_outputs) > 2:
                layer_output += layer_outputs[-2]
            layer_outputs.append(layer_output)
        # Unflatten data.
        output = tf.reshape(layer_outputs[-1], shape=[-1,layer_size/params.conv_n_filters, params.conv_n_filters])
    return output

def decoder_transform_to_external(input_data, depth, layer_id, training_flag):
    with tf.variable_scope("decoder_transform_%d" % layer_id):
      layer_outputs = [input_data]
      for _ in range(depth):
        layer_output = tf.layers.conv1d(layer_outputs[-1],filters=params.conv_n_filters, kernel_size=params.conv_filter_size, padding="same", activation=None)
        layer_output = tf.layers.batch_normalization(layer_output, training=training_flag)
        layer_output = tf.nn.relu(layer_output)
        # Extend size 
        if len(layer_outputs) > 2:
          # add residual connection
          layer_output += layer_outputs[-2]
        layer_outputs.append(layer_output)
    return layer_outputs[-1]

def decoder_recursion_layer(input_data, layer_id, training_flag):
    with tf.variable_scope("decoder_expansion_%d" % layer_id):
        processed = decoder_transform_to_external(input_data, params.model_n-1,layer_id, training_flag)
        # Expand convolution.
        expanded = tf.layers.conv1d(processed,filters=2*params.conv_n_filters, kernel_size=params.conv_filter_size, padding="same", activation=None)
        expanded = tf.layers.batch_normalization(expanded, training=training_flag)
        expanded = tf.nn.relu(expanded)
        # Un-sampling is done by resize.
        expanded_shape = expanded.get_shape()
        expanded = tf.reshape(expanded, shape = [-1,expanded_shape[1].value*2,expanded_shape[2].value//2])
        return expanded

def build_decoder(output_data, training_flag):
    #data = [decoder_fully_connected(output_data)]
    data = [output_data]
    for i in range(5,1, -1):
        data.append(decoder_recursion_layer(data[-1], i, training_flag))
    return decoder_transform_to_external(data[-1], params.model_n, 0, training_flag)

In [None]:
with tf.Graph().as_default():
    temperature = tf.placeholder_with_default(0.1, shape=[], name='temperature')
    temperature = tf.Print(temperature, [temperature], message="Temperature")
    dataset = tf.data.TextLineDataset("big_set.en")
    dataset = dataset.repeat(params.num_epochs)
    dataset = dataset.shuffle(params.shuffle_buffer)
    dataset = dataset.map(lambda text_string: 
                          tuple(tf.py_func(convert_example,[text_string, temperature], [tf.uint8, tf.uint8])),8)    
    dataset = dataset.padded_batch(params.batch_size,
                                   padded_shapes=([params.max_string_length], [params.max_string_length]))
    dataset = dataset.map(byte_hot_encoding)
    training_iterator = dataset.make_initializable_iterator()
    training_flag = tf.placeholder(tf.bool, shape=(), name="training_flag")
    next_element = training_iterator.get_next()
    encoded = build_encoder(next_element[0], training_flag)
    predicted = build_decoder(encoded, training_flag)
    predicted_text = decode_text(predicted)
    loss = tf.losses.softmax_cross_entropy(tf.reshape(next_element[1], shape=[-1,256]),
                                           tf.reshape(predicted, shape=[-1, 256]))
    optimizer = tf.train.AdamOptimizer()
    global_step = tf.Variable(0, trainable=False, name="global_step")
    exponential_decay_fn = functools.partial(
      tf.train.exponential_decay,
      decay_steps=10000,
      decay_rate=0.5,
      staircase=True)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        training_op = tf.contrib.layers.optimize_loss(
            loss,
            global_step,
            learning_rate=None,
            optimizer=optimizer,
            learning_rate_decay_fn=None,
            gradient_noise_scale=params.gradient_noise_scale
        )
    scaffold = tf.train.Scaffold(
        local_init_op=tf.group(tf.local_variables_initializer(),training_iterator.initializer),
        init_op=tf.global_variables_initializer())
    with tf.train.MonitoredTrainingSession(checkpoint_dir="./checkpoint_predictor",scaffold=scaffold,
                                           save_checkpoint_secs=120) as sess:
        print("Start training")
        while not sess.should_stop():
            o_loss, _, o_step = sess.run([loss, training_op, global_step],  {training_flag:True, temperature:0.1})
            if o_step%100 == 1:
                clear_output(True)
                output_string = """Loss {0}, step {1} <BR><table>
                <tr><th>Source</th><th>Decoded</th></tr>""".format(o_loss, o_step)
                o_source_text, o_corrupted_text, o_predicted_text = sess.run([next_element[0], next_element[1], predicted_text], 
                                                                             {training_flag:False,temperature:0.1})
                for t_source, t_corrupted, t_predicted in zip((x for x in o_source_text),
                                                              (y for y in o_corrupted_text),
                                                              (z for z in o_predicted_text)):
                    output_string += """<tr><td>{source}<br>{corrupted}</td><td>{decoded}</td></tr>""".format(
                        source=_argmax_to_string(t_source),
                        corrupted=_argmax_to_string(t_corrupted),
                        decoded=_codes_to_string(t_predicted))
                output_string += "</table>"
                display(HTML(output_string))
                
                


Source,Decoded
b'to the abutments. This transmi'������������������������������� b'tted the load'������������������������������������������������,b'ssio e''��������������������������������������������������
b'the garment. The edges of the bre'���������������������������� b'ast prostheses'�����������������������������������������������,b'a ae'����������������������������������������������������
b'Regensburg Hauptbahnhof and'���������������������������������� b'also welcome'�������������������������������������������������,b'the a�'�������������������������������������������������
b'that would later become Grea'��������������������������������� b't St. Martin.'������������������������������������������������,b't roree on'�������������������������������������������������
b'high school sports. Broadcasts'������������������������������� b'of basketball'������������������������������������������������,b'of e ''���������������������������������������������������
"b'Dragons"", also published by Wiza'����������������������������� b'rds of the'���������������������������������������������������",b'na '�����������������������������������������������������
b'War II. He was a graduate of t'������������������������������� b'he University'������������������������������������������������,b'he sort a'���������������������������������������������������
b'altitude'����������������������������������������������������� b's.'�����������������������������������������������������������,b'.'������������������������������������������������������������
"b'of Melbourne, graduating with B'������������������������������ b'achelor of'���������������������������������������������������",b'art '����������������������������������������������������
b'in the second qualifying round held at'����������������������� b'Terenzano'����������������������������������������������������,b'aareeaiin'����������������������������������������������������


INFO:tensorflow:Saving checkpoints for 507203 into ./checkpoint_predictor/model.ckpt.
INFO:tensorflow:global_step/sec: 2.83051
