In [5]:
import string
import random
import os
import numpy as np
import re
from keras.preprocessing.text import Tokenizer

In [6]:
def make_reuters_dataset(path=os.path.join('reuters21578'), min_samples_per_topic=15):

    wire_topics = []
    topic_counts = {}
    wire_bodies = []

    for fname in sorted(os.listdir(path)):
        if 'sgm' in fname:
            s = open(os.path.join(path, fname)).read()
            tag = '<TOPICS>'
            while tag in s:
                s = s[s.find(tag)+len(tag):]
                topics = s[:s.find('</')]
                if topics and '</D><D>' not in topics:
                    topic = topics.replace('<D>', '').replace('</D>', '')
                    wire_topics.append(topic)
                    topic_counts[topic] = topic_counts.get(topic, 0) + 1
                else:
                    continue

                bodytag = '<BODY>'
                body = s[s.find(bodytag)+len(bodytag):]
                body = body[:body.find('</')]
                wire_bodies.append(body)

    # only keep most common topics
    items = list(topic_counts.items())
    items.sort(key=lambda x: x[1])
    kept_topics = set()
    for x in items:
        print(x[0] + ': ' + str(x[1]))
        if x[1] >= min_samples_per_topic:
            kept_topics.add(x[0])
    print('-')
    print('Kept topics:', len(kept_topics))
    
    # filter wires with rare topics
    kept_wires = []
    labels = []
    topic_indexes = {}
    for t, b in zip(wire_topics, wire_bodies):
        if t in kept_topics:
            if t not in topic_indexes:
                topic_index = len(topic_indexes)
                topic_indexes[t] = topic_index
            else:
                topic_index = topic_indexes[t]

            labels.append(topic_index)
            kept_wires.append(b)

    print('Kept wires:', len(kept_wires))
    print('-')
    print('Topic mapping:', sorted(topic_indexes.items(), key=lambda x:x[1]))
    print('-')

    # vectorize wires
    tokenizer = Tokenizer()
    tokenizer.fit_on_texts(kept_wires)
    X = tokenizer.texts_to_sequences(kept_wires)

    print('Sanity check:')
    for w in ["banana", "oil", "chocolate", "the", "dsft"]:
        print('...index of', w, ':', tokenizer.word_index.get(w))
    print('text reconstruction:')
    reverse_word_index = dict([(v, k) for k, v in tokenizer.word_index.items()])
    print(' '.join(reverse_word_index[i] for i in X[10]))

    dataset = (X, labels)
    print('-')

In [7]:

if __name__ == "__main__":
    make_reuters_dataset()

fishmeal: 1
tapioca: 1
rand: 1
saudriyal: 1
nzdlr: 1
wool: 1
austdlr: 1
soy-meal: 1
barley: 1
cruzado: 1
hk: 1
naphtha: 1
l-cattle: 2
coconut: 2
rapeseed: 2
plywood: 2
f-cattle: 2
propane: 3
rice: 3
inventories: 3
palm-oil: 3
groundnut: 3
cpu: 4
jet: 4
platinum: 4
soybean: 5
potato: 5
nickel: 5
stg: 5
yen: 6
instal-debt: 7
tea: 9
corn: 9
lumber: 13
fuel: 13
income: 13
heat: 16
lei: 16
silver: 16
hog: 17
housing: 19
strategic-metal: 19
lead: 19
zinc: 21
wheat: 22
meal-feed: 22
orange: 22
retail: 23
wpi: 27
cotton: 28
carcass: 29
pet-chem: 29
tin: 32
gas: 38
rubber: 42
dlr: 46
nat-gas: 51
iron-steel: 52
alum: 53
ipi: 57
jobs: 57
livestock: 58
bop: 60
copper: 62
reserves: 62
cocoa: 67
oilseed: 81
cpi: 86
veg-oil: 94
gold: 123
coffee: 126
gnp: 127
sugar: 154
money-supply: 177
ship: 209
interest: 339
trade: 473
grain: 537
crude: 543
money-fx: 682
acq: 2423
earn: 3972
-
Kept topics: 46
Kept wires: 11228
-
Topic mapping: [('cocoa', 0), ('grain', 1), ('veg-oil', 2), ('earn', 3), ('acq', 4), ('