In [1]:
import tensorflow as tf
import numpy as np
import datetime
import json
import gzip
from tqdm import tqdm_notebook

In [2]:
sess = None

In [3]:
def reset_tf(sess = None, log_device_placement = False):
    if sess:
        sess.close()
    tf.reset_default_graph()
    tf.set_random_seed(0)
    return tf.InteractiveSession(config = tf.ConfigProto(log_device_placement = log_device_placement))

In [23]:
class HyperParameters:
    learning_rate = 1e-3
    
    vocab_size = 30000
    
    dropout_rate = 0.1
    
    context_size = 81
    context_center_index = context_size // 2
    
    d_embedding_position = 16
    d_embedding_word = 128
#     d_embedding_word = 64
    
    d_attention = 128
    d_attention_ff = 256
#     d_attention = 64
#     d_attention_ff = 128
    
#     attention_num_layers = 3
    attention_num_layers = 1

    d_candidate_layers = [256, 128]
    d_logistic_layers = [64, 16]
#     d_candidate_layers = [64]
#     d_logistic_layers = [64, 16]
    
    pipeline_batch_size = 128
    # number of negative samples per positive example
    pipeline_num_negative_samples = 5
    pipeline_num_parallel_calls = 4
    pipeline_prefetch_size = pipeline_batch_size * 16
    pipeline_shuffle_size = 5000
    
    embed_page_tfs_batch_size = 1024

In [24]:
class EntityLinkingModel:
    def __init__(self, hp):
        self._hp = hp

    def _parse_example(self, example_proto):
        features = {
            'page_id': tf.FixedLenFeature([1], tf.int64),
            'target_id': tf.FixedLenFeature([1], tf.int64),
            'target_word_ids': tf.VarLenFeature(tf.int64),
            'target_word_freqs': tf.VarLenFeature(tf.int64),
            'word_ids': tf.FixedLenFeature([self._hp.context_size], tf.int64)
        }

        parsed = tf.parse_single_example(example_proto, features)
        
        target_word_ids = tf.sparse_tensor_to_dense(parsed['target_word_ids'])
        target_word_freqs = tf.sparse_tensor_to_dense(parsed['target_word_freqs'])
        target_tf = tf.sparse_to_dense(
            target_word_ids,
            [self._hp.vocab_size],
            target_word_freqs)
        
        return (
            parsed['page_id'],
            parsed['word_ids'],
            parsed['target_id'],
            target_tf)

    def _build_data_pipeline(self):
        with tf.variable_scope('dataset'):
            # placeholder: examples filenames
            self._dataset_filenames = tf.placeholder(tf.string, shape = [None])

            # build examples dataset
            dataset = tf.data.TFRecordDataset(
                self._dataset_filenames,
                compression_type='GZIP')
            dataset = dataset.map(
                self._parse_example,
                num_parallel_calls = self._hp.pipeline_num_parallel_calls)
            dataset = dataset.shuffle(self._hp.pipeline_shuffle_size)
            dataset = dataset.prefetch(self._hp.pipeline_prefetch_size)
            dataset = dataset.batch(self._hp.pipeline_batch_size)

            # build dataset iterator
            self._dataset_iterator = dataset.make_initializable_iterator()
            (input_page_ids,
             input_context_word_ids,
             target_page_ids,
             target_page_tfs) = self._dataset_iterator.get_next()
            
            input_page_ids = tf.placeholder_with_default(
                input_page_ids,
                [None, 1],
                name='input_page_ids')
            input_context_word_ids = tf.placeholder_with_default(
                input_context_word_ids,
                [None, self._hp.context_size],
                name='input_context_word_ids')
            target_page_ids = tf.placeholder_with_default(
                target_page_ids,
                [None, 1],
                name='target_page_ids')
            target_page_tfs = tf.placeholder_with_default(
                target_page_tfs,
                [None, self._hp.vocab_size],
                name='target_page_tfs')
            
            minibatch_size = tf.shape(input_page_ids)[0]
            sample_size = self._hp.pipeline_num_negative_samples + 1
            expanded_minibatch_size = minibatch_size * sample_size
            
            x = tf.reshape(tf.range(minibatch_size), [minibatch_size, 1])
            y = tf.reshape(tf.range(sample_size), [1, sample_size])
            z = tf.ones([1, sample_size], dtype=tf.int32)
            anchor_indexes = tf.reshape(
                x * z,
                [expanded_minibatch_size])
            cand_indexes = tf.reshape(
                tf.mod(x + y, minibatch_size),
                [expanded_minibatch_size])
            
            self._input_context_word_ids = tf.nn.embedding_lookup(
                input_context_word_ids,
                anchor_indexes)
            
            candidate_page_tfs = tf.nn.embedding_lookup(
                target_page_tfs,
                cand_indexes)
            page_tf_norms = tf.reduce_sum(candidate_page_tfs, axis=-1, keep_dims=True)
            self._candidate_page_dists = candidate_page_tfs / page_tf_norms
            
            anchor_page_ids = tf.nn.embedding_lookup(
                target_page_ids,
                anchor_indexes)
            cand_page_ids = tf.nn.embedding_lookup(
                target_page_ids,
                cand_indexes)
            self._target_labels = tf.cast(
                tf.equal(anchor_page_ids, cand_page_ids),
                tf.int64)
            self._target_labels = tf.squeeze(
                self._target_labels,
                axis=-1)
            
            # positions
            p = tf.range(self._hp.context_size, dtype = tf.int64)
            p = tf.tile(p, [expanded_minibatch_size])
            p = tf.reshape(p, [expanded_minibatch_size, self._hp.context_size])
            self._input_context_positions = p
            
            # placeholder: training flag
            self._is_training = tf.placeholder(tf.bool)
            
            self._minibatch_size = expanded_minibatch_size

    def _layer_norm(self, x, scope, reuse=None, epsilon=1e-6):
        with tf.variable_scope(scope, reuse=reuse):
            num_units = self._hp.d_attention
            scale = tf.get_variable(
                "layer_norm_scale", [num_units], initializer=tf.ones_initializer())
            bias = tf.get_variable(
                "layer_norm_bias", [num_units], initializer=tf.zeros_initializer())
            result = self._layer_norm_compute(x, epsilon, scale, bias)
            return result

    def _layer_norm_compute(self, x, epsilon, scale, bias):
        mean = tf.reduce_mean(x, axis=[-1], keep_dims=True)
        variance = tf.reduce_mean(tf.square(x - mean), axis=[-1], keep_dims=True)
        norm_x = (x - mean) * tf.rsqrt(variance + epsilon)
        return norm_x * scale + bias
    
    def _attention_layer(self, A):
        A_T = tf.transpose(A, perm=[0, 2, 1])
        scaled_logits = tf.matmul(A, A_T) / tf.sqrt(tf.cast(tf.shape(A)[-1], tf.float32))
        result = tf.matmul(tf.nn.softmax(scaled_logits), A)
        result = tf.layers.dropout(
            result, 
            rate=self._hp.dropout_rate,
            training=self._is_training)
        return result

    def _attention_ff_layer(self, A, scope, reuse=None):
        with tf.variable_scope(scope, reuse=reuse):
            A = tf.layers.dense(A, self._hp.d_attention_ff, activation=tf.nn.relu, name='fc1')
            A = tf.layers.dense(A, self._hp.d_attention, name='fc2')
            A = tf.layers.dropout(
                A, 
                rate=self._hp.dropout_rate, 
                training=self._is_training)
            return A
    
    def _attention_full_layer(self, A, scope, reuse=None):
        with tf.variable_scope(scope, reuse=reuse):
            A = self._layer_norm(
                A + self._attention_layer(A), 
                scope='attention_norm')
            A = self._layer_norm(
                A + self._attention_ff_layer(A, 'ff', reuse),
                scope='attention_ff_norm')
            return A
            
    def _build_model(self):
        with tf.variable_scope('model'):
            # embed context words
            word_embeddings = tf.get_variable(
                'word_embeddings', 
                [self._hp.vocab_size, self._hp.d_embedding_word])
            input_context_words_embedded = tf.nn.embedding_lookup(
                word_embeddings,
                self._input_context_word_ids)

            # embed context positions
            position_embeddings = tf.get_variable(
                'position_embeddings',
                [self._hp.context_size, self._hp.d_embedding_position],
                dtype=tf.float32)
            input_context_positions_embedded = tf.nn.embedding_lookup(
                position_embeddings,
                self._input_context_positions)

            # build full context vector (concat embeddings)
            input_context_full = tf.concat(
                [input_context_words_embedded, input_context_positions_embedded], 
                axis=-1)
            
            # build attention input vector
            input_attention = tf.layers.dense(
                input_context_full,
                self._hp.d_attention,
                activation=tf.nn.relu,
                name='input_attention')
            input_attention = tf.layers.dropout(
                input_attention,
                rate=self._hp.dropout_rate,
                training=self._is_training)
            input_attention = self._layer_norm(
                input_attention,
                scope='input_attention')
            
            # build attention layers
            with tf.variable_scope('attention'):
                layer = input_attention
                for i in range(self._hp.attention_num_layers):
                    layer = self._attention_full_layer(layer, 'layer_%d' % i)
            output_attention = layer[:, self._hp.context_center_index, :]

            # build candidate layers
            with tf.variable_scope('candidate'):
                layer = tf.cast(self._candidate_page_dists, tf.float32)
                for i, d_layer in enumerate(self._hp.d_candidate_layers):
                    layer = tf.layers.dense(
                        layer, 
                        d_layer,
                        activation=tf.nn.relu,
                        name=('layer_%d' % i))
                    layer = tf.layers.dropout(
                        layer,
                        rate=self._hp.dropout_rate,
                        training=self._is_training)
            self._candidate_page_dists_embedded = layer
            
            # build final logistic neuron
            with tf.variable_scope('logistic'):
                layer = tf.concat(
                    [output_attention, self._candidate_page_dists_embedded],
                    axis=-1)
                for i, d_layer in enumerate(self._hp.d_logistic_layers):
                    layer = tf.layers.dense(
                        layer, 
                        d_layer,
                        activation=tf.nn.relu,
                        name=('layer_%d' % i))
                    layer = tf.layers.dropout(
                        layer,
                        rate=self._hp.dropout_rate,
                        training=self._is_training)
                self._output_logits = tf.squeeze(
                    tf.layers.dense(layer, 1, name='logistic'),
                    axis=-1)
#                 self._output_logits = tf.layers.dense(layer, 2, name='logistic')

    def _build_training_model(self):
        with tf.variable_scope('train'):
            losses = tf.nn.sigmoid_cross_entropy_with_logits(
                labels = tf.cast(self._target_labels, tf.float32),
                logits = self._output_logits)
#             losses = tf.nn.sparse_softmax_cross_entropy_with_logits(
#                 labels = self._target_labels,
#                 logits = self._output_logits)
            
            self._total_loss = tf.reduce_sum(losses)
            self._mean_loss = tf.reduce_mean(losses)
            
            output_labels = tf.greater(tf.sigmoid(self._output_logits), 0.5)
            output_labels = tf.cast(output_labels, tf.int64)
#             output_labels = tf.argmax(tf.nn.softmax(self._output_logits), axis=-1)
            self._num_correct_labels = tf.reduce_sum(tf.cast(
                tf.equal(output_labels, self._target_labels), 
                tf.int32))
            
            self._global_step = tf.Variable(0, name='global_step', trainable=False)
            self._optimizer = tf.train.AdamOptimizer(learning_rate=self._hp.learning_rate)
            self._train_op = self._optimizer.minimize(
                self._mean_loss,
                global_step=self._global_step)
                
    def dump_statistics(self):
        total_parameters = 0
        for variable in tf.trainable_variables():
            # shape is an array of tf.Dimension
            shape = variable.get_shape()
            variable_parameters = 1
            for dim in shape:
                variable_parameters *= dim.value
            print('parameters for "%s": %d' % (variable.name, variable_parameters))
            total_parameters += variable_parameters
        print('total parameters: %d' % total_parameters)
        
    def _embed_page_tfs(self, sess, page_tfs):
        result = []
        batch = np.zeros([self._hp.embed_page_tfs_batch_size, self._hp.vocab_size])

        for i in tqdm_notebook(range(0, len(page_tfs), self._hp.embed_page_tfs_batch_size)):
            batch_size = min(len(page_tfs) - i, self._hp.embed_page_tfs_batch_size)
            
            batch.fill(0)
            for j in range(batch_size):
                for word_id, word_freq in zip(page_tfs[j][0], page_tfs[j][1]):
                    batch[j, word_id] = word_freq

            # turn term freqs into term distribution
            batch /= (np.sum(batch, axis=-1, keepdims=True) + 1e-8)

            result.append(sess.run(model._candidate_page_dists_embedded, feed_dict={
                model._is_training: True,
                model._candidate_page_dists: batch[:batch_size] }))

        return np.concatenate(result, axis=0)

    def evaluate_dataset(self,
                         sess,
                         dataset_filename,
                         options = None,
                         run_metadata = None,
                         header = 'results',
                         train = False,
                         show_progress = True,
                         log_file = None):
        cum_loss = 0
        cum_num_examples = 0
        cum_correct_examples = 0
        
        start = datetime.datetime.now()

        sess.run(self._dataset_iterator.initializer, feed_dict={
            self._dataset_filenames: [dataset_filename]
        })

        if show_progress:
            progress = tqdm_notebook()

        while True:
            try:
                (_,
                 curr_total_loss, 
                 curr_minibatch_size,
                 curr_num_correct_labels) = sess.run(
                    (self._train_op if train else [],
                     self._total_loss,
                     self._minibatch_size,
                     self._num_correct_labels),
                    feed_dict = { self._is_training: train },
                    options = options,
                    run_metadata = run_metadata)
            except tf.errors.OutOfRangeError:
                break

            if show_progress:
                progress.update(curr_minibatch_size)

            cum_loss += curr_total_loss
            cum_num_examples += curr_minibatch_size
            cum_correct_examples += curr_num_correct_labels

        if show_progress:
            progress.close()
            
        finish = datetime.datetime.now()
        elapsed = (finish - start).total_seconds() * 1000.0

        message = '%s (%d) (%g ms): loss=%g, accuracy=%g' % (
            header,
            tf.train.global_step(sess, self._global_step),
            elapsed,
            cum_loss/cum_num_examples,
            cum_correct_examples/cum_num_examples)
        print(message)
        if log_file:
            print(message, file=log_file)
            log_file.flush()

In [25]:
sess = reset_tf(sess)

model = EntityLinkingModel(HyperParameters())
model._build_data_pipeline()
model._build_model()
model._build_training_model()
model.dump_statistics()

parameters for "model/word_embeddings:0": 3840000
parameters for "model/position_embeddings:0": 1296
parameters for "model/input_attention/kernel:0": 18432
parameters for "model/input_attention/bias:0": 128
parameters for "model/input_attention/layer_norm_scale:0": 128
parameters for "model/input_attention/layer_norm_bias:0": 128
parameters for "model/attention/layer_0/attention_norm/layer_norm_scale:0": 128
parameters for "model/attention/layer_0/attention_norm/layer_norm_bias:0": 128
parameters for "model/attention/layer_0/ff/fc1/kernel:0": 32768
parameters for "model/attention/layer_0/ff/fc1/bias:0": 256
parameters for "model/attention/layer_0/ff/fc2/kernel:0": 32768
parameters for "model/attention/layer_0/ff/fc2/bias:0": 128
parameters for "model/attention/layer_0/attention_ff_norm/layer_norm_scale:0": 128
parameters for "model/attention/layer_0/attention_ff_norm/layer_norm_bias:0": 128
parameters for "model/candidate/layer_0/kernel:0": 7680000
parameters for "model/candidate/layer

In [26]:
sess.run(tf.global_variables_initializer())

In [42]:
with gzip.open('../data/simplewiki/simplewiki-20171103.entity_linking.page_tf.json.gz', 'rt', encoding='utf-8') as f:
    page_tfs = json.load(f)

In [27]:
for i in range(20):
    model.evaluate_dataset(
        sess,
        '../data/simplewiki/simplewiki-20171103.entity_linking.test.tfrecords.gz',
        header='test %d' % i,
        train = True,
        show_progress = True)
    model.evaluate_dataset(
        sess,
        '../data/simplewiki/simplewiki-20171103.entity_linking.dev.tfrecords.gz',
        header='dev %d' % i,
        train = False,
        show_progress = True)

test 0 (235) (29386.1 ms): loss=0.461598, accuracy=0.829956


dev 0 (235) (17518.6 ms): loss=0.451748, accuracy=0.832678


test 1 (470) (28867.8 ms): loss=0.456045, accuracy=0.832767


dev 1 (470) (17521 ms): loss=0.451734, accuracy=0.832678


test 2 (705) (29370.4 ms): loss=0.455253, accuracy=0.832767


dev 2 (705) (17464.8 ms): loss=0.45174, accuracy=0.832678


test 3 (940) (29308.7 ms): loss=0.454922, accuracy=0.832767


dev 3 (940) (17457.8 ms): loss=0.451667, accuracy=0.832678


test 4 (1175) (29296.6 ms): loss=0.454088, accuracy=0.832767


dev 4 (1175) (17432.3 ms): loss=0.451702, accuracy=0.832678


test 5 (1410) (29287.3 ms): loss=0.453703, accuracy=0.832767


dev 5 (1410) (17672.8 ms): loss=0.451634, accuracy=0.832678


test 6 (1645) (29108.9 ms): loss=0.453634, accuracy=0.832767


dev 6 (1645) (17532.8 ms): loss=0.451668, accuracy=0.832678


test 7 (1880) (28626.5 ms): loss=0.453123, accuracy=0.832767


dev 7 (1880) (17595.8 ms): loss=0.451636, accuracy=0.832678


KeyboardInterrupt: 

In [8]:
with open('/tmp/mediawiki_model_entity_linking_1.log', 'wt', encoding='utf-8') as f:
    for i in range(50):
        model.evaluate_dataset(
            sess,
            '../data/simplewiki/simplewiki-20171103.entity_linking.train.tfrecords.gz',
            header='train %d' % i,
            train = True,
            show_progress = True,
            log_file=f)
        model.evaluate_dataset(
            sess,
            '../data/simplewiki/simplewiki-20171103.entity_linking.dev.tfrecords.gz',
            header='dev %d' % i,
            train = False,
            show_progress = True,
            log_file=f)
        builder.save()

INFO:tensorflow:No assets to save.
INFO:tensorflow:No assets to write.
INFO:tensorflow:No assets to save.
INFO:tensorflow:No assets to write.



train 0 (235) (45289.9 ms): loss=0.460859, accuracy=0.830294



dev 0 (235) (19103.9 ms): loss=0.434446, accuracy=0.835478
INFO:tensorflow:SavedModel written to: b'../models/simplewiki/entity_linking_1/saved_model.pb'



train 1 (470) (44914.7 ms): loss=0.420472, accuracy=0.835911



dev 1 (470) (19095.5 ms): loss=0.399567, accuracy=0.838383
INFO:tensorflow:SavedModel written to: b'../models/simplewiki/entity_linking_1/saved_model.pb'



train 2 (705) (44704.3 ms): loss=0.369727, accuracy=0.842456



dev 2 (705) (18643.9 ms): loss=0.376092, accuracy=0.845667
INFO:tensorflow:SavedModel written to: b'../models/simplewiki/entity_linking_1/saved_model.pb'



train 3 (940) (44364.4 ms): loss=0.329213, accuracy=0.849322



dev 3 (940) (18781 ms): loss=0.362703, accuracy=0.850178
INFO:tensorflow:SavedModel written to: b'../models/simplewiki/entity_linking_1/saved_model.pb'



train 4 (1175) (45977.9 ms): loss=0.305415, accuracy=0.853317



dev 4 (1175) (18804.1 ms): loss=0.371279, accuracy=0.849611
INFO:tensorflow:SavedModel written to: b'../models/simplewiki/entity_linking_1/saved_model.pb'



train 5 (1410) (44520.2 ms): loss=0.285702, accuracy=0.859589



dev 5 (1410) (19000 ms): loss=0.391276, accuracy=0.844261
INFO:tensorflow:SavedModel written to: b'../models/simplewiki/entity_linking_1/saved_model.pb'



train 6 (1645) (48738 ms): loss=0.269504, accuracy=0.865711



dev 6 (1645) (21034.2 ms): loss=0.386451, accuracy=0.833939
INFO:tensorflow:SavedModel written to: b'../models/simplewiki/entity_linking_1/saved_model.pb'



train 7 (1880) (47977.6 ms): loss=0.253719, accuracy=0.873728



dev 7 (1880) (19821.4 ms): loss=0.376001, accuracy=0.83985
INFO:tensorflow:SavedModel written to: b'../models/simplewiki/entity_linking_1/saved_model.pb'


KeyboardInterrupt: 