In [1]:
import nltk
import re
import gzip
import json
import collections
import tensorflow as tf
import random
import shutil
import os
from tqdm import tqdm_notebook
from intervaltree import Interval, IntervalTree

  from ._conv import register_converters as _register_converters


In [2]:
with gzip.open('../data/simplewiki/simplewiki-20171103.parsed.norm.json.gz', 'rt', encoding='utf-8') as f:
    wiki = json.load(f)

In [3]:
def word_tokenize(text):
    for word in nltk.word_tokenize(text):
        if word == '``' or word == "''":
            word = '"'
        for subword in re.split('([-/])', word):
            yield subword
            
def span_tokenize(text):
    return nltk.tokenize.util.align_tokens(word_tokenize(text), text)

In [9]:
counter = collections.Counter()
for _, page in tqdm_notebook(wiki.items(), leave=False):
    for link in page['links']:
        counter[link['target']] += 1
id_to_target = list(target for target, _ in counter.most_common(8000))
target_to_id = { target: target_id for target_id, target in enumerate(id_to_target) }
top_targets = set(id_to_target)



In [10]:
def filter_links(page, targets):
    filtered = [l for l in page['links'] if l['target'] in targets]
    page = dict(page)
    page['links'] = filtered
    return page

In [11]:
def generate_examples(page, width):
    page_id = page['id']
    page_links = page['links']
    page_text = page['text']

    # tokenize
    word_spans = span_tokenize(page_text)

    # index links
    link_spans = IntervalTree()
    for link in page_links:
        link_spans[link['start']:link['end']] = link['target']
        
    # strip links that collide
    for span in word_spans:
        if len(link_spans[span[0]:span[1]]) > 1:
            link_spans.remove_overlap(span[0], span[1])

    # compute words/targets
    words = []
    targets = []
    for i, word_span in enumerate(word_spans):
        # append word
        words.append(page_text[word_span[0]:word_span[1]])

        # append target
        spans = link_spans[word_span[0]:word_span[1]]
        if len(spans) == 0:
            targets.append(None)
        elif len(spans) == 1:
            targets.append(list(spans)[0].data)
        else:
            raise ValueError('multiple targets found for span')

    # compute target lengths
    targets_left = [-1] * len(targets)
    targets_right = [-1] * len(targets)
    curr_len = 0
    for i in range(len(targets)):
        if i > 0 and targets[i] and targets[i] == targets[i-1]:
            curr_len += 1
        else:
            curr_len = 0
        if targets[i]:
            targets_left[i] = curr_len
    curr_len = 0
    for i in range(len(targets) - 1, -1, -1):
        if i < len(targets) - 1 and targets[i] and targets[i] == targets[i+1]:
            curr_len += 1
        else:
            curr_len = 0
        if targets[i]:
            targets_right[i] = curr_len
    
    # build examples
    context_size = width*2 + 1
    for i in range(len(targets)):
        if not targets[i] or (i > 0 and targets[i-1] == targets[i]):
            continue
        curr_words = ['<OOB>'] * context_size
        curr_targets = [None] * context_size
        curr_left = [-1] * context_size
        curr_right = [-1] * context_size
        for j in range(context_size):
            k = i + j - width
            if k < 0 or k >= len(targets):
                continue
            curr_words[j] = words[k]
            curr_targets[j] = targets[k]
            curr_left[j] = targets_left[k]
            curr_right[j] = targets_right[k]
        yield page_id, curr_words, curr_targets, curr_left, curr_right

In [12]:
examples = []
for _, page in tqdm_notebook(wiki.items(), leave = False):
    page = filter_links(page, top_targets)
    examples.append(list(generate_examples(page, 60)))
random.shuffle(examples)



In [18]:
dev_set = examples[:3500]
test_set = examples[3500:7000]
train_set = examples[7000:]

dev_set = [e for es in dev_set for e in es]
test_set = [e for es in test_set for e in es]
train_set = [e for es in train_set for e in es]
total_set = [e for es in examples for e in es]

len(dev_set), len(test_set), len(train_set)

(21974, 21664, 724833)

In [19]:
counter = collections.Counter()
for _, context, _, _, _ in tqdm_notebook(total_set, leave=False):
    for word in context:
        if word == '<OOB>':
            continue
        counter[word] += 1



In [20]:
id_to_word = [word for word, _ in counter.most_common(30000-3)] + ['<UNK>', '<OOB>', '<START>']
word_to_id = { word: word_id for word_id, word in enumerate(id_to_word) }

In [21]:
def get_word_id(word):
    word_id = word_to_id.get(word)
    return word_id if word_id is not None else word_to_id['<UNK>']

In [22]:
def convert_to_tfrecord(example):
    page_id, words, targets, targets_left, targets_right = example
    
    # convert words/targets to IDs
    word_ids = [get_word_id(w) for w in words]
    target_ids = [target_to_id[t] if t else -1 for t in targets]
    
    return tf.train.Example(features = tf.train.Features(feature = {
        'page': tf.train.Feature(int64_list = tf.train.Int64List(value = [page_id])),
        'context': tf.train.Feature(int64_list = tf.train.Int64List(value = word_ids)),
        'targets': tf.train.Feature(int64_list = tf.train.Int64List(value = target_ids)),
        'targets_left': tf.train.Feature(int64_list = tf.train.Int64List(value = targets_left)),
        'targets_right': tf.train.Feature(int64_list = tf.train.Int64List(value = targets_right))
    }))

def write_tfrecords(examples, path, batch_size = 10000):
    # remove old directory
    shutil.rmtree(path, ignore_errors = True)
    
    # make directory
    os.makedirs(path, exist_ok = True)
    
    # write batches
    for offset in tqdm_notebook(range(0, len(examples), batch_size), leave = False):
        batch = examples[offset:offset + batch_size]
        batch_path = os.path.join(path, 'examples.%010d.tfrecords.gz' % offset)
        options = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.GZIP)
        with tf.python_io.TFRecordWriter(batch_path, options = options) as writer:
            for example in batch:
                writer.write(convert_to_tfrecord(example).SerializeToString())        

In [24]:
with open('../data/simplewiki/simplewiki-20171103.er_softmax_2.vocab.txt', 'wt') as f:
    for word in id_to_word:
        print(word, file = f)

In [25]:
with open('../data/simplewiki/simplewiki-20171103.er_softmax_2.targets.txt', 'wt') as f:
    for target in id_to_target:
        print(target, file = f)

In [26]:
write_tfrecords(dev_set, '../data/simplewiki/simplewiki-20171103.er_softmax_2.dev')
write_tfrecords(test_set, '../data/simplewiki/simplewiki-20171103.er_softmax_2.test')
write_tfrecords(train_set, '../data/simplewiki/simplewiki-20171103.er_softmax_2.train')





