In [1]:
import tensorflow as tf
import numpy as np
import datetime
import json
import gzip
from tqdm import tqdm_notebook

In [2]:
sess = None

In [3]:
def reset_tf(sess = None, log_device_placement = False):
    if sess:
        sess.close()
    tf.reset_default_graph()
    tf.set_random_seed(0)
    return tf.InteractiveSession(config = tf.ConfigProto(log_device_placement = log_device_placement))

In [4]:
class HyperParameters:
    learning_rate = 1e-3
    
    vocab_size = 30000
    
    dropout_rate = 0.1
    
    context_size = 81
    context_center_index = context_size // 2
    
    d_embedding_position = 16
    d_embedding_word = 128
    
    d_attention = 128
    d_attention_ff = 256
    
    attention_num_layers = 3

    d_candidate_layers = [256, 128]
    d_logistic_layers = [64, 16]
    
    pipeline_batch_size = 256
    # number of negative samples per positive example
    pipeline_num_negative_samples = 1
    pipeline_num_parallel_calls = 4
    pipeline_prefetch_size = pipeline_batch_size * 16
    pipeline_shuffle_size = 5000
    
    embed_page_tfs_batch_size = 1024

In [5]:
class EntityLinkingModel:
    def __init__(self, session, hp):
        self._session = session
        self._hp = hp
        
    def _parse_example(self, example_proto):
        features = {
            'page_id': tf.FixedLenFeature([1], tf.int64),
            'target_id': tf.FixedLenFeature([1], tf.int64),
            'target_word_ids': tf.VarLenFeature(tf.int64),
            'target_word_freqs': tf.VarLenFeature(tf.int64),
            'word_ids': tf.FixedLenFeature([self._hp.context_size], tf.int64)
        }

        parsed = tf.parse_single_example(example_proto, features)
        
        target_word_ids = tf.sparse_tensor_to_dense(parsed['target_word_ids'])
        target_word_freqs = tf.sparse_tensor_to_dense(parsed['target_word_freqs'])
        target_tf = tf.sparse_to_dense(
            target_word_ids,
            [self._hp.vocab_size],
            target_word_freqs)
        
        return (
            parsed['page_id'],
            parsed['word_ids'],
            parsed['target_id'],
            target_tf)
    
    def _create_negative_samples_inner(self, page_ids, word_ids, target_ids, target_tfs):
        in_batch_size = word_ids.shape[0]
        out_batch_size = (1 + self._hp.pipeline_num_negative_samples) * in_batch_size

        # batch too small to do negative sampling
        if in_batch_size < self._hp.pipeline_num_negative_samples:
            return (word_ids, target_tfs, [1] * in_batch_size)
        # handle getting called by pipeline before the model is ready/initialized
        if not hasattr(self, '_context_word_ids'):
            return (word_ids, target_tfs, [1] * in_batch_size)
        
        context_embedded, candidate_embedded, kernels, biases = self._session.run(
            (self._context_attention_embedded,
             self._candidate_page_dists_embedded,
             self._logistic_kernels,
             self._logistic_biases),
            feed_dict = {
                self._context_word_ids: word_ids,
                self._candidate_page_dists: target_tfs,
                self._training: False })
        
        context_word_ids = np.zeros(
            (out_batch_size, word_ids.shape[1]),
            dtype=word_ids.dtype)
        candidate_page_dists = np.zeros(
            (out_batch_size, target_tfs.shape[1]),
            dtype=target_tfs.dtype)
        target_labels = np.zeros(
            out_batch_size,
            dtype=np.int64)
        
        for i in range(in_batch_size):
            # hand-compute logits
            a = np.concatenate(
                [np.tile(context_embedded[i], [in_batch_size, 1]), candidate_embedded],
                axis=-1)
            for k, b in zip(kernels, biases):
                a = np.matmul(a, k) + b
                if b.size > 1:
                    a = np.maximum(a, 0)
            a *= np.logical_not(np.equal(target_ids[i], target_ids))
            a = a.flatten()

            # find highest logits as negative samples
            neg_samples = np.argpartition(
                a,
                -self._hp.pipeline_num_negative_samples)[-self._hp.pipeline_num_negative_samples:]
            
            # set positive sample
            base_i = i * (1 + self._hp.pipeline_num_negative_samples)
            context_word_ids[base_i] = word_ids[i]
            candidate_page_dists[base_i] = target_tfs[i]
            target_labels[base_i] = 1
            
            # set negative samples
            for j, k in enumerate(neg_samples):
                samp_i = base_i + j + 1
                context_word_ids[samp_i] = word_ids[i]
                candidate_page_dists[samp_i] = target_tfs[k]
                target_labels[samp_i] = 1 if target_ids[i] == target_ids[k] else 0
        
        return (context_word_ids, candidate_page_dists, target_labels)
    
    def _create_negative_samples(self, *args):
        return tf.tuple(tf.py_func(
            self._create_negative_samples_inner,
            args,
            [tf.int64, tf.int64, tf.int64]))

    def _build_data_pipeline(self):
        with tf.variable_scope('dataset'):
            # placeholder: examples filenames
            self._dataset_filenames = tf.placeholder(tf.string, shape = [None])

            # build examples dataset
            dataset = tf.data.TFRecordDataset(
                self._dataset_filenames,
                compression_type='GZIP')
            dataset = dataset.map(
                self._parse_example,
                num_parallel_calls = self._hp.pipeline_num_parallel_calls)
            dataset = dataset.shuffle(self._hp.pipeline_shuffle_size)
            dataset = dataset.prefetch(self._hp.pipeline_prefetch_size)
            dataset = dataset.batch(self._hp.pipeline_batch_size)
            dataset = dataset.map(
                self._create_negative_samples,
                # TODO: make this bigger?
                num_parallel_calls = 1)

            # build dataset iterator
            self._dataset_iterator = dataset.make_initializable_iterator()
            (context_word_ids,
             candidate_page_dists,
             target_labels) = self._dataset_iterator.get_next()
            
            candidate_page_dists = tf.cast(candidate_page_dists, tf.float32)

            # placeholders
            self._context_word_ids = tf.placeholder_with_default(
                context_word_ids,
                shape = [None, self._hp.context_size],
                name = 'input_context_word_ids')
            self._candidate_page_dists = tf.placeholder_with_default(
                candidate_page_dists,
                shape = [None, self._hp.vocab_size],
                name = 'candidate_page_dists')
            self._target_labels = tf.placeholder_with_default(
                target_labels,
                shape = [None],
                name = 'target_labels')
            
            # positions
            self._minibatch_size = tf.shape(self._context_word_ids)[0]
            p = tf.range(self._hp.context_size, dtype = tf.int64)
            p = tf.tile(p, [self._minibatch_size])
            p = tf.reshape(p, [self._minibatch_size, self._hp.context_size])
            self._context_positions = p
    
    def _attention_layer(self, A):
        A_T = tf.transpose(A, perm=[0, 2, 1])
        scaled_logits = tf.matmul(A, A_T) / tf.sqrt(tf.cast(tf.shape(A)[-1], tf.float32))
        result = tf.matmul(tf.nn.softmax(scaled_logits), A)
        result = tf.layers.dropout(
            result, 
            rate=self._hp.dropout_rate,
            training=self._training)
        return result

    def _attention_ff_layer(self, A, scope, reuse=None):
        with tf.variable_scope(scope, reuse=reuse):
            A = tf.layers.dense(A, self._hp.d_attention_ff, activation=tf.nn.relu, name='fc1')
            A = tf.layers.dense(A, self._hp.d_attention, name='fc2')
            A = tf.layers.dropout(
                A, 
                rate=self._hp.dropout_rate, 
                training=self._training)
            return A
    
    def _attention_full_layer(self, A, scope, reuse=None):
        with tf.variable_scope(scope, reuse=reuse):
            A = tf.layers.batch_normalization(
                A + self._attention_layer(A), 
                training=self._training,
                name='attention_batch_norm',
                reuse=reuse)
            A = tf.layers.batch_normalization(
                A + self._attention_ff_layer(A, 'ff', reuse),
                training=self._training,
                name='attention_ff_batch_norm',
                reuse=reuse)
            return A
            
    def _build_model(self):
        with tf.variable_scope('model'):
            # placeholder: training flag
            self._training = tf.placeholder(tf.bool)
            
            # embed context words
            word_embeddings = tf.get_variable(
                'word_embeddings', 
                [self._hp.vocab_size, self._hp.d_embedding_word])
            context_words_embedded = tf.nn.embedding_lookup(
                word_embeddings,
                self._context_word_ids)

            # embed context positions
            position_embeddings = tf.get_variable(
                'position_embeddings',
                [self._hp.context_size, self._hp.d_embedding_position],
                dtype=tf.float32)
            context_positions_embedded = tf.nn.embedding_lookup(
                position_embeddings,
                self._context_positions)

            # build full context vector (concat embeddings)
            context_full = tf.concat(
                [context_words_embedded, context_positions_embedded], 
                axis=-1)
            
            # build attention input vector
            context_attention = tf.layers.dense(
                context_full,
                self._hp.d_attention,
                activation=tf.nn.relu,
                name='context_attention')
            context_attention = tf.layers.batch_normalization(
                context_attention,
                training=self._training,
                name='context_attention')
            context_attention = tf.layers.dropout(
                context_attention,
                rate=self._hp.dropout_rate,
                training=self._training)
            
            # build attention layers
            with tf.variable_scope('attention'):
                layer = context_attention
                for i in range(self._hp.attention_num_layers):
                    layer = self._attention_full_layer(layer, 'layer_%d' % i)
            self._context_attention_embedded = layer[:, self._hp.context_center_index, :]

            # build candidate layers
            with tf.variable_scope('candidate'):
                layer = self._candidate_page_dists
                for i, d_layer in enumerate(self._hp.d_candidate_layers):
                    layer = tf.layers.dense(
                        layer, 
                        d_layer,
                        activation=tf.nn.relu,
                        name='layer_%d' % i)
                    layer = tf.layers.batch_normalization(
                        layer,
                        training=self._training,
                        name='layer_%d' % i)
                    layer = tf.layers.dropout(
                        layer,
                        rate=self._hp.dropout_rate,
                        training=self._training)
            self._candidate_page_dists_embedded = layer
            
            # build final logistic neuron
            with tf.variable_scope('logistic'):
                layer = tf.concat(
                    [self._context_attention_embedded, self._candidate_page_dists_embedded],
                    axis=-1)
                for i, d_layer in enumerate(self._hp.d_logistic_layers):
                    layer = tf.layers.dense(
                        layer, 
                        d_layer,
                        activation=tf.nn.relu,
                        name='layer_%d' % i)
                    layer = tf.layers.dropout(
                        layer,
                        rate=self._hp.dropout_rate,
                        training=self._training)
                self._output_logits = tf.squeeze(
                    tf.layers.dense(layer, 1, name='layer_%d' % len(self._hp.d_logistic_layers)),
                    axis=-1)

            # grab tensors so we can do negative sampling
            self._logistic_kernels = []
            self._logistic_biases = []
            for i in range(len(self._hp.d_logistic_layers)+1):
                self._logistic_kernels.append(tf.get_default_graph().get_tensor_by_name(
                    'model/logistic/layer_%d/kernel:0' % i))
                self._logistic_biases.append(tf.get_default_graph().get_tensor_by_name(
                    'model/logistic/layer_%d/bias:0' % i))

    def _build_training_model(self):
        with tf.variable_scope('train'):
            losses = tf.nn.sigmoid_cross_entropy_with_logits(
                labels = tf.cast(self._target_labels, tf.float32),
                logits = self._output_logits)
            
            self._total_loss = tf.reduce_sum(losses)
            self._mean_loss = tf.reduce_mean(losses)
            
            output_labels = tf.greater(tf.sigmoid(self._output_logits), 0.5)
            output_labels = tf.cast(output_labels, tf.int64)
            self._num_correct_labels = tf.reduce_sum(tf.cast(
                tf.equal(output_labels, self._target_labels), 
                tf.int32))
            
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                self._global_step = tf.Variable(0, name='global_step', trainable=False)
                self._optimizer = tf.train.AdamOptimizer(learning_rate=self._hp.learning_rate)
                self._train_op = self._optimizer.minimize(
                    self._mean_loss,
                    global_step=self._global_step)
    
    def build_model(self):
        self._build_data_pipeline()
        self._build_model()
        self._build_training_model()

    def dump_statistics(self):
        total_parameters = 0
        for variable in tf.trainable_variables():
            # shape is an array of tf.Dimension
            shape = variable.get_shape()
            variable_parameters = 1
            for dim in shape:
                variable_parameters *= dim.value
            print('parameters for "%s": %d' % (variable.name, variable_parameters))
            total_parameters += variable_parameters
        print('total parameters: %d' % total_parameters)

    def process(self,
                dataset_filename,
                options = None,
                run_metadata = None,
                header = 'results',
                train = False,
                show_progress = True,
                log_file = None):
        cum_loss = 0
        cum_num_examples = 0
        cum_correct_examples = 0
        
        start = datetime.datetime.now()

        self._session.run(self._dataset_iterator.initializer, feed_dict={
            self._dataset_filenames: [dataset_filename]
        })

        if show_progress:
            progress = tqdm_notebook()

        while True:
            try:
                (_,
                 curr_total_loss, 
                 curr_minibatch_size,
                 curr_num_correct_labels,
                 curr_logistic_kernels,
                 curr_logistic_biases) = self._session.run(
                    (self._train_op if train else (),
                     self._total_loss,
                     self._minibatch_size,
                     self._num_correct_labels,
                     self._logistic_kernels if train else (),
                     self._logistic_biases if train else ()),
                    feed_dict = { self._training: train },
                    options = options,
                    run_metadata = run_metadata)
            except tf.errors.OutOfRangeError:
                break

            if show_progress:
                progress.update(curr_minibatch_size)
                
            if curr_logistic_kernels:
                self._logistic_kernels_checkpoint = curr_logistic_kernels
                self._logistic_biases_checkpoint = curr_logistic_biases

            cum_loss += curr_total_loss
            cum_num_examples += curr_minibatch_size
            cum_correct_examples += curr_num_correct_labels

        if show_progress:
            progress.close()
            
        finish = datetime.datetime.now()
        elapsed = (finish - start).total_seconds() * 1000.0

        message = '%s (%d) (%g ms): loss=%g, accuracy=%g' % (
            header,
            tf.train.global_step(sess, self._global_step),
            elapsed,
            cum_loss/cum_num_examples,
            cum_correct_examples/cum_num_examples)
        print(message)
        if log_file:
            print(message, file=log_file)
            log_file.flush()

In [6]:
sess = reset_tf(sess)

model = EntityLinkingModel(sess, HyperParameters())
model.build_model()
model.dump_statistics()

parameters for "model/word_embeddings:0": 3840000
parameters for "model/position_embeddings:0": 1296
parameters for "model/context_attention/kernel:0": 18432
parameters for "model/context_attention/bias:0": 128
parameters for "model/context_attention/gamma:0": 128
parameters for "model/context_attention/beta:0": 128
parameters for "model/attention/layer_0/attention_batch_norm/gamma:0": 128
parameters for "model/attention/layer_0/attention_batch_norm/beta:0": 128
parameters for "model/attention/layer_0/ff/fc1/kernel:0": 32768
parameters for "model/attention/layer_0/ff/fc1/bias:0": 256
parameters for "model/attention/layer_0/ff/fc2/kernel:0": 32768
parameters for "model/attention/layer_0/ff/fc2/bias:0": 128
parameters for "model/attention/layer_0/attention_ff_batch_norm/gamma:0": 128
parameters for "model/attention/layer_0/attention_ff_batch_norm/beta:0": 128
parameters for "model/attention/layer_1/attention_batch_norm/gamma:0": 128
parameters for "model/attention/layer_1/attention_batch

In [7]:
sess.run(tf.global_variables_initializer())

In [8]:
for i in range(5):
    model.process(
        '../data/simplewiki/simplewiki-20171103.entity_linking.dev.tfrecords.gz', 
        train=True,
        header='dev %d' % i)
    model.process(
        '../data/simplewiki/simplewiki-20171103.entity_linking.test.tfrecords.gz', 
        train=False,
        header='test %d' % i)    


dev 0 (118) (36097.6 ms): loss=0.278505, accuracy=0.896433



test 0 (118) (27150.5 ms): loss=2.02677, accuracy=0.49995



dev 1 (236) (34920.1 ms): loss=0.242297, accuracy=0.896633



test 1 (236) (27390.2 ms): loss=2.27029, accuracy=0.48575



dev 2 (354) (35134.4 ms): loss=0.160987, accuracy=0.950233



test 2 (354) (28071.9 ms): loss=3.06801, accuracy=0.486783



dev 3 (472) (36136.1 ms): loss=0.0985317, accuracy=0.973117



test 3 (472) (29130.3 ms): loss=3.51425, accuracy=0.492517



dev 4 (590) (37093.6 ms): loss=0.0809694, accuracy=0.98155



test 4 (590) (29644.2 ms): loss=3.48509, accuracy=0.488467
