In [1]:
import json
import nltk
import gzip
import tensorflow as tf
from modules import bpencoding
from nltk.tokenize import TreebankWordTokenizer
from intervaltree import Interval, IntervalTree
from tqdm import tqdm
from random import shuffle

In [2]:
with gzip.open('../data/simplewiki/simplewiki-20171103.sentences.json.gz', 'rt', encoding='utf8') as f:
    sentences = json.load(f)
# N.B., globally pre-shuffle data since we'll be streaming it during training,
# and will only be able to shuffle within a small lookahead buffer
shuffle(sentences)

In [3]:
with open('../data/simplewiki/simplewiki-20171103.encoder_table_10k.txt', 'rt', encoding='utf-8') as f:
    lines = f.readlines()
# reserve index 10000 for "unknown" token
table = [l.strip() for l in lines][:9999]

In [4]:
encoder = bpencoding.Encoder(table)

In [5]:
def align_tokens(tokens, sentence):
    point = 0
    offsets = []
    for token in tokens:
        if token == '``' or token == "''":
            token = '"'
        try:
            start = sentence.index(token, point)
        except ValueError:
            raise ValueError('substring "{}" not found in "{}"'.format(token, sentence))
        point = start + len(token)
        offsets.append((start, point))
    return offsets

def span_tokenize(sentence):
    return align_tokens(nltk.word_tokenize(sentence), sentence)

def generate_example(sentence, encoder):
    links = IntervalTree()
    for l in sentence['links']:
        links[l['start']:l['finish']] = l['target']

    text = sentence['text'].lower()
    text = text.replace("``", '"')
    text = text.replace("''", '"')
    inputs = []
    word_endings = []
    targets = []
    
    for s in span_tokenize(text): 
        offset = s[0]
        word = text[s[0]:s[1]]
        wfs = encoder.encode(word)
        for i, wf in enumerate(wfs):
            start = offset
            end = offset + len(wf.text)
            inputs.append(wf.index)
            word_endings.append(int(i == len(wfs)-1))
            targets.append(int(bool(links[start:end])))
            offset = end
            
    return tf.train.Example(features = tf.train.Features(feature = {
        'inputs': tf.train.Feature(int64_list = tf.train.Int64List(value = inputs)),
        'word_endings': tf.train.Feature(int64_list = tf.train.Int64List(value = word_endings)),
        'targets': tf.train.Feature(int64_list = tf.train.Int64List(value = targets)),
    }))

In [9]:
train_sentences = sentences[50000:]
test_sentences = sentences[:50000]

In [10]:
def write_tfrecords(sentences, file):
    with tf.python_io.TFRecordWriter(file) as writer:
        for s in tqdm(sentences):
            example = generate_example(s, encoder)
            writer.write(example.SerializeToString())

In [12]:
write_tfrecords(train_sentences, '../data/simplewiki/simplewiki-20171103.entity_recognition.train.tfrecords')
write_tfrecords(test_sentences, '../data/simplewiki/simplewiki-20171103.entity_recognition.test.tfrecords')

100%|██████████| 1045155/1045155 [10:34<00:00, 1647.44it/s]
100%|██████████| 50000/50000 [00:29<00:00, 1680.80it/s]
