In [1]:
import json
import nltk
import gzip
import tensorflow as tf
import random
import collections
from modules import bpencoding
from nltk.tokenize import TreebankWordTokenizer
from intervaltree import Interval, IntervalTree
from tqdm import tqdm_notebook

In [2]:
with gzip.open('../data/simplewiki/simplewiki-20171103.sentences.json.gz', 'rt', encoding='utf8') as f:
    sentences = json.load(f)

In [3]:
# N.B., globally pre-shuffle data since we'll be streaming it during training,
# and will only be able to shuffle within a small lookahead buffer.
# Also, we want to ensure that our train/dev/test sets have the same distribution.
random.shuffle(sentences)

In [4]:
def normalize_text(text):
    text = text.lower()
    # N.B., replacements necessary for nltk.word_tokenize + alignment 
    # doesn't get confused.
    text = text.replace("``", '"')
    text = text.replace("''", '"')
    return text

In [5]:
word_freqs = collections.Counter()
for sentence in tqdm_notebook(sentences):
    text = normalize_text(sentence['text'])
    word_freqs.update(nltk.word_tokenize(text))




In [None]:
sentences

In [15]:
id_to_word_30k = [word for word, _ in word_freqs.most_common(30000 - 1)] + ['<UNK>']
word_to_id_30k = dict((word, id) for id, word in enumerate(id_to_word_30k))

In [20]:
def align_tokens(tokens, text):
    point = 0
    offsets = []
    for token in tokens:
        if token == '``' or token == "''":
            token = '"'
        try:
            start = text.index(token, point)
        except ValueError:
            raise ValueError('substring "{}" not found in "{}"'.format(token, text))
        point = start + len(token)
        offsets.append((start, point))
    return offsets

def span_tokenize(text):
    return align_tokens(nltk.word_tokenize(text), text)

def generate_example(sentence):
    links = IntervalTree()
    for l in sentence['links']:
        links[l['start']:l['finish']] = l['target']

    text = normalize_text(sentence['text'])
    inputs = []
    targets = []
    
    for s in span_tokenize(text): 
        word = text[s[0]:s[1]]
        word_id = word_to_id_30k.get(word)
        
        if word_id:
            inputs.append(word_id)
            targets.append(int(bool(links[s[0]:s[1]])))
        else:
            inputs.append(len(id_to_word_30k)-1)
            targets.append(0)
            
    return tf.train.Example(features = tf.train.Features(feature = {
        'page_id': tf.train.Feature(int64_list = tf.train.Int64List(value = [sentence['page_id']])),
        'para_id': tf.train.Feature(int64_list = tf.train.Int64List(value = [sentence['para_id']])),
        'sentence_id': tf.train.Feature(int64_list = tf.train.Int64List(value = [sentence['sentence_id']])),
        'inputs': tf.train.Feature(int64_list = tf.train.Int64List(value = inputs)),
        'targets': tf.train.Feature(int64_list = tf.train.Int64List(value = targets)),
    }))

In [21]:
dev_set_size = 30000
test_set_size = 30000

dev_set_sentences = sentences[:dev_set_size]
test_set_sentences = sentences[dev_set_size:dev_set_size+test_set_size]
train_set_sentences = sentences[dev_set_size+test_set_size:]

In [22]:
def write_tfrecords(sentences, file):
    with tf.python_io.TFRecordWriter(file) as writer:
        for s in tqdm_notebook(sentences):
            example = generate_example(s)
            writer.write(example.SerializeToString())

In [23]:
write_tfrecords(dev_set_sentences, '../data/simplewiki/simplewiki-20171103.entity_recognition.dev.tfrecords')
write_tfrecords(test_set_sentences, '../data/simplewiki/simplewiki-20171103.entity_recognition.test.tfrecords')
write_tfrecords(train_set_sentences, '../data/simplewiki/simplewiki-20171103.entity_recognition.train.tfrecords')










In [25]:
with open('../data/simplewiki/simplewiki-20171103.vocab_30k.txt', 'wt', encoding='utf-8') as f:
    for word in id_to_word_30k:
        print(word, file=f)