Text Encoder Decoder
Inspired by https://arxiv.org/abs/1802.01817

In [11]:
import tensorflow as tf
import numpy as np
from IPython.display import display, HTML, clear_output
import functools

In [3]:
# model parameters
params = tf.contrib.training.HParams(
    max_string_length = 64,
    model_n = 8,
    
    batch_size = 256,
    shuffle_buffer = 10000,
    num_epochs = 10,
    
    conv_n_filters = 256,
    conv_filter_size = 3,
)

In [4]:
# Converts a string of text into a vector of char codes.
def _text_to_codes(text_string):
    text_string = text_string.strip()[:params.max_string_length]
    in_array = np.array([ord(c) for c in text_string], dtype=np.uint8)
    return in_array

def byte_hot_encoding(byte_codes):
    return tf.one_hot(byte_codes, 256)

In [18]:
# Converts a prediction back to a string.
def _codes_to_string(codes):
    return "".join((chr(min(127, c)) for c in codes))

def _argmax_to_string(argmaxes):
    return "".join((chr(min(127, np.argmax(c))) for c in argmaxes))

def decode_text(decoded_tensor):
    from_max = tf.argmax(decoded_tensor, axis=-1)
    return from_max

In [6]:
def encoder_transform_to_internal(input_data, layer_id, training_flag):
    with tf.variable_scope("input_transform_%d" % layer_id):
      layer_outputs = [input_data]
      for _ in range(params.model_n):
        layer_output = tf.layers.conv1d(layer_outputs[-1],filters=params.conv_n_filters, kernel_size=params.conv_filter_size, padding="same", activation=None)
        layer_output = tf.layers.batch_normalization(layer_output, training=training_flag)
        layer_output = tf.nn.relu(layer_output)
        if len(layer_outputs) > 2:
          # add residual connection
          layer_output += layer_outputs[-2]
        layer_outputs.append(layer_output)
    return layer_outputs[-1]
    
def encoder_recursion_layer(input_data, layer_id, training_flag):
    processed = encoder_transform_to_internal(input_data, layer_id, training_flag)
    return tf.layers.max_pooling1d(processed, pool_size=2, strides = 2)

def encoder_fully_connected(input_data):
    with tf.variable_scope("encoder_fully_connected"):
        flattened = tf.contrib.layers.flatten(input_data)
        layer_size = flattened.get_shape()[-1].value
        layers_output = [flattened]
        for _ in range(params.model_n):
            layer_output = tf.contrib.layers.fully_connected(layers_output[-1], layer_size)
            if len(layers_output) > 2:
                layer_output += layers_output[-2]
            layers_output.append(layer_output)
    return layers_output[-1]
    

def build_encoder(input_data, training_flag):
    data = [encoder_transform_to_internal(input_data, 0, training_flag)]
    for i in range(1, 5):
      data.append(encoder_recursion_layer(data[-1], i, training_flag))
    # return encoder_fully_connected(data[-1])    
    return data[-1]

In [7]:
def decoder_fully_connected(input_data):
    with tf.variable_scope("decoder_fully_connected"):
        layer_size = input_data.get_shape()[-1].value
        layer_outputs = [input_data]
        for _ in range(params.model_n):
            layer_output = tf.contrib.layers.fully_connected(layer_outputs[-1], layer_size)
            if len(layer_outputs) > 2:
                layer_output += layer_outputs[-2]
            layer_outputs.append(layer_output)
        # Unflatten data.
        output = tf.reshape(layer_outputs[-1], shape=[-1,layer_size/params.conv_n_filters, params.conv_n_filters])
    return output

def decoder_transform_to_external(input_data, depth, layer_id, training_flag):
    with tf.variable_scope("decoder_transform_%d" % layer_id):
      layer_outputs = [input_data]
      for _ in range(depth):
        layer_output = tf.layers.conv1d(layer_outputs[-1],filters=params.conv_n_filters, kernel_size=params.conv_filter_size, padding="same", activation=None)
        layer_output = tf.layers.batch_normalization(layer_output, training=training_flag)
        layer_output = tf.nn.relu(layer_output)
        # Extend size 
        if len(layer_outputs) > 2:
          # add residual connection
          layer_output += layer_outputs[-2]
        layer_outputs.append(layer_output)
    return layer_outputs[-1]

def decoder_recursion_layer(input_data, layer_id, training_flag):
    with tf.variable_scope("decoder_expansion_%d" % layer_id):
        processed = decoder_transform_to_external(input_data, params.model_n-1,layer_id, training_flag)
        # Expand convolution.
        expanded = tf.layers.conv1d(processed,filters=2*params.conv_n_filters, kernel_size=params.conv_filter_size, padding="same", activation=None)
        expanded = tf.layers.batch_normalization(expanded, training=training_flag)
        expanded = tf.nn.relu(expanded)
        # Un-sampling is done by resize.
        expanded_shape = expanded.get_shape()
        expanded = tf.reshape(expanded, shape = [-1,expanded_shape[1].value*2,expanded_shape[2].value/2])
        return expanded

def build_decoder(output_data, training_flag):
    #data = [decoder_fully_connected(output_data)]
    data = [output_data]
    for i in range(5,1, -1):
        data.append(decoder_recursion_layer(data[-1], i, training_flag))
    return decoder_transform_to_external(data[-1], params.model_n, 0, training_flag)

In [22]:
with tf.Graph().as_default():
    dataset = tf.contrib.data.TextLineDataset("train.en")
    dataset = dataset.repeat(params.num_epochs)
    dataset = dataset.shuffle(params.shuffle_buffer)
    dataset = dataset.map(lambda text_string: tf.py_func(_text_to_codes,[text_string], tf.uint8))
    dataset = dataset.padded_batch(params.batch_size, [params.max_string_length])
    dataset = dataset.map(byte_hot_encoding)
    training_iterator = dataset.make_initializable_iterator()
    training_flag = tf.placeholder(tf.bool, shape=(), name="training_flag")
    next_element = training_iterator.get_next()
    encoded = build_encoder(next_element, training_flag)
    predicted = build_decoder(encoded, training_flag)
    predicted_text = decode_text(predicted)
    loss = tf.losses.softmax_cross_entropy(tf.reshape(next_element, shape=[-1,256]), tf.reshape(predicted, shape=[-1, 256]))
    optimizer = tf.train.AdamOptimizer()
    global_step = tf.Variable(0, trainable=False, name="global_step")
    exponential_decay_fn = functools.partial(
      tf.train.exponential_decay,
      decay_steps=10000,
      decay_rate=0.5,
      staircase=True)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        training_op = tf.contrib.layers.optimize_loss(
            loss,
            global_step,
            learning_rate=None,
            optimizer=optimizer,
            learning_rate_decay_fn=exponential_decay_fn)
    scaffold = tf.train.Scaffold(
        local_init_op=tf.group(tf.local_variables_initializer(),training_iterator.initializer),
        init_op=tf.global_variables_initializer())
    with tf.train.MonitoredTrainingSession(checkpoint_dir="./checkpoint",scaffold=scaffold, save_checkpoint_secs=120, save_summaries_secs=60) as sess:
        print "Start training"
        while not sess.should_stop():
            o_loss, _, o_step = sess.run([loss, training_op, global_step],  {training_flag:True})
            if o_step%100 == 1:
                clear_output(True)
                output_string = """Loss {0}, step {1} <BR><table>
                <tr><th>Source</th><th>Decoded</th></tr>""".format(o_loss, o_step)
                o_source_text, o_predicted_text = sess.run([next_element, predicted_text], {training_flag:False})
                for t_source, t_predicted in zip((x for x in o_source_text),(y for y in o_predicted_text)):
                    output_string += """<tr><td>{source}</td><td>{decoded}</td></tr>""".format(source=_argmax_to_string(t_source),
                                                                                               decoded=_codes_to_string(t_predicted))
                output_string += "</table>"
                display(HTML(output_string))
                
                


Source,Decoded
What of the Brand and the Label?��������������������������������,What if the Brank and the polel?��������������������������������
And this still goes on as I think I told you last year.���������,And this still goes on as I think I told you last year.���������
But it is strange.����������������������������������������������,But it is strange.����������������������������������������������
Would you all like to stand up for a moment?��������������������,Could you all like to staod up for a foment?��������������������
"In the middle of my traveling, I turned 40 and I began to hate","In the middle of my traveling, I turned up and I began to ease"
"So it was always the satirist, like Juvenal or Martial, repres","So it was always the sytinist, like Nagiral of carhfal, deflec"
"Right now, AIDG is working with KPFF Consulting Engineers, Arch","Right now, AID is working with PASA conselting engineers, such"
"Now, why do you get a paralyzed phantom limb?�������������������","Now, why do you get a paralyzed phantom lifp?�������������������"
And so we ended up with this wild patchwork of regulations all,And so we ended up with this wold Borthwork of regulations all
"Using the power of the smartphone, it can examine the cube and","Using the power of the smartphone, it can exalize the ciny and"


INFO:tensorflow:global_step/sec: 3.48744
INFO:tensorflow:Saving checkpoints for 40636 into ./checkpoint/model.ckpt.
INFO:tensorflow:Saving checkpoints for 40654 into ./checkpoint/model.ckpt.
