In [3]:
%matplotlib inline
import matplotlib.pyplot as plt

import numpy as np
import tensorflow as tf



In [53]:
# get some data
import requests
import os
import tarfile
import re
import xml.etree.ElementTree as ET

def maybe_download(data_path='reuters21578.tar.gz'):
    """might download Reuters-21578"""
    if not os.path.exists(data_path):
        print('Downloading dataset :)')
        url = 'https://archive.ics.uci.edu/ml/machine-learning-databases/reuters21578-mld/reuters21578.tar.gz'
        r = requests.get(url, stream=True)
        with open(data_path, 'wb') as f:
            for chunk in r.iter_content(chunk_size=1024):
                if chunk: # could be keep-alive chunks
                    f.write(chunk)
    # quickly validate the data
    size = os.path.getsize(data_path) >> 20
    if size != 7:  # wrong size :(
        raise ValueError('data file is wrong size ({}).'.format(size))
    return data_path

def read_reuters_file(filename):
    """Reads in a whole file & finds the necessary bits
    Returns a tuple of (text, label, set, id) where set is one of {train, test, valid}.
    according to the "ModHayes" split
    """
    # the files have some things that make the xml parser unhappy
    # there are probably efficient ways to do this, if so this is
    # not one of them
    with open(filename, errors='ignore') as raw_file:
        file_str = raw_file.read()
        file_str = re.sub('&#\d{1,2};', '', file_str)
        # the docs don't have any kind of root tags
        # so we skip the doctype and wrap them in one
        file_str = '<root>' + file_str[file_str.find('\n'):-1] + '</root>'
        print('..parsing {}'.format(filename))
        root = ET.fromstring(file_str)
        data = []
        for child in root:
            if child.attrib['TOPICS'] == 'YES':  # we need to be able to evaluate
                try:
                    text = child.find('./TEXT/BODY').text
                except AttributeError:
                    text = child.find('./TEXT').text  # should check type=brief
                topics = [d.text for d in child.findall('./TOPICS/D')]
                data.append(
                    (text, topics, 
                    'train' if child.attrib['CGISPLIT'] == 'TRAINING-SET' else 'test',
                    child.attrib['NEWID']))
        return data

def get_reuters(data_dir='data'):
    """gets the reuters dataset as training/test/validation
    tensor sequences"""
    if not os.path.exists(data_dir):
        with tarfile.open(maybe_download(), 'r:gz') as datafile:
            print('Extracting archive')
            datafile.extractall(path=data_dir)
    # the data is probably small enough that we will just hold it in memory
    filenames = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if re.search('.sgm$', f)]
    all_data = []
    for filename in filenames:
        start = len(all_data)
        all_data.extend(read_reuters_file(filename))
        end = len(all_data)
        #print('...{} new items'.format(end-start))
    print('got {} in total'.format(len(all_data)))
    #for file in all_data[0:5]:
     #   print(file)

In [54]:
get_reuters()

..parsing data/reut2-000.sgm
..parsing data/reut2-001.sgm
..parsing data/reut2-002.sgm
..parsing data/reut2-003.sgm
..parsing data/reut2-004.sgm
..parsing data/reut2-005.sgm
..parsing data/reut2-006.sgm
..parsing data/reut2-007.sgm
..parsing data/reut2-008.sgm
..parsing data/reut2-009.sgm
..parsing data/reut2-010.sgm
..parsing data/reut2-011.sgm
..parsing data/reut2-012.sgm
..parsing data/reut2-013.sgm
..parsing data/reut2-014.sgm
..parsing data/reut2-015.sgm
..parsing data/reut2-016.sgm
..parsing data/reut2-017.sgm
..parsing data/reut2-018.sgm
..parsing data/reut2-019.sgm
..parsing data/reut2-020.sgm
..parsing data/reut2-021.sgm
got 13476 in total
