In [1]:
import tensorflow as tf
import numpy as np
import datetime
import json
import gzip
import matplotlib.pyplot as plt
import re
from tqdm import tqdm_notebook

In [2]:
sess = None

In [3]:
def reset_tf(sess = None, log_device_placement = False):
    if sess:
        sess.close()
    tf.reset_default_graph()
    tf.set_random_seed(0)
    return tf.InteractiveSession(config = tf.ConfigProto(log_device_placement = log_device_placement))

In [4]:
class HyperParameters:
    learning_rate = 1e-1
    
    vocab_size = 30000
    num_targets = 2000
    
    context_size = 81
    context_center_index = context_size // 2
    
    pipeline_batch_size = 64
    pipeline_num_parallel_calls = 4
    pipeline_prefetch_size = pipeline_batch_size * 16
    pipeline_shuffle_size = 5000

In [41]:
class EntityLinkingModel:
    def __init__(self, session, hp):
        self._session = session
        self._hp = hp
        
    def _parse_example(self, example_proto):
        parsed = tf.parse_single_example(example_proto, features = {
            'page_id': tf.FixedLenFeature([1], tf.int64),
            'target_id': tf.FixedLenFeature([1], tf.int64),
            'context_word_ids': tf.FixedLenFeature([self._hp.context_size], tf.int64) })
        
        return (
            parsed['target_id'],
            parsed['context_word_ids'])

    def _build_data_pipeline(self):
        with tf.variable_scope('dataset'):
            # placeholder: examples filenames
            self._dataset_filenames = tf.placeholder(tf.string, shape = [None])
            
            self._pipeline_limit = tf.placeholder_with_default(tf.constant(-1, dtype = tf.int64), [])

            # build examples dataset
            dataset = tf.data.TFRecordDataset(
                self._dataset_filenames,
                compression_type='GZIP')
            dataset = dataset.take(self._pipeline_limit)
            dataset = dataset.map(
                self._parse_example,
                num_parallel_calls = self._hp.pipeline_num_parallel_calls)
            dataset = dataset.shuffle(self._hp.pipeline_shuffle_size)
            dataset = dataset.prefetch(self._hp.pipeline_prefetch_size)
            dataset = dataset.batch(self._hp.pipeline_batch_size)

            # build dataset iterator
            self._dataset_iterator = dataset.make_initializable_iterator()
            (target_labels, context_word_ids) = self._dataset_iterator.get_next()
            target_labels = tf.squeeze(target_labels, axis = -1)

            # placeholders
            self._context_word_ids = tf.placeholder_with_default(
                context_word_ids,
                shape = [None, self._hp.context_size],
                name = 'context_word_ids')
            self._target_labels = tf.placeholder_with_default(
                target_labels,
                shape = [None],
                name = 'target_labels')
            
            self._minibatch_size = tf.shape(self._context_word_ids)[0]
            
    def _bincount_vectorized(self, A, dim):
        offsets = tf.expand_dims(tf.range(tf.shape(A)[0]) * dim, 1)
        A_ravel = tf.reshape(A + offsets, [-1])
        dim_ravel = tf.shape(A)[0] * dim
        bincount_ravel = tf.bincount(
            A_ravel, 
            minlength = dim_ravel,
            maxlength = dim_ravel)
        return tf.reshape(bincount_ravel, [-1, dim])            
            
    def _build_model(self):
        with tf.variable_scope('model'):
            features = self._bincount_vectorized(
                tf.cast(self._context_word_ids, tf.int32),
                self._hp.vocab_size)
            features = tf.cast(features, tf.float32)
            self._output_logits =  tf.layers.dense(features, self._hp.num_targets)
            
            self._features = features

    def _build_training_model(self):
        with tf.variable_scope('train'):
            losses = tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels = self._target_labels,
                logits = self._output_logits)
            
            self._total_loss = tf.reduce_sum(losses)
            self._mean_loss = tf.reduce_mean(losses)
            
            # N.B., tf.nn.softmax here is unnecessary?
            output_labels = tf.argmax(tf.nn.softmax(self._output_logits), axis=-1)
            self._num_correct_labels = tf.reduce_sum(tf.cast(
                tf.equal(output_labels, self._target_labels), 
                tf.int32))
            
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                self._global_step = tf.Variable(0, name='global_step', trainable=False)
                self._optimizer = tf.train.AdamOptimizer(learning_rate=self._hp.learning_rate)
                self._train_op = self._optimizer.minimize(
                    self._mean_loss,
                    global_step=self._global_step)
                
#                 # gradient clipping
#                 gradients, variables = zip(*self._optimizer.compute_gradients(self._mean_loss))
#                 gradients, _ = tf.clip_by_global_norm(
#                     gradients, 
#                     self._hp.gradient_clip_norm)
#                 self._train_op = self._optimizer.apply_gradients(
#                     zip(gradients, variables),
#                     global_step = self._global_step)

    
    def build_model(self):
        self._build_data_pipeline()
        self._build_model()
        self._build_training_model()

    def dump_statistics(self):
        total_parameters = 0
        for variable in tf.trainable_variables():
            # shape is an array of tf.Dimension
            shape = variable.get_shape()
            variable_parameters = 1
            for dim in shape:
                variable_parameters *= dim.value
            print('parameters for "%s": %d' % (variable.name, variable_parameters))
            total_parameters += variable_parameters
        print('total parameters: %d' % total_parameters)

    def process(self,
                dataset_filename,
                header = 'results',
                train = False,
                show_progress = True,
                log_file = None,
                limit = -1):
        cum_loss = 0
        cum_num_examples = 0
        cum_correct_examples = 0
        
        start = datetime.datetime.now()

        self._session.run(self._dataset_iterator.initializer, feed_dict={
            self._dataset_filenames: [dataset_filename],
            self._pipeline_limit: limit
        })

        if show_progress:
            progress = tqdm_notebook(leave = False, desc = header)

        while True:
            try:
                (_,
                 curr_total_loss, 
                 curr_minibatch_size,
                 curr_num_correct_labels) = self._session.run(
                    (self._train_op if train else (),
                     self._total_loss,
                     self._minibatch_size,
                     self._num_correct_labels))
            except tf.errors.OutOfRangeError:
                break

            if show_progress:
                progress.update(curr_minibatch_size)

            cum_loss += curr_total_loss
            cum_num_examples += curr_minibatch_size
            cum_correct_examples += curr_num_correct_labels

        if show_progress:
            progress.close()
            
        finish = datetime.datetime.now()

        message = '%s (%d) (%s): loss=%g, accuracy=%g' % (
            header,
            tf.train.global_step(sess, self._global_step),
            finish - start,
            cum_loss / cum_num_examples,
            cum_correct_examples / cum_num_examples)
        print(message)
        if log_file:
            print(message, file=log_file)
            log_file.flush()

In [42]:
sess = reset_tf(sess)

model = EntityLinkingModel(sess, HyperParameters())
model.build_model()
model.dump_statistics()

parameters for "model/dense/kernel:0": 60000000
parameters for "model/dense/bias:0": 2000
total parameters: 60002000


In [57]:
sess.run(tf.global_variables_initializer())

In [58]:
for i in range(100):
    model.process(
        '../data/simplewiki/simplewiki-20171103.el_softmax_1.tiny.tfrecords.gz', 
        train = True,
        limit = 17)

results (1) (0:00:00.078696): loss=7.65313, accuracy=0


results (2) (0:00:00.067870): loss=1.04351, accuracy=0.411765


results (3) (0:00:00.048509): loss=1.79446, accuracy=0.647059


results (4) (0:00:00.047016): loss=1.75491, accuracy=0.647059


results (5) (0:00:00.055244): loss=1.14088, accuracy=0.764706


results (6) (0:00:00.046765): loss=1.81201, accuracy=0.411765


results (7) (0:00:00.045789): loss=0.630289, accuracy=0.764706


results (8) (0:00:00.046328): loss=0.684491, accuracy=0.705882


results (9) (0:00:00.047580): loss=0.965852, accuracy=0.588235


results (10) (0:00:00.048644): loss=0.799935, accuracy=0.823529


results (11) (0:00:00.047618): loss=0.930652, accuracy=0.705882


results (12) (0:00:00.046404): loss=0.829161, accuracy=0.823529


results (13) (0:00:00.045932): loss=0.636508, accuracy=0.823529


results (14) (0:00:00.046441): loss=0.400164, accuracy=0.882353


results (15) (0:00:00.046446): loss=0.269068, accuracy=0.882353


results (16) (0:00:00.046062): loss=0.346201, accuracy=0.882353


results (17) (0:00:00.046083): loss=0.414476, accuracy=0.882353


results (18) (0:00:00.046311): loss=0.316034, accuracy=0.882353


results (19) (0:00:00.046325): loss=0.292081, accuracy=0.823529


results (20) (0:00:00.046011): loss=0.222101, accuracy=0.941176


results (21) (0:00:00.046088): loss=0.179381, accuracy=0.941176


results (22) (0:00:00.047211): loss=0.195721, accuracy=0.941176


results (23) (0:00:00.047365): loss=0.226959, accuracy=0.941176


results (24) (0:00:00.046473): loss=0.202471, accuracy=0.941176


results (25) (0:00:00.046718): loss=0.151098, accuracy=0.941176


results (26) (0:00:00.059767): loss=0.155721, accuracy=0.941176


results (27) (0:00:00.047106): loss=0.137347, accuracy=0.941176


results (28) (0:00:00.046461): loss=0.128025, accuracy=0.941176


results (29) (0:00:00.046394): loss=0.117692, accuracy=0.941176


results (30) (0:00:00.046033): loss=0.122089, accuracy=0.941176


results (31) (0:00:00.046542): loss=0.118098, accuracy=0.941176


results (32) (0:00:00.046336): loss=0.121599, accuracy=0.941176


results (33) (0:00:00.046835): loss=0.121498, accuracy=0.941176


results (34) (0:00:00.046162): loss=0.121237, accuracy=0.941176


results (35) (0:00:00.059751): loss=0.114308, accuracy=0.941176


results (36) (0:00:00.052989): loss=0.106399, accuracy=0.941176


results (37) (0:00:00.048558): loss=0.0998459, accuracy=0.941176


results (38) (0:00:00.046750): loss=0.097721, accuracy=0.941176


results (39) (0:00:00.058292): loss=0.0975855, accuracy=0.941176


results (40) (0:00:00.051409): loss=0.0983506, accuracy=0.941176


results (41) (0:00:00.050465): loss=0.0977727, accuracy=0.941176


results (42) (0:00:00.045945): loss=0.0978088, accuracy=0.941176


results (43) (0:00:00.046073): loss=0.0974209, accuracy=0.941176


results (44) (0:00:00.046442): loss=0.0985631, accuracy=0.941176


results (45) (0:00:00.046485): loss=0.0974936, accuracy=0.941176


results (46) (0:00:00.045781): loss=0.0972017, accuracy=0.941176


results (47) (0:00:00.048176): loss=0.0947014, accuracy=0.941176


results (48) (0:00:00.045995): loss=0.0947313, accuracy=0.941176


results (49) (0:00:00.046636): loss=0.0932738, accuracy=0.941176


results (50) (0:00:00.046631): loss=0.0941674, accuracy=0.941176


results (51) (0:00:00.046532): loss=0.0930101, accuracy=0.941176


results (52) (0:00:00.047392): loss=0.0931886, accuracy=0.941176


results (53) (0:00:00.046356): loss=0.0919017, accuracy=0.941176


results (54) (0:00:00.046522): loss=0.0915931, accuracy=0.941176


results (55) (0:00:00.047996): loss=0.0910687, accuracy=0.941176


results (56) (0:00:00.047002): loss=0.0906435, accuracy=0.941176


results (57) (0:00:00.046090): loss=0.0906142, accuracy=0.941176


results (58) (0:00:00.046111): loss=0.0898628, accuracy=0.941176


results (59) (0:00:00.047305): loss=0.0899744, accuracy=0.941176


results (60) (0:00:00.047013): loss=0.0894377, accuracy=0.941176


results (61) (0:00:00.057284): loss=0.0896502, accuracy=0.941176


results (62) (0:00:00.047560): loss=0.0895828, accuracy=0.941176


results (63) (0:00:00.046587): loss=0.0894966, accuracy=0.941176


results (64) (0:00:00.046060): loss=0.0895636, accuracy=0.941176


results (65) (0:00:00.046314): loss=0.0891581, accuracy=0.941176


results (66) (0:00:00.047162): loss=0.0891816, accuracy=0.941176


results (67) (0:00:00.047636): loss=0.0889205, accuracy=0.941176


results (68) (0:00:00.046557): loss=0.0888116, accuracy=0.941176


results (69) (0:00:00.046619): loss=0.088795, accuracy=0.941176


results (70) (0:00:00.046669): loss=0.088533, accuracy=0.941176


results (71) (0:00:00.046795): loss=0.0885233, accuracy=0.941176


results (72) (0:00:00.047189): loss=0.0883493, accuracy=0.941176


results (73) (0:00:00.046584): loss=0.0882286, accuracy=0.941176


results (74) (0:00:00.054194): loss=0.0882371, accuracy=0.941176


results (75) (0:00:00.046105): loss=0.0880788, accuracy=0.941176


results (76) (0:00:00.046255): loss=0.0880388, accuracy=0.941176


results (77) (0:00:00.046960): loss=0.0879516, accuracy=0.941176


results (78) (0:00:00.045806): loss=0.0877847, accuracy=0.941176


results (79) (0:00:00.046161): loss=0.0877389, accuracy=0.941176


results (80) (0:00:00.046670): loss=0.0876318, accuracy=0.941176


results (81) (0:00:00.046363): loss=0.0875358, accuracy=0.941176


results (82) (0:00:00.046667): loss=0.0875151, accuracy=0.941176


results (83) (0:00:00.049082): loss=0.0874232, accuracy=0.941176


results (84) (0:00:00.046301): loss=0.0873545, accuracy=0.941176


results (85) (0:00:00.047763): loss=0.0873242, accuracy=0.941176


results (86) (0:00:00.046193): loss=0.0872443, accuracy=0.941176


results (87) (0:00:00.052350): loss=0.0871951, accuracy=0.941176


results (88) (0:00:00.047052): loss=0.0871686, accuracy=0.941176


results (89) (0:00:00.050938): loss=0.0870994, accuracy=0.941176


results (90) (0:00:00.046041): loss=0.0870453, accuracy=0.941176


results (91) (0:00:00.046383): loss=0.0870054, accuracy=0.941176


results (92) (0:00:00.046214): loss=0.0869368, accuracy=0.941176


results (93) (0:00:00.046436): loss=0.0868809, accuracy=0.941176


results (94) (0:00:00.046581): loss=0.0868456, accuracy=0.941176


results (95) (0:00:00.047028): loss=0.0867929, accuracy=0.941176


results (96) (0:00:00.047103): loss=0.0867411, accuracy=0.941176


results (97) (0:00:00.045970): loss=0.0867054, accuracy=0.941176


results (98) (0:00:00.046522): loss=0.0866594, accuracy=0.941176


results (99) (0:00:00.046341): loss=0.0866068, accuracy=0.941176


results (100) (0:00:00.056267): loss=0.0865675, accuracy=0.941176
