In [1]:
import json
import gzip
import nltk.data
import re
import collections
import random
import tensorflow as tf
import hashlib
from tqdm import tqdm_notebook

In [2]:
with gzip.open('../data/simplewiki/simplewiki-20171103.parsed.json.gz', 'rt', encoding='utf-8') as f:
    wiki = json.load(f)

In [3]:
def sha1_hash(text):
    h = hashlib.sha1()
    h.update(text.encode())
    return h.digest()

In [4]:
def get_text(page):
    text = page['text']
    links = page['links']
    fragments = []
    
    while True:
        m = re.search(r'\{\{\d+\}\}', text)
        if not m:
            break
        span = m.span()
        fragments.append(text[:span[0]])
        fragments.append(links[text[span[0]:span[1]]]['text'])
        text = text[span[1]:]
    
    fragments.append(text)
    return ''.join(fragments)

In [5]:
term_freqs = []
for title, page in tqdm_notebook(wiki.items()):
    text = get_text(page)
    page_id = page['id']
    
    counter = collections.Counter()
    for token in nltk.word_tokenize(text):
        counter[token.lower()] += 1
    term_freqs.append((page_id, -1, counter))

# TODO: support paragraph embeddings
#     paras = text.split('\n')
#     para_id = 0
#     for para in paras:
#         para = para.strip()
#         if len(para) > 0:
#             counter = collections.Counter()
#             for token in nltk.word_tokenize(para):
#                 counter[token.lower()] += 1
#             term_freqs.append((page_id, para_id, counter))
#             para_id += 1




In [6]:
term_freqs_combined = collections.Counter()
for _, _, counter in tqdm_notebook(term_freqs):
    term_freqs_combined.update(counter)




In [7]:
term_to_number_30k = {}
number_to_term_30k = []
for index, (term, freq) in enumerate(term_freqs_combined.most_common(30000)):
    term_to_number_30k[term] = index
    number_to_term_30k.append(term)

In [8]:
random.shuffle(term_freqs)

In [9]:
test_set_size = 20000
dev_set_size = 20000

test_set  = term_freqs[:test_set_size]
dev_set   = term_freqs[test_set_size:dev_set_size+test_set_size]
train_set = term_freqs[dev_set_size+test_set_size:]

len(test_set), len(dev_set), len(train_set)

(20000, 20000, 85947)

In [12]:
def write_tfrecords(term_freqs, term_index, file):
    with tf.python_io.TFRecordWriter(file) as writer:
        for page_id, para_id, counter in tqdm_notebook(term_freqs):
            indexed = []
            for term, freq in counter.most_common():
                index = term_index.get(term)
                if not index is None: # index can be zero
                    indexed.append((index, freq))
            indexed.sort()
            
            indices = [p[0] for p in indexed]
            freqs = [p[1] for p in indexed]
                    
            example = tf.train.Example(features = tf.train.Features(feature = {
                'page_id': tf.train.Feature(int64_list = tf.train.Int64List(value = [page_id])),
                'para_id': tf.train.Feature(int64_list = tf.train.Int64List(value = [para_id])),
                'indices': tf.train.Feature(int64_list = tf.train.Int64List(value = indices)),
                'freqs': tf.train.Feature(int64_list = tf.train.Int64List(value = freqs)),
            }))
            
            writer.write(example.SerializeToString())

In [13]:
write_tfrecords(test_set, term_to_number_30k, '../data/simplewiki/simplewiki-20171103.topic_model.30k.test.tfrecords')
write_tfrecords(dev_set, term_to_number_30k, '../data/simplewiki/simplewiki-20171103.topic_model.30k.dev.tfrecords')
write_tfrecords(train_set, term_to_number_30k, '../data/simplewiki/simplewiki-20171103.topic_model.30k.train.tfrecords')








