Text Encoder Decoder
Inspired by https://arxiv.org/abs/1802.01817

In [11]:
import tensorflow as tf
import numpy as np
from IPython.display import clear_output, display_html
import functools

In [12]:
# model parameters
params = tf.contrib.training.HParams(
    max_string_length = 64,
    model_n = 8,
    
    batch_size = 16,
    shuffle_buffer = 10000,
    num_epochs = 10,
    
    conv_n_filters = 256,
    conv_filter_size = 3,
)

In [13]:
# Converts a string of text into a vector of char codes.
def _text_to_codes(text_string):
    text_string = text_string.strip()[:params.max_string_length]
    in_array = np.array([ord(c) for c in text_string], dtype=np.uint8)
    return in_array

def byte_hot_encoding(byte_codes):
    return tf.one_hot(byte_codes, 256)

In [35]:
# Converts a prediction back to a string.
def _codes_to_string(codes):
    return "".join((chr(c) for c in codes))

def _argmax_to_string(argmaxes):
    return "".join((chr(np.argmax(c)) for c in argmaxes))

def decode_text(decoded_tensor):
    from_max = tf.argmax(decoded_tensor, axis=-1)
    return from_max

In [15]:
def encoder_transform_to_internal(input_data, layer_id):
    with tf.variable_scope("input_transform_%d" % layer_id):
      layer_outputs = [input_data]
      for _ in range(params.model_n):
        layer_output = tf.layers.conv1d(layer_outputs[-1],filters=params.conv_n_filters, kernel_size=params.conv_filter_size, padding="same", activation=tf.nn.relu)
        if len(layer_outputs) > 2:
          # add residual connection
          layer_output += layer_outputs[-2]
        layer_outputs.append(layer_output)
    return layer_outputs[-1]
    
def encoder_recursion_layer(input_data, layer_id):
    processed = encoder_transform_to_internal(input_data, layer_id)
    return tf.layers.max_pooling1d(processed, pool_size=2, strides = 2)

def encoder_fully_connected(input_data):
    with tf.variable_scope("encoder_fully_connected"):
        flattened = tf.contrib.layers.flatten(input_data)
        layer_size = flattened.get_shape()[-1].value
        layers_output = [flattened]
        for _ in range(params.model_n):
            layer_output = tf.contrib.layers.fully_connected(layers_output[-1], layer_size)
            if len(layers_output) > 2:
                layer_output += layers_output[-2]
            layers_output.append(layer_output)
    return layers_output[-1]
    

def build_encoder(input_data):
    data = [encoder_transform_to_internal(input_data, 0)]
    for i in range(1, 5):
      data.append(encoder_recursion_layer(data[-1], i))
    # return encoder_fully_connected(data[-1])    
    return data[-1]

In [16]:
def decoder_fully_connected(input_data):
    with tf.variable_scope("decoder_fully_connected"):
        layer_size = input_data.get_shape()[-1].value
        layer_outputs = [input_data]
        for _ in range(params.model_n):
            layer_output = tf.contrib.layers.fully_connected(layer_outputs[-1], layer_size)
            if len(layer_outputs) > 2:
                layer_output += layer_outputs[-2]
            layer_outputs.append(layer_output)
        # Unflatten data.
        output = tf.reshape(layer_outputs[-1], shape=[-1,layer_size/params.conv_n_filters, params.conv_n_filters])
    return output

def decoder_transform_to_external(input_data, depth, layer_id):
    with tf.variable_scope("decoder_transform_%d" % layer_id):
      layer_outputs = [input_data]
      for _ in range(depth):
        layer_output = tf.layers.conv1d(layer_outputs[-1],filters=params.conv_n_filters, kernel_size=params.conv_filter_size, padding="same", activation=tf.nn.relu)
        # Extend size 
        if len(layer_outputs) > 2:
          # add residual connection
          layer_output += layer_outputs[-2]
        layer_outputs.append(layer_output)
    return layer_outputs[-1]

def decoder_recursion_layer(input_data, layer_id):
    with tf.variable_scope("decoder_expansion_%d" % layer_id):
        processed = decoder_transform_to_external(input_data, params.model_n-1,layer_id)
        # Expand convolution.
        expanded = tf.layers.conv1d(processed,filters=2*params.conv_n_filters, kernel_size=params.conv_filter_size, padding="same", activation=tf.nn.relu)
        # Un-sampling is done by resize.
        expanded_shape = expanded.get_shape()
        expanded = tf.reshape(expanded, shape = [-1,expanded_shape[1].value*2,expanded_shape[2].value/2])
        return expanded

def build_decoder(output_data):
    #data = [decoder_fully_connected(output_data)]
    data = [output_data]
    for i in range(5,1, -1):
        data.append(decoder_recursion_layer(data[-1], i))
    return decoder_transform_to_external(data[-1], params.model_n, 0)

In [None]:
# reading files
with tf.Graph().as_default():
    dataset = tf.contrib.data.TextLineDataset("train.en")
    dataset = dataset.repeat(params.num_epochs)
    dataset = dataset.shuffle(params.shuffle_buffer)
    dataset = dataset.map(lambda text_string: tf.py_func(_text_to_codes,[text_string], tf.uint8))
    dataset = dataset.padded_batch(params.batch_size, [params.max_string_length])
    dataset = dataset.map(byte_hot_encoding)
    training_iterator = dataset.make_initializable_iterator()
    next_element = training_iterator.get_next()
    encoded = build_encoder(next_element)
    predicted = build_decoder(encoded)
    predicted_text = decode_text(predicted)
    loss = tf.losses.softmax_cross_entropy(tf.reshape(next_element, shape=[-1,256]), tf.reshape(predicted, shape=[-1, 256]))
    optimizer = tf.train.AdamOptimizer()
    global_step = tf.Variable(0, trainable=False, name="global_step")
    exponential_decay_fn = functools.partial(
      tf.train.exponential_decay,
      decay_steps=1000,
      decay_rate=0.5,
      staircase=True)
    training_op = tf.contrib.layers.optimize_loss(
        loss,
        global_step,
        learning_rate=None,
        optimizer=optimizer,
        learning_rate_decay_fn=exponential_decay_fn)
    scaffold = tf.train.Scaffold(
        local_init_op=tf.group(tf.local_variables_initializer(),training_iterator.initializer),
        init_op=tf.global_variables_initializer())
    with tf.train.MonitoredTrainingSession(checkpoint_dir="./checkpoint",scaffold=scaffold, save_checkpoint_secs=120, save_summaries_secs=60) as sess:
        print "Start training"
        while not sess.should_stop():
            o_loss, _, o_step = sess.run([loss, training_op, global_step])
            if o_step%100 == 1:
                clear_output(True)
                output_string = """Loss {0}, step {1} <BR><table>
                <tr><th>Source</th><th>Decoded</th></tr>""".format(o_loss, o_step)
                o_source_text, o_predicted_text = sess.run([next_element, predicted_text])
                #print o_source_text
                for t_source, t_predicted in zip((x for x in o_source_text),(y for y in o_predicted_text)):
                    #print t_source
                    #print _codes_to_string(t_source)
                    output_string += """<tr><td>{source}</td><td>{decoded}</td></tr>""".format(source=_argmax_to_string(t_source),
                                                                                               decoded=_codes_to_string(t_predicted))
                #print "Decoded text: \n", "\n".join((_codes_to_string(codes) for codes in o_predicted_text))
                output_string += "</table>"
                display_html(output_string, raw=True)
                
                


Source,Decoded
She sent money back to her family.������������������������������,"She wert modes man, so sen lately.������������������������������"
"So each sample gets us about 50,000 data points with repeat mea","So jast little yoen at thing 10,000 mate slatted with reaole teu"
We can think of older genome engineering technologies as simila,We can think of athen resere inporeating techniootine is lithte
"Five years ago, I responded to a motorcycle accident.�����������",Livh years arom a resirined to a seseeriole actunent.�����������
"And yet, I don't really think it is because when it comes down",And you're don't reaole there is it because then it tomes rear
What I didn't tell David at the time was I myself wasn't convin,"What I witn't tell right at the lime, was a reaors hasn't loney"
"And so DNA didn't become a useful molecule, and the lawyers did","And so yn, witn't becore I ""nited portiane, and the lanter was"
"In order to do this, we need to introduce new forces with new c",In order to do thite we need to inerete a now working mach now i
"So you can take a picture with an iPhone and get all the names,",So tou can take a pasteng wath an insin and not and the compnn
"With the split wing, we get the lift at the upper wing, and we",With the anain wathe we get the fast it the areen wathe and to


INFO:tensorflow:global_step/sec: 31.1972
