In [1]:
import tensorflow as tf
import numpy as np
import datetime
import json
import gzip
from tqdm import tqdm_notebook

In [2]:
sess = None

In [3]:
def reset_tf(sess = None, log_device_placement = False):
    if sess:
        sess.close()
    tf.reset_default_graph()
    tf.set_random_seed(0)
    return tf.InteractiveSession(config = tf.ConfigProto(log_device_placement = log_device_placement))

In [4]:
class HyperParameters:
    learning_rate = 1e-3
    
    vocab_size = 30000
    
    dropout_rate = 0.1
    
    context_size = 81
    context_center_index = context_size // 2
    
    d_embedding_position = 16
    d_embedding_word = 128
    
    d_attention = 128
    d_attention_ff = 256
    
    attention_num_layers = 3

    d_attention_embedding_layers = [128]
    d_candidate_embedding_layers = [256]
    
    d_output_embedding = 128
    
    triplet_loss_margin = 0.5
    
    pipeline_batch_size = 128
    # number of negative samples per positive example
    pipeline_num_negative_samples = 5
    pipeline_num_parallel_calls = 4
    pipeline_prefetch_size = pipeline_batch_size * 16
    pipeline_shuffle_size = 5000
    
    embed_page_tfs_batch_size = 1024

In [5]:
class EntityLinkingModel:
    def __init__(self, hp):
        self._hp = hp

    def _parse_example(self, example_proto):
        features = {
            'page_id': tf.FixedLenFeature([1], tf.int64),
            'target_id': tf.FixedLenFeature([1], tf.int64),
            'target_word_ids': tf.VarLenFeature(tf.int64),
            'target_word_freqs': tf.VarLenFeature(tf.int64),
            'word_ids': tf.FixedLenFeature([self._hp.context_size], tf.int64)
        }

        parsed = tf.parse_single_example(example_proto, features)
        
        target_word_ids = tf.sparse_tensor_to_dense(parsed['target_word_ids'])
        target_word_freqs = tf.sparse_tensor_to_dense(parsed['target_word_freqs'])
        target_tf = tf.sparse_to_dense(
            target_word_ids,
            [self._hp.vocab_size],
            target_word_freqs)
        
        return (
            parsed['page_id'],
            parsed['word_ids'],
            parsed['target_id'],
            target_tf)

    def _build_data_pipeline(self):
        with tf.variable_scope('dataset'):
            # placeholder: examples filenames
            self._dataset_filenames = tf.placeholder(tf.string, shape = [None])

            # build examples dataset
            dataset = tf.data.TFRecordDataset(
                self._dataset_filenames,
                compression_type='GZIP')
            dataset = dataset.map(
                self._parse_example,
                num_parallel_calls = self._hp.pipeline_num_parallel_calls)
            dataset = dataset.shuffle(self._hp.pipeline_shuffle_size)
            dataset = dataset.prefetch(self._hp.pipeline_prefetch_size)
            dataset = dataset.batch(self._hp.pipeline_batch_size)

            # build dataset iterator
            self._dataset_iterator = dataset.make_initializable_iterator()
            (input_page_ids,
             input_context_word_ids,
             target_page_ids,
             target_page_tfs) = self._dataset_iterator.get_next()
            
            input_page_ids = tf.placeholder_with_default(
                input_page_ids,
                [None, 1],
                name='input_page_ids')
            input_context_word_ids = tf.placeholder_with_default(
                input_context_word_ids,
                [None, self._hp.context_size],
                name='input_context_word_ids')
            target_page_ids = tf.placeholder_with_default(
                target_page_ids,
                [None, 1],
                name='target_page_ids')
            target_page_tfs = tf.placeholder_with_default(
                target_page_tfs,
                [None, self._hp.vocab_size],
                name='target_page_tfs')
            
            minibatch_size = tf.shape(input_page_ids)[0]
            sample_size = self._hp.pipeline_num_negative_samples + 1
            expanded_minibatch_size = minibatch_size * sample_size
            
            x = tf.reshape(tf.range(minibatch_size), [minibatch_size, 1])
            y = tf.reshape(tf.range(1, sample_size+1), [1, sample_size])
            z = tf.ones([1, sample_size], dtype=tf.int32)
            pos_indices = tf.reshape(
                x * z,
                [expanded_minibatch_size])
            neg_indices = tf.reshape(
                tf.mod(x + y, minibatch_size),
                [expanded_minibatch_size])
            
            self._anchor_context_word_ids = tf.nn.embedding_lookup(
                input_context_word_ids,
                pos_indices)
            
            pos_page_tfs = tf.nn.embedding_lookup(
                target_page_tfs,
                pos_indices)
            neg_page_tfs = tf.nn.embedding_lookup(
                target_page_tfs,
                neg_indices)
#             candidate_page_tfs = tf.cast(candidate_page_tfs, tf.float32)
#             page_tf_norms = tf.reduce_sum(candidate_page_tfs, axis=-1, keep_dims=True) + 1e-6
#             self._candidate_page_dists = tf.concat(
#                 [candidate_page_tfs / page_tf_norms, page_tf_norms],
#                 axis=-1)
            self._pos_page_dists = tf.cast(pos_page_tfs, tf.float32)
            self._neg_page_dists = tf.cast(neg_page_tfs, tf.float32)
            
            pos_page_ids = tf.nn.embedding_lookup(
                target_page_ids,
                pos_indices)
            neg_page_ids = tf.nn.embedding_lookup(
                target_page_ids,
                neg_indices)
            self._loss_mask = tf.cast(
                tf.logical_not(tf.equal(pos_page_ids, neg_page_ids)),
                tf.float32)
            
            # positions
            p = tf.range(self._hp.context_size, dtype = tf.int64)
            p = tf.tile(p, [expanded_minibatch_size])
            p = tf.reshape(p, [expanded_minibatch_size, self._hp.context_size])
            self._anchor_context_positions = p
            
            # placeholder: training flag
            self._is_training = tf.placeholder(tf.bool)
            
            self._minibatch_size = expanded_minibatch_size

    def _layer_norm(self, x, num_units, scope=None, reuse=None, epsilon=1e-6):
        with tf.variable_scope(scope, reuse=reuse):
            return tf.layers.batch_normalization(x, training=self._is_training)
#             scale = tf.get_variable(
#                 "layer_norm_scale", [num_units], initializer=tf.ones_initializer())
#             bias = tf.get_variable(
#                 "layer_norm_bias", [num_units], initializer=tf.zeros_initializer())
#             result = self._layer_norm_compute(x, epsilon, scale, bias)
#             return result

    def _layer_norm_compute(self, x, epsilon, scale, bias):
        mean = tf.reduce_mean(x, axis=[-1], keep_dims=True)
        variance = tf.reduce_mean(tf.square(x - mean), axis=[-1], keep_dims=True)
        norm_x = (x - mean) * tf.rsqrt(variance + epsilon)
        return norm_x * scale + bias
    
    def _attention_layer(self, A):
        A_T = tf.transpose(A, perm=[0, 2, 1])
        scaled_logits = tf.matmul(A, A_T) / tf.sqrt(tf.cast(tf.shape(A)[-1], tf.float32))
        result = tf.matmul(tf.nn.softmax(scaled_logits), A)
        result = tf.layers.dropout(
            result, 
            rate=self._hp.dropout_rate,
            training=self._is_training)
        return result

    def _attention_ff_layer(self, A, scope, reuse=None):
        with tf.variable_scope(scope, reuse=reuse):
            A = tf.layers.dense(A, self._hp.d_attention_ff, activation=tf.nn.relu, name='fc1')
            A = tf.layers.dense(A, self._hp.d_attention, name='fc2')
            A = tf.layers.dropout(
                A, 
                rate=self._hp.dropout_rate, 
                training=self._is_training)
            return A
    
    def _attention_full_layer(self, A, scope, reuse=None):
        with tf.variable_scope(scope, reuse=reuse):
            A = self._layer_norm(
                A + self._attention_layer(A), 
                num_units=self._hp.d_attention,
                scope='attention_norm')
            A = self._layer_norm(
                A + self._attention_ff_layer(A, 'ff', reuse),
                num_units=self._hp.d_attention,
                scope='attention_ff_norm')
            return A

    def _candidate_embedding(self, A, scope, reuse=None):
        with tf.variable_scope(scope, reuse=reuse):
            layer = A
            for i, d_layer in enumerate(self._hp.d_candidate_embedding_layers):
                layer = tf.layers.dense(
                    layer, 
                    d_layer,
                    activation=tf.nn.relu,
                    name='layer_%d' % i,
                    reuse=reuse)
                layer = tf.layers.batch_normalization(
                    layer,
                    training=self._is_training,
                    name='layer_%d_batch_norm' % i,
                    reuse=reuse)
                layer = tf.layers.dropout(
                    layer,
                    rate=self._hp.dropout_rate,
                    training=self._is_training)
            layer = tf.layers.dense(
                layer,
                self._hp.d_output_embedding,
                name='layer_embedding',
                reuse=reuse)
            return tf.nn.l2_normalize(layer, dim=-1)
            
    def _build_model(self):
        with tf.variable_scope('model'):
            # embed context words
            word_embeddings = tf.get_variable(
                'word_embeddings', 
                [self._hp.vocab_size, self._hp.d_embedding_word])
            anchor_context_words_embedded = tf.nn.embedding_lookup(
                word_embeddings,
                self._anchor_context_word_ids)

            # embed context positions
            position_embeddings = tf.get_variable(
                'position_embeddings',
                [self._hp.context_size, self._hp.d_embedding_position],
                dtype=tf.float32)
            anchor_context_positions_embedded = tf.nn.embedding_lookup(
                position_embeddings,
                self._anchor_context_positions)

            # build full context vector (concat embeddings)
            anchor_context_full = tf.concat(
                [anchor_context_words_embedded, anchor_context_positions_embedded], 
                axis=-1)
            
            # build attention input vector
            anchor_attention = tf.layers.dense(
                anchor_context_full,
                self._hp.d_attention,
                activation=tf.nn.relu,
                name='anchor_attention')
            anchor_attention = self._layer_norm(
                anchor_attention,
                num_units=self._hp.d_attention,
                scope='anchor_attention')
            anchor_attention = tf.layers.dropout(
                anchor_attention,
                rate=self._hp.dropout_rate,
                training=self._is_training)
            
            with tf.variable_scope('anchor'):
                # attention layers
                layer = anchor_attention
                for i in range(self._hp.attention_num_layers):
                    layer = self._attention_full_layer(layer, 'layer_%d' % i)
                # post-attention embedding
                layer = layer[:, self._hp.context_center_index, :]
                for i, d_layer in enumerate(self._hp.d_attention_embedding_layers):
                    layer = tf.layers.dense(
                        layer,
                        d_layer,
                        activation=tf.nn.relu,
                        name='layer_embedding_%d' % i)
                    layer = tf.layers.batch_normalization(
                        layer,
                        training=self._is_training,
                        name='layer_embedding_%d_batch_norm' % i)
                    layer = tf.layers.dropout(
                        layer,
                        rate=self._hp.dropout_rate,
                        training=self._is_training)
                layer = tf.layers.dense(
                    layer,
                    self._hp.d_output_embedding,
                    name='layer_embedding_final')
            self._anchor_embedded = tf.nn.l2_normalize(layer, dim=-1)

            # build positive/negative layers
            self._pos_embedded = self._candidate_embedding(
                self._pos_page_dists, 
                scope='candidate', 
                reuse=False)
            self._neg_embedded = self._candidate_embedding(
                self._neg_page_dists, 
                scope='candidate', 
                reuse=True)

    def _build_training_model(self):
        with tf.variable_scope('train'):
            pos_dist = tf.norm(
                self._anchor_embedded - self._pos_embedded,
                axis=-1,
                keep_dims=True)
            neg_dist = tf.norm(
                self._anchor_embedded - self._neg_embedded,
                axis=-1,
                keep_dims=True)
            
            losses = tf.maximum(
                pos_dist - neg_dist + self._hp.triplet_loss_margin,
                0)
            losses *= self._loss_mask
            
            self._total_loss = tf.reduce_sum(losses)
            self._mean_loss = tf.reduce_mean(losses)
            
            c = tf.less(losses, self._hp.triplet_loss_margin)
            c = tf.cast(c, tf.int32)
            self._num_correct_labels = tf.reduce_sum(c)
            
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                self._global_step = tf.Variable(0, name='global_step', trainable=False)
                self._optimizer = tf.train.AdamOptimizer(learning_rate=self._hp.learning_rate)
                self._train_op = self._optimizer.minimize(
                    self._mean_loss,
                    global_step=self._global_step)
                
    def dump_statistics(self):
        total_parameters = 0
        for variable in tf.trainable_variables():
            # shape is an array of tf.Dimension
            shape = variable.get_shape()
            variable_parameters = 1
            for dim in shape:
                variable_parameters *= dim.value
            print('parameters for "%s": %d' % (variable.name, variable_parameters))
            total_parameters += variable_parameters
        print('total parameters: %d' % total_parameters)
        
    def process(self,
                sess,
                dataset_filename,
                options = None,
                run_metadata = None,
                header = 'results',
                train = False,
                show_progress = True,
                log_file = None):
        cum_loss = 0
        cum_num_examples = 0
        cum_correct_examples = 0
        
        start = datetime.datetime.now()

        sess.run(self._dataset_iterator.initializer, feed_dict={
            self._dataset_filenames: [dataset_filename]
        })

        if show_progress:
            progress = tqdm_notebook()

        while True:
            try:
                (_,
                 curr_total_loss, 
                 curr_minibatch_size,
                 curr_num_correct_labels) = sess.run(
                    (self._train_op if train else [],
                     self._total_loss,
                     self._minibatch_size,
                     self._num_correct_labels),
                    feed_dict = { self._is_training: train },
                    options = options,
                    run_metadata = run_metadata)
            except tf.errors.OutOfRangeError:
                break

            if show_progress:
                progress.update(curr_minibatch_size)

            cum_loss += curr_total_loss
            cum_num_examples += curr_minibatch_size
            cum_correct_examples += curr_num_correct_labels

        if show_progress:
            progress.close()
            
        finish = datetime.datetime.now()

        message = '%s (%d) (%s): loss=%g, accuracy=%g' % (
            header,
            tf.train.global_step(sess, self._global_step),
            finish-start,
            cum_loss/cum_num_examples,
            cum_correct_examples/cum_num_examples)
        print(message)
        if log_file:
            print(message, file=log_file)
            log_file.flush()
            
    def evaluate(self, sess, dataset_filename, page_tfs, limit):
        embedded_page_tfs = self._embed_page_tfs(sess, page_tfs)
        batch_size = 1024
        
        sess.run(self._dataset_iterator.initializer, feed_dict={
            self._dataset_filenames: [dataset_filename]
        })

In [6]:
sess = reset_tf(sess)

model = EntityLinkingModel(HyperParameters())
model._build_data_pipeline()
model._build_model()
model._build_training_model()
model.dump_statistics()

parameters for "model/word_embeddings:0": 3840000
parameters for "model/position_embeddings:0": 1296
parameters for "model/anchor_attention/kernel:0": 18432
parameters for "model/anchor_attention/bias:0": 128
parameters for "model/anchor_attention/batch_normalization/gamma:0": 128
parameters for "model/anchor_attention/batch_normalization/beta:0": 128
parameters for "model/anchor/layer_0/attention_norm/batch_normalization/gamma:0": 128
parameters for "model/anchor/layer_0/attention_norm/batch_normalization/beta:0": 128
parameters for "model/anchor/layer_0/ff/fc1/kernel:0": 32768
parameters for "model/anchor/layer_0/ff/fc1/bias:0": 256
parameters for "model/anchor/layer_0/ff/fc2/kernel:0": 32768
parameters for "model/anchor/layer_0/ff/fc2/bias:0": 128
parameters for "model/anchor/layer_0/attention_ff_norm/batch_normalization/gamma:0": 128
parameters for "model/anchor/layer_0/attention_ff_norm/batch_normalization/beta:0": 128
parameters for "model/anchor/layer_1/attention_norm/batch_norm

In [7]:
sess.run(tf.global_variables_initializer())

In [8]:
for i in range(5):
    model.process(
        sess,
        '../data/simplewiki/simplewiki-20171103.entity_linking.test.tfrecords.gz',
        header='test %d' % i,
        train = True,
        show_progress = True)
    model.process(
        sess,
        '../data/simplewiki/simplewiki-20171103.entity_linking.dev.tfrecords.gz',
        header='dev %d' % i,
        train = False,
        show_progress = True)


test 0 (235) (0:00:54.218661): loss=0.243635, accuracy=0.793928



dev 0 (235) (0:00:26.258545): loss=0.443067, accuracy=0.596906



test 1 (470) (0:00:53.456649): loss=0.0715242, accuracy=0.955939



dev 1 (470) (0:00:26.039631): loss=0.288039, accuracy=0.748922



test 2 (705) (0:00:53.574088): loss=0.0346475, accuracy=0.985011



dev 2 (705) (0:00:26.152587): loss=0.102347, accuracy=0.923828



test 3 (940) (0:00:53.176926): loss=0.0233398, accuracy=0.992206



dev 3 (940) (0:00:26.127982): loss=0.0994088, accuracy=0.925072



test 4 (1175) (0:00:53.986909): loss=0.0185829, accuracy=0.994917



dev 4 (1175) (0:00:25.832809): loss=0.10002, accuracy=0.924194


In [9]:
# with open('/tmp/mediawiki_model_entity_linking_1.log', 'wt', encoding='utf-8') as f:
#     for i in range(50):
#         model.process(
#             sess,
#             '../data/simplewiki/simplewiki-20171103.entity_linking.train.tfrecords.gz',
#             header='train %d' % i,
#             train = True,
#             show_progress = True,
#             log_file=f)
#         model.process(
#             sess,
#             '../data/simplewiki/simplewiki-20171103.entity_linking.dev.tfrecords.gz',
#             header='dev %d' % i,
#             train = False,
#             show_progress = True,
#             log_file=f)
#         builder.save()