In [1]:
import nltk
import re
import gzip
import json
import collections
import tensorflow as tf
import random
from tqdm import tqdm_notebook
from intervaltree import Interval, IntervalTree

In [2]:
with gzip.open('../data/simplewiki/simplewiki-20171103.parsed.norm.json.gz', 'rt', encoding='utf-8') as f:
    wiki = json.load(f)

In [3]:
def align_tokens(tokens, text):
    point = 0
    offsets = []
    for token in tokens:
        if token == '``' or token == "''":
            token = '"'
        try:
            start = text.index(token, point)
        except ValueError:
            raise ValueError('substring "{}" not found in "{}"'.format(token, text))
        point = start + len(token)
        offsets.append((start, point))
    return offsets

def span_tokenize(text):
    return align_tokens(nltk.word_tokenize(text), text)

def word_tokenize(text):
    return (('"' if word == '``' or word == "''" else word) for word in nltk.word_tokenize(text))

In [4]:
word_freqs = collections.Counter()
for _, page in tqdm_notebook(wiki.items()):
    word_freqs.update(word_tokenize(page['text']))




In [5]:
id_to_word_30k = [word for word, _ in word_freqs.most_common(29998)] + ['<UNK>', '<OOB>']
word_to_id_30k = dict((word, word_id) for word_id, word in enumerate(id_to_word_30k))

In [5]:
def get_word_id(word):
    word_id = word_to_id_30k.get(word)
    return word_id if not word_id is None else word_to_id_30k['<UNK>']    

In [7]:
with open('../data/simplewiki/simplewiki-20171103.vocab_30k.txt', 'wt', encoding='utf-8') as f:
    for word in id_to_word_30k:
        print(word, file = f)

In [6]:
doc_word_freqs = []
for _, page in tqdm_notebook(wiki.items()):
    page_id = page['id']
    page_text = page['text']
    counter = collections.Counter()
    counter.update(get_word_id(word) for word in word_tokenize(page_text))
    sorted_freqs = sorted(counter.most_common())
    doc_word_freqs.append((
        page_id,
        [wid for wid, _ in sorted_freqs],
        [freq for _, freq in sorted_freqs]))
doc_word_freqs.sort()
for i, (page_id, _, _) in enumerate(doc_word_freqs):
    assert i == page_id
doc_word_freqs = [(word_ids, word_freqs) for _, word_ids, word_freqs in doc_word_freqs]




In [23]:
def generate_examples_with_words(page, context_width):
    page_id = page['id']
    page_links = page['links']
    page_text = page['text']
    
    link_spans = IntervalTree()
    for link in page_links:
        link_spans[link['start']:link['end']] = link['target']
        
    word_spans = span_tokenize(page_text)
    for i, word_span in enumerate(word_spans):
        for target in link_spans[word_span[0]:word_span[1]]:
            context = []
            for j in range(i - context_width, i + context_width + 1):
                if j < 0 or j >= len(word_spans):
                    context.append('<OOB>')
                else:
                    s = word_spans[j]
                    w = page_text[s[0]:s[1]]
                    context.append(w)
            yield (page_id, target.data, context)

def generate_examples_with_ids(page, context_width):
    for page_id, target, words in generate_examples_with_words(page, context_width):
        target_id = wiki[target]['id']
        target_word = words[context_width]
        
        if not target_word in word_to_id_30k:
            continue
            
        word_ids = [get_word_id(word) for word in words]
        yield (page_id, target_id, word_ids)

In [24]:
examples = []
for _, page in tqdm_notebook(wiki.items()):
    examples.extend(generate_examples_with_ids(page, 40))




In [25]:
random.shuffle(examples)

In [26]:
dev_set_size = 30000
test_set_size = 30000

dev_set = examples[:dev_set_size]
test_set = examples[dev_set_size:dev_set_size+test_set_size]
train_set = examples[dev_set_size+test_set_size:]

len(dev_set), len(test_set), len(train_set)

(30000, 30000, 1324008)

In [27]:
def write_tfrecords(examples, filename):
    options = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.GZIP)
    with tf.python_io.TFRecordWriter(filename, options = options) as writer:
        for page_id, target_id, word_ids in tqdm_notebook(examples):
            target_word_ids, target_word_freqs = doc_word_freqs[target_id]
            example = tf.train.Example(features = tf.train.Features(feature = {
                'page_id': tf.train.Feature(int64_list = tf.train.Int64List(value = [page_id])),
                'target_id': tf.train.Feature(int64_list = tf.train.Int64List(value = [target_id])),
                'target_word_ids': tf.train.Feature(int64_list = tf.train.Int64List(value = target_word_ids)),
                'target_word_freqs': tf.train.Feature(int64_list = tf.train.Int64List(value = target_word_freqs)),
                'word_ids': tf.train.Feature(int64_list = tf.train.Int64List(value = word_ids)),
            }))
            writer.write(example.SerializeToString())

In [28]:
write_tfrecords(dev_set, '../data/simplewiki/simplewiki-20171103.entity_linking.dev.tfrecords.gz')
write_tfrecords(test_set, '../data/simplewiki/simplewiki-20171103.entity_linking.test.tfrecords.gz')
write_tfrecords(train_set, '../data/simplewiki/simplewiki-20171103.entity_linking.train.tfrecords.gz')










In [4]:
# with open('../data/simplewiki/simplewiki-20171103.vocab_30k.txt', 'rt', encoding='utf-8') as f:
#     id_to_word_30k = [line.strip() for line in f]
#     word_to_id_30k = dict((word, word_id) for word_id, word in enumerate(id_to_word_30k))

In [9]:
with gzip.open('../data/simplewiki/simplewiki-20171103.entity_linking.page_tf.json.gz', 'wt', encoding='utf-8') as f:
    json.dump(doc_word_freqs, f)