In [1]:
import tensorflow as tf
import numpy as np
import datetime
import json
import gzip
import matplotlib.pyplot as plt
import re
from tqdm import tqdm_notebook

In [2]:
sess = None

In [3]:
def reset_tf(sess = None, log_device_placement = False):
    if sess:
        sess.close()
    tf.reset_default_graph()
    tf.set_random_seed(0)
    return tf.InteractiveSession(config = tf.ConfigProto(log_device_placement = log_device_placement))

In [4]:
class HyperParameters:
    learning_rate = 1e-1
    
    vocab_size = 30000
    num_targets = 2000
    
    d_hidden = 128
    
    context_size = 81
    
    dropout_rate = 0.1
    
    pipeline_batch_size = 128
    pipeline_num_parallel_calls = 4
    pipeline_prefetch_size = pipeline_batch_size * 16
    pipeline_shuffle_size = 5000

In [5]:
class EntityLinkingModel:
    def __init__(self, session, hp):
        self._session = session
        self._hp = hp
        
    def _parse_example(self, example_proto):
        parsed = tf.parse_single_example(example_proto, features = {
            'page_id': tf.FixedLenFeature([1], tf.int64),
            'target_id': tf.FixedLenFeature([1], tf.int64),
            'context_word_ids': tf.FixedLenFeature([self._hp.context_size], tf.int64),
            'link_mask': tf.FixedLenFeature([self._hp.context_size], tf.int64) })
        
        return (
            parsed['target_id'],
            parsed['context_word_ids'],
            parsed['link_mask'])

    def _build_data_pipeline(self):
        with tf.variable_scope('dataset'):
            # placeholder: examples filenames
            self._dataset_filenames = tf.placeholder(tf.string, shape = [None])
            
            self._pipeline_limit = tf.placeholder_with_default(tf.constant(-1, dtype = tf.int64), [])

            # build examples dataset
            dataset = tf.data.TFRecordDataset(
                self._dataset_filenames,
                compression_type='GZIP')
            dataset = dataset.take(self._pipeline_limit)
            dataset = dataset.map(
                self._parse_example,
                num_parallel_calls = self._hp.pipeline_num_parallel_calls)
            dataset = dataset.shuffle(self._hp.pipeline_shuffle_size)
            dataset = dataset.prefetch(self._hp.pipeline_prefetch_size)
            dataset = dataset.batch(self._hp.pipeline_batch_size)

            # build dataset iterator
            self._dataset_iterator = dataset.make_initializable_iterator()
            (target_labels, context_word_ids, link_mask) = self._dataset_iterator.get_next()
            target_labels = tf.squeeze(target_labels, axis = -1)

            # placeholders
            self._context_word_ids = tf.placeholder_with_default(
                context_word_ids,
                shape = [None, self._hp.context_size],
                name = 'context_word_ids')
            self._link_mask = tf.placeholder_with_default(
                link_mask,
                shape = [None, self._hp.context_size],
                name = 'link_mask')
            self._target_labels = tf.placeholder_with_default(
                target_labels,
                shape = [None],
                name = 'target_labels')
            
            self._minibatch_size = tf.shape(self._context_word_ids)[0]
            
#     def _bincount_vectorized(self, A, dim):
#         offsets = tf.expand_dims(tf.range(tf.shape(A)[0]) * dim, 1)
#         A_ravel = tf.reshape(A + offsets, [-1])
#         dim_ravel = tf.shape(A)[0] * dim
#         bincount_ravel = tf.bincount(
#             A_ravel, 
#             minlength = dim_ravel,
#             maxlength = dim_ravel)
#         return tf.reshape(bincount_ravel, [-1, dim])            
            
    def _build_model(self):
        with tf.variable_scope('model'):
            # placeholders
            self._training = tf.placeholder(tf.bool, name = 'training')
            
            # variables
            word_embedding = tf.get_variable(
                'word_embedding',
                [self._hp.vocab_size, self._hp.d_hidden])
            
            # embedding
            with tf.variable_scope('embedding'):
                layer = tf.nn.embedding_lookup(
                    word_embedding,
                    self._context_word_ids)
                layer = tf.layers.batch_normalization(
                    layer,
                    training = self._training)
                layer = tf.layers.dropout(
                    layer,
                    rate = self._hp.dropout_rate,
                    training = self._training)
            
            # hidden layer
            with tf.variable_scope('hidden'):
                layer = tf.layers.dense(
                    layer,
                    self._hp.d_hidden,
                    activation = tf.tanh)
                layer = tf.layers.batch_normalization(
                    layer,
                    training = self._training)
                layer = tf.layers.dropout(
                    layer,
                    rate = self._hp.dropout_rate,
                    training = self._training)

            # output
            mask = tf.cast(self._link_mask, tf.float32) # [batch_size, context_size]
            mask = tf.expand_dims(mask, axis = -1)      # [batch_size, context_size, 1]
            layer *= mask                               # [batch_size, context_size, d_hidden]
            layer = tf.reduce_sum(layer, axis = -2) / tf.reduce_sum(mask, axis = -2) # [batch_size, d_hidden]
            layer = tf.layers.batch_normalization(
                layer,
                training = self._training)
            layer = tf.layers.dropout(
                layer,
                rate = self._hp.dropout_rate,
                training = self._training)
            
            # softmax
            self._output_logits = tf.layers.dense(
                layer,
                self._hp.num_targets)

    def _build_training_model(self):
        with tf.variable_scope('train'):
            losses = tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels = self._target_labels,
                logits = self._output_logits)
            
            self._total_loss = tf.reduce_sum(losses)
            self._mean_loss = tf.reduce_mean(losses)
            
            # N.B., tf.nn.softmax here is unnecessary?
            output_labels = tf.argmax(tf.nn.softmax(self._output_logits), axis=-1)
            self._num_correct_labels = tf.reduce_sum(tf.cast(
                tf.equal(output_labels, self._target_labels), 
                tf.int32))
            
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                self._global_step = tf.Variable(0, name='global_step', trainable=False)
                self._optimizer = tf.train.AdamOptimizer(learning_rate=self._hp.learning_rate)
                self._train_op = self._optimizer.minimize(
                    self._mean_loss,
                    global_step=self._global_step)
                
#                 # gradient clipping
#                 gradients, variables = zip(*self._optimizer.compute_gradients(self._mean_loss))
#                 gradients, _ = tf.clip_by_global_norm(
#                     gradients, 
#                     self._hp.gradient_clip_norm)
#                 self._train_op = self._optimizer.apply_gradients(
#                     zip(gradients, variables),
#                     global_step = self._global_step)

    
    def build_model(self):
        self._build_data_pipeline()
        self._build_model()
        self._build_training_model()

    def dump_statistics(self):
        total_parameters = 0
        for variable in tf.trainable_variables():
            # shape is an array of tf.Dimension
            shape = variable.get_shape()
            variable_parameters = 1
            for dim in shape:
                variable_parameters *= dim.value
            print('parameters for "%s": %d' % (variable.name, variable_parameters))
            total_parameters += variable_parameters
        print('total parameters: %d' % total_parameters)

    def process(self,
                dataset_filenames,
                dataset_limit = -1,
                header = 'results',
                train = False,
                show_progress = True,
                log_file = None):
        cum_loss = 0
        cum_num_examples = 0
        cum_correct_examples = 0
        
        start = datetime.datetime.now()

        self._session.run(self._dataset_iterator.initializer, feed_dict={
            self._dataset_filenames: dataset_filenames,
            self._pipeline_limit: dataset_limit
        })

        if show_progress:
            progress = tqdm_notebook(leave = False, desc = header)

        while True:
            try:
                (_,
                 curr_total_loss, 
                 curr_minibatch_size,
                 curr_num_correct_labels) = self._session.run(
                    (self._train_op if train else (),
                     self._total_loss,
                     self._minibatch_size,
                     self._num_correct_labels),
                    feed_dict = { self._training: train })
            except tf.errors.OutOfRangeError:
                break

            if show_progress:
                progress.update(curr_minibatch_size)

            cum_loss += curr_total_loss
            cum_num_examples += curr_minibatch_size
            cum_correct_examples += curr_num_correct_labels

        if show_progress:
            progress.close()
            
        finish = datetime.datetime.now()

        message = '%s (%d) (%s): loss=%g, accuracy=%g' % (
            header,
            tf.train.global_step(sess, self._global_step),
            finish - start,
            cum_loss / cum_num_examples,
            cum_correct_examples / cum_num_examples)
        print(message)
        if log_file:
            print(message, file=log_file)
            log_file.flush()

In [6]:
sess = reset_tf(sess)

model = EntityLinkingModel(sess, HyperParameters())
model.build_model()
model.dump_statistics()

parameters for "model/word_embedding:0": 3840000
parameters for "model/embedding/batch_normalization/gamma:0": 128
parameters for "model/embedding/batch_normalization/beta:0": 128
parameters for "model/hidden/dense/kernel:0": 16384
parameters for "model/hidden/dense/bias:0": 128
parameters for "model/hidden/batch_normalization/gamma:0": 128
parameters for "model/hidden/batch_normalization/beta:0": 128
parameters for "model/batch_normalization/gamma:0": 128
parameters for "model/batch_normalization/beta:0": 128
parameters for "model/dense/kernel:0": 256000
parameters for "model/dense/bias:0": 2000
total parameters: 4115280


In [7]:
sess.run(tf.global_variables_initializer())

In [8]:
import os

def list_files(path):
    return sorted([os.path.join(path, file) for file in os.listdir(path)])

train_set = list_files('../data/simplewiki/simplewiki-20171103.el_softmax_2.train')
dev_set = list_files('../data/simplewiki/simplewiki-20171103.el_softmax_2.dev')
test_set = list_files('../data/simplewiki/simplewiki-20171103.el_softmax_2.test')

In [10]:
for i in range(100):
    model.process(train_set[:1], dataset_limit = 100, train = True)

results (101) (0:00:00.065574): loss=0.573668, accuracy=0.92


results (102) (0:00:00.040936): loss=0.558188, accuracy=0.92


results (103) (0:00:00.035905): loss=1.05368, accuracy=0.91


results (104) (0:00:00.035957): loss=1.25108, accuracy=0.86


results (105) (0:00:00.037058): loss=0.728, accuracy=0.88


results (106) (0:00:00.036059): loss=1.18703, accuracy=0.89


results (107) (0:00:00.036105): loss=1.05227, accuracy=0.9


results (108) (0:00:00.036452): loss=0.541395, accuracy=0.92


results (109) (0:00:00.036159): loss=0.236716, accuracy=0.95


results (110) (0:00:00.036132): loss=0.256866, accuracy=0.95


results (111) (0:00:00.037281): loss=0.60908, accuracy=0.89


results (112) (0:00:00.036241): loss=0.408293, accuracy=0.9


results (113) (0:00:00.035988): loss=0.433237, accuracy=0.93


results (114) (0:00:00.035821): loss=0.406653, accuracy=0.91


results (115) (0:00:00.036255): loss=0.838363, accuracy=0.87


results (116) (0:00:00.036150): loss=0.402189, accuracy=0.94


results (117) (0:00:00.035734): loss=0.575816, accuracy=0.92


results (118) (0:00:00.035631): loss=0.829553, accuracy=0.94


results (119) (0:00:00.035964): loss=0.153534, accuracy=0.97


results (120) (0:00:00.036118): loss=0.286455, accuracy=0.94


results (121) (0:00:00.036551): loss=0.703162, accuracy=0.9


results (122) (0:00:00.037251): loss=0.992441, accuracy=0.94


results (123) (0:00:00.036560): loss=0.523451, accuracy=0.94


results (124) (0:00:00.036014): loss=0.474538, accuracy=0.95


results (125) (0:00:00.036454): loss=0.468183, accuracy=0.91


results (126) (0:00:00.036666): loss=0.393479, accuracy=0.92


results (127) (0:00:00.036366): loss=0.576155, accuracy=0.92


results (128) (0:00:00.042290): loss=0.590484, accuracy=0.92


results (129) (0:00:00.036628): loss=0.9919, accuracy=0.92


results (130) (0:00:00.036296): loss=0.64487, accuracy=0.93


results (131) (0:00:00.036577): loss=0.149567, accuracy=0.96


results (132) (0:00:00.036041): loss=0.919684, accuracy=0.93


results (133) (0:00:00.036778): loss=0.958503, accuracy=0.88


results (134) (0:00:00.036377): loss=0.243872, accuracy=0.95


results (135) (0:00:00.036795): loss=0.74308, accuracy=0.92


results (136) (0:00:00.036412): loss=0.627044, accuracy=0.91


results (137) (0:00:00.036560): loss=0.694883, accuracy=0.89


results (138) (0:00:00.035813): loss=0.487066, accuracy=0.93


results (139) (0:00:00.035602): loss=0.492183, accuracy=0.94


results (140) (0:00:00.036670): loss=0.442377, accuracy=0.92


results (141) (0:00:00.035583): loss=0.485699, accuracy=0.94


results (142) (0:00:00.035565): loss=0.863989, accuracy=0.9


results (143) (0:00:00.035969): loss=1.48486, accuracy=0.89


results (144) (0:00:00.035978): loss=0.459537, accuracy=0.95


results (145) (0:00:00.035585): loss=0.876702, accuracy=0.9


results (146) (0:00:00.035834): loss=0.828007, accuracy=0.91


results (147) (0:00:00.036503): loss=0.901952, accuracy=0.86


results (148) (0:00:00.036503): loss=0.290543, accuracy=0.94


results (149) (0:00:00.037189): loss=1.32073, accuracy=0.9


results (150) (0:00:00.035945): loss=1.38755, accuracy=0.88


results (151) (0:00:00.036205): loss=0.903946, accuracy=0.9


results (152) (0:00:00.036366): loss=0.750563, accuracy=0.91


results (153) (0:00:00.036231): loss=0.864309, accuracy=0.92


results (154) (0:00:00.036300): loss=1.08787, accuracy=0.92


results (155) (0:00:00.036395): loss=1.47999, accuracy=0.92


results (156) (0:00:00.035720): loss=1.17464, accuracy=0.89


results (157) (0:00:00.036215): loss=0.641074, accuracy=0.92


results (158) (0:00:00.035811): loss=1.47531, accuracy=0.87


results (159) (0:00:00.035995): loss=0.820464, accuracy=0.94


results (160) (0:00:00.035631): loss=0.542429, accuracy=0.93


results (161) (0:00:00.035999): loss=0.853743, accuracy=0.88


results (162) (0:00:00.037253): loss=0.778993, accuracy=0.91


results (163) (0:00:00.035780): loss=0.358989, accuracy=0.95


results (164) (0:00:00.036486): loss=1.16238, accuracy=0.89


results (165) (0:00:00.036120): loss=0.467158, accuracy=0.92


results (166) (0:00:00.035623): loss=0.188137, accuracy=0.96


results (167) (0:00:00.035920): loss=0.826618, accuracy=0.93


results (168) (0:00:00.035818): loss=0.360353, accuracy=0.94


results (169) (0:00:00.036221): loss=0.197846, accuracy=0.95


results (170) (0:00:00.036146): loss=0.252852, accuracy=0.97


results (171) (0:00:00.036360): loss=0.135597, accuracy=0.99


results (172) (0:00:00.037630): loss=0.829912, accuracy=0.9


results (173) (0:00:00.036780): loss=1.04486, accuracy=0.86


results (174) (0:00:00.035568): loss=0.895027, accuracy=0.95


results (175) (0:00:00.035772): loss=0.440816, accuracy=0.96


results (176) (0:00:00.036717): loss=0.489726, accuracy=0.93


results (177) (0:00:00.035771): loss=0.187302, accuracy=0.96


results (178) (0:00:00.036564): loss=0.47487, accuracy=0.95


results (179) (0:00:00.036140): loss=0.0068895, accuracy=1


results (180) (0:00:00.037965): loss=0.293098, accuracy=0.97


results (181) (0:00:00.036202): loss=0.940953, accuracy=0.9


results (182) (0:00:00.036097): loss=1.05536, accuracy=0.94


results (183) (0:00:00.036915): loss=0.49763, accuracy=0.96


results (184) (0:00:00.035942): loss=0.950988, accuracy=0.94


results (185) (0:00:00.036459): loss=0.726267, accuracy=0.94


results (186) (0:00:00.036353): loss=0.370367, accuracy=0.96


results (187) (0:00:00.036218): loss=0.606954, accuracy=0.93


results (188) (0:00:00.036615): loss=0.280188, accuracy=0.95


results (189) (0:00:00.037155): loss=0.289456, accuracy=0.97


results (190) (0:00:00.036337): loss=0.794593, accuracy=0.91


results (191) (0:00:00.035787): loss=0.105648, accuracy=0.97


results (192) (0:00:00.035807): loss=0.416026, accuracy=0.95


results (193) (0:00:00.035892): loss=0.200904, accuracy=0.97


results (194) (0:00:00.036424): loss=0.0881328, accuracy=0.98


results (195) (0:00:00.038934): loss=0.681123, accuracy=0.95


results (196) (0:00:00.038247): loss=0.370826, accuracy=0.93


results (197) (0:00:00.036992): loss=0.407706, accuracy=0.96


results (198) (0:00:00.036143): loss=0.788597, accuracy=0.91


results (199) (0:00:00.036209): loss=0.256723, accuracy=0.94


results (200) (0:00:00.036565): loss=0.368046, accuracy=0.96


In [11]:
for i in range(5):
    model.process(
        train_set, 
        header = 'train %d' % i,
        train = True)
    model.process(
        dev_set,
        header = 'dev %d' % i,
        train = False)

train 0 (5350) (0:00:49.169281): loss=3.16776, accuracy=0.608136


dev 0 (5350) (0:00:00.546094): loss=1.9619, accuracy=0.82255


train 1 (10500) (0:00:49.093360): loss=3.03075, accuracy=0.695238


dev 1 (10500) (0:00:00.519986): loss=2.04405, accuracy=0.83355


KeyboardInterrupt: 

In [13]:
model.process(
    '../data/simplewiki/simplewiki-20171103.el_softmax_2.dev/examples.0000000000.tfrecords.gz', 
    train = False)

results (1465) (0:00:03.792288): loss=73.9995, accuracy=0.4729
