In [1]:
import tensorflow as tf
import numpy as np
import datetime
import json
import gzip
import matplotlib.pyplot as plt
import re
from tqdm import tqdm_notebook

In [2]:
sess = None

In [3]:
def reset_tf(sess = None, log_device_placement = False):
    if sess:
        sess.close()
    tf.reset_default_graph()
    tf.set_random_seed(0)
    return tf.InteractiveSession(config = tf.ConfigProto(log_device_placement = log_device_placement))

In [4]:
class HyperParameters:
    learning_rate = 1e-3
    
    vocab_size = 30000
    num_targets = 2000
    
    dropout_rate = 0.2
    
    context_size = 81
    context_center_index = context_size // 2
    
    d_embedding_position = 16
    d_embedding_word = 256
    
    d_attention = 256
    d_attention_ff = 512
    
    num_attention_heads = 4
    d_attention_head = d_attention // num_attention_heads
    
    attention_num_layers = 3

    pipeline_batch_size = 256
    pipeline_num_parallel_calls = 4
    pipeline_prefetch_size = pipeline_batch_size * 16
    pipeline_shuffle_size = 5000
    
    gradient_clip_norm = 5.0

In [5]:
class EntityLinkingModel:
    def __init__(self, session, hp):
        self._session = session
        self._hp = hp
        
    def _parse_example(self, example_proto):
        parsed = tf.parse_single_example(example_proto, features = {
            'page_id': tf.FixedLenFeature([1], tf.int64),
            'target_id': tf.FixedLenFeature([1], tf.int64),
            'context_word_ids': tf.FixedLenFeature([self._hp.context_size], tf.int64) })
        
        return (
            parsed['target_id'],
            parsed['context_word_ids'])

    def _build_data_pipeline(self):
        with tf.variable_scope('dataset'):
            # placeholder: examples filenames
            self._dataset_filenames = tf.placeholder(tf.string, shape = [None])

            # build examples dataset
            dataset = tf.data.TFRecordDataset(
                self._dataset_filenames,
                compression_type='GZIP')
            dataset = dataset.map(
                self._parse_example,
                num_parallel_calls = self._hp.pipeline_num_parallel_calls)
            dataset = dataset.shuffle(self._hp.pipeline_shuffle_size)
            dataset = dataset.prefetch(self._hp.pipeline_prefetch_size)
            dataset = dataset.batch(self._hp.pipeline_batch_size)

            # build dataset iterator
            self._dataset_iterator = dataset.make_initializable_iterator()
            (target_labels, context_word_ids) = self._dataset_iterator.get_next()
            target_labels = tf.squeeze(target_labels, axis = -1)

            # placeholders
            self._context_word_ids = tf.placeholder_with_default(
                context_word_ids,
                shape = [None, self._hp.context_size],
                name = 'context_word_ids')
            self._target_labels = tf.placeholder_with_default(
                target_labels,
                shape = [None],
                name = 'target_labels')
            
            # positions
            self._minibatch_size = tf.shape(self._context_word_ids)[0]
            p = tf.range(self._hp.context_size, dtype = tf.int64)
            p = tf.tile(p, [self._minibatch_size])
            p = tf.reshape(p, [self._minibatch_size, self._hp.context_size])
            self._context_positions = p
            
    def _attention_layer(self, A, scope, reuse=None):
        with tf.variable_scope(scope, reuse=reuse):
            query_project = tf.get_variable(
                'query_project',
                [self._hp.num_attention_heads, self._hp.d_attention, self._hp.d_attention_head])
            key_project = tf.get_variable(
                'key_project',
                [self._hp.num_attention_heads, self._hp.d_attention, self._hp.d_attention_head])
            value_project = tf.get_variable(
                'value_project',
                [self._hp.num_attention_heads, self._hp.d_attention, self._hp.d_attention_head])
            output_weights = tf.get_variable(
                'output',
                [self._hp.d_attention, self._hp.d_attention])

            # compute each attention head
            heads = []
            for i in range(self._hp.num_attention_heads):
                Q_proj = tf.tensordot(A, query_project[i], axes = 1)
                K_proj = tf.tensordot(A, key_project[i], axes = 1)
                V_proj = tf.tensordot(A, value_project[i], axes = 1)
                
                # N.B., tensordot has a bug / fails to infer shapes,
                # so we have to hint shapes ourselves
                Q_proj.set_shape([None, self._hp.context_size, self._hp.d_attention_head])
                K_proj.set_shape([None, self._hp.context_size, self._hp.d_attention_head])
                V_proj.set_shape([None, self._hp.context_size, self._hp.d_attention_head])
                
                K_proj_T = tf.transpose(K_proj, perm = [0, 2, 1])
                scaled_logits = tf.matmul(Q_proj, K_proj_T) / tf.sqrt(float(self._hp.d_attention_head))
                head = tf.matmul(tf.nn.softmax(scaled_logits), V_proj)
                head = tf.layers.dropout(
                    head,
                    rate = self._hp.dropout_rate,
                    training = self._training)
                heads.append(head)

            # concatenate heads
            result = tf.concat(heads, axis=-1)
            
            # transform by output weights
            result = tf.tensordot(result, output_weights, axes = 1)
            result.set_shape([None, self._hp.context_size, self._hp.d_attention])
            
            # dropout
            result = tf.layers.dropout(
                result, 
                rate=self._hp.dropout_rate,
                training=self._training)
            
            return result

    def _attention_ff_layer(self, A, scope, reuse=None):
        with tf.variable_scope(scope, reuse=reuse):
            A = tf.layers.dense(A, self._hp.d_attention_ff, activation=tf.nn.relu, name='fc1')
            A = tf.layers.dense(A, self._hp.d_attention, name='fc2')
            A = tf.layers.dropout(
                A, 
                rate=self._hp.dropout_rate, 
                training=self._training)
            return A
    
    def _attention_full_layer(self, A, scope, reuse=None):
        with tf.variable_scope(scope, reuse=reuse):
            A = tf.layers.batch_normalization(
                A + self._attention_layer(A, 'attention', reuse), 
                training=self._training,
                name='attention',
                reuse=reuse)
            A = tf.layers.batch_normalization(
                A + self._attention_ff_layer(A, 'ff', reuse),
                training=self._training,
                name='ff',
                reuse=reuse)
            return A
            
    def _build_model(self):
        with tf.variable_scope('model'):
            # placeholder: training flag
            self._training = tf.placeholder(tf.bool)
            
            # embed context words
            word_embeddings = tf.get_variable(
                'word_embeddings', 
                [self._hp.vocab_size, self._hp.d_embedding_word])
            context_words_embedded = tf.nn.embedding_lookup(
                word_embeddings,
                self._context_word_ids)

            # embed context positions
            position_embeddings = tf.get_variable(
                'position_embeddings',
                [self._hp.context_size, self._hp.d_embedding_position],
                dtype=tf.float32)
            context_positions_embedded = tf.nn.embedding_lookup(
                position_embeddings,
                self._context_positions)

            # build full context vector (concat embeddings)
            context_full = tf.concat(
                [context_words_embedded, context_positions_embedded], 
                axis=-1)
            
            # build attention layers
            with tf.variable_scope('attention_with_ff'):
                # build input vector
                context_attention = tf.layers.dense(
                    context_full,
                    self._hp.d_attention,
                    activation=tf.nn.relu,
                    name='input')
                context_attention = tf.layers.batch_normalization(
                    context_attention,
                    training=self._training,
                    name='input')
                context_attention = tf.layers.dropout(
                    context_attention,
                    rate=self._hp.dropout_rate,
                    training=self._training)
                
                layer = context_attention
                for i in range(self._hp.attention_num_layers):
                    layer = self._attention_full_layer(layer, 'layer_%d' % i)
            self._context_attention_embedded = layer[:, self._hp.context_center_index, :]
            
            # build final softmax layer
            self._output_logits = tf.layers.dense(
                self._context_attention_embedded,
                self._hp.num_targets,
                name = 'softmax')

    def _build_training_model(self):
        with tf.variable_scope('train'):
            losses = tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels = self._target_labels,
                logits = self._output_logits)
            
            self._total_loss = tf.reduce_sum(losses)
            self._mean_loss = tf.reduce_mean(losses)
            
            # N.B., tf.nn.softmax here is unnecessary?
            output_labels = tf.argmax(tf.nn.softmax(self._output_logits), axis=-1)
            self._num_correct_labels = tf.reduce_sum(tf.cast(
                tf.equal(output_labels, self._target_labels), 
                tf.int32))
            
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                self._global_step = tf.Variable(0, name='global_step', trainable=False)
                self._optimizer = tf.train.AdamOptimizer(learning_rate=self._hp.learning_rate)
                
                # gradient clipping
                gradients, variables = zip(*self._optimizer.compute_gradients(self._mean_loss))
                gradients, _ = tf.clip_by_global_norm(
                    gradients, 
                    self._hp.gradient_clip_norm)
                
                self._train_op = self._optimizer.apply_gradients(
                    zip(gradients, variables),
                    global_step = self._global_step)
#                 self._train_op = self._optimizer.minimize(
#                     self._mean_loss,
#                     global_step=self._global_step)

    
    def build_model(self):
        self._build_data_pipeline()
        self._build_model()
        self._build_training_model()

    def dump_statistics(self):
        total_parameters = 0
        for variable in tf.trainable_variables():
            # shape is an array of tf.Dimension
            shape = variable.get_shape()
            variable_parameters = 1
            for dim in shape:
                variable_parameters *= dim.value
            print('parameters for "%s": %d' % (variable.name, variable_parameters))
            total_parameters += variable_parameters
        print('total parameters: %d' % total_parameters)

    def process(self,
                dataset_filename,
                header = 'results',
                train = False,
                show_progress = True,
                log_file = None):
        cum_loss = 0
        cum_num_examples = 0
        cum_correct_examples = 0
        
        start = datetime.datetime.now()

        self._session.run(self._dataset_iterator.initializer, feed_dict={
            self._dataset_filenames: [dataset_filename]
        })

        if show_progress:
            progress = tqdm_notebook(leave = False, desc = header)

        while True:
            try:
                (_,
                 curr_total_loss, 
                 curr_minibatch_size,
                 curr_num_correct_labels) = self._session.run(
                    (self._train_op if train else (),
                     self._total_loss,
                     self._minibatch_size,
                     self._num_correct_labels),
                    feed_dict = { self._training: train })
            except tf.errors.OutOfRangeError:
                break

            if show_progress:
                progress.update(curr_minibatch_size)

            cum_loss += curr_total_loss
            cum_num_examples += curr_minibatch_size
            cum_correct_examples += curr_num_correct_labels

        if show_progress:
            progress.close()
            
        finish = datetime.datetime.now()

        message = '%s (%d) (%s): loss=%g, accuracy=%g' % (
            header,
            tf.train.global_step(sess, self._global_step),
            finish - start,
            cum_loss / cum_num_examples,
            cum_correct_examples / cum_num_examples)
        print(message)
        if log_file:
            print(message, file=log_file)
            log_file.flush()

In [6]:
sess = reset_tf(sess)

model = EntityLinkingModel(sess, HyperParameters())
model.build_model()
model.dump_statistics()

parameters for "model/word_embeddings:0": 7680000
parameters for "model/position_embeddings:0": 1296
parameters for "model/attention_with_ff/input/kernel:0": 69632
parameters for "model/attention_with_ff/input/bias:0": 256
parameters for "model/attention_with_ff/input/gamma:0": 256
parameters for "model/attention_with_ff/input/beta:0": 256
parameters for "model/attention_with_ff/layer_0/attention/query_project:0": 65536
parameters for "model/attention_with_ff/layer_0/attention/key_project:0": 65536
parameters for "model/attention_with_ff/layer_0/attention/value_project:0": 65536
parameters for "model/attention_with_ff/layer_0/attention/output:0": 65536
parameters for "model/attention_with_ff/layer_0/attention/gamma:0": 256
parameters for "model/attention_with_ff/layer_0/attention/beta:0": 256
parameters for "model/attention_with_ff/layer_0/ff/fc1/kernel:0": 131072
parameters for "model/attention_with_ff/layer_0/ff/fc1/bias:0": 512
parameters for "model/attention_with_ff/layer_0/ff/fc2/

In [None]:
sess.run(tf.global_variables_initializer())

In [None]:
with open('../logs/simplewiki/mediawiki_el_softmax_1.multihead.log', 'wt') as f:
    for i in range(30):
        model.process(
            '../data/simplewiki/simplewiki-20171103.el_softmax_1.train.tfrecords.gz', 
            header = 'train %d' % i,
            train = True,
            log_file = f)
        model.process(
            '../data/simplewiki/simplewiki-20171103.el_softmax_1.dev.tfrecords.gz',
            header = 'dev %d' % i,
            train = False,
            log_file = f)

train 0 (2481) (0:06:56.692491): loss=0.673396, accuracy=0.885264


dev 0 (2481) (0:00:04.289151): loss=0.319516, accuracy=0.9397


In [None]:
# builder = tf.saved_model.builder.SavedModelBuilder('../models/simplewiki/el_softmax_1')
# builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.TRAINING])
# builder.add_meta_graph([tf.saved_model.tag_constants.SERVING])
# builder.save()

# Error Analysis

In [7]:
_ = tf.saved_model.loader.load(
    sess,
    [tf.saved_model.tag_constants.TRAINING],
    '../models/simplewiki/el_softmax_1')

INFO:tensorflow:Restoring parameters from b'../models/simplewiki/el_softmax_1/variables/variables'


In [34]:
dataset_iterator = tf.get_default_graph().get_operation_by_name('dataset/MakeIterator')
dataset_filenames = tf.get_default_graph().get_tensor_by_name('dataset/Placeholder:0')
training = tf.get_default_graph().get_tensor_by_name('model/Placeholder:0')
context_word_ids = tf.get_default_graph().get_tensor_by_name('dataset/context_word_ids:0')
target_labels = tf.get_default_graph().get_tensor_by_name('dataset/target_labels:0')
output_labels = tf.get_default_graph().get_tensor_by_name('train/ArgMax:0')
# total_loss = tf.get_default_graph().get_tensor_by_name('train/Sum:0')
# num_correct_labels = tf.get_default_graph().get_tensor_by_name('train/Sum_1:0')

In [39]:
with open('../data/simplewiki/simplewiki-20171103.el_softmax_1.vocab.txt', 'rt') as f:
    vocab = [w.strip() for w in f]

In [52]:
with open('../data/simplewiki/simplewiki-20171103.el_softmax_1.targets.txt', 'rt') as f:
    targets = [t.strip() for t in f]

In [102]:
def compute_negative_examples(filenames, limit = None):
    # initialize dataset iterator
    sess.run(dataset_iterator, feed_dict = {
        dataset_filenames: filenames,
        training: False })
    
    examples = []
    
    while True:
        # compute minibatch
        try:
            (curr_context_word_ids, curr_target_labels, curr_output_labels) = sess.run(
                (context_word_ids, target_labels, output_labels),
                feed_dict = { training: False })
        except tf.errors.OutOfRangeError:
                break

        # loop through examples
        for c, t, o in zip(curr_context_word_ids, curr_target_labels, curr_output_labels):
            # skip accurate inferences
            if t == o:
                continue
                
            # stop if limit reached
            if limit and len(examples) >= limit:
                break
            
            # decode context
            words = [vocab[word_id] for word_id in c]
            words[len(words) // 2] = '_%s_' % words[len(words) // 2]

            # decode example
            examples.append([
                targets[t], # target label
                targets[o], # output label
                vocab[c[len(c) // 2]], # center word
                ' '.join(words) ]) # context
    
    return examples

In [103]:
examples = compute_negative_examples(['../data/simplewiki/simplewiki-20171103.el_softmax_1.dev.tfrecords.gz'])

In [104]:
len(examples)

1066

In [107]:
import csv
with open('/tmp/errors.csv', 'wt') as f:
    writer = csv.writer(f)
    writer.writerow(['target', 'output', 'word', 'context'])
    for example in examples:
        writer.writerow(example)

In [None]:
# with open('../logs/simplewiki/mediawiki_el_softmax_1.multihead.log.old', 'wt') as f:
#     print(lines)

In [None]:
# with open('../logs/simplewiki/mediawiki_el_softmax_1.multihead.log', 'rt') as f:
#     lines = '\n'.join([l for l in f])
    
# plt.plot([float(v) for v in re.findall(r'train.*loss=(\d+\.\d+)', lines)], label='train')
# plt.plot([float(v) for v in re.findall(r'dev.*loss=(\d+\.\d+)', lines)], label='dev')
# plt.title('loss')
# plt.legend()
# plt.show()
# plt.plot([float(v) for v in re.findall(r'train.*accuracy=(\d+\.\d+)', lines)], label='train')
# plt.plot([float(v) for v in re.findall(r'dev.*accuracy=(\d+\.\d+)', lines)], label='dev')
# plt.title('accuracy')
# plt.legend()
# plt.show()