In [2]:
import nltk
import re
import gzip
import json
import collections
import tensorflow as tf
import random
from tqdm import tqdm_notebook
from intervaltree import Interval, IntervalTree

In [2]:
with gzip.open('../data/simplewiki/simplewiki-20171103.parsed.norm.json.gz', 'rt', encoding='utf-8') as f:
    wiki = json.load(f)

In [3]:
def align_tokens(tokens, text):
    point = 0
    offsets = []
    for token in tokens:
        if token == '``' or token == "''":
            token = '"'
        try:
            start = text.index(token, point)
        except ValueError:
            raise ValueError('substring "{}" not found in "{}"'.format(token, text))
        point = start + len(token)
        offsets.append((start, point))
    return offsets

def span_tokenize(text):
    return align_tokens(nltk.word_tokenize(text), text)

def word_tokenize(text):
    return (('"' if word == '``' or word == "''" else word) for word in nltk.word_tokenize(text))

In [4]:
counter = collections.Counter()
for _, page in tqdm_notebook(wiki.items(), leave=False):
    for link in page['links']:
        counter[link['target']] += 1
id_to_target_2k = list(target for target, _ in counter.most_common(2000))
target_to_id_2k = dict((target, target_id) for target_id, target in enumerate(id_to_target_2k))
top_2k_targets = set(id_to_target_2k)



In [5]:
def generate_examples_with_words(page, context_width, valid_link_targets):
    page_id = page['id']
    page_links = page['links']
    page_text = page['text']
    
    link_spans = IntervalTree()
    for link in page_links:
        link_spans[link['start']:link['end']] = link['target']
        
    word_spans = span_tokenize(page_text)
    for i, word_span in enumerate(word_spans):
        for target in link_spans[word_span[0]:word_span[1]]:
            if target.data not in valid_link_targets:
                continue
            context = []
            for j in range(i - context_width, i + context_width + 1):
                if j < 0 or j >= len(word_spans):
                    context.append('<OOB>')
                else:
                    s = word_spans[j]
                    w = page_text[s[0]:s[1]]
                    context.append(w)
            yield (page_id, target.data, context)

In [6]:
examples = []
for _, page in tqdm_notebook(wiki.items(), leave=False):
    examples.extend(generate_examples_with_words(page, 40, top_2k_targets))



In [9]:
dev_set_size = 20000
test_set_size = 20000

dev_set = examples[:dev_set_size]
test_set = examples[dev_set_size:dev_set_size+test_set_size]
train_set = examples[dev_set_size+test_set_size:]
tiny_set = examples[:30]

len(dev_set), len(test_set), len(train_set), len(tiny_set)

(20000, 20000, 635058, 30)

In [8]:
counter = collections.Counter()
for _, _, context in tqdm_notebook(examples, leave=False):
    for word in context:
        if word == '<OOB>':
            continue
        counter[word] += 1



In [10]:
id_to_word_30k = [word for word, _ in counter.most_common(30000-2)] + ['<UNK>', '<OOB>']
word_to_id_30k = dict((word, word_id) for word_id, word in enumerate(id_to_word_30k))

In [11]:
def convert_to_tfrecord(example):
    page_id, target, context = example
    target_id = target_to_id_2k[target]
    context_word_ids = []
    
    for word in context:
        word_id = word_to_id_30k.get(word)
        if word_id is None:
            word_id = word_to_id_30k['<UNK>']
        context_word_ids.append(word_id)
    
    return tf.train.Example(features = tf.train.Features(feature = {
        'page_id': tf.train.Feature(int64_list = tf.train.Int64List(value = [page_id])),
        'target_id': tf.train.Feature(int64_list = tf.train.Int64List(value = [target_id])),
        'context_word_ids': tf.train.Feature(int64_list = tf.train.Int64List(value = context_word_ids))
    }))

def write_tfrecords(examples, filename):
    options = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.GZIP)
    with tf.python_io.TFRecordWriter(filename, options = options) as writer:
        for example in tqdm_notebook(examples, leave = False):
            writer.write(convert_to_tfrecord(example).SerializeToString())

In [12]:
with open('../data/simplewiki/simplewiki-20171103.el_softmax_1.vocab.txt', 'wt') as f:
    for word in id_to_word_30k:
        print(word, file = f)

In [13]:
with open('../data/simplewiki/simplewiki-20171103.el_softmax_1.targets.txt', 'wt') as f:
    for target in id_to_target_2k:
        print(target, file = f)

In [14]:
write_tfrecords(dev_set, '../data/simplewiki/simplewiki-20171103.el_softmax_1.dev.tfrecords.gz')
write_tfrecords(test_set, '../data/simplewiki/simplewiki-20171103.el_softmax_1.test.tfrecords.gz')
write_tfrecords(train_set, '../data/simplewiki/simplewiki-20171103.el_softmax_1.train.tfrecords.gz')
write_tfrecords(tiny_set, '../data/simplewiki/simplewiki-20171103.el_softmax_1.tiny.tfrecords.gz')







