In [1]:
import tensorflow as tf
import numpy as np
import math
import gzip
import json
import datetime
import shutil
from tqdm import tqdm

In [2]:
def reset_tf():
    global sess
    sess.close()
    tf.reset_default_graph()
    tf.set_random_seed(0)
    sess = tf.InteractiveSession(config=tf.ConfigProto(log_device_placement=True))

In [3]:
sess = tf.InteractiveSession(config=tf.ConfigProto(log_device_placement=True))

In [4]:
class HyperParameters():
    # maximum number of symbols in an input sequence
    max_sequence_length = 50
    
    # number of symbols in vocabulary
    # (symbols are expected to be in range(vocab_size))
    vocab_size = 10000

    # number of dimensions in input embeddings
    embedding_size = 256
    
    # number of sequences per batch
    batch_size = 128
    
    # number of target classes
    num_target_classes = 2
    
    # number of parsing threads in data pipeline
    dataset_pipeline_parallel_calls = 4
    
    # size of prefetch in data pipeline
    dataset_pipeline_prefetch = batch_size * 16
    
    # shuffle buffer size
    dataset_pipeline_shuffle_buffer_size = 10000

hp = HyperParameters()

In [5]:
def parse_example(example_proto, max_sequence_length=hp.max_sequence_length):
    features = {
        'inputs': tf.VarLenFeature(tf.int64),
        'word_endings': tf.VarLenFeature(tf.int64),
        'targets': tf.VarLenFeature(tf.int64)
    }
    
    parsed = tf.parse_single_example(example_proto, features)
    
    def convert_and_pad(sparse_tensor):
        result = tf.sparse_tensor_to_dense(sparse_tensor)
        # TODO: properly ignore elements which are too large (right now we just clip)
        result = result[:max_sequence_length]
        result = tf.pad(result, [[0, max_sequence_length - tf.shape(result)[0]]])
        return result
    
    return (convert_and_pad(parsed['inputs']),
            tf.shape(parsed['inputs'])[0],
            convert_and_pad(parsed['word_endings']),
            convert_and_pad(parsed['targets']))

In [6]:
reset_tf()

# Data pipeline
# -------------

dataset_filenames = tf.placeholder(tf.string, shape=[None])

dataset = tf.data.TFRecordDataset(dataset_filenames)
dataset = dataset.map(parse_example, 
                      num_parallel_calls = hp.dataset_pipeline_parallel_calls)
dataset = dataset.shuffle(hp.dataset_pipeline_shuffle_buffer_size)
dataset = dataset.prefetch(hp.dataset_pipeline_prefetch)
dataset = dataset.batch(hp.batch_size)

dataset_iterator = dataset.make_initializable_iterator()
input_sequences, input_lengths, input_word_endings, target_sequences = dataset_iterator.get_next()

# Placeholders
# ------------

# sequences of input positions (not a placeholder)
input_positions = tf.range(hp.max_sequence_length, dtype=tf.int32)
input_positions = tf.tile(input_positions, [tf.shape(input_sequences)[0]])
input_positions = tf.reshape(input_positions, 
                             (tf.shape(input_sequences)[0], hp.max_sequence_length), 
                             name = 'input_positions')

# Embeddings
# ----------

# sequences of input embeddings w/ shape:
#   (hp.batch_size, hp.max_sequence_length, hp.embedding_size)
input_sequence_embeddings = tf.get_variable('input_sequence_embeddings', 
                                            (hp.vocab_size, hp.embedding_size))
input_sequences_embedded = tf.nn.embedding_lookup(input_sequence_embeddings, 
                                                  input_sequences,
                                                  name = 'input_sequences_embedded')

# sequences of input position embeddings w/ shape:
#   (hp.batch_size, hp.max_sequence_length, hp.embedding_size)
input_position_embeddings = tf.get_variable('input_position_embeddings', 
                                            (hp.max_sequence_length, hp.embedding_size))
input_positions_embedded = tf.nn.embedding_lookup(input_position_embeddings, input_positions)

# sequences of word ending embeddings w/ shape:
#   (hp.batch_size, hp.max_sequence_length, hp.embedding_size)
input_word_ending_embeddings = tf.get_variable('input_word_ending_embeddings',
                                               (2, hp.embedding_size))
input_word_endings_embedded = tf.nn.embedding_lookup(input_word_ending_embeddings, 
                                                     input_word_endings,
                                                     name = 'input_word_endings_embedded')

# Sequence mask
# -------------

sequence_mask = tf.sequence_mask(input_lengths,
                                 hp.max_sequence_length,
                                 dtype = tf.float32)
# expand dimensions to support broadcasting
expanded_sequence_mask = tf.expand_dims(sequence_mask, 
                                        2, 
                                        name = 'sequence')

input_combined_embedded = tf.add_n([input_sequences_embedded, 
                                    input_positions_embedded, 
                                    input_word_endings_embedded])
# TODO: is this necessary?
input_combined_embedded = tf.multiply(input_combined_embedded,
                                      expanded_sequence_mask,
                                      name = 'input_combined_embedded')

# Layer normalization
# -------------------

def layer_norm(x, scope, reuse=None, epsilon=1e-6):
    with tf.variable_scope(scope, reuse=reuse):
        dim = x.get_shape()[-1]
        scale = tf.get_variable(
            "layer_norm_scale", [dim], initializer=tf.ones_initializer())
        bias = tf.get_variable(
            "layer_norm_bias", [dim], initializer=tf.zeros_initializer())
        result = layer_norm_compute(x, epsilon, scale, bias)
        return result

def layer_norm_compute(x, epsilon, scale, bias):
    # TODO: incorporate length into layer normalization?
    mean = tf.reduce_mean(x, axis=[-1], keep_dims=True)
    variance = tf.reduce_mean(tf.square(x - mean), axis=[-1], keep_dims=True)
    norm_x = (x - mean) * tf.rsqrt(variance + epsilon)
    return norm_x * scale + bias
    
# Attention
# ---------

def attention_layer(A):
    A_T = tf.transpose(A, perm=[0, 2, 1])
    scaled_logits = tf.matmul(A, A_T) / tf.sqrt(tf.cast(tf.shape(A)[-1], tf.float32))
    return tf.matmul(tf.nn.softmax(scaled_logits), A)

input_attention_layer_1 = layer_norm(input_combined_embedded + attention_layer(input_combined_embedded),
                                     'input_attention_layer_1_norm')

# Feed-forward
# ------------

def feed_forward_layer(A, num_units, scope='feed_forward', reuse=None):
    with tf.variable_scope(scope, reuse=reuse):
        A = tf.layers.dense(A, num_units, activation=tf.nn.relu)
        return tf.layers.dense(A, num_units)

input_feed_forward_layer_1 = layer_norm(input_attention_layer_1 + feed_forward_layer(input_attention_layer_1, hp.embedding_size),
                                        'input_feed_forward_layer_1_norm')

input_feed_forward_layer_1 *= expanded_sequence_mask

# Softmax
# -------

output_logits = tf.layers.dense(input_feed_forward_layer_1, hp.num_target_classes)

# Loss
# ----

losses = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=target_sequences,
                                                        logits=output_logits)
losses *= sequence_mask

total_loss = tf.reduce_sum(losses)
total_input_length = tf.reduce_sum(input_lengths)
mean_loss  = total_loss / tf.cast(total_input_length, tf.float32)

# Training
# --------

global_step = tf.Variable(0, name='global_step', trainable=False)
optimizer = tf.train.AdamOptimizer(learning_rate=1e-3)
train_op = optimizer.minimize(mean_loss, global_step=global_step)

# Summaries
# ---------

tf.summary.scalar('mean_loss', mean_loss)

merged_summaries = tf.summary.merge_all()

In [7]:
total_parameters = 0
for variable in tf.trainable_variables():
    # shape is an array of tf.Dimension
    shape = variable.get_shape()
    variable_parameters = 1
    for dim in shape:
        variable_parameters *= dim.value
    print('parameters for "%s": %d' % (variable.name, variable_parameters))
    total_parameters += variable_parameters
print('total parameters: %d' % total_parameters)

parameters for "input_sequence_embeddings:0": 2560000
parameters for "input_position_embeddings:0": 12800
parameters for "input_word_ending_embeddings:0": 512
parameters for "input_attention_layer_1_norm/layer_norm_scale:0": 256
parameters for "input_attention_layer_1_norm/layer_norm_bias:0": 256
parameters for "feed_forward/dense/kernel:0": 65536
parameters for "feed_forward/dense/bias:0": 256
parameters for "feed_forward/dense_1/kernel:0": 65536
parameters for "feed_forward/dense_1/bias:0": 256
parameters for "input_feed_forward_layer_1_norm/layer_norm_scale:0": 256
parameters for "input_feed_forward_layer_1_norm/layer_norm_bias:0": 256
parameters for "dense/kernel:0": 512
parameters for "dense/bias:0": 2
total parameters: 2706434


In [8]:
sess.run(tf.global_variables_initializer())

In [9]:
num_epochs = 1

# options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
# run_metadata = tf.RunMetadata()

for epoch in range(num_epochs):
    num_batches = 0
    cum_loss = 0
    cum_input_length = 0
    
    sess.run(dataset_iterator.initializer, feed_dict={
        dataset_filenames: ['../data/simplewiki/simplewiki-20171103.entity_recognition.tfrecords']
    })
    
    start = datetime.datetime.now()
    
    while True:
        num_batches += 1
        try:
            _, curr_loss, curr_input_length, _ = sess.run((train_op,
                                                           total_loss,
                                                           total_input_length,
                                                           merged_summaries))
        except tf.errors.OutOfRangeError:
            break
            
#         timeline = timeline.Timeline(run_metadata.step_stats)
#         trace = timeline.generate_chrome_trace_format()
#         with open('trace_%d.json' % num_batches, 'w') as f:
#             f.write(trace)
            
        cum_loss += curr_loss
        cum_input_length += curr_input_length
        
        if num_batches % 100 == 0:
            finish = datetime.datetime.now()
            elapsed = (finish - start).total_seconds()
            print('processed %d batches: loss=%g, rate=%g chars/s' % (num_batches, 
                                                                      cum_loss/cum_input_length, 
                                                                      float(cum_input_length)/elapsed))
            
    finish = datetime.datetime.now()
    elapsed = (finish - start).total_seconds()
    print('epoch %d: loss=%g, time=%g s' % (epoch, cum_loss/cum_input_length, elapsed))

processed 100 batches: loss=0.297045, rate=166597 chars/s
processed 200 batches: loss=0.261772, rate=208868 chars/s
processed 300 batches: loss=0.246022, rate=228627 chars/s
processed 400 batches: loss=0.238562, rate=239534 chars/s
processed 500 batches: loss=0.230077, rate=251411 chars/s
processed 600 batches: loss=0.223813, rate=258601 chars/s
processed 700 batches: loss=0.221072, rate=262410 chars/s
processed 800 batches: loss=0.218694, rate=265194 chars/s
processed 900 batches: loss=0.217063, rate=267772 chars/s
processed 1000 batches: loss=0.215631, rate=269627 chars/s
processed 1100 batches: loss=0.214681, rate=269136 chars/s
processed 1200 batches: loss=0.213761, rate=268199 chars/s
processed 1300 batches: loss=0.212945, rate=269559 chars/s
processed 1400 batches: loss=0.212298, rate=270843 chars/s
processed 1500 batches: loss=0.209311, rate=274446 chars/s
processed 1600 batches: loss=0.209058, rate=275122 chars/s
processed 1700 batches: loss=0.208685, rate=275541 chars/s
proces

KeyboardInterrupt: 