In [1]:
import tensorflow as tf
import json
import binascii
import gzip
from tqdm import tqdm_notebook

In [2]:
sess = None

In [3]:
def reset_tf(sess = None, log_device_placement = False):
    if sess:
        sess.close()
    tf.reset_default_graph()
    tf.set_random_seed(0)
    return tf.InteractiveSession(config = tf.ConfigProto(log_device_placement = log_device_placement))

In [4]:
class HyperParameters():
    # adam learning rate
    learning_rate = 1e-3

    # number of distinct terms (term indices are expected in 0..range(num_terms))
    vocab_size = 30000
    
    # number of dimensions in hidden layer
    hidden_size = 512

    # number of dimensions in document embedding
    embedding_size = 128
    
    # dropout rate
    dropout_rate = 0.1
    
    # number of sequences per batch
    pipeline_batch_size = 32
    
    # number of parsing threads in data pipeline
    pipeline_num_parallel_calls = 4
    
    # size of prefetch in data pipeline
    pipeline_prefetch_size = pipeline_batch_size * 16
    
    # shuffle buffer size
    pipeline_shuffle_size = 256

hp = HyperParameters()

In [5]:
sess = reset_tf(sess)

# Pipeline
# --------

# TODO: use SparseTensor / don't use dataset API for speed?

def parse_example(example_proto):
    features = {
        'page_id': tf.FixedLenFeature([1], dtype=tf.string),
        'para_id': tf.FixedLenFeature([1], dtype=tf.int64),
        'indices': tf.VarLenFeature(tf.int64),
        'freqs': tf.VarLenFeature(tf.int64)
    }
    parsed = tf.parse_single_example(example_proto, features)
    page_id = parsed['page_id']
    para_id = parsed['para_id']
    indices = tf.sparse_tensor_to_dense(parsed['indices'])
    freqs = tf.sparse_tensor_to_dense(parsed['freqs'])
    return page_id, para_id, tf.cast(tf.sparse_to_dense(indices, [hp.vocab_size], freqs), tf.float32)

dataset_filenames = tf.placeholder(tf.string, shape = [None], name = 'dataset_filenames')

dataset = tf.data.TFRecordDataset(dataset_filenames)
dataset = dataset.map(parse_example,
                      num_parallel_calls = hp.pipeline_num_parallel_calls)
dataset = dataset.shuffle(hp.pipeline_shuffle_size)
dataset = dataset.prefetch(hp.pipeline_prefetch_size)
dataset = dataset.batch(hp.pipeline_batch_size)

dataset_iterator = dataset.make_initializable_iterator()

input_page_id_iter, input_para_id_iter, input_tf_vector_iter = dataset_iterator.get_next()

input_page_id = tf.placeholder_with_default(input_page_id_iter, [None, 1], name = 'input_page_id')
input_para_id = tf.placeholder_with_default(input_para_id_iter, [None, 1], name = 'input_para_id')
input_tf_vector = tf.placeholder_with_default(input_tf_vector_iter, 
                                              [None, hp.vocab_size],
                                              name = 'input_tf_vector')
input_tf_vector_count = tf.shape(input_tf_vector)[0]

input_tf_vector_norm = tf.reduce_sum(input_tf_vector, axis = -1,  keep_dims = True)
input_tf_vector_normalized = input_tf_vector / (input_tf_vector_norm + 1e-8)

# Model
# -----

def layer_dense_with_norm(x, num_units, scope, reuse=None, epsilon=1e-6):
    x = tf.layers.dense(x, num_units, activation = tf.nn.relu, name=scope)
    return x

layer = input_tf_vector_normalized

layer = layer_dense_with_norm(layer, hp.hidden_size, 'input_hidden_layer')
layer = layer_dense_with_norm(layer, hp.embedding_size, 'input_embedding_layer')
layer = layer_dense_with_norm(layer, hp.hidden_size, 'output_hidden_layer')

with tf.variable_scope('output_layer'):
    output_tf_vector_normalized = tf.layers.dense(layer,hp.vocab_size)

# Loss
# ----

indiv_loss = tf.losses.mean_squared_error(input_tf_vector_normalized,
                                          output_tf_vector_normalized,
                                          reduction = tf.losses.Reduction.NONE)
total_loss = tf.reduce_sum(indiv_loss, name = 'total_loss')
mean_loss = tf.div(total_loss, 
                   (tf.cast(input_tf_vector_count, tf.float32) * hp.vocab_size),
                   name = 'mean_loss')

# Optimization
# ------------

optimizer = tf.train.AdamOptimizer(learning_rate = hp.learning_rate)
train_op = optimizer.minimize(mean_loss)
# gradients, variables = zip(*optimizer.compute_gradients(mean_loss))
# gradients, _ = tf.clip_by_global_norm(gradients, 5.0)
# train_op = optimizer.apply_gradients(zip(gradients, variables))

# Stats
# -----

total_parameters = 0
for variable in tf.trainable_variables():
    # shape is an array of tf.Dimension
    shape = variable.get_shape()
    variable_parameters = 1
    for dim in shape:
        variable_parameters *= dim.value
    print('parameters for "%s": %d' % (variable.name, variable_parameters))
    total_parameters += variable_parameters
print('total parameters: %d' % total_parameters)

parameters for "input_hidden_layer/kernel:0": 15360000
parameters for "input_hidden_layer/bias:0": 512
parameters for "input_embedding_layer/kernel:0": 131072
parameters for "input_embedding_layer/bias:0": 256
parameters for "output_hidden_layer/kernel:0": 131072
parameters for "output_hidden_layer/bias:0": 512
parameters for "output_layer/dense/kernel:0": 15360000
parameters for "output_layer/dense/bias:0": 30000
total parameters: 31013424


In [6]:
sess.run(tf.global_variables_initializer())

In [7]:
def evaluate_dataset(dataset_filename,
                     header = 'results',
                     train = False,
                     show_progress = True):
    cum_loss = 0
    cum_count = 0

    sess.run(dataset_iterator.initializer, feed_dict={
        dataset_filenames: [dataset_filename]
    })

    if show_progress:
        progress = tqdm_notebook()

    while True:
        try:
            (_,
             curr_loss,
             curr_count) = sess.run((train_op if train else [],
                                     total_loss,
                                     input_tf_vector_count))
        except tf.errors.OutOfRangeError:
            break

        if show_progress:
            progress.update(curr_count)

        cum_loss += curr_loss
        cum_count += curr_count

    if show_progress:
        progress.close()

    print('%s: loss=%g (%g/%d)' % (header, cum_loss/cum_count, cum_loss, cum_count))

In [8]:
for epoch in range(50):
    evaluate_dataset('../data/simplewiki/simplewiki-20171103.topic_model.30k.train.tfrecords',
                     header = 'train %d' % epoch,
                     train = True,
                     show_progress = True)
    evaluate_dataset('../data/simplewiki/simplewiki-20171103.topic_model.30k.dev.tfrecords',
                     header = 'dev   %d' % epoch,
                     train = False,
                     show_progress = False)

738562it [04:27, 2756.68it/s]


train 0: loss=0.0455689 (33655.4/738562)


224it [00:00, 2201.12it/s]

dev   0: loss=0.0369412 (738.825/20000)


738562it [04:27, 2756.66it/s]


train 1: loss=0.0342598 (25303/738562)


256it [00:00, 2345.24it/s]

dev   1: loss=0.0321963 (643.926/20000)


738562it [04:28, 2755.36it/s]


train 2: loss=0.0306681 (22650.3/738562)


256it [00:00, 2354.71it/s]

dev   2: loss=0.0297304 (594.608/20000)


738562it [04:27, 2756.40it/s]


train 3: loss=0.0283451 (20934.6/738562)


256it [00:00, 2356.55it/s]

dev   3: loss=0.027965 (559.299/20000)


738562it [04:26, 2774.44it/s]


train 4: loss=0.026858 (19836.3/738562)


256it [00:00, 2334.07it/s]

dev   4: loss=0.0269292 (538.584/20000)


738562it [04:25, 2781.06it/s]


train 5: loss=0.0259038 (19131.6/738562)


256it [00:00, 2354.82it/s]

dev   5: loss=0.0262503 (525.006/20000)


738562it [04:25, 2777.85it/s]


train 6: loss=0.0253363 (18712.4/738562)


256it [00:00, 2361.46it/s]

dev   6: loss=0.025709 (514.18/20000)


738562it [04:26, 2774.41it/s]


train 7: loss=0.0249534 (18429.6/738562)


256it [00:00, 2332.98it/s]

dev   7: loss=0.025424 (508.48/20000)


738562it [04:25, 2784.53it/s]


train 8: loss=0.0246289 (18189.9/738562)


256it [00:00, 2325.51it/s]

dev   8: loss=0.0251662 (503.324/20000)


738562it [04:25, 2783.91it/s]


train 9: loss=0.0243284 (17968.1/738562)


256it [00:00, 2366.55it/s]

dev   9: loss=0.0248962 (497.924/20000)


738562it [04:25, 2778.71it/s]


train 10: loss=0.0240583 (17768.6/738562)


256it [00:00, 2355.29it/s]

dev   10: loss=0.0246107 (492.214/20000)


738562it [04:26, 2775.70it/s]


train 11: loss=0.0238136 (17587.8/738562)


256it [00:00, 2374.12it/s]

dev   11: loss=0.0243625 (487.25/20000)


738562it [04:26, 2776.47it/s]


train 12: loss=0.0235875 (17420.8/738562)


256it [00:00, 2358.27it/s]

dev   12: loss=0.024112 (482.239/20000)


738562it [04:25, 2777.34it/s]


train 13: loss=0.0233658 (17257.1/738562)


256it [00:00, 2311.24it/s]

dev   13: loss=0.0238682 (477.363/20000)


738562it [04:25, 2781.01it/s]


train 14: loss=0.0231433 (17092.8/738562)


256it [00:00, 2326.54it/s]

dev   14: loss=0.0236582 (473.165/20000)


738562it [04:26, 2776.37it/s]


train 15: loss=0.022931 (16936/738562)


256it [00:00, 2305.99it/s]

dev   15: loss=0.023437 (468.74/20000)


738562it [04:25, 2784.77it/s]


train 16: loss=0.0227327 (16789.5/738562)


256it [00:00, 2392.68it/s]

dev   16: loss=0.0232161 (464.323/20000)


738562it [04:25, 2778.86it/s]


train 17: loss=0.0225506 (16655/738562)


256it [00:00, 2362.65it/s]

dev   17: loss=0.0230452 (460.905/20000)


738562it [04:25, 2780.85it/s]


train 18: loss=0.022395 (16540.1/738562)


256it [00:00, 2406.45it/s]

dev   18: loss=0.0229194 (458.388/20000)


54976it [00:19, 2751.17it/s]

KeyboardInterrupt: 

55072it [00:30, 2751.17it/s]

In [10]:
builder = tf.saved_model.builder.SavedModelBuilder('../models/simplewiki/topic_model_1_128')
builder.add_meta_graph_and_variables(sess, ['training'])
builder.save()

INFO:tensorflow:No assets to save.
INFO:tensorflow:No assets to write.
INFO:tensorflow:SavedModel written to: b'../models/simplewiki/topic_model_1_256/saved_model.pb'


b'../models/simplewiki/topic_model_1_256/saved_model.pb'

In [6]:
sess = reset_tf()
tf.saved_model.loader.load(sess, ['training'], '../models/simplewiki/topic_model_1_256')
print('loaded')

INFO:tensorflow:Restoring parameters from b'../models/simplewiki/topic_model_1_256/variables/variables'
loaded


In [7]:
dataset_filenames = tf.get_default_graph().get_operation_by_name('dataset_filenames').outputs[0]
embedding_layer = tf.get_default_graph().get_operation_by_name('input_embedding_layer/Relu').outputs[0]
input_page_id = tf.get_default_graph().get_operation_by_name('input_page_id').outputs[0]
input_para_id = tf.get_default_graph().get_operation_by_name('input_para_id').outputs[0]
make_iterator = tf.get_default_graph().get_operation_by_name('MakeIterator')

In [9]:
def dump_embeddings(dataset_filename, output_file):
    sess.run(make_iterator, feed_dict={
        dataset_filenames: [dataset_filename]
    })
    
    progress = tqdm_notebook()
    
    while True:
        try:
            (curr_input_page_id, 
             curr_input_para_id, 
             curr_embedding_layer) = sess.run((input_page_id, input_para_id, embedding_layer))
        except tf.errors.OutOfRangeError:
            break
        for i in range(curr_input_page_id.shape[0]):
            json.dump({
                'page_id': curr_input_page_id[i][0].hex(), 
                'para_id': int(curr_input_para_id[i][0]), 
                'embedding': curr_embedding_layer[i].tolist()
            }, output_file)
        progress.update(curr_input_page_id.shape[0])


In [11]:
with gzip.open('../data/simplewiki/simplewiki-20171103.embeddings.30k.json.gz', 'wt', encoding='utf-8') as f:
    dump_embeddings('../data/simplewiki/simplewiki-20171103.topic_model.30k.dev.tfrecords', f)
    dump_embeddings('../data/simplewiki/simplewiki-20171103.topic_model.30k.test.tfrecords', f)
    dump_embeddings('../data/simplewiki/simplewiki-20171103.topic_model.30k.train.tfrecords', f)

Exception in thread Thread-4:
Traceback (most recent call last):
  File "/home/achang/anaconda3/lib/python3.5/threading.py", line 914, in _bootstrap_inner
    self.run()
  File "/home/achang/anaconda3/lib/python3.5/site-packages/tqdm/_tqdm.py", line 144, in run
    for instance in self.tqdm_cls._instances:
  File "/home/achang/anaconda3/lib/python3.5/_weakrefset.py", line 60, in __iter__
    for itemref in self.data:
RuntimeError: Set changed size during iteration

